diff --git a/frontend/components/global/AutoForm.vue b/frontend/components/global/AutoForm.vue index 473131712..5bdbbbe3b 100644 --- a/frontend/components/global/AutoForm.vue +++ b/frontend/components/global/AutoForm.vue @@ -88,6 +88,25 @@ validate-on="input" /> + + + (() => [ { cols: 8, label: i18n.t("general.name"), @@ -262,6 +278,59 @@ const formItems: AutoFormItems = [ varName: "description", type: fieldTypes.TEXT, }, + { + section: i18n.t("data-pages.units.standardization"), + sectionDetails: i18n.t("data-pages.units.standardization-description"), + cols: 2, + varName: "standardQuantity", + type: fieldTypes.NUMBER, + numberInputConfig: { + min: 0, + max: undefined, + precision: undefined, + controlVariant: "hidden", + }, + }, + { + cols: 10, + varName: "standardUnit", + type: fieldTypes.SELECT, + selectReturnValue: "value", + options: [ + { + text: i18n.t("data-pages.units.standard-unit-labels.fluid-ounce"), + value: "fluid_ounce", + }, + { + text: i18n.t("data-pages.units.standard-unit-labels.cup"), + value: "cup", + }, + { + text: i18n.t("data-pages.units.standard-unit-labels.ounce"), + value: "ounce", + }, + { + text: i18n.t("data-pages.units.standard-unit-labels.pound"), + value: "pound", + }, + { + text: i18n.t("data-pages.units.standard-unit-labels.milliliter"), + value: "milliliter", + }, + { + text: i18n.t("data-pages.units.standard-unit-labels.liter"), + value: "liter", + }, + { + text: i18n.t("data-pages.units.standard-unit-labels.gram"), + value: "gram", + }, + { + text: i18n.t("data-pages.units.standard-unit-labels.kilogram"), + value: "kilogram", + }, + ] as StandardizedUnitTypeOption[], + }, { section: i18n.t("general.settings"), cols: 4, @@ -275,7 +344,7 @@ const formItems: AutoFormItems = [ varName: "fraction", type: fieldTypes.BOOLEAN, }, -]; +]); // ============================================================ // Create diff --git a/frontend/types/auto-forms.ts b/frontend/types/auto-forms.ts index f52165688..b232f66b6 100644 --- a/frontend/types/auto-forms.ts +++ b/frontend/types/auto-forms.ts @@ -1,6 +1,15 @@ import type { VForm as VuetifyForm } from "vuetify/components/VForm"; -type FormFieldType = "text" | "textarea" | "list" | "select" | "object" | "boolean" | "color" | "password"; +type FormFieldType + = | "text" + | "textarea" + | "number" + | "list" + | "select" + | "object" + | "boolean" + | "color" + | "password"; export type FormValidationRule = (value: any) => boolean | string; @@ -9,6 +18,13 @@ export interface FormSelectOption { value?: string; } +export interface FormFieldNumberInputConfig { + min?: number; + max?: number; + precision?: number; + controlVariant?: "split" | "default" | "hidden" | "stacked"; +} + export interface FormField { section?: string; sectionDetails?: string; @@ -20,6 +36,7 @@ export interface FormField { rules?: FormValidationRule[]; disableUpdate?: boolean; disableCreate?: boolean; + numberInputConfig?: FormFieldNumberInputConfig; options?: FormSelectOption[]; selectReturnValue?: "text" | "value"; } diff --git a/mealie/alembic/versions/2026-02-21-17.59.01_a39c7f1826e3_add_unit_standardization_fields.py b/mealie/alembic/versions/2026-02-21-17.59.01_a39c7f1826e3_add_unit_standardization_fields.py new file mode 100644 index 000000000..c0898392a --- /dev/null +++ b/mealie/alembic/versions/2026-02-21-17.59.01_a39c7f1826e3_add_unit_standardization_fields.py @@ -0,0 +1,106 @@ +"""add unit standardization fields + +Revision ID: a39c7f1826e3 +Revises: 1d9a002d7234 +Create Date: 2026-02-21 17:59:01.161812 + +""" + +import sqlalchemy as sa +from sqlalchemy import orm +from alembic import op +from mealie.repos.repository_units import RepositoryUnit +from mealie.core.root_logger import get_logger +from mealie.db.models._model_utils.guid import GUID +from mealie.repos.seed.seeders import IngredientUnitsSeeder +from mealie.lang.locale_config import LOCALE_CONFIG + +# revision identifiers, used by Alembic. +revision = "a39c7f1826e3" +down_revision: str | None = "1d9a002d7234" +branch_labels: str | tuple[str, ...] | None = None +depends_on: str | tuple[str, ...] | None = None + +logger = get_logger() + + +class SqlAlchemyBase(orm.DeclarativeBase): ... + + +class IngredientUnitModel(SqlAlchemyBase): + __tablename__ = "ingredient_units" + + id: orm.Mapped[GUID] = orm.mapped_column(GUID, primary_key=True, default=GUID.generate) + name: orm.Mapped[str | None] = orm.mapped_column(sa.String) + plural_name: orm.Mapped[str | None] = orm.mapped_column(sa.String) + abbreviation: orm.Mapped[str | None] = orm.mapped_column(sa.String) + plural_abbreviation: orm.Mapped[str | None] = orm.mapped_column(sa.String) + standard_quantity: orm.Mapped[float | None] = orm.mapped_column(sa.Float) + standard_unit: orm.Mapped[str | None] = orm.mapped_column(sa.String) + + +def populate_standards() -> None: + bind = op.get_bind() + + session = orm.Session(bind) + + # We aren't using most of the functionality of this class, so we pass dummy args + repo = RepositoryUnit(None, None, None, None, group_id=None) # type: ignore + + stmt = sa.select(IngredientUnitModel) + units = session.execute(stmt).scalars().all() + if not units: + return + + # Manually build repo._standardized_unit_map with all locales + repo._standardized_unit_map = {} + for locale in LOCALE_CONFIG: + locale_file = IngredientUnitsSeeder.get_file(locale) + for unit_key, unit in IngredientUnitsSeeder.load_file(locale_file).items(): + for prop in ["name", "plural_name", "abbreviation"]: + val = unit.get(prop) + if val and isinstance(val, str): + repo._standardized_unit_map[val.strip().lower()] = unit_key + + for unit in units: + unit_data = { + "name": unit.name, + "plural_name": unit.plural_name, + "abbreviation": unit.abbreviation, + "plural_abbreviation": unit.plural_abbreviation, + } + + standardized_data = repo._add_standardized_unit(unit_data) + std_q = standardized_data.get("standard_quantity") + std_u = standardized_data.get("standard_unit") + if std_q and std_u: + logger.info(f"Found unit '{unit.name}', which is standardized as '{std_q} * {std_u}'") + unit.standard_quantity = std_q + unit.standard_unit = std_u + + session.commit() + session.close() + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("ingredient_units", schema=None) as batch_op: + batch_op.add_column(sa.Column("standard_quantity", sa.Float(), nullable=True)) + batch_op.add_column(sa.Column("standard_unit", sa.String(), nullable=True)) + + # ### end Alembic commands ### + + # Populate standardized units for existing records + try: + populate_standards() + except Exception: + logger.exception("Failed to populate unit standards, skipping...") + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("ingredient_units", schema=None) as batch_op: + batch_op.drop_column("standard_unit") + batch_op.drop_column("standard_quantity") + + # ### end Alembic commands ### diff --git a/mealie/db/models/recipe/ingredient.py b/mealie/db/models/recipe/ingredient.py index f65830953..82e6f89af 100644 --- a/mealie/db/models/recipe/ingredient.py +++ b/mealie/db/models/recipe/ingredient.py @@ -52,6 +52,10 @@ class IngredientUnitModel(SqlAlchemyBase, BaseMixins): cascade="all, delete, delete-orphan", ) + # Standardization + standard_quantity: Mapped[float | None] = mapped_column(Float) + standard_unit: Mapped[str | None] = mapped_column(String) + # Automatically updated by sqlalchemy event, do not write to this manually name_normalized: Mapped[str | None] = mapped_column(sa.String, index=True) plural_name_normalized: Mapped[str | None] = mapped_column(sa.String, index=True) diff --git a/mealie/lang/locale_config.py b/mealie/lang/locale_config.py index 9ebef497f..1c7d9dbca 100644 --- a/mealie/lang/locale_config.py +++ b/mealie/lang/locale_config.py @@ -15,52 +15,63 @@ class LocalePluralFoodHandling(StrEnum): @dataclass class LocaleConfig: + key: str name: str dir: LocaleTextDirection = LocaleTextDirection.LTR plural_food_handling: LocalePluralFoodHandling = LocalePluralFoodHandling.ALWAYS LOCALE_CONFIG: dict[str, LocaleConfig] = { - "af-ZA": LocaleConfig(name="Afrikaans (Afrikaans)"), - "ar-SA": LocaleConfig(name="العربية (Arabic)", dir=LocaleTextDirection.RTL), - "bg-BG": LocaleConfig(name="Български (Bulgarian)"), - "ca-ES": LocaleConfig(name="Català (Catalan)"), - "cs-CZ": LocaleConfig(name="Čeština (Czech)"), - "da-DK": LocaleConfig(name="Dansk (Danish)"), - "de-DE": LocaleConfig(name="Deutsch (German)"), - "el-GR": LocaleConfig(name="Ελληνικά (Greek)"), - "en-GB": LocaleConfig(name="British English", plural_food_handling=LocalePluralFoodHandling.WITHOUT_UNIT), - "en-US": LocaleConfig(name="American English", plural_food_handling=LocalePluralFoodHandling.WITHOUT_UNIT), - "es-ES": LocaleConfig(name="Español (Spanish)"), - "et-EE": LocaleConfig(name="Eesti (Estonian)"), - "fi-FI": LocaleConfig(name="Suomi (Finnish)"), - "fr-BE": LocaleConfig(name="Belge (Belgian)"), - "fr-CA": LocaleConfig(name="Français canadien (Canadian French)"), - "fr-FR": LocaleConfig(name="Français (French)"), - "gl-ES": LocaleConfig(name="Galego (Galician)"), - "he-IL": LocaleConfig(name="עברית (Hebrew)", dir=LocaleTextDirection.RTL), - "hr-HR": LocaleConfig(name="Hrvatski (Croatian)"), - "hu-HU": LocaleConfig(name="Magyar (Hungarian)"), - "is-IS": LocaleConfig(name="Íslenska (Icelandic)"), - "it-IT": LocaleConfig(name="Italiano (Italian)"), - "ja-JP": LocaleConfig(name="日本語 (Japanese)", plural_food_handling=LocalePluralFoodHandling.NEVER), - "ko-KR": LocaleConfig(name="한국어 (Korean)", plural_food_handling=LocalePluralFoodHandling.NEVER), - "lt-LT": LocaleConfig(name="Lietuvių (Lithuanian)"), - "lv-LV": LocaleConfig(name="Latviešu (Latvian)"), - "nl-NL": LocaleConfig(name="Nederlands (Dutch)"), - "no-NO": LocaleConfig(name="Norsk (Norwegian)"), - "pl-PL": LocaleConfig(name="Polski (Polish)"), - "pt-BR": LocaleConfig(name="Português do Brasil (Brazilian Portuguese)"), - "pt-PT": LocaleConfig(name="Português (Portuguese)"), - "ro-RO": LocaleConfig(name="Română (Romanian)"), - "ru-RU": LocaleConfig(name="Pусский (Russian)"), - "sk-SK": LocaleConfig(name="Slovenčina (Slovak)"), - "sl-SI": LocaleConfig(name="Slovenščina (Slovenian)"), - "sr-SP": LocaleConfig(name="српски (Serbian)"), - "sv-SE": LocaleConfig(name="Svenska (Swedish)"), - "tr-TR": LocaleConfig(name="Türkçe (Turkish)", plural_food_handling=LocalePluralFoodHandling.NEVER), - "uk-UA": LocaleConfig(name="Українська (Ukrainian)"), - "vi-VN": LocaleConfig(name="Tiếng Việt (Vietnamese)", plural_food_handling=LocalePluralFoodHandling.NEVER), - "zh-CN": LocaleConfig(name="简体中文 (Chinese simplified)", plural_food_handling=LocalePluralFoodHandling.NEVER), - "zh-TW": LocaleConfig(name="繁體中文 (Chinese traditional)", plural_food_handling=LocalePluralFoodHandling.NEVER), + "af-ZA": LocaleConfig(key="af-ZA", name="Afrikaans (Afrikaans)"), + "ar-SA": LocaleConfig(key="ar-SA", name="العربية (Arabic)", dir=LocaleTextDirection.RTL), + "bg-BG": LocaleConfig(key="bg-BG", name="Български (Bulgarian)"), + "ca-ES": LocaleConfig(key="ca-ES", name="Català (Catalan)"), + "cs-CZ": LocaleConfig(key="cs-CZ", name="Čeština (Czech)"), + "da-DK": LocaleConfig(key="da-DK", name="Dansk (Danish)"), + "de-DE": LocaleConfig(key="de-DE", name="Deutsch (German)"), + "el-GR": LocaleConfig(key="el-GR", name="Ελληνικά (Greek)"), + "en-GB": LocaleConfig( + key="en-GB", name="British English", plural_food_handling=LocalePluralFoodHandling.WITHOUT_UNIT + ), + "en-US": LocaleConfig( + key="en-US", name="American English", plural_food_handling=LocalePluralFoodHandling.WITHOUT_UNIT + ), + "es-ES": LocaleConfig(key="es-ES", name="Español (Spanish)"), + "et-EE": LocaleConfig(key="et-EE", name="Eesti (Estonian)"), + "fi-FI": LocaleConfig(key="fi-FI", name="Suomi (Finnish)"), + "fr-BE": LocaleConfig(key="fr-BE", name="Belge (Belgian)"), + "fr-CA": LocaleConfig(key="fr-CA", name="Français canadien (Canadian French)"), + "fr-FR": LocaleConfig(key="fr-FR", name="Français (French)"), + "gl-ES": LocaleConfig(key="gl-ES", name="Galego (Galician)"), + "he-IL": LocaleConfig(key="he-IL", name="עברית (Hebrew)", dir=LocaleTextDirection.RTL), + "hr-HR": LocaleConfig(key="hr-HR", name="Hrvatski (Croatian)"), + "hu-HU": LocaleConfig(key="hu-HU", name="Magyar (Hungarian)"), + "is-IS": LocaleConfig(key="is-IS", name="Íslenska (Icelandic)"), + "it-IT": LocaleConfig(key="it-IT", name="Italiano (Italian)"), + "ja-JP": LocaleConfig(key="ja-JP", name="日本語 (Japanese)", plural_food_handling=LocalePluralFoodHandling.NEVER), + "ko-KR": LocaleConfig(key="ko-KR", name="한국어 (Korean)", plural_food_handling=LocalePluralFoodHandling.NEVER), + "lt-LT": LocaleConfig(key="lt-LT", name="Lietuvių (Lithuanian)"), + "lv-LV": LocaleConfig(key="lv-LV", name="Latviešu (Latvian)"), + "nl-NL": LocaleConfig(key="nl-NL", name="Nederlands (Dutch)"), + "no-NO": LocaleConfig(key="no-NO", name="Norsk (Norwegian)"), + "pl-PL": LocaleConfig(key="pl-PL", name="Polski (Polish)"), + "pt-BR": LocaleConfig(key="pt-BR", name="Português do Brasil (Brazilian Portuguese)"), + "pt-PT": LocaleConfig(key="pt-PT", name="Português (Portuguese)"), + "ro-RO": LocaleConfig(key="ro-RO", name="Română (Romanian)"), + "ru-RU": LocaleConfig(key="ru-RU", name="Pусский (Russian)"), + "sk-SK": LocaleConfig(key="sk-SK", name="Slovenčina (Slovak)"), + "sl-SI": LocaleConfig(key="sl-SI", name="Slovenščina (Slovenian)"), + "sr-SP": LocaleConfig(key="sr-SP", name="српски (Serbian)"), + "sv-SE": LocaleConfig(key="sv-SE", name="Svenska (Swedish)"), + "tr-TR": LocaleConfig(key="tr-TR", name="Türkçe (Turkish)", plural_food_handling=LocalePluralFoodHandling.NEVER), + "uk-UA": LocaleConfig(key="uk-UA", name="Українська (Ukrainian)"), + "vi-VN": LocaleConfig( + key="vi-VN", name="Tiếng Việt (Vietnamese)", plural_food_handling=LocalePluralFoodHandling.NEVER + ), + "zh-CN": LocaleConfig( + key="zh-CN", name="简体中文 (Chinese simplified)", plural_food_handling=LocalePluralFoodHandling.NEVER + ), + "zh-TW": LocaleConfig( + key="zh-TW", name="繁體中文 (Chinese traditional)", plural_food_handling=LocalePluralFoodHandling.NEVER + ), } diff --git a/mealie/repos/repository_units.py b/mealie/repos/repository_units.py index b9e3a1496..3de3b571b 100644 --- a/mealie/repos/repository_units.py +++ b/mealie/repos/repository_units.py @@ -1,17 +1,119 @@ -from pydantic import UUID4 +from collections.abc import Iterable + +from pydantic import UUID4, BaseModel from sqlalchemy import select from mealie.db.models.recipe.ingredient import IngredientUnitModel -from mealie.schema.recipe.recipe_ingredient import IngredientUnit +from mealie.lang.providers import get_locale_context +from mealie.schema.recipe.recipe_ingredient import IngredientUnit, StandardizedUnitType from .repository_generic import GroupRepositoryGeneric class RepositoryUnit(GroupRepositoryGeneric[IngredientUnit, IngredientUnitModel]): + _standardized_unit_map: dict[str, str] | None = None + + @property + def standardized_unit_map(self) -> dict[str, str]: + """A map of potential known units to its standardized name in our seed data""" + + if self._standardized_unit_map is None: + from .seed.seeders import IngredientUnitsSeeder + + ctx = get_locale_context() + if ctx: + locale = ctx[1].key + else: + locale = None + + self._standardized_unit_map = {} + locale_file = IngredientUnitsSeeder.get_file(locale=locale) + for unit_key, unit in IngredientUnitsSeeder.load_file(locale_file).items(): + for prop in ["name", "plural_name", "abbreviation"]: + val = unit.get(prop) + if val and isinstance(val, str): + self._standardized_unit_map[val.strip().lower()] = unit_key + + return self._standardized_unit_map + def _get_unit(self, id: UUID4) -> IngredientUnitModel: stmt = select(self.model).filter_by(**self._filter_builder(**{"id": id})) return self.session.execute(stmt).scalars().one() + def _add_standardized_unit(self, data: BaseModel | dict) -> dict: + if not isinstance(data, dict): + data = data.model_dump() + + # Don't overwrite user data if it exists + if data.get("standard_quantity") is not None or data.get("standard_unit") is not None: + return data + + # Compare name attrs to translation files and see if there's a match to a known standard unit + for prop in ["name", "plural_name", "abbreviation", "plural_abbreviation"]: + val = data.get(prop) + if not (val and isinstance(val, str)): + continue + + standardized_unit_key = self.standardized_unit_map.get(val.strip().lower()) + if not standardized_unit_key: + continue + + match standardized_unit_key: + case "teaspoon": + data["standard_quantity"] = 1 / 6 + data["standard_unit"] = StandardizedUnitType.FLUID_OUNCE + case "tablespoon": + data["standard_quantity"] = 1 / 2 + data["standard_unit"] = StandardizedUnitType.FLUID_OUNCE + case "cup": + data["standard_quantity"] = 1 + data["standard_unit"] = StandardizedUnitType.CUP + case "fluid-ounce": + data["standard_quantity"] = 1 + data["standard_unit"] = StandardizedUnitType.FLUID_OUNCE + case "pint": + data["standard_quantity"] = 2 + data["standard_unit"] = StandardizedUnitType.CUP + case "quart": + data["standard_quantity"] = 4 + data["standard_unit"] = StandardizedUnitType.CUP + case "gallon": + data["standard_quantity"] = 16 + data["standard_unit"] = StandardizedUnitType.CUP + case "milliliter": + data["standard_quantity"] = 1 + data["standard_unit"] = StandardizedUnitType.MILLILITER + case "liter": + data["standard_quantity"] = 1 + data["standard_unit"] = StandardizedUnitType.LITER + case "pound": + data["standard_quantity"] = 1 + data["standard_unit"] = StandardizedUnitType.POUND + case "ounce": + data["standard_quantity"] = 1 + data["standard_unit"] = StandardizedUnitType.OUNCE + case "gram": + data["standard_quantity"] = 1 + data["standard_unit"] = StandardizedUnitType.GRAM + case "kilogram": + data["standard_quantity"] = 1 + data["standard_unit"] = StandardizedUnitType.KILOGRAM + case "milligram": + data["standard_quantity"] = 1 / 1000 + data["standard_unit"] = StandardizedUnitType.GRAM + case _: + continue + + return data + + def create(self, data: IngredientUnit | dict) -> IngredientUnit: + data = self._add_standardized_unit(data) + return super().create(data) + + def create_many(self, data: Iterable[IngredientUnit | dict]) -> list[IngredientUnit]: + data = [self._add_standardized_unit(i) for i in data] + return super().create_many(data) + def merge(self, from_unit: UUID4, to_unit: UUID4) -> IngredientUnit | None: from_model = self._get_unit(from_unit) to_model = self._get_unit(to_unit) diff --git a/mealie/repos/seed/_abstract_seeder.py b/mealie/repos/seed/_abstract_seeder.py index 7f9ce3d4e..b77fd718c 100644 --- a/mealie/repos/seed/_abstract_seeder.py +++ b/mealie/repos/seed/_abstract_seeder.py @@ -1,3 +1,4 @@ +import json from abc import ABC, abstractmethod from logging import Logger from pathlib import Path @@ -11,6 +12,8 @@ class AbstractSeeder(ABC): Abstract class for seeding data. """ + resources = Path(__file__).parent / "resources" + def __init__(self, db: AllRepositories, logger: Logger | None = None): """ Initialize the abstract seeder. @@ -19,7 +22,14 @@ class AbstractSeeder(ABC): """ self.repos = db self.logger = logger or get_logger("Data Seeder") - self.resources = Path(__file__).parent / "resources" + + @classmethod + @abstractmethod + def get_file(self, locale: str | None = None) -> Path: ... + + @classmethod + def load_file(self, file: Path) -> dict[str, dict]: + return json.loads(file.read_text(encoding="utf-8")) @abstractmethod def seed(self, locale: str | None = None) -> None: ... diff --git a/mealie/repos/seed/seeders.py b/mealie/repos/seed/seeders.py index 72f68c453..f0c1dfd27 100644 --- a/mealie/repos/seed/seeders.py +++ b/mealie/repos/seed/seeders.py @@ -1,4 +1,3 @@ -import json import pathlib from collections.abc import Generator from functools import cached_property @@ -21,9 +20,10 @@ class MultiPurposeLabelSeeder(AbstractSeeder): def service(self): return MultiPurposeLabelService(self.repos) - def get_file(self, locale: str | None = None) -> pathlib.Path: + @classmethod + def get_file(cls, locale: str | None = None) -> pathlib.Path: # Get the labels from the foods seed file now - locale_path = self.resources / "foods" / "locales" / f"{locale}.json" + locale_path = cls.resources / "foods" / "locales" / f"{locale}.json" return locale_path if locale_path.exists() else foods.en_US def get_all_labels(self) -> list[MultiPurposeLabelOut]: @@ -34,7 +34,7 @@ class MultiPurposeLabelSeeder(AbstractSeeder): current_label_names = {label.name for label in self.get_all_labels()} # load from the foods locale file and remove any empty strings - seed_label_names = set(filter(None, json.loads(file.read_text(encoding="utf-8")).keys())) # type: set[str] + seed_label_names = set(filter(None, self.load_file(file).keys())) # type: set[str] # only seed new labels to_seed_labels = seed_label_names - current_label_names for label in to_seed_labels: @@ -53,8 +53,9 @@ class MultiPurposeLabelSeeder(AbstractSeeder): class IngredientUnitsSeeder(AbstractSeeder): - def get_file(self, locale: str | None = None) -> pathlib.Path: - locale_path = self.resources / "units" / "locales" / f"{locale}.json" + @classmethod + def get_file(cls, locale: str | None = None) -> pathlib.Path: + locale_path = cls.resources / "units" / "locales" / f"{locale}.json" return locale_path if locale_path.exists() else units.en_US def get_all_units(self) -> list[IngredientUnit]: @@ -64,7 +65,7 @@ class IngredientUnitsSeeder(AbstractSeeder): file = self.get_file(locale) seen_unit_names = {unit.name for unit in self.get_all_units()} - for unit in json.loads(file.read_text(encoding="utf-8")).values(): + for unit in self.load_file(file).values(): if unit["name"] in seen_unit_names: continue @@ -88,8 +89,9 @@ class IngredientUnitsSeeder(AbstractSeeder): class IngredientFoodsSeeder(AbstractSeeder): - def get_file(self, locale: str | None = None) -> pathlib.Path: - locale_path = self.resources / "foods" / "locales" / f"{locale}.json" + @classmethod + def get_file(cls, locale: str | None = None) -> pathlib.Path: + locale_path = cls.resources / "foods" / "locales" / f"{locale}.json" return locale_path if locale_path.exists() else foods.en_US def get_label(self, value: str) -> MultiPurposeLabelOut | None: @@ -103,7 +105,7 @@ class IngredientFoodsSeeder(AbstractSeeder): # get all current unique foods seen_foods_names = {food.name for food in self.get_all_foods()} - for label, values in json.loads(file.read_text(encoding="utf-8")).items(): + for label, values in self.load_file(file).items(): label_out = self.get_label(label) for food_name, attributes in values["foods"].items(): diff --git a/mealie/schema/recipe/__init__.py b/mealie/schema/recipe/__init__.py index 0d91b89a4..48785e4a4 100644 --- a/mealie/schema/recipe/__init__.py +++ b/mealie/schema/recipe/__init__.py @@ -67,6 +67,7 @@ from .recipe_ingredient import ( RegisteredParser, SaveIngredientFood, SaveIngredientUnit, + StandardizedUnitType, UnitFoodBase, ) from .recipe_notes import RecipeNote @@ -159,6 +160,7 @@ __all__ = [ "RegisteredParser", "SaveIngredientFood", "SaveIngredientUnit", + "StandardizedUnitType", "UnitFoodBase", "RecipeSuggestionQuery", "RecipeSuggestionResponse", diff --git a/mealie/schema/recipe/recipe_ingredient.py b/mealie/schema/recipe/recipe_ingredient.py index 4a32f0ff9..3ccb06a5f 100644 --- a/mealie/schema/recipe/recipe_ingredient.py +++ b/mealie/schema/recipe/recipe_ingredient.py @@ -2,6 +2,7 @@ from __future__ import annotations import datetime import enum +from enum import StrEnum from fractions import Fraction from typing import ClassVar from uuid import UUID, uuid4 @@ -34,6 +35,28 @@ def display_fraction(fraction: Fraction): ) +class StandardizedUnitType(StrEnum): + """ + An arbitrary list of standardized units supported by unit conversions. + The backend doesn't really care what standardized unit you use, as long as it's recognized, + but defining them here keeps it consistant with the frontend. + """ + + # Imperial + FLUID_OUNCE = "fluid_ounce" + CUP = "cup" + + OUNCE = "ounce" + POUND = "pound" + + # Metric + MILLILITER = "milliliter" + LITER = "liter" + + GRAM = "gram" + KILOGRAM = "kilogram" + + class UnitFoodBase(MealieModel): id: UUID4 | None = None name: str @@ -109,9 +132,6 @@ class IngredientFood(CreateIngredientFood): except AttributeError: return v - def is_on_hand(self, household_slug: str) -> bool: - return household_slug in self.households_with_tool - class IngredientFoodPagination(PaginationBase): items: list[IngredientFood] @@ -130,7 +150,21 @@ class CreateIngredientUnit(UnitFoodBase): abbreviation: str = "" plural_abbreviation: str | None = "" use_abbreviation: bool = False + aliases: list[CreateIngredientUnitAlias] = [] + standard_quantity: float | None = None + standard_unit: str | None = None + + @model_validator(mode="after") + def validate_standardization_fields(self): + # If one is set, the other must be set. + # If quantity is <= 0, it's considered not set. + if not self.standard_unit: + self.standard_quantity = self.standard_unit = None + elif not ((self.standard_quantity or 0) > 0): + self.standard_quantity = self.standard_unit = None + + return self class SaveIngredientUnit(CreateIngredientUnit): diff --git a/mealie/schema/recipe/recipe_tool.py b/mealie/schema/recipe/recipe_tool.py index 9cc8ff4f1..9c541cd05 100644 --- a/mealie/schema/recipe/recipe_tool.py +++ b/mealie/schema/recipe/recipe_tool.py @@ -32,9 +32,6 @@ class RecipeToolOut(RecipeToolCreate): except AttributeError: return v - def is_on_hand(self, household_slug: str) -> bool: - return household_slug in self.households_with_tool - @classmethod def loader_options(cls) -> list[LoaderOption]: return [ diff --git a/mealie/services/household_services/shopping_lists.py b/mealie/services/household_services/shopping_lists.py index 4cdb6abf1..831877ca2 100644 --- a/mealie/services/household_services/shopping_lists.py +++ b/mealie/services/household_services/shopping_lists.py @@ -28,6 +28,7 @@ from mealie.schema.recipe.recipe_ingredient import ( ) from mealie.schema.response.pagination import OrderDirection, PaginationQuery from mealie.services.parser_services._base import DataMatcher +from mealie.services.parser_services.parser_utils import UnitConverter, merge_quantity_and_unit class ShoppingListService: @@ -41,8 +42,7 @@ class ShoppingListService: self.list_refs = repos.group_shopping_list_recipe_refs self.data_matcher = DataMatcher(self.repos, food_fuzzy_match_threshold=self.DEFAULT_FOOD_FUZZY_MATCH_THRESHOLD) - @staticmethod - def can_merge(item1: ShoppingListItemBase, item2: ShoppingListItemBase) -> bool: + def can_merge(self, item1: ShoppingListItemBase, item2: ShoppingListItemBase) -> bool: """Check to see if this item can be merged with another item""" if any( @@ -50,16 +50,28 @@ class ShoppingListService: item1.checked, item2.checked, item1.food_id != item2.food_id, - item1.unit_id != item2.unit_id, ] ): return False + # check if units match or if they're compatable + if item1.unit_id != item2.unit_id: + item1_unit = item1.unit or self.data_matcher.units_by_id.get(item1.unit_id) + item2_unit = item2.unit or self.data_matcher.units_by_id.get(item2.unit_id) + if not (item1_unit and item1_unit.standard_unit): + return False + if not (item2_unit and item2_unit.standard_unit): + return False + + uc = UnitConverter() + if not uc.can_convert(item1_unit.standard_unit, item2_unit.standard_unit): + return False + # if foods match, we can merge, otherwise compare the notes return bool(item1.food_id) or item1.note == item2.note - @staticmethod def merge_items( + self, from_item: ShoppingListItemCreate | ShoppingListItemUpdateBulk, to_item: ShoppingListItemCreate | ShoppingListItemUpdateBulk | ShoppingListItemOut, ) -> ShoppingListItemUpdate: @@ -69,7 +81,20 @@ class ShoppingListService: Attributes of the `to_item` take priority over the `from_item`, except extras with overlapping keys """ - to_item.quantity += from_item.quantity + to_item_unit = to_item.unit or self.data_matcher.units_by_id.get(to_item.unit_id) + from_item_unit = from_item.unit or self.data_matcher.units_by_id.get(from_item.unit_id) + if to_item_unit and to_item_unit.standard_unit and from_item_unit and from_item_unit.standard_unit: + merged_qty, merged_unit = merge_quantity_and_unit( + from_item.quantity or 0, from_item_unit, to_item.quantity or 0, to_item_unit + ) + to_item.quantity = merged_qty + to_item.unit_id = merged_unit.id + to_item.unit = merged_unit + + else: + # No conversion needed, just sum the quantities + to_item.quantity += from_item.quantity + if to_item.note != from_item.note: to_item.note = " | ".join([note for note in [to_item.note, from_item.note] if note]) diff --git a/mealie/services/parser_services/_base.py b/mealie/services/parser_services/_base.py index 89fc8f290..b3662990d 100644 --- a/mealie/services/parser_services/_base.py +++ b/mealie/services/parser_services/_base.py @@ -29,18 +29,38 @@ class DataMatcher: self._food_fuzzy_match_threshold = food_fuzzy_match_threshold self._unit_fuzzy_match_threshold = unit_fuzzy_match_threshold + + self._foods_by_id: dict[UUID4, IngredientFood] | None = None + self._units_by_id: dict[UUID4, IngredientUnit] | None = None + self._foods_by_alias: dict[str, IngredientFood] | None = None self._units_by_alias: dict[str, IngredientUnit] | None = None @property - def foods_by_alias(self) -> dict[str, IngredientFood]: - if self._foods_by_alias is None: + def foods_by_id(self) -> dict[UUID4, IngredientFood]: + if self._foods_by_id is None: foods_repo = self.repos.ingredient_foods query = PaginationQuery(page=1, per_page=-1) all_foods = foods_repo.page_all(query).items + self._foods_by_id = {food.id: food for food in all_foods} + return self._foods_by_id + + @property + def units_by_id(self) -> dict[UUID4, IngredientUnit]: + if self._units_by_id is None: + units_repo = self.repos.ingredient_units + query = PaginationQuery(page=1, per_page=-1) + all_units = units_repo.page_all(query).items + self._units_by_id = {unit.id: unit for unit in all_units} + + return self._units_by_id + + @property + def foods_by_alias(self) -> dict[str, IngredientFood]: + if self._foods_by_alias is None: foods_by_alias: dict[str, IngredientFood] = {} - for food in all_foods: + for food in self.foods_by_id.values(): if food.name: foods_by_alias[IngredientFoodModel.normalize(food.name)] = food if food.plural_name: @@ -57,12 +77,8 @@ class DataMatcher: @property def units_by_alias(self) -> dict[str, IngredientUnit]: if self._units_by_alias is None: - units_repo = self.repos.ingredient_units - query = PaginationQuery(page=1, per_page=-1) - all_units = units_repo.page_all(query).items - units_by_alias: dict[str, IngredientUnit] = {} - for unit in all_units: + for unit in self.units_by_id.values(): if unit.name: units_by_alias[IngredientUnitModel.normalize(unit.name)] = unit if unit.plural_name: diff --git a/mealie/services/parser_services/parser_utils/__init__.py b/mealie/services/parser_services/parser_utils/__init__.py index 481851a81..1593c92a3 100644 --- a/mealie/services/parser_services/parser_utils/__init__.py +++ b/mealie/services/parser_services/parser_utils/__init__.py @@ -1 +1,2 @@ from .string_utils import * +from .unit_utils import * diff --git a/mealie/services/parser_services/parser_utils/unit_utils.py b/mealie/services/parser_services/parser_utils/unit_utils.py new file mode 100644 index 000000000..ef73e7e29 --- /dev/null +++ b/mealie/services/parser_services/parser_utils/unit_utils.py @@ -0,0 +1,146 @@ +from typing import TYPE_CHECKING, Literal, overload + +from pint import Quantity, Unit, UnitRegistry + +if TYPE_CHECKING: + from mealie.schema.recipe.recipe_ingredient import CreateIngredientUnit + + +class UnitNotFound(Exception): + """Raised when trying to access a unit not found in the unit registry.""" + + def __init__(self, message: str = "Unit not found in unit registry"): + self.message = message + super().__init__(self.message) + + def __str__(self): + return f"{self.message}" + + +class UnitConverter: + def __init__(self): + self.ureg = UnitRegistry() + + def _resolve_ounce(self, unit_1: Unit, unit_2: Unit) -> tuple[Unit, Unit]: + """ + Often times "ounce" is used in place of "fluid ounce" in recipes. + When trying to convert/combine ounces with a volume, we can assume it should have been a fluid ounce. + This function will convert ounces to fluid ounces if the other unit is a volume. + """ + + OUNCE = self.ureg("ounce") + FL_OUNCE = self.ureg("fluid_ounce") + VOLUME = "[length] ** 3" + + if unit_1 == OUNCE and unit_2.dimensionality == VOLUME: + return FL_OUNCE, unit_2 + if unit_2 == OUNCE and unit_1.dimensionality == VOLUME: + return unit_1, FL_OUNCE + + return unit_1, unit_2 + + @overload + def parse(self, unit: str | Unit, strict: Literal[False] = False) -> str | Unit: ... + + @overload + def parse(self, unit: str | Unit, strict: Literal[True]) -> Unit: ... + + def parse(self, unit: str | Unit, strict: bool = False) -> str | Unit: + """ + Parse a string unit into a pint.Unit. + + If strict is False (default), returns a pint.Unit if it exists, otherwise returns the original string. + If strict is True, raises UnitNotFound instead of returning a string. + If the input is already a parsed pint.Unit, returns it as-is. + """ + if isinstance(unit, Unit): + return unit + + try: + return self.ureg(unit).units + except Exception as e: + if strict: + raise UnitNotFound(f"Unit '{unit}' not found in unit registry") from e + return unit + + def can_convert(self, unit: str | Unit, to_unit: str | Unit) -> bool: + """Whether or not a given unit can be converted into another unit.""" + + unit = self.parse(unit) + to_unit = self.parse(to_unit) + + if not (isinstance(unit, Unit) and isinstance(to_unit, Unit)): + return False + + unit, to_unit = self._resolve_ounce(unit, to_unit) + return unit.is_compatible_with(to_unit) + + def convert(self, quantity: float, unit: str | Unit, to_unit: str | Unit) -> tuple[float, Unit]: + """ + Convert a quantity and a unit into another unit. + + Returns tuple[quantity, unit] + """ + + unit = self.parse(unit, strict=True) + to_unit = self.parse(to_unit, strict=True) + unit, to_unit = self._resolve_ounce(unit, to_unit) + + qty = quantity * unit + converted = qty.to(to_unit) + return float(converted.magnitude), converted.units + + def merge(self, quantity_1: float, unit_1: str | Unit, quantity_2: float, unit_2: str | Unit) -> tuple[float, Unit]: + """Merge two quantities together""" + + unit_1 = self.parse(unit_1, strict=True) + unit_2 = self.parse(unit_2, strict=True) + unit_1, unit_2 = self._resolve_ounce(unit_1, unit_2) + + q1 = quantity_1 * unit_1 + q2 = quantity_2 * unit_2 + + out: Quantity = q1 + q2 + return float(out.magnitude), out.units + + +def merge_quantity_and_unit[T: CreateIngredientUnit]( + qty_1: float, unit_1: T, qty_2: float, unit_2: T +) -> tuple[float, T]: + """ + Merge a quantity and unit. + + Returns tuple[quantity, unit] + """ + + if not (unit_1.standard_quantity and unit_1.standard_unit and unit_2.standard_quantity and unit_2.standard_unit): + raise ValueError("Both units must contain standardized unit data") + + PINT_UNIT_1_TXT = "_mealie_unit_1" + PINT_UNIT_2_TXT = "_mealie_unit_2" + + uc = UnitConverter() + + # pre-process units to account for ounce -> fluid_ounce conversion + unit_1_standard = uc.parse(unit_1.standard_unit, strict=True) + unit_2_standard = uc.parse(unit_2.standard_unit, strict=True) + unit_1_standard, unit_2_standard = uc._resolve_ounce(unit_1_standard, unit_2_standard) + + # create custon unit definition so pint can handle them natively + uc.ureg.define(f"{PINT_UNIT_1_TXT} = {unit_1.standard_quantity} * {unit_1_standard}") + uc.ureg.define(f"{PINT_UNIT_2_TXT} = {unit_2.standard_quantity} * {unit_2_standard}") + + pint_unit_1 = uc.parse(PINT_UNIT_1_TXT) + pint_unit_2 = uc.parse(PINT_UNIT_2_TXT) + + merged_q, merged_u = uc.merge(qty_1, pint_unit_1, qty_2, pint_unit_2) + + # Convert to the bigger unit if quantity >= 1, else the smaller unit + merged_q, merged_u = uc.convert(merged_q, merged_u, max(pint_unit_1, pint_unit_2)) + if abs(merged_q) < 1: + merged_q, merged_u = uc.convert(merged_q, merged_u, min(pint_unit_1, pint_unit_2)) + + if str(merged_u) == PINT_UNIT_1_TXT: + return merged_q, unit_1 + else: + return merged_q, unit_2 diff --git a/pyproject.toml b/pyproject.toml index c649bef61..caede87ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dependencies = [ "typing-extensions==4.15.0", "itsdangerous==2.2.0", "ingredient-parser-nlp==2.5.0", + "pint>=0.25", ] [project.scripts] diff --git a/tests/data/__init__.py b/tests/data/__init__.py index 92c81ab68..dc23abb4a 100644 --- a/tests/data/__init__.py +++ b/tests/data/__init__.py @@ -4,6 +4,9 @@ CWD = Path(__file__).parent locale_dir = CWD / "locale" +backup_version_1d9a002d7234_1 = CWD / "backups/backup-version-1d9a002d7234-1.zip" +"""1d9a002d7234: add referenced_recipe to ingredients""" + backup_version_44e8d670719d_1 = CWD / "backups/backup-version-44e8d670719d-1.zip" """44e8d670719d: add extras to shopping lists, list items, and ingredient foods""" diff --git a/tests/data/backups/backup-version-1d9a002d7234-1.zip b/tests/data/backups/backup-version-1d9a002d7234-1.zip new file mode 100644 index 000000000..b3d64a555 Binary files /dev/null and b/tests/data/backups/backup-version-1d9a002d7234-1.zip differ diff --git a/tests/integration_tests/user_group_tests/test_group_seeder.py b/tests/integration_tests/user_group_tests/test_group_seeder.py index 87c58475b..8788f154e 100644 --- a/tests/integration_tests/user_group_tests/test_group_seeder.py +++ b/tests/integration_tests/user_group_tests/test_group_seeder.py @@ -15,14 +15,12 @@ def test_seed_foods(api_client: TestClient, unique_user: TestUser): CREATED_FOODS = 2687 database = unique_user.repos - # Check that the foods was created foods = database.ingredient_foods.page_all(PaginationQuery(page=1, per_page=-1)).items assert len(foods) == 0 resp = api_client.post(api_routes.groups_seeders_foods, json={"locale": "en-US"}, headers=unique_user.token) assert resp.status_code == 200 - # Check that the foods was created foods = database.ingredient_foods.page_all(PaginationQuery(page=1, per_page=-1)).items assert len(foods) == CREATED_FOODS @@ -31,29 +29,37 @@ def test_seed_units(api_client: TestClient, unique_user: TestUser): CREATED_UNITS = 24 database = unique_user.repos - # Check that the foods was created units = database.ingredient_units.page_all(PaginationQuery(page=1, per_page=-1)).items assert len(units) == 0 resp = api_client.post(api_routes.groups_seeders_units, json={"locale": "en-US"}, headers=unique_user.token) assert resp.status_code == 200 - # Check that the foods was created units = database.ingredient_units.page_all(PaginationQuery(page=1, per_page=-1)).items assert len(units) == CREATED_UNITS + # Check that the "pint" unit was created and includes standardized data + pint_found = False + for unit in units: + if unit.name != "pint": + continue + + pint_found = True + assert unit.standard_quantity == 2 + assert unit.standard_unit == "cup" + + assert pint_found + def test_seed_labels(api_client: TestClient, unique_user: TestUser): CREATED_LABELS = 32 database = unique_user.repos - # Check that the foods was created labels = database.group_multi_purpose_labels.page_all(PaginationQuery(page=1, per_page=-1)).items assert len(labels) == 0 resp = api_client.post(api_routes.groups_seeders_labels, json={"locale": "en-US"}, headers=unique_user.token) assert resp.status_code == 200 - # Check that the foods was created labels = database.group_multi_purpose_labels.page_all(PaginationQuery(page=1, per_page=-1)).items assert len(labels) == CREATED_LABELS diff --git a/tests/integration_tests/user_household_tests/test_group_shopping_list_items.py b/tests/integration_tests/user_household_tests/test_group_shopping_list_items.py index 83aaf8d85..641f37f28 100644 --- a/tests/integration_tests/user_household_tests/test_group_shopping_list_items.py +++ b/tests/integration_tests/user_household_tests/test_group_shopping_list_items.py @@ -7,7 +7,7 @@ from fastapi.testclient import TestClient from pydantic import UUID4 from mealie.schema.household.group_shopping_list import ShoppingListItemOut, ShoppingListOut -from mealie.schema.recipe.recipe_ingredient import SaveIngredientFood +from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientFood from tests import utils from tests.utils import api_routes from tests.utils.factories import random_int, random_string @@ -641,6 +641,96 @@ def test_shopping_list_items_with_zero_quantity( assert len(as_json["listItems"]) == len(normal_items + zero_qty_items) - 1 +def test_shopping_list_merge_standard_unit( + api_client: TestClient, unique_user: TestUser, shopping_list: ShoppingListOut +): + unit_1_cup_data = {"name": random_string(), "standardQuantity": 1, "standardUnit": "cup"} + unit_2_cup_data = {"name": random_string(), "standardQuantity": 2, "standardUnit": "cup"} + unit_1_out = api_client.post(api_routes.units, json=unit_1_cup_data, headers=unique_user.token) + unit_2_out = api_client.post(api_routes.units, json=unit_2_cup_data, headers=unique_user.token) + + unit_1 = IngredientUnit.model_validate(unit_1_out.json()) + unit_2 = IngredientUnit.model_validate(unit_2_out.json()) + + list_item_1_data = create_item(shopping_list.id, unit_id=str(unit_1.id), note="mealie-food") + list_item_2_data = create_item(shopping_list.id, unit_id=str(unit_2.id), note="mealie-food") + response = api_client.post( + api_routes.households_shopping_items_create_bulk, + json=[list_item_1_data, list_item_2_data], + headers=unique_user.token, + ) + + as_json = utils.assert_deserialize(response, 201) + assert len(as_json["createdItems"]) == 1 + + item_out = as_json["createdItems"][0] + + # should use larger "2 cup" unit (a la "pint") + assert item_out["unitId"] == str(unit_2.id) + # calculate quantity by summing base "cup" amount and dividing by 2 (a la pints) + assert item_out["quantity"] == (list_item_1_data["quantity"] + (list_item_2_data["quantity"] * 2)) / 2 + + +def test_shopping_list_merge_standard_unit_different_foods( + api_client: TestClient, unique_user: TestUser, shopping_list: ShoppingListOut +): + unit_1_cup_data = {"name": random_string(), "standardQuantity": 1, "standardUnit": "cup"} + unit_2_cup_data = {"name": random_string(), "standardQuantity": 2, "standardUnit": "cup"} + unit_1_out = api_client.post(api_routes.units, json=unit_1_cup_data, headers=unique_user.token) + unit_2_out = api_client.post(api_routes.units, json=unit_2_cup_data, headers=unique_user.token) + + unit_1 = IngredientUnit.model_validate(unit_1_out.json()) + unit_2 = IngredientUnit.model_validate(unit_2_out.json()) + + list_item_1_data = create_item(shopping_list.id, unit_id=str(unit_1.id), note="mealie-food-1") + list_item_2_data = create_item(shopping_list.id, unit_id=str(unit_2.id), note="mealie-food-2") + response = api_client.post( + api_routes.households_shopping_items_create_bulk, + json=[list_item_1_data, list_item_2_data], + headers=unique_user.token, + ) + + as_json = utils.assert_deserialize(response, 201) + assert len(as_json["createdItems"]) == 2 + for in_data, out_data in zip( + [list_item_1_data, list_item_2_data], [as_json["createdItems"][0], as_json["createdItems"][1]], strict=True + ): + assert in_data["quantity"] == out_data["quantity"] + assert out_data["unit"] + assert in_data["unit_id"] == out_data["unit"]["id"] + assert in_data["note"] == out_data["note"] + + +def test_shopping_list_merge_standard_unit_incompatible_units( + api_client: TestClient, unique_user: TestUser, shopping_list: ShoppingListOut +): + unit_1_data = {"name": random_string(), "standardQuantity": 1, "standardUnit": "cup"} + unit_2_data = {"name": random_string(), "standardQuantity": 2, "standardUnit": "gram"} + unit_1_out = api_client.post(api_routes.units, json=unit_1_data, headers=unique_user.token) + unit_2_out = api_client.post(api_routes.units, json=unit_2_data, headers=unique_user.token) + + unit_1 = IngredientUnit.model_validate(unit_1_out.json()) + unit_2 = IngredientUnit.model_validate(unit_2_out.json()) + + list_item_1_data = create_item(shopping_list.id, unit_id=str(unit_1.id), note="mealie-food") + list_item_2_data = create_item(shopping_list.id, unit_id=str(unit_2.id), note="mealie-food") + response = api_client.post( + api_routes.households_shopping_items_create_bulk, + json=[list_item_1_data, list_item_2_data], + headers=unique_user.token, + ) + + as_json = utils.assert_deserialize(response, 201) + assert len(as_json["createdItems"]) == 2 + for in_data, out_data in zip( + [list_item_1_data, list_item_2_data], [as_json["createdItems"][0], as_json["createdItems"][1]], strict=True + ): + assert in_data["quantity"] == out_data["quantity"] + assert out_data["unit"] + assert in_data["unit_id"] == out_data["unit"]["id"] + assert in_data["note"] == out_data["note"] + + def test_shopping_list_item_extras( api_client: TestClient, unique_user: TestUser, shopping_list: ShoppingListOut ) -> None: diff --git a/tests/unit_tests/ingredient_parser/test_unit_utils.py b/tests/unit_tests/ingredient_parser/test_unit_utils.py new file mode 100644 index 000000000..2c81f43c8 --- /dev/null +++ b/tests/unit_tests/ingredient_parser/test_unit_utils.py @@ -0,0 +1,309 @@ +import pint +import pytest + +from mealie.schema.recipe.recipe_ingredient import CreateIngredientUnit +from mealie.services.parser_services.parser_utils import UnitConverter, UnitNotFound, merge_quantity_and_unit +from tests.utils import random_string + + +def test_uc_parse_string(): + uc = UnitConverter() + parsed = uc.parse("cup") + + assert isinstance(parsed, pint.Unit) + assert (str(parsed)) == "cup" + + +def test_uc_parse_unit(): + uc = UnitConverter() + parsed = uc.parse(uc.parse("cup")) + + assert isinstance(parsed, pint.Unit) + assert (str(parsed)) == "cup" + + +def test_uc_parse_invalid(): + uc = UnitConverter() + input_str = random_string() + parsed = uc.parse(input_str) + + assert not isinstance(parsed, pint.Unit) + assert parsed == input_str + + +def test_uc_parse_invalid_strict(): + uc = UnitConverter() + input_str = random_string() + + with pytest.raises(UnitNotFound): + uc.parse(input_str, strict=True) + + +@pytest.mark.parametrize("pre_parse_1", [True, False]) +@pytest.mark.parametrize("pre_parse_2", [True, False]) +def test_can_convert(pre_parse_1: bool, pre_parse_2: bool): + unit_1 = "cup" + unit_2 = "pint" + + uc = UnitConverter() + if pre_parse_1: + unit_1 = uc.parse(unit_1) + if pre_parse_2: + unit_2 = uc.parse(unit_2) + + assert uc.can_convert(unit_1, unit_2) + + +@pytest.mark.parametrize("pre_parse_1", [True, False]) +@pytest.mark.parametrize("pre_parse_2", [True, False]) +def test_cannot_convert(pre_parse_1: bool, pre_parse_2: bool): + unit_1 = "cup" + unit_2 = "pound" + + uc = UnitConverter() + if pre_parse_1: + unit_1 = uc.parse(unit_1) + if pre_parse_2: + unit_2 = uc.parse(unit_2) + + assert not uc.can_convert(unit_1, unit_2) + + +def test_cannot_convert_invalid_unit(): + uc = UnitConverter() + assert not uc.can_convert("cup", random_string()) + assert not uc.can_convert(random_string(), "cup") + + +def test_can_convert_same_unit(): + uc = UnitConverter() + assert uc.can_convert("cup", "cup") + + +def test_can_convert_volume_ounce(): + uc = UnitConverter() + assert uc.can_convert("ounce", "cup") + assert uc.can_convert("cup", "ounce") + + +def test_convert_simple(): + uc = UnitConverter() + quantity, unit = uc.convert(1, "cup", "pint") + + assert isinstance(unit, pint.Unit) + assert str(unit) == "pint" + assert quantity == 1 / 2 + + +@pytest.mark.parametrize("pre_parse_1", [True, False]) +@pytest.mark.parametrize("pre_parse_2", [True, False]) +def test_convert_pre_parsed(pre_parse_1: bool, pre_parse_2: bool): + unit_1 = "cup" + unit_2 = "pint" + + uc = UnitConverter() + if pre_parse_1: + unit_1 = uc.parse(unit_1) + if pre_parse_2: + unit_2 = uc.parse(unit_2) + + quantity, unit = uc.convert(1, unit_1, unit_2) + assert isinstance(unit, pint.Unit) + assert str(unit) == "pint" + assert quantity == 1 / 2 + + +def test_convert_weight(): + uc = UnitConverter() + quantity, unit = uc.convert(16, "ounce", "pound") + + assert isinstance(unit, pint.Unit) + assert str(unit) == "pound" + assert quantity == 1 + + +def test_convert_zero_quantity(): + uc = UnitConverter() + quantity, unit = uc.convert(0, "cup", "pint") + + assert isinstance(unit, pint.Unit) + assert quantity == 0 + + +def test_convert_invalid_unit(): + uc = UnitConverter() + + with pytest.raises(UnitNotFound): + uc.convert(1, "pound", random_string()) + + +def test_convert_incompatible_units(): + uc = UnitConverter() + + with pytest.raises(pint.errors.DimensionalityError): + uc.convert(1, "pound", "cup") + + +def test_convert_volume_ounce(): + uc = UnitConverter() + quantity, unit = uc.convert(8, "ounce", "cup") + + assert isinstance(unit, pint.Unit) + assert str(unit) == "cup" + assert quantity == 1 + + +def test_merge_same_unit(): + uc = UnitConverter() + quantity, unit = uc.merge(1, "cup", 2, "cup") + + assert isinstance(unit, pint.Unit) + assert str(unit) == "cup" + assert quantity == 3 + + +@pytest.mark.parametrize("pre_parse_1", [True, False]) +@pytest.mark.parametrize("pre_parse_2", [True, False]) +def test_merge_compatible_units(pre_parse_1: bool, pre_parse_2: bool): + unit_1 = "cup" + unit_2 = "pint" + + uc = UnitConverter() + if pre_parse_1: + unit_1 = uc.parse(unit_1) + if pre_parse_2: + unit_2 = uc.parse(unit_2) + + quantity, unit = uc.merge(1, unit_1, 1, unit_2) + assert isinstance(unit, pint.Unit) + # 1 cup + 1 pint = 1 cup + 2 cups = 3 cups + assert quantity == 3 + + +def test_merge_weight_units(): + uc = UnitConverter() + quantity, unit = uc.merge(8, "ounce", 8, "ounce") + + assert isinstance(unit, pint.Unit) + assert str(unit) == "ounce" + assert quantity == 16 + + +def test_merge_different_weight_units(): + uc = UnitConverter() + quantity, unit = uc.merge(1, "pound", 8, "ounce") + + assert isinstance(unit, pint.Unit) + # 1 pound + 8 ounces = 16 ounces + 8 ounces = 24 ounces + assert str(unit) == "pound" + assert quantity == 1.5 + + +def test_merge_zero_quantities(): + uc = UnitConverter() + quantity, unit = uc.merge(0, "cup", 1, "cup") + + assert isinstance(unit, pint.Unit) + assert str(unit) == "cup" + assert quantity == 1 + + +def test_merge_invalid_unit(): + uc = UnitConverter() + + with pytest.raises(UnitNotFound): + uc.merge(1, "pound", 1, random_string()) + + +def test_merge_incompatible_units(): + uc = UnitConverter() + + with pytest.raises(pint.errors.DimensionalityError): + uc.merge(1, "pound", 1, "cup") + + +def test_merge_negative_quantity(): + uc = UnitConverter() + quantity, unit = uc.merge(-1, "cup", 2, "cup") + + assert isinstance(unit, pint.Unit) + assert str(unit) == "cup" + assert quantity == 1 + + +def test_merge_volume_ounce(): + uc = UnitConverter() + quantity, unit = uc.merge(4, "ounce", 1, "cup") + + assert isinstance(unit, pint.Unit) + assert str(unit) == "fluid_ounce" # converted automatically from ounce + assert quantity == 12 + + +def test_merge_quantity_and_unit_simple(): + unit_1 = CreateIngredientUnit(name="mealie_cup", standard_quantity=1, standard_unit="cup") + unit_2 = CreateIngredientUnit(name="mealie_cup", standard_quantity=1, standard_unit="cup") + + quantity, unit = merge_quantity_and_unit(1, unit_1, 2, unit_2) + + assert quantity == 3 + assert unit.name == "mealie_cup" + + +def test_merge_quantity_and_unit_invalid(): + unit_1 = CreateIngredientUnit(name="mealie_cup", standard_quantity=1, standard_unit="cup") + unit_2 = CreateIngredientUnit(name="mealie_random", standard_quantity=1, standard_unit=random_string()) + + with pytest.raises(UnitNotFound): + merge_quantity_and_unit(1, unit_1, 1, unit_2) + + +def test_merge_quantity_and_unit_compatible(): + unit_1 = CreateIngredientUnit(name="mealie_pint", standard_quantity=1, standard_unit="pint") + unit_2 = CreateIngredientUnit(name="mealie_cup", standard_quantity=1, standard_unit="cup") + + quantity, unit = merge_quantity_and_unit(1, unit_1, 1, unit_2) + + # 1 pint + 1 cup = 2 pints + 1 cup = 3 cups, converted to pint = 1.5 pint + assert quantity == 1.5 + assert unit.name == "mealie_pint" + + +def test_merge_quantity_and_unit_selects_larger_unit(): + unit_1 = CreateIngredientUnit(name="mealie_pint", standard_quantity=1, standard_unit="pint") + unit_2 = CreateIngredientUnit(name="mealie_cup", standard_quantity=1, standard_unit="cup") + + quantity, unit = merge_quantity_and_unit(2, unit_1, 4, unit_2) + + # 2 pint + 4 cup = 4 cups + 4 cups = 8 cups, should be returned as pint (larger unit) + assert quantity == 4 + assert unit.name == "mealie_pint" + + +def test_merge_quantity_and_unit_selects_smaller_unit(): + unit_1 = CreateIngredientUnit(name="mealie_pint", standard_quantity=1, standard_unit="pint") + unit_2 = CreateIngredientUnit(name="mealie_cup", standard_quantity=1, standard_unit="cup") + + quantity, unit = merge_quantity_and_unit(0.125, unit_1, 0.5, unit_2) + + # 0.125 pint + 0.5 cup = 0.25 cup + 0.5 cup = 0.75 cup, should be returned as cup (smaller for < 1) + assert quantity == 0.75 + assert unit.name == "mealie_cup" + + +def test_merge_quantity_and_unit_missing_standard_data(): + unit_1 = CreateIngredientUnit(name="mealie_cup", standard_quantity=1, standard_unit="cup") + unit_2 = CreateIngredientUnit(name="mealie_cup_no_std", standard_quantity=None, standard_unit=None) + + with pytest.raises(ValueError): + merge_quantity_and_unit(1, unit_1, 1, unit_2) + + +def test_merge_quantity_and_unit_volume_ounce(): + unit_1 = CreateIngredientUnit(name="mealie_oz", standard_quantity=1, standard_unit="ounce") + unit_2 = CreateIngredientUnit(name="mealie_cup", standard_quantity=1, standard_unit="cup") + + quantity, unit = merge_quantity_and_unit(8, unit_1, 1, unit_2) + + assert quantity == 2 + assert unit.name == "mealie_cup" diff --git a/tests/unit_tests/repository_tests/test_unit_repository.py b/tests/unit_tests/repository_tests/test_unit_repository.py index d2a0b7dfe..e48168131 100644 --- a/tests/unit_tests/repository_tests/test_unit_repository.py +++ b/tests/unit_tests/repository_tests/test_unit_repository.py @@ -1,11 +1,26 @@ from uuid import UUID +import pytest +from sqlalchemy.orm import Session + +from mealie.repos.all_repositories import AllRepositories, get_repositories from mealie.schema.recipe.recipe import Recipe from mealie.schema.recipe.recipe_ingredient import RecipeIngredient, SaveIngredientUnit -from tests.utils.factories import random_string +from mealie.schema.user.user import GroupBase +from tests.utils.factories import random_int, random_string from tests.utils.fixture_schemas import TestUser +@pytest.fixture() +def unique_local_group_id(unfiltered_database: AllRepositories) -> str: + return str(unfiltered_database.groups.create(GroupBase(name=random_string())).id) + + +@pytest.fixture() +def unique_db(session: Session, unique_local_group_id: str) -> AllRepositories: + return get_repositories(session, group_id=unique_local_group_id) + + def test_unit_merger(unique_user: TestUser): database = unique_user.repos recipe: Recipe | None = None @@ -51,3 +66,79 @@ def test_unit_merger(unique_user: TestUser): for ingredient in recipe.recipe_ingredient: assert ingredient.unit.id == unit_1.id # type: ignore + + +@pytest.mark.parametrize("standard_field", ["name", "plural_name", "abbreviation", "plural_abbreviation"]) +@pytest.mark.parametrize("use_bulk", [True, False]) +def test_auto_inject_standardization(unique_db: AllRepositories, standard_field: str, use_bulk: bool): + unit_in = SaveIngredientUnit(name=random_string(), group_id=unique_db.group_id).model_dump() + unit_in[standard_field] = "gallon" + + if use_bulk: + out_many = unique_db.ingredient_units.create_many([unit_in]) + assert len(out_many) == 1 + unit_out = out_many[0] + else: + unit_out = unique_db.ingredient_units.create(unit_in) + + assert unit_out.standard_unit == "cup" + assert unit_out.standard_quantity == 16 + + +def test_dont_auto_inject_random(unique_db: AllRepositories): + unit_in = SaveIngredientUnit(name=random_string(), group_id=unique_db.group_id) + unit_out = unique_db.ingredient_units.create(unit_in) + + assert unit_out.standard_quantity is None + assert unit_out.standard_unit is None + + +def test_auto_inject_other_language(unique_db: AllRepositories): + # Inject custom unit map + GALLON = random_string() + unique_db.ingredient_units._standardized_unit_map = {GALLON: "gallon"} + + # Create unit with translated value + unit_in = SaveIngredientUnit(name=GALLON, group_id=unique_db.group_id) + unit_out = unique_db.ingredient_units.create(unit_in) + + assert unit_out.standard_unit == "cup" + assert unit_out.standard_quantity == 16 + + +@pytest.mark.parametrize("name", ["custom-mealie-unit", "gallon"]) +def test_user_standardization(unique_db: AllRepositories, name: str): + unit_in = SaveIngredientUnit( + name=name, + group_id=unique_db.group_id, + standard_quantity=random_int(1, 10), + standard_unit=random_string(), + ) + unit_out = unique_db.ingredient_units.create(unit_in) + + assert unit_out.standard_quantity == unit_in.standard_quantity + assert unit_out.standard_unit == unit_in.standard_unit + + +def test_ignore_incomplete_standardization(unique_db: AllRepositories): + unit_in = SaveIngredientUnit( + name=random_string(), + group_id=unique_db.group_id, + standard_quantity=random_int(1, 10), + standard_unit=None, + ) + unit_out = unique_db.ingredient_units.create(unit_in) + + assert unit_out.standard_quantity is None + assert unit_out.standard_unit is None + + unit_in = SaveIngredientUnit( + name=random_string(), + group_id=unique_db.group_id, + standard_quantity=None, + standard_unit=random_string(), + ) + unit_out = unique_db.ingredient_units.create(unit_in) + + assert unit_out.standard_quantity is None + assert unit_out.standard_unit is None diff --git a/tests/unit_tests/services_tests/backup_v2_tests/test_backup_v2.py b/tests/unit_tests/services_tests/backup_v2_tests/test_backup_v2.py index 754b632fa..22f25d33d 100644 --- a/tests/unit_tests/services_tests/backup_v2_tests/test_backup_v2.py +++ b/tests/unit_tests/services_tests/backup_v2_tests/test_backup_v2.py @@ -217,6 +217,22 @@ def _b9e516e2d3b3_add_household_to_recipe_last_made_household_to_foods_and_tools assert not tool.households_with_tool +def _a39c7f1826e3_add_unit_standardization_fields(session: Session): + groups = session.query(Group).all() + + for group in groups: + # test_data.backup_version_1d9a002d7234_1 has a non-anonymized "pint" unit + # and has not yet run the standardization migration. + pint_units = ( + session.query(IngredientUnitModel) + .filter(IngredientUnitModel.group_id == group.id, IngredientUnitModel.name == "pint") + .all() + ) + for unit in pint_units: + assert unit.standard_quantity == 2 + assert unit.standard_unit == "cup" + + def test_database_restore_data(): """ This tests real user backups to make sure the data is restored correctly. The data has been anonymized, but @@ -227,6 +243,7 @@ def test_database_restore_data(): """ backup_paths = [ + test_data.backup_version_1d9a002d7234_1, test_data.backup_version_44e8d670719d_1, test_data.backup_version_44e8d670719d_2, test_data.backup_version_44e8d670719d_3, @@ -245,6 +262,7 @@ def test_database_restore_data(): _d7c6efd2de42_migrate_favorites_and_ratings_to_user_ratings, _86054b40fd06_added_query_filter_string_to_cookbook_and_mealplan, _b9e516e2d3b3_add_household_to_recipe_last_made_household_to_foods_and_tools, + _a39c7f1826e3_add_unit_standardization_fields, ] settings = get_app_settings() diff --git a/uv.lock b/uv.lock index fd203f486..265499e65 100644 --- a/uv.lock +++ b/uv.lock @@ -850,6 +850,7 @@ dependencies = [ { name = "paho-mqtt" }, { name = "pillow" }, { name = "pillow-heif" }, + { name = "pint" }, { name = "pydantic" }, { name = "pydantic-settings" }, { name = "pyhumps" }, @@ -923,6 +924,7 @@ requires-dist = [ { name = "paho-mqtt", specifier = "==1.6.1" }, { name = "pillow", specifier = "==12.1.1" }, { name = "pillow-heif", specifier = "==1.3.0" }, + { name = "pint", specifier = ">=0.25" }, { name = "psycopg2-binary", marker = "extra == 'pgsql'", specifier = "==2.9.11" }, { name = "pydantic", specifier = "==2.12.5" }, { name = "pydantic-settings", specifier = "==2.13.1" },