mirror of
https://github.com/mealie-recipes/mealie.git
synced 2026-05-11 12:33:32 -04:00
fix: Update backend normalization to match search normalization logic (#7603)
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
@@ -0,0 +1,76 @@
|
|||||||
|
"""more aggresive normalization
|
||||||
|
|
||||||
|
Revision ID: c7427796f7b6
|
||||||
|
Revises: 4395a04f7784
|
||||||
|
Create Date: 2026-05-10 18:44:53.159775
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sqlalchemy import orm, text
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
from mealie.db.models._model_base import SqlAlchemyBase
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "c7427796f7b6"
|
||||||
|
down_revision: str | None = "4395a04f7784"
|
||||||
|
branch_labels: str | tuple[str, ...] | None = None
|
||||||
|
depends_on: str | tuple[str, ...] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _update_table(session: orm.Session, table: str, columns: list[str], source_columns: list[str]) -> None:
|
||||||
|
"""Re-normalize all rows in `table`, reading raw values from `source_columns` and writing to `columns`."""
|
||||||
|
rows = session.execute(text(f"SELECT id, {', '.join(source_columns)} FROM {table}")).fetchall()
|
||||||
|
for row in rows:
|
||||||
|
id_ = row[0]
|
||||||
|
updates = {}
|
||||||
|
for col, src in zip(columns, source_columns, strict=True):
|
||||||
|
val = row[source_columns.index(src) + 1]
|
||||||
|
updates[col] = SqlAlchemyBase.normalize(val) if val is not None else None
|
||||||
|
|
||||||
|
set_clause = ", ".join(f"{col} = :{col}" for col in columns)
|
||||||
|
session.execute(text(f"UPDATE {table} SET {set_clause} WHERE id = :id"), {**updates, "id": id_})
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def update_normalization() -> None:
|
||||||
|
bind = op.get_bind()
|
||||||
|
session = orm.Session(bind=bind)
|
||||||
|
|
||||||
|
# recipes: name_normalized, description_normalized
|
||||||
|
_update_table(session, "recipes", ["name_normalized", "description_normalized"], ["name", "description"])
|
||||||
|
|
||||||
|
# recipe ingredients: note_normalized, original_text_normalized
|
||||||
|
_update_table(
|
||||||
|
session,
|
||||||
|
"recipes_ingredients",
|
||||||
|
["note_normalized", "original_text_normalized"],
|
||||||
|
["note", "original_text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# ingredient units: name, plural_name, abbreviation, plural_abbreviation
|
||||||
|
_update_table(
|
||||||
|
session,
|
||||||
|
"ingredient_units",
|
||||||
|
["name_normalized", "plural_name_normalized", "abbreviation_normalized", "plural_abbreviation_normalized"],
|
||||||
|
["name", "plural_name", "abbreviation", "plural_abbreviation"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# ingredient foods: name, plural_name
|
||||||
|
_update_table(session, "ingredient_foods", ["name_normalized", "plural_name_normalized"], ["name", "plural_name"])
|
||||||
|
|
||||||
|
# unit aliases
|
||||||
|
_update_table(session, "ingredient_units_aliases", ["name_normalized"], ["name"])
|
||||||
|
|
||||||
|
# food aliases
|
||||||
|
_update_table(session, "ingredient_foods_aliases", ["name_normalized"], ["name"])
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# no table changes, this is a data migration
|
||||||
|
update_normalization()
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
pass
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import string
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import Integer
|
from sqlalchemy import Integer
|
||||||
@@ -6,6 +7,12 @@ from text_unidecode import unidecode
|
|||||||
|
|
||||||
from ._model_utils.datetime import NaiveDateTime, get_utc_now
|
from ._model_utils.datetime import NaiveDateTime, get_utc_now
|
||||||
|
|
||||||
|
# Punctuation characters replaced with spaces during text normalization.
|
||||||
|
# Mirrors SearchFilter in query_search.py: string.punctuation minus apostrophe and
|
||||||
|
# double-quote, which are reserved for quoted literal searches.
|
||||||
|
NORMALIZE_PUNCTUATION = string.punctuation.replace("'", "").replace('"', "")
|
||||||
|
_NORMALIZE_PUNCTUATION_TABLE = str.maketrans(NORMALIZE_PUNCTUATION, " " * len(NORMALIZE_PUNCTUATION))
|
||||||
|
|
||||||
|
|
||||||
class SqlAlchemyBase(DeclarativeBase):
|
class SqlAlchemyBase(DeclarativeBase):
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||||
@@ -20,7 +27,7 @@ class SqlAlchemyBase(DeclarativeBase):
|
|||||||
def normalize(cls, val: str) -> str:
|
def normalize(cls, val: str) -> str:
|
||||||
# We cap the length to 255 to prevent indexes from being too long; see:
|
# We cap the length to 255 to prevent indexes from being too long; see:
|
||||||
# https://www.postgresql.org/docs/current/btree.html
|
# https://www.postgresql.org/docs/current/btree.html
|
||||||
return unidecode(val).lower().strip()[:255]
|
return unidecode(val).translate(_NORMALIZE_PUNCTUATION_TABLE).lower().strip()[:255]
|
||||||
|
|
||||||
|
|
||||||
class BaseMixins:
|
class BaseMixins:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from sqlalchemy import Select
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from text_unidecode import unidecode
|
from text_unidecode import unidecode
|
||||||
|
|
||||||
from ...db.models._model_base import SqlAlchemyBase
|
from ...db.models._model_base import NORMALIZE_PUNCTUATION, SqlAlchemyBase
|
||||||
from .._mealie import MealieModel, SearchType
|
from .._mealie import MealieModel, SearchType
|
||||||
|
|
||||||
|
|
||||||
@@ -16,7 +16,7 @@ class SearchFilter:
|
|||||||
3. remove special characters from each non-literal search string
|
3. remove special characters from each non-literal search string
|
||||||
"""
|
"""
|
||||||
|
|
||||||
punctuation = r"!\#$%&()*+,-./:;<=>?@[\\]^_`{|}~" # string.punctuation with ' & " removed
|
punctuation = NORMALIZE_PUNCTUATION
|
||||||
quoted_regex = re.compile(r"""(["'])(?:(?=(\\?))\2.)*?\1""")
|
quoted_regex = re.compile(r"""(["'])(?:(?=(\\?))\2.)*?\1""")
|
||||||
remove_quotes_regex = re.compile(r"""['"](.*)['"]""")
|
remove_quotes_regex = re.compile(r"""['"](.*)['"]""")
|
||||||
|
|
||||||
|
|||||||
@@ -3,10 +3,12 @@ from datetime import UTC, datetime
|
|||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from mealie.db.models._model_base import SqlAlchemyBase
|
||||||
from mealie.repos.all_repositories import get_repositories
|
from mealie.repos.all_repositories import get_repositories
|
||||||
from mealie.repos.repository_factory import AllRepositories
|
from mealie.repos.repository_factory import AllRepositories
|
||||||
from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientUnit
|
from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientUnit
|
||||||
from mealie.schema.response.pagination import OrderDirection, PaginationQuery
|
from mealie.schema.response.pagination import OrderDirection, PaginationQuery
|
||||||
|
from mealie.schema.response.query_search import SearchFilter
|
||||||
from mealie.schema.user.user import GroupBase
|
from mealie.schema.user.user import GroupBase
|
||||||
from tests.utils.factories import random_int, random_string
|
from tests.utils.factories import random_int, random_string
|
||||||
|
|
||||||
@@ -137,3 +139,35 @@ def test_random_order_search(
|
|||||||
pagination.pagination_seed = str(datetime.now(UTC))
|
pagination.pagination_seed = str(datetime.now(UTC))
|
||||||
random_ordered.append(repo.page_all(pagination, search="unit").items)
|
random_ordered.append(repo.page_all(pagination, search="unit").items)
|
||||||
assert not all(i == random_ordered[0] for i in random_ordered)
|
assert not all(i == random_ordered[0] for i in random_ordered)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"name, expected",
|
||||||
|
[
|
||||||
|
("Gluten-Free Bread", "gluten free bread"),
|
||||||
|
("Mac & Cheese", "mac cheese"),
|
||||||
|
("Chicken/Rice Bowl", "chicken rice bowl"),
|
||||||
|
("Rátàtôuile", "ratatouile"),
|
||||||
|
("Mom's Pasta", "mom's pasta"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_normalize_strips_punctuation(name: str, expected: str):
|
||||||
|
assert SqlAlchemyBase.normalize(name) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"name",
|
||||||
|
[
|
||||||
|
"Gluten-Free Bread",
|
||||||
|
"Mac & Cheese",
|
||||||
|
"Chicken/Rice Bowl",
|
||||||
|
"Rátàtôuile",
|
||||||
|
"Mom's Pasta",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_search_normalize_symmetric_with_store_normalize(name: str):
|
||||||
|
"""SearchFilter._normalize_search and SqlAlchemyBase.normalize must produce the same
|
||||||
|
output for the same input, otherwise stored values and search queries won't match."""
|
||||||
|
stored = SqlAlchemyBase.normalize(name)
|
||||||
|
searched = SearchFilter._normalize_search(name, normalize_characters=True)
|
||||||
|
assert stored == searched, f"Normalization mismatch for {name!r}: stored={stored!r}, searched={searched!r}"
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
import filecmp
|
|
||||||
import statistics
|
import statistics
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -32,17 +30,6 @@ def dict_sorter(d: dict) -> Any:
|
|||||||
return next((d[key] for key in possible_keys if d.get(key)), 1)
|
return next((d[key] for key in possible_keys if d.get(key)), 1)
|
||||||
|
|
||||||
|
|
||||||
# For Future Use
|
|
||||||
def match_file_tree(path_a: Path, path_b: Path):
|
|
||||||
if path_a.is_dir() and path_b.is_dir():
|
|
||||||
for a_file in path_a.iterdir():
|
|
||||||
b_file = path_b.joinpath(a_file.name)
|
|
||||||
assert b_file.exists()
|
|
||||||
match_file_tree(a_file, b_file)
|
|
||||||
else:
|
|
||||||
assert filecmp.cmp(path_a, path_b)
|
|
||||||
|
|
||||||
|
|
||||||
def test_database_backup():
|
def test_database_backup():
|
||||||
backup_v2 = BackupV2()
|
backup_v2 = BackupV2()
|
||||||
path_to_backup = backup_v2.backup()
|
path_to_backup = backup_v2.backup()
|
||||||
|
|||||||
Reference in New Issue
Block a user