mirror of
https://github.com/mealie-recipes/mealie.git
synced 2026-05-11 12:33:32 -04:00
fix: Update backend normalization to match search normalization logic (#7603)
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
@@ -0,0 +1,76 @@
|
||||
"""more aggresive normalization
|
||||
|
||||
Revision ID: c7427796f7b6
|
||||
Revises: 4395a04f7784
|
||||
Create Date: 2026-05-10 18:44:53.159775
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy import orm, text
|
||||
|
||||
from alembic import op
|
||||
from mealie.db.models._model_base import SqlAlchemyBase
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c7427796f7b6"
|
||||
down_revision: str | None = "4395a04f7784"
|
||||
branch_labels: str | tuple[str, ...] | None = None
|
||||
depends_on: str | tuple[str, ...] | None = None
|
||||
|
||||
|
||||
def _update_table(session: orm.Session, table: str, columns: list[str], source_columns: list[str]) -> None:
|
||||
"""Re-normalize all rows in `table`, reading raw values from `source_columns` and writing to `columns`."""
|
||||
rows = session.execute(text(f"SELECT id, {', '.join(source_columns)} FROM {table}")).fetchall()
|
||||
for row in rows:
|
||||
id_ = row[0]
|
||||
updates = {}
|
||||
for col, src in zip(columns, source_columns, strict=True):
|
||||
val = row[source_columns.index(src) + 1]
|
||||
updates[col] = SqlAlchemyBase.normalize(val) if val is not None else None
|
||||
|
||||
set_clause = ", ".join(f"{col} = :{col}" for col in columns)
|
||||
session.execute(text(f"UPDATE {table} SET {set_clause} WHERE id = :id"), {**updates, "id": id_})
|
||||
session.commit()
|
||||
|
||||
|
||||
def update_normalization() -> None:
|
||||
bind = op.get_bind()
|
||||
session = orm.Session(bind=bind)
|
||||
|
||||
# recipes: name_normalized, description_normalized
|
||||
_update_table(session, "recipes", ["name_normalized", "description_normalized"], ["name", "description"])
|
||||
|
||||
# recipe ingredients: note_normalized, original_text_normalized
|
||||
_update_table(
|
||||
session,
|
||||
"recipes_ingredients",
|
||||
["note_normalized", "original_text_normalized"],
|
||||
["note", "original_text"],
|
||||
)
|
||||
|
||||
# ingredient units: name, plural_name, abbreviation, plural_abbreviation
|
||||
_update_table(
|
||||
session,
|
||||
"ingredient_units",
|
||||
["name_normalized", "plural_name_normalized", "abbreviation_normalized", "plural_abbreviation_normalized"],
|
||||
["name", "plural_name", "abbreviation", "plural_abbreviation"],
|
||||
)
|
||||
|
||||
# ingredient foods: name, plural_name
|
||||
_update_table(session, "ingredient_foods", ["name_normalized", "plural_name_normalized"], ["name", "plural_name"])
|
||||
|
||||
# unit aliases
|
||||
_update_table(session, "ingredient_units_aliases", ["name_normalized"], ["name"])
|
||||
|
||||
# food aliases
|
||||
_update_table(session, "ingredient_foods_aliases", ["name_normalized"], ["name"])
|
||||
|
||||
|
||||
def upgrade():
|
||||
# no table changes, this is a data migration
|
||||
update_normalization()
|
||||
|
||||
|
||||
def downgrade():
|
||||
pass
|
||||
@@ -1,3 +1,4 @@
|
||||
import string
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Integer
|
||||
@@ -6,6 +7,12 @@ from text_unidecode import unidecode
|
||||
|
||||
from ._model_utils.datetime import NaiveDateTime, get_utc_now
|
||||
|
||||
# Punctuation characters replaced with spaces during text normalization.
|
||||
# Mirrors SearchFilter in query_search.py: string.punctuation minus apostrophe and
|
||||
# double-quote, which are reserved for quoted literal searches.
|
||||
NORMALIZE_PUNCTUATION = string.punctuation.replace("'", "").replace('"', "")
|
||||
_NORMALIZE_PUNCTUATION_TABLE = str.maketrans(NORMALIZE_PUNCTUATION, " " * len(NORMALIZE_PUNCTUATION))
|
||||
|
||||
|
||||
class SqlAlchemyBase(DeclarativeBase):
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
@@ -20,7 +27,7 @@ class SqlAlchemyBase(DeclarativeBase):
|
||||
def normalize(cls, val: str) -> str:
|
||||
# We cap the length to 255 to prevent indexes from being too long; see:
|
||||
# https://www.postgresql.org/docs/current/btree.html
|
||||
return unidecode(val).lower().strip()[:255]
|
||||
return unidecode(val).translate(_NORMALIZE_PUNCTUATION_TABLE).lower().strip()[:255]
|
||||
|
||||
|
||||
class BaseMixins:
|
||||
|
||||
@@ -4,7 +4,7 @@ from sqlalchemy import Select
|
||||
from sqlalchemy.orm import Session
|
||||
from text_unidecode import unidecode
|
||||
|
||||
from ...db.models._model_base import SqlAlchemyBase
|
||||
from ...db.models._model_base import NORMALIZE_PUNCTUATION, SqlAlchemyBase
|
||||
from .._mealie import MealieModel, SearchType
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ class SearchFilter:
|
||||
3. remove special characters from each non-literal search string
|
||||
"""
|
||||
|
||||
punctuation = r"!\#$%&()*+,-./:;<=>?@[\\]^_`{|}~" # string.punctuation with ' & " removed
|
||||
punctuation = NORMALIZE_PUNCTUATION
|
||||
quoted_regex = re.compile(r"""(["'])(?:(?=(\\?))\2.)*?\1""")
|
||||
remove_quotes_regex = re.compile(r"""['"](.*)['"]""")
|
||||
|
||||
|
||||
@@ -3,10 +3,12 @@ from datetime import UTC, datetime
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from mealie.db.models._model_base import SqlAlchemyBase
|
||||
from mealie.repos.all_repositories import get_repositories
|
||||
from mealie.repos.repository_factory import AllRepositories
|
||||
from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientUnit
|
||||
from mealie.schema.response.pagination import OrderDirection, PaginationQuery
|
||||
from mealie.schema.response.query_search import SearchFilter
|
||||
from mealie.schema.user.user import GroupBase
|
||||
from tests.utils.factories import random_int, random_string
|
||||
|
||||
@@ -137,3 +139,35 @@ def test_random_order_search(
|
||||
pagination.pagination_seed = str(datetime.now(UTC))
|
||||
random_ordered.append(repo.page_all(pagination, search="unit").items)
|
||||
assert not all(i == random_ordered[0] for i in random_ordered)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name, expected",
|
||||
[
|
||||
("Gluten-Free Bread", "gluten free bread"),
|
||||
("Mac & Cheese", "mac cheese"),
|
||||
("Chicken/Rice Bowl", "chicken rice bowl"),
|
||||
("Rátàtôuile", "ratatouile"),
|
||||
("Mom's Pasta", "mom's pasta"),
|
||||
],
|
||||
)
|
||||
def test_normalize_strips_punctuation(name: str, expected: str):
|
||||
assert SqlAlchemyBase.normalize(name) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name",
|
||||
[
|
||||
"Gluten-Free Bread",
|
||||
"Mac & Cheese",
|
||||
"Chicken/Rice Bowl",
|
||||
"Rátàtôuile",
|
||||
"Mom's Pasta",
|
||||
],
|
||||
)
|
||||
def test_search_normalize_symmetric_with_store_normalize(name: str):
|
||||
"""SearchFilter._normalize_search and SqlAlchemyBase.normalize must produce the same
|
||||
output for the same input, otherwise stored values and search queries won't match."""
|
||||
stored = SqlAlchemyBase.normalize(name)
|
||||
searched = SearchFilter._normalize_search(name, normalize_characters=True)
|
||||
assert stored == searched, f"Normalization mismatch for {name!r}: stored={stored!r}, searched={searched!r}"
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import filecmp
|
||||
import statistics
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -32,17 +30,6 @@ def dict_sorter(d: dict) -> Any:
|
||||
return next((d[key] for key in possible_keys if d.get(key)), 1)
|
||||
|
||||
|
||||
# For Future Use
|
||||
def match_file_tree(path_a: Path, path_b: Path):
|
||||
if path_a.is_dir() and path_b.is_dir():
|
||||
for a_file in path_a.iterdir():
|
||||
b_file = path_b.joinpath(a_file.name)
|
||||
assert b_file.exists()
|
||||
match_file_tree(a_file, b_file)
|
||||
else:
|
||||
assert filecmp.cmp(path_a, path_b)
|
||||
|
||||
|
||||
def test_database_backup():
|
||||
backup_v2 = BackupV2()
|
||||
path_to_backup = backup_v2.backup()
|
||||
|
||||
Reference in New Issue
Block a user