fix: Update backend normalization to match search normalization logic (#7603)

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
Michael Genson
2026-05-10 21:23:57 -05:00
committed by GitHub
parent 551a92a031
commit d340fdd9df
5 changed files with 120 additions and 16 deletions

View File

@@ -0,0 +1,76 @@
"""more aggresive normalization
Revision ID: c7427796f7b6
Revises: 4395a04f7784
Create Date: 2026-05-10 18:44:53.159775
"""
from sqlalchemy import orm, text
from alembic import op
from mealie.db.models._model_base import SqlAlchemyBase
# revision identifiers, used by Alembic.
revision = "c7427796f7b6"
down_revision: str | None = "4395a04f7784"
branch_labels: str | tuple[str, ...] | None = None
depends_on: str | tuple[str, ...] | None = None
def _update_table(session: orm.Session, table: str, columns: list[str], source_columns: list[str]) -> None:
"""Re-normalize all rows in `table`, reading raw values from `source_columns` and writing to `columns`."""
rows = session.execute(text(f"SELECT id, {', '.join(source_columns)} FROM {table}")).fetchall()
for row in rows:
id_ = row[0]
updates = {}
for col, src in zip(columns, source_columns, strict=True):
val = row[source_columns.index(src) + 1]
updates[col] = SqlAlchemyBase.normalize(val) if val is not None else None
set_clause = ", ".join(f"{col} = :{col}" for col in columns)
session.execute(text(f"UPDATE {table} SET {set_clause} WHERE id = :id"), {**updates, "id": id_})
session.commit()
def update_normalization() -> None:
bind = op.get_bind()
session = orm.Session(bind=bind)
# recipes: name_normalized, description_normalized
_update_table(session, "recipes", ["name_normalized", "description_normalized"], ["name", "description"])
# recipe ingredients: note_normalized, original_text_normalized
_update_table(
session,
"recipes_ingredients",
["note_normalized", "original_text_normalized"],
["note", "original_text"],
)
# ingredient units: name, plural_name, abbreviation, plural_abbreviation
_update_table(
session,
"ingredient_units",
["name_normalized", "plural_name_normalized", "abbreviation_normalized", "plural_abbreviation_normalized"],
["name", "plural_name", "abbreviation", "plural_abbreviation"],
)
# ingredient foods: name, plural_name
_update_table(session, "ingredient_foods", ["name_normalized", "plural_name_normalized"], ["name", "plural_name"])
# unit aliases
_update_table(session, "ingredient_units_aliases", ["name_normalized"], ["name"])
# food aliases
_update_table(session, "ingredient_foods_aliases", ["name_normalized"], ["name"])
def upgrade():
# no table changes, this is a data migration
update_normalization()
def downgrade():
pass

View File

@@ -1,3 +1,4 @@
import string
from datetime import datetime from datetime import datetime
from sqlalchemy import Integer from sqlalchemy import Integer
@@ -6,6 +7,12 @@ from text_unidecode import unidecode
from ._model_utils.datetime import NaiveDateTime, get_utc_now from ._model_utils.datetime import NaiveDateTime, get_utc_now
# Punctuation characters replaced with spaces during text normalization.
# Mirrors SearchFilter in query_search.py: string.punctuation minus apostrophe and
# double-quote, which are reserved for quoted literal searches.
NORMALIZE_PUNCTUATION = string.punctuation.replace("'", "").replace('"', "")
_NORMALIZE_PUNCTUATION_TABLE = str.maketrans(NORMALIZE_PUNCTUATION, " " * len(NORMALIZE_PUNCTUATION))
class SqlAlchemyBase(DeclarativeBase): class SqlAlchemyBase(DeclarativeBase):
id: Mapped[int] = mapped_column(Integer, primary_key=True) id: Mapped[int] = mapped_column(Integer, primary_key=True)
@@ -20,7 +27,7 @@ class SqlAlchemyBase(DeclarativeBase):
def normalize(cls, val: str) -> str: def normalize(cls, val: str) -> str:
# We cap the length to 255 to prevent indexes from being too long; see: # We cap the length to 255 to prevent indexes from being too long; see:
# https://www.postgresql.org/docs/current/btree.html # https://www.postgresql.org/docs/current/btree.html
return unidecode(val).lower().strip()[:255] return unidecode(val).translate(_NORMALIZE_PUNCTUATION_TABLE).lower().strip()[:255]
class BaseMixins: class BaseMixins:

View File

@@ -4,7 +4,7 @@ from sqlalchemy import Select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from text_unidecode import unidecode from text_unidecode import unidecode
from ...db.models._model_base import SqlAlchemyBase from ...db.models._model_base import NORMALIZE_PUNCTUATION, SqlAlchemyBase
from .._mealie import MealieModel, SearchType from .._mealie import MealieModel, SearchType
@@ -16,7 +16,7 @@ class SearchFilter:
3. remove special characters from each non-literal search string 3. remove special characters from each non-literal search string
""" """
punctuation = r"!\#$%&()*+,-./:;<=>?@[\\]^_`{|}~" # string.punctuation with ' & " removed punctuation = NORMALIZE_PUNCTUATION
quoted_regex = re.compile(r"""(["'])(?:(?=(\\?))\2.)*?\1""") quoted_regex = re.compile(r"""(["'])(?:(?=(\\?))\2.)*?\1""")
remove_quotes_regex = re.compile(r"""['"](.*)['"]""") remove_quotes_regex = re.compile(r"""['"](.*)['"]""")

View File

@@ -3,10 +3,12 @@ from datetime import UTC, datetime
import pytest import pytest
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from mealie.db.models._model_base import SqlAlchemyBase
from mealie.repos.all_repositories import get_repositories from mealie.repos.all_repositories import get_repositories
from mealie.repos.repository_factory import AllRepositories from mealie.repos.repository_factory import AllRepositories
from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientUnit from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientUnit
from mealie.schema.response.pagination import OrderDirection, PaginationQuery from mealie.schema.response.pagination import OrderDirection, PaginationQuery
from mealie.schema.response.query_search import SearchFilter
from mealie.schema.user.user import GroupBase from mealie.schema.user.user import GroupBase
from tests.utils.factories import random_int, random_string from tests.utils.factories import random_int, random_string
@@ -137,3 +139,35 @@ def test_random_order_search(
pagination.pagination_seed = str(datetime.now(UTC)) pagination.pagination_seed = str(datetime.now(UTC))
random_ordered.append(repo.page_all(pagination, search="unit").items) random_ordered.append(repo.page_all(pagination, search="unit").items)
assert not all(i == random_ordered[0] for i in random_ordered) assert not all(i == random_ordered[0] for i in random_ordered)
@pytest.mark.parametrize(
"name, expected",
[
("Gluten-Free Bread", "gluten free bread"),
("Mac & Cheese", "mac cheese"),
("Chicken/Rice Bowl", "chicken rice bowl"),
("Rátàtôuile", "ratatouile"),
("Mom's Pasta", "mom's pasta"),
],
)
def test_normalize_strips_punctuation(name: str, expected: str):
assert SqlAlchemyBase.normalize(name) == expected
@pytest.mark.parametrize(
"name",
[
"Gluten-Free Bread",
"Mac & Cheese",
"Chicken/Rice Bowl",
"Rátàtôuile",
"Mom's Pasta",
],
)
def test_search_normalize_symmetric_with_store_normalize(name: str):
"""SearchFilter._normalize_search and SqlAlchemyBase.normalize must produce the same
output for the same input, otherwise stored values and search queries won't match."""
stored = SqlAlchemyBase.normalize(name)
searched = SearchFilter._normalize_search(name, normalize_characters=True)
assert stored == searched, f"Normalization mismatch for {name!r}: stored={stored!r}, searched={searched!r}"

View File

@@ -1,6 +1,4 @@
import filecmp
import statistics import statistics
from pathlib import Path
from typing import Any from typing import Any
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -32,17 +30,6 @@ def dict_sorter(d: dict) -> Any:
return next((d[key] for key in possible_keys if d.get(key)), 1) return next((d[key] for key in possible_keys if d.get(key)), 1)
# For Future Use
def match_file_tree(path_a: Path, path_b: Path):
if path_a.is_dir() and path_b.is_dir():
for a_file in path_a.iterdir():
b_file = path_b.joinpath(a_file.name)
assert b_file.exists()
match_file_tree(a_file, b_file)
else:
assert filecmp.cmp(path_a, path_b)
def test_database_backup(): def test_database_backup():
backup_v2 = BackupV2() backup_v2 = BackupV2()
path_to_backup = backup_v2.backup() path_to_backup = backup_v2.backup()