feat: Improved Ingredient Matching (#2535)

* added normalization to foods and units

* changed search to reference new normalized fields

* fix tests

* added parsed food matching to backend

* prevent pagination from ordering when searching

* added extra fuzzy matching to sqlite ing matching

* added tests

* only apply search ordering when order_by is null

* enabled post-search fuzzy matching for postgres

* fixed postgres fuzzy search test

* idk why this is failing

* 🤦

* simplified frontend ing matching
and restored automatic unit creation

* tightened food fuzzy threshold

* change to rapidfuzz

* sped up fuzzy matching with process

* fixed units not matching by abbreviation

* fast return for exact matches

* replace db searching with pure fuzz

* added fuzzy normalization

* tightened unit fuzzy matching thresh

* cleaned up comments/var names

* ran matching logic through the dryer

* oops

* simplified order by application logic
This commit is contained in:
Michael Genson
2023-09-15 12:19:34 -05:00
committed by GitHub
parent 084ad4228b
commit 2dfbe9f08d
17 changed files with 738 additions and 97 deletions

View File

@@ -2,6 +2,7 @@ from datetime import datetime
from sqlalchemy import DateTime, Integer
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from text_unidecode import unidecode
class SqlAlchemyBase(DeclarativeBase):
@@ -9,6 +10,10 @@ class SqlAlchemyBase(DeclarativeBase):
created_at: Mapped[datetime | None] = mapped_column(DateTime, default=datetime.now, index=True)
update_at: Mapped[datetime | None] = mapped_column(DateTime, default=datetime.now, onupdate=datetime.now)
@classmethod
def normalize(cls, val: str) -> str:
return unidecode(val).lower().strip()
class BaseMixins:
"""

View File

@@ -4,7 +4,6 @@ import sqlalchemy as sa
from sqlalchemy import Boolean, Float, ForeignKey, Integer, String, event, orm
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm.session import Session
from text_unidecode import unidecode
from mealie.db.models._model_base import BaseMixins, SqlAlchemyBase
from mealie.db.models.labels import MultiPurposeLabel
@@ -34,9 +33,56 @@ class IngredientUnitModel(SqlAlchemyBase, BaseMixins):
"RecipeIngredientModel", back_populates="unit"
)
# Automatically updated by sqlalchemy event, do not write to this manually
name_normalized: Mapped[str | None] = mapped_column(sa.String, index=True)
abbreviation_normalized: Mapped[str | None] = mapped_column(String, index=True)
@auto_init()
def __init__(self, **_) -> None:
pass
def __init__(self, session: Session, name: str | None = None, abbreviation: str | None = None, **_) -> None:
if name is not None:
self.name_normalized = self.normalize(name)
if abbreviation is not None:
self.abbreviation = self.normalize(abbreviation)
tableargs = [
sa.Index(
"ix_ingredient_units_name_normalized",
"name_normalized",
unique=False,
),
sa.Index(
"ix_ingredient_units_abbreviation_normalized",
"abbreviation_normalized",
unique=False,
),
]
if session.get_bind().name == "postgresql":
tableargs.extend(
[
sa.Index(
"ix_ingredient_units_name_normalized_gin",
"name_normalized",
unique=False,
postgresql_using="gin",
postgresql_ops={
"name_normalized": "gin_trgm_ops",
},
),
sa.Index(
"ix_ingredient_units_abbreviation_normalized_gin",
"abbreviation_normalized",
unique=False,
postgresql_using="gin",
postgresql_ops={
"abbreviation_normalized": "gin_trgm_ops",
},
),
]
)
self.__table_args__ = tuple(tableargs)
class IngredientFoodModel(SqlAlchemyBase, BaseMixins):
@@ -57,10 +103,39 @@ class IngredientFoodModel(SqlAlchemyBase, BaseMixins):
label_id: Mapped[GUID | None] = mapped_column(GUID, ForeignKey("multi_purpose_labels.id"), index=True)
label: Mapped[MultiPurposeLabel | None] = orm.relationship(MultiPurposeLabel, uselist=False, back_populates="foods")
# Automatically updated by sqlalchemy event, do not write to this manually
name_normalized: Mapped[str | None] = mapped_column(sa.String, index=True)
@api_extras
@auto_init()
def __init__(self, **_) -> None:
pass
def __init__(self, session: Session, name: str | None = None, **_) -> None:
if name is not None:
self.name_normalized = self.normalize(name)
tableargs = [
sa.Index(
"ix_ingredient_foods_name_normalized",
"name_normalized",
unique=False,
),
]
if session.get_bind().name == "postgresql":
tableargs.extend(
[
sa.Index(
"ix_ingredient_foods_name_normalized_gin",
"name_normalized",
unique=False,
postgresql_using="gin",
postgresql_ops={
"name_normalized": "gin_trgm_ops",
},
)
]
)
self.__table_args__ = tuple(tableargs)
class RecipeIngredientModel(SqlAlchemyBase, BaseMixins):
@@ -92,10 +167,10 @@ class RecipeIngredientModel(SqlAlchemyBase, BaseMixins):
def __init__(self, session: Session, note: str | None = None, orginal_text: str | None = None, **_) -> None:
# SQLAlchemy events do not seem to register things that are set during auto_init
if note is not None:
self.note_normalized = unidecode(note).lower().strip()
self.note_normalized = self.normalize(note)
if orginal_text is not None:
self.orginal_text = unidecode(orginal_text).lower().strip()
self.orginal_text = self.normalize(orginal_text)
tableargs = [ # base set of indices
sa.Index(
@@ -136,17 +211,41 @@ class RecipeIngredientModel(SqlAlchemyBase, BaseMixins):
self.__table_args__ = tuple(tableargs)
@event.listens_for(RecipeIngredientModel.note, "set")
def receive_note(target: RecipeIngredientModel, value: str, oldvalue, initiator):
@event.listens_for(IngredientUnitModel.name, "set")
def receive_unit_name(target: IngredientUnitModel, value: str | None, oldvalue, initiator):
if value is not None:
target.note_normalized = unidecode(value).lower().strip()
target.name_normalized = IngredientUnitModel.normalize(value)
else:
target.name_normalized = None
@event.listens_for(IngredientUnitModel.abbreviation, "set")
def receive_unit_abbreviation(target: IngredientUnitModel, value: str | None, oldvalue, initiator):
if value is not None:
target.abbreviation_normalized = IngredientUnitModel.normalize(value)
else:
target.abbreviation_normalized = None
@event.listens_for(IngredientFoodModel.name, "set")
def receive_food_name(target: IngredientFoodModel, value: str | None, oldvalue, initiator):
if value is not None:
target.name_normalized = IngredientFoodModel.normalize(value)
else:
target.name_normalized = None
@event.listens_for(RecipeIngredientModel.note, "set")
def receive_ingredient_note(target: RecipeIngredientModel, value: str | None, oldvalue, initiator):
if value is not None:
target.note_normalized = RecipeIngredientModel.normalize(value)
else:
target.note_normalized = None
@event.listens_for(RecipeIngredientModel.original_text, "set")
def receive_original_text(target: RecipeIngredientModel, value: str, oldvalue, initiator):
def receive_ingredient_original_text(target: RecipeIngredientModel, value: str | None, oldvalue, initiator):
if value is not None:
target.original_text_normalized = unidecode(value).lower().strip()
target.original_text_normalized = RecipeIngredientModel.normalize(value)
else:
target.original_text_normalized = None

View File

@@ -6,7 +6,6 @@ import sqlalchemy.orm as orm
from sqlalchemy import event
from sqlalchemy.ext.orderinglist import ordering_list
from sqlalchemy.orm import Mapped, mapped_column, validates
from text_unidecode import unidecode
from mealie.db.models._model_utils.guid import GUID
@@ -189,10 +188,10 @@ class RecipeModel(SqlAlchemyBase, BaseMixins):
# SQLAlchemy events do not seem to register things that are set during auto_init
if name is not None:
self.name_normalized = unidecode(name).lower().strip()
self.name_normalized = self.normalize(name)
if description is not None:
self.description_normalized = unidecode(description).lower().strip()
self.description_normalized = self.normalize(description)
tableargs = [ # base set of indices
sa.UniqueConstraint("slug", "group_id", name="recipe_slug_group_id_key"),
@@ -237,12 +236,12 @@ class RecipeModel(SqlAlchemyBase, BaseMixins):
@event.listens_for(RecipeModel.name, "set")
def receive_name(target: RecipeModel, value: str, oldvalue, initiator):
target.name_normalized = unidecode(value).lower().strip()
target.name_normalized = RecipeModel.normalize(value)
@event.listens_for(RecipeModel.description, "set")
def receive_description(target: RecipeModel, value: str, oldvalue, initiator):
if value is not None:
target.description_normalized = unidecode(value).lower().strip()
target.description_normalized = RecipeModel.normalize(value)
else:
target.description_normalized = None

View File

@@ -312,6 +312,10 @@ class RepositoryGeneric(Generic[Schema, Model]):
if search:
q = self.add_search_to_query(q, eff_schema, search)
if not pagination_result.order_by and not search:
# default ordering if not searching
pagination_result.order_by = "created_at"
q, count, total_pages = self.add_pagination_to_query(q, pagination_result)
# Apply options late, so they do not get used for counting
@@ -371,16 +375,14 @@ class RepositoryGeneric(Generic[Schema, Model]):
if pagination.page < 1:
pagination.page = 1
if pagination.order_by:
query = self.add_order_by_to_query(query, pagination)
query = self.add_order_by_to_query(query, pagination)
return query.limit(pagination.per_page).offset((pagination.page - 1) * pagination.per_page), count, total_pages
def add_order_by_to_query(self, query: Select, pagination: PaginationQuery) -> Select:
if not pagination.order_by:
return query
if pagination.order_by == "random":
elif pagination.order_by == "random":
# randomize outside of database, since not all db's can set random seeds
# this solution is db-independent & stable to paging
temp_query = query.with_only_columns(self.model.id)

View File

@@ -203,6 +203,10 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
if search:
q = self.add_search_to_query(q, self.schema, search)
if not pagination_result.order_by and not search:
# default ordering if not searching
pagination_result.order_by = "created_at"
q, count, total_pages = self.add_pagination_to_query(q, pagination_result)
try:

View File

@@ -12,10 +12,10 @@ router = APIRouter(prefix="/parser")
class IngredientParserController(BaseUserController):
@router.post("/ingredients", response_model=list[ParsedIngredient])
def parse_ingredients(self, ingredients: IngredientsRequest):
parser = get_parser(ingredients.parser)
parser = get_parser(ingredients.parser, self.group_id, self.session)
return parser.parse(ingredients.ingredients)
@router.post("/ingredient", response_model=ParsedIngredient)
def parse_ingredient(self, ingredient: IngredientRequest):
parser = get_parser(ingredient.parser)
parser = get_parser(ingredient.parser, self.group_id, self.session)
return parser.parse([ingredient.ingredient])[0]

View File

@@ -51,7 +51,8 @@ class IngredientFood(CreateIngredientFood):
created_at: datetime.datetime | None
update_at: datetime.datetime | None
_searchable_properties: ClassVar[list[str]] = ["name", "description"]
_searchable_properties: ClassVar[list[str]] = ["name_normalized"]
_normalize_search: ClassVar[bool] = True
class Config:
orm_mode = True
@@ -81,7 +82,8 @@ class IngredientUnit(CreateIngredientUnit):
created_at: datetime.datetime | None
update_at: datetime.datetime | None
_searchable_properties: ClassVar[list[str]] = ["name", "abbreviation", "description"]
_searchable_properties: ClassVar[list[str]] = ["name_normalized", "abbreviation_normalized"]
_normalize_search: ClassVar[bool] = True
class Config:
orm_mode = True

View File

@@ -34,7 +34,7 @@ class RecipeSearchQuery(MealieModel):
class PaginationQuery(MealieModel):
page: int = 1
per_page: int = 50
order_by: str = "created_at"
order_by: str | None = None
order_by_null_position: OrderByNullPosition | None = None
order_direction: OrderDirection = OrderDirection.desc
query_filter: str | None = None

View File

@@ -1,20 +1,32 @@
from abc import ABC, abstractmethod
from fractions import Fraction
from typing import TypeVar
from pydantic import UUID4, BaseModel
from rapidfuzz import fuzz, process
from sqlalchemy.orm import Session
from mealie.core.root_logger import get_logger
from mealie.db.models.recipe.ingredient import IngredientFoodModel, IngredientUnitModel
from mealie.repos.all_repositories import get_repositories
from mealie.repos.repository_factory import AllRepositories
from mealie.schema.recipe import RecipeIngredient
from mealie.schema.recipe.recipe_ingredient import (
MAX_INGREDIENT_DENOMINATOR,
CreateIngredientFood,
CreateIngredientUnit,
IngredientConfidence,
IngredientFood,
IngredientUnit,
ParsedIngredient,
RegisteredParser,
)
from mealie.schema.response.pagination import PaginationQuery
from . import brute, crfpp
logger = get_logger(__name__)
T = TypeVar("T", bound=BaseModel)
class ABCIngredientParser(ABC):
@@ -22,6 +34,53 @@ class ABCIngredientParser(ABC):
Abstract class for ingredient parsers.
"""
def __init__(self, group_id: UUID4, session: Session) -> None:
self.group_id = group_id
self.session = session
self._foods_by_name: dict[str, IngredientFood] | None = None
self._units_by_name: dict[str, IngredientUnit] | None = None
@property
def _repos(self) -> AllRepositories:
return get_repositories(self.session)
@property
def foods_by_normalized_name(self) -> dict[str, IngredientFood]:
if self._foods_by_name is None:
foods_repo = self._repos.ingredient_foods.by_group(self.group_id)
query = PaginationQuery(page=1, per_page=-1)
all_foods = foods_repo.page_all(query).items
self._foods_by_name = {IngredientFoodModel.normalize(food.name): food for food in all_foods if food.name}
return self._foods_by_name
@property
def units_by_normalized_name_or_abbreviation(self) -> dict[str, IngredientUnit]:
if self._units_by_name is None:
units_repo = self._repos.ingredient_units.by_group(self.group_id)
query = PaginationQuery(page=1, per_page=-1)
all_units = units_repo.page_all(query).items
self._units_by_name = {
IngredientUnitModel.normalize(unit.name): unit for unit in all_units if unit.name
} | {IngredientUnitModel.normalize(unit.abbreviation): unit for unit in all_units if unit.abbreviation}
return self._units_by_name
@property
def food_fuzzy_match_threshold(self) -> int:
"""Minimum threshold to fuzzy match against a database food search"""
return 85
@property
def unit_fuzzy_match_threshold(self) -> int:
"""Minimum threshold to fuzzy match against a database unit search"""
return 70
@abstractmethod
def parse_one(self, ingredient_string: str) -> ParsedIngredient:
...
@@ -30,19 +89,64 @@ class ABCIngredientParser(ABC):
def parse(self, ingredients: list[str]) -> list[ParsedIngredient]:
...
@classmethod
def find_match(cls, match_value: str, *, store_map: dict[str, T], fuzzy_match_threshold: int = 0) -> T | None:
# check for literal matches
if match_value in store_map:
return store_map[match_value]
# fuzzy match against food store
fuzz_result = process.extractOne(match_value, store_map.keys(), scorer=fuzz.ratio)
if fuzz_result is None:
return None
choice, score, _ = fuzz_result
if score < fuzzy_match_threshold:
return None
else:
return store_map[choice]
def find_food_match(self, food: IngredientFood | CreateIngredientFood) -> IngredientFood | None:
if isinstance(food, IngredientFood):
return food
match_value = IngredientFoodModel.normalize(food.name)
return self.find_match(
match_value,
store_map=self.foods_by_normalized_name,
fuzzy_match_threshold=self.food_fuzzy_match_threshold,
)
def find_unit_match(self, unit: IngredientUnit | CreateIngredientUnit) -> IngredientUnit | None:
if isinstance(unit, IngredientUnit):
return unit
match_value = IngredientUnitModel.normalize(unit.name)
return self.find_match(
match_value,
store_map=self.units_by_normalized_name_or_abbreviation,
fuzzy_match_threshold=self.unit_fuzzy_match_threshold,
)
def find_ingredient_match(self, ingredient: ParsedIngredient) -> ParsedIngredient:
if ingredient.ingredient.food and (food_match := self.find_food_match(ingredient.ingredient.food)):
ingredient.ingredient.food = food_match
if ingredient.ingredient.unit and (unit_match := self.find_unit_match(ingredient.ingredient.unit)):
ingredient.ingredient.unit = unit_match
return ingredient
class BruteForceParser(ABCIngredientParser):
"""
Brute force ingredient parser.
"""
def __init__(self) -> None:
pass
def parse_one(self, ingredient: str) -> ParsedIngredient:
bfi = brute.parse(ingredient)
return ParsedIngredient(
parsed_ingredient = ParsedIngredient(
input=ingredient,
ingredient=RecipeIngredient(
unit=CreateIngredientUnit(name=bfi.unit),
@@ -53,6 +157,8 @@ class BruteForceParser(ABCIngredientParser):
),
)
return self.find_ingredient_match(parsed_ingredient)
def parse(self, ingredients: list[str]) -> list[ParsedIngredient]:
return [self.parse_one(ingredient) for ingredient in ingredients]
@@ -62,9 +168,6 @@ class NLPParser(ABCIngredientParser):
Class for CRFPP ingredient parsers.
"""
def __init__(self) -> None:
pass
def _crf_to_ingredient(self, crf_model: crfpp.CRFIngredient) -> ParsedIngredient:
ingredient = None
@@ -87,7 +190,7 @@ class NLPParser(ABCIngredientParser):
note=crf_model.input,
)
return ParsedIngredient(
parsed_ingredient = ParsedIngredient(
input=crf_model.input,
ingredient=ingredient,
confidence=IngredientConfidence(
@@ -97,6 +200,8 @@ class NLPParser(ABCIngredientParser):
),
)
return self.find_ingredient_match(parsed_ingredient)
def parse(self, ingredients: list[str]) -> list[ParsedIngredient]:
crf_models = crfpp.convert_list_to_crf_model(ingredients)
return [self._crf_to_ingredient(crf_model) for crf_model in crf_models]
@@ -112,9 +217,9 @@ __registrar = {
}
def get_parser(parser: RegisteredParser) -> ABCIngredientParser:
def get_parser(parser: RegisteredParser, group_id: UUID4, session: Session) -> ABCIngredientParser:
"""
get_parser returns an ingrdeint parser based on the string enum value
passed in.
"""
return __registrar.get(parser, NLPParser)()
return __registrar.get(parser, NLPParser)(group_id, session)