feat: Unit standardization / conversion (#7121)

This commit is contained in:
Michael Genson
2026-03-09 12:13:41 -05:00
committed by GitHub
parent 96597915ff
commit b5c089f58c
30 changed files with 1203 additions and 86 deletions

View File

@@ -0,0 +1,106 @@
"""add unit standardization fields
Revision ID: a39c7f1826e3
Revises: 1d9a002d7234
Create Date: 2026-02-21 17:59:01.161812
"""
import sqlalchemy as sa
from sqlalchemy import orm
from alembic import op
from mealie.repos.repository_units import RepositoryUnit
from mealie.core.root_logger import get_logger
from mealie.db.models._model_utils.guid import GUID
from mealie.repos.seed.seeders import IngredientUnitsSeeder
from mealie.lang.locale_config import LOCALE_CONFIG
# revision identifiers, used by Alembic.
revision = "a39c7f1826e3"
down_revision: str | None = "1d9a002d7234"
branch_labels: str | tuple[str, ...] | None = None
depends_on: str | tuple[str, ...] | None = None
logger = get_logger()
class SqlAlchemyBase(orm.DeclarativeBase): ...
class IngredientUnitModel(SqlAlchemyBase):
__tablename__ = "ingredient_units"
id: orm.Mapped[GUID] = orm.mapped_column(GUID, primary_key=True, default=GUID.generate)
name: orm.Mapped[str | None] = orm.mapped_column(sa.String)
plural_name: orm.Mapped[str | None] = orm.mapped_column(sa.String)
abbreviation: orm.Mapped[str | None] = orm.mapped_column(sa.String)
plural_abbreviation: orm.Mapped[str | None] = orm.mapped_column(sa.String)
standard_quantity: orm.Mapped[float | None] = orm.mapped_column(sa.Float)
standard_unit: orm.Mapped[str | None] = orm.mapped_column(sa.String)
def populate_standards() -> None:
bind = op.get_bind()
session = orm.Session(bind)
# We aren't using most of the functionality of this class, so we pass dummy args
repo = RepositoryUnit(None, None, None, None, group_id=None) # type: ignore
stmt = sa.select(IngredientUnitModel)
units = session.execute(stmt).scalars().all()
if not units:
return
# Manually build repo._standardized_unit_map with all locales
repo._standardized_unit_map = {}
for locale in LOCALE_CONFIG:
locale_file = IngredientUnitsSeeder.get_file(locale)
for unit_key, unit in IngredientUnitsSeeder.load_file(locale_file).items():
for prop in ["name", "plural_name", "abbreviation"]:
val = unit.get(prop)
if val and isinstance(val, str):
repo._standardized_unit_map[val.strip().lower()] = unit_key
for unit in units:
unit_data = {
"name": unit.name,
"plural_name": unit.plural_name,
"abbreviation": unit.abbreviation,
"plural_abbreviation": unit.plural_abbreviation,
}
standardized_data = repo._add_standardized_unit(unit_data)
std_q = standardized_data.get("standard_quantity")
std_u = standardized_data.get("standard_unit")
if std_q and std_u:
logger.info(f"Found unit '{unit.name}', which is standardized as '{std_q} * {std_u}'")
unit.standard_quantity = std_q
unit.standard_unit = std_u
session.commit()
session.close()
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("ingredient_units", schema=None) as batch_op:
batch_op.add_column(sa.Column("standard_quantity", sa.Float(), nullable=True))
batch_op.add_column(sa.Column("standard_unit", sa.String(), nullable=True))
# ### end Alembic commands ###
# Populate standardized units for existing records
try:
populate_standards()
except Exception:
logger.exception("Failed to populate unit standards, skipping...")
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("ingredient_units", schema=None) as batch_op:
batch_op.drop_column("standard_unit")
batch_op.drop_column("standard_quantity")
# ### end Alembic commands ###

View File

@@ -52,6 +52,10 @@ class IngredientUnitModel(SqlAlchemyBase, BaseMixins):
cascade="all, delete, delete-orphan",
)
# Standardization
standard_quantity: Mapped[float | None] = mapped_column(Float)
standard_unit: Mapped[str | None] = mapped_column(String)
# Automatically updated by sqlalchemy event, do not write to this manually
name_normalized: Mapped[str | None] = mapped_column(sa.String, index=True)
plural_name_normalized: Mapped[str | None] = mapped_column(sa.String, index=True)

View File

@@ -15,52 +15,63 @@ class LocalePluralFoodHandling(StrEnum):
@dataclass
class LocaleConfig:
key: str
name: str
dir: LocaleTextDirection = LocaleTextDirection.LTR
plural_food_handling: LocalePluralFoodHandling = LocalePluralFoodHandling.ALWAYS
LOCALE_CONFIG: dict[str, LocaleConfig] = {
"af-ZA": LocaleConfig(name="Afrikaans (Afrikaans)"),
"ar-SA": LocaleConfig(name="العربية (Arabic)", dir=LocaleTextDirection.RTL),
"bg-BG": LocaleConfig(name="Български (Bulgarian)"),
"ca-ES": LocaleConfig(name="Català (Catalan)"),
"cs-CZ": LocaleConfig(name="Čeština (Czech)"),
"da-DK": LocaleConfig(name="Dansk (Danish)"),
"de-DE": LocaleConfig(name="Deutsch (German)"),
"el-GR": LocaleConfig(name="Ελληνικά (Greek)"),
"en-GB": LocaleConfig(name="British English", plural_food_handling=LocalePluralFoodHandling.WITHOUT_UNIT),
"en-US": LocaleConfig(name="American English", plural_food_handling=LocalePluralFoodHandling.WITHOUT_UNIT),
"es-ES": LocaleConfig(name="Español (Spanish)"),
"et-EE": LocaleConfig(name="Eesti (Estonian)"),
"fi-FI": LocaleConfig(name="Suomi (Finnish)"),
"fr-BE": LocaleConfig(name="Belge (Belgian)"),
"fr-CA": LocaleConfig(name="Français canadien (Canadian French)"),
"fr-FR": LocaleConfig(name="Français (French)"),
"gl-ES": LocaleConfig(name="Galego (Galician)"),
"he-IL": LocaleConfig(name="עברית (Hebrew)", dir=LocaleTextDirection.RTL),
"hr-HR": LocaleConfig(name="Hrvatski (Croatian)"),
"hu-HU": LocaleConfig(name="Magyar (Hungarian)"),
"is-IS": LocaleConfig(name="Íslenska (Icelandic)"),
"it-IT": LocaleConfig(name="Italiano (Italian)"),
"ja-JP": LocaleConfig(name="日本語 (Japanese)", plural_food_handling=LocalePluralFoodHandling.NEVER),
"ko-KR": LocaleConfig(name="한국어 (Korean)", plural_food_handling=LocalePluralFoodHandling.NEVER),
"lt-LT": LocaleConfig(name="Lietuvių (Lithuanian)"),
"lv-LV": LocaleConfig(name="Latviešu (Latvian)"),
"nl-NL": LocaleConfig(name="Nederlands (Dutch)"),
"no-NO": LocaleConfig(name="Norsk (Norwegian)"),
"pl-PL": LocaleConfig(name="Polski (Polish)"),
"pt-BR": LocaleConfig(name="Português do Brasil (Brazilian Portuguese)"),
"pt-PT": LocaleConfig(name="Português (Portuguese)"),
"ro-RO": LocaleConfig(name="Română (Romanian)"),
"ru-RU": LocaleConfig(name="Pусский (Russian)"),
"sk-SK": LocaleConfig(name="Slovenčina (Slovak)"),
"sl-SI": LocaleConfig(name="Slovenščina (Slovenian)"),
"sr-SP": LocaleConfig(name="српски (Serbian)"),
"sv-SE": LocaleConfig(name="Svenska (Swedish)"),
"tr-TR": LocaleConfig(name="Türkçe (Turkish)", plural_food_handling=LocalePluralFoodHandling.NEVER),
"uk-UA": LocaleConfig(name="Українська (Ukrainian)"),
"vi-VN": LocaleConfig(name="Tiếng Việt (Vietnamese)", plural_food_handling=LocalePluralFoodHandling.NEVER),
"zh-CN": LocaleConfig(name="简体中文 (Chinese simplified)", plural_food_handling=LocalePluralFoodHandling.NEVER),
"zh-TW": LocaleConfig(name="繁體中文 (Chinese traditional)", plural_food_handling=LocalePluralFoodHandling.NEVER),
"af-ZA": LocaleConfig(key="af-ZA", name="Afrikaans (Afrikaans)"),
"ar-SA": LocaleConfig(key="ar-SA", name="العربية (Arabic)", dir=LocaleTextDirection.RTL),
"bg-BG": LocaleConfig(key="bg-BG", name="Български (Bulgarian)"),
"ca-ES": LocaleConfig(key="ca-ES", name="Català (Catalan)"),
"cs-CZ": LocaleConfig(key="cs-CZ", name="Čeština (Czech)"),
"da-DK": LocaleConfig(key="da-DK", name="Dansk (Danish)"),
"de-DE": LocaleConfig(key="de-DE", name="Deutsch (German)"),
"el-GR": LocaleConfig(key="el-GR", name="Ελληνικά (Greek)"),
"en-GB": LocaleConfig(
key="en-GB", name="British English", plural_food_handling=LocalePluralFoodHandling.WITHOUT_UNIT
),
"en-US": LocaleConfig(
key="en-US", name="American English", plural_food_handling=LocalePluralFoodHandling.WITHOUT_UNIT
),
"es-ES": LocaleConfig(key="es-ES", name="Español (Spanish)"),
"et-EE": LocaleConfig(key="et-EE", name="Eesti (Estonian)"),
"fi-FI": LocaleConfig(key="fi-FI", name="Suomi (Finnish)"),
"fr-BE": LocaleConfig(key="fr-BE", name="Belge (Belgian)"),
"fr-CA": LocaleConfig(key="fr-CA", name="Français canadien (Canadian French)"),
"fr-FR": LocaleConfig(key="fr-FR", name="Français (French)"),
"gl-ES": LocaleConfig(key="gl-ES", name="Galego (Galician)"),
"he-IL": LocaleConfig(key="he-IL", name="עברית (Hebrew)", dir=LocaleTextDirection.RTL),
"hr-HR": LocaleConfig(key="hr-HR", name="Hrvatski (Croatian)"),
"hu-HU": LocaleConfig(key="hu-HU", name="Magyar (Hungarian)"),
"is-IS": LocaleConfig(key="is-IS", name="Íslenska (Icelandic)"),
"it-IT": LocaleConfig(key="it-IT", name="Italiano (Italian)"),
"ja-JP": LocaleConfig(key="ja-JP", name="日本語 (Japanese)", plural_food_handling=LocalePluralFoodHandling.NEVER),
"ko-KR": LocaleConfig(key="ko-KR", name="한국어 (Korean)", plural_food_handling=LocalePluralFoodHandling.NEVER),
"lt-LT": LocaleConfig(key="lt-LT", name="Lietuvių (Lithuanian)"),
"lv-LV": LocaleConfig(key="lv-LV", name="Latviešu (Latvian)"),
"nl-NL": LocaleConfig(key="nl-NL", name="Nederlands (Dutch)"),
"no-NO": LocaleConfig(key="no-NO", name="Norsk (Norwegian)"),
"pl-PL": LocaleConfig(key="pl-PL", name="Polski (Polish)"),
"pt-BR": LocaleConfig(key="pt-BR", name="Português do Brasil (Brazilian Portuguese)"),
"pt-PT": LocaleConfig(key="pt-PT", name="Português (Portuguese)"),
"ro-RO": LocaleConfig(key="ro-RO", name="Română (Romanian)"),
"ru-RU": LocaleConfig(key="ru-RU", name="Pусский (Russian)"),
"sk-SK": LocaleConfig(key="sk-SK", name="Slovenčina (Slovak)"),
"sl-SI": LocaleConfig(key="sl-SI", name="Slovenščina (Slovenian)"),
"sr-SP": LocaleConfig(key="sr-SP", name="српски (Serbian)"),
"sv-SE": LocaleConfig(key="sv-SE", name="Svenska (Swedish)"),
"tr-TR": LocaleConfig(key="tr-TR", name="Türkçe (Turkish)", plural_food_handling=LocalePluralFoodHandling.NEVER),
"uk-UA": LocaleConfig(key="uk-UA", name="Українська (Ukrainian)"),
"vi-VN": LocaleConfig(
key="vi-VN", name="Tiếng Việt (Vietnamese)", plural_food_handling=LocalePluralFoodHandling.NEVER
),
"zh-CN": LocaleConfig(
key="zh-CN", name="简体中文 (Chinese simplified)", plural_food_handling=LocalePluralFoodHandling.NEVER
),
"zh-TW": LocaleConfig(
key="zh-TW", name="繁體中文 (Chinese traditional)", plural_food_handling=LocalePluralFoodHandling.NEVER
),
}

View File

@@ -1,17 +1,119 @@
from pydantic import UUID4
from collections.abc import Iterable
from pydantic import UUID4, BaseModel
from sqlalchemy import select
from mealie.db.models.recipe.ingredient import IngredientUnitModel
from mealie.schema.recipe.recipe_ingredient import IngredientUnit
from mealie.lang.providers import get_locale_context
from mealie.schema.recipe.recipe_ingredient import IngredientUnit, StandardizedUnitType
from .repository_generic import GroupRepositoryGeneric
class RepositoryUnit(GroupRepositoryGeneric[IngredientUnit, IngredientUnitModel]):
_standardized_unit_map: dict[str, str] | None = None
@property
def standardized_unit_map(self) -> dict[str, str]:
"""A map of potential known units to its standardized name in our seed data"""
if self._standardized_unit_map is None:
from .seed.seeders import IngredientUnitsSeeder
ctx = get_locale_context()
if ctx:
locale = ctx[1].key
else:
locale = None
self._standardized_unit_map = {}
locale_file = IngredientUnitsSeeder.get_file(locale=locale)
for unit_key, unit in IngredientUnitsSeeder.load_file(locale_file).items():
for prop in ["name", "plural_name", "abbreviation"]:
val = unit.get(prop)
if val and isinstance(val, str):
self._standardized_unit_map[val.strip().lower()] = unit_key
return self._standardized_unit_map
def _get_unit(self, id: UUID4) -> IngredientUnitModel:
stmt = select(self.model).filter_by(**self._filter_builder(**{"id": id}))
return self.session.execute(stmt).scalars().one()
def _add_standardized_unit(self, data: BaseModel | dict) -> dict:
if not isinstance(data, dict):
data = data.model_dump()
# Don't overwrite user data if it exists
if data.get("standard_quantity") is not None or data.get("standard_unit") is not None:
return data
# Compare name attrs to translation files and see if there's a match to a known standard unit
for prop in ["name", "plural_name", "abbreviation", "plural_abbreviation"]:
val = data.get(prop)
if not (val and isinstance(val, str)):
continue
standardized_unit_key = self.standardized_unit_map.get(val.strip().lower())
if not standardized_unit_key:
continue
match standardized_unit_key:
case "teaspoon":
data["standard_quantity"] = 1 / 6
data["standard_unit"] = StandardizedUnitType.FLUID_OUNCE
case "tablespoon":
data["standard_quantity"] = 1 / 2
data["standard_unit"] = StandardizedUnitType.FLUID_OUNCE
case "cup":
data["standard_quantity"] = 1
data["standard_unit"] = StandardizedUnitType.CUP
case "fluid-ounce":
data["standard_quantity"] = 1
data["standard_unit"] = StandardizedUnitType.FLUID_OUNCE
case "pint":
data["standard_quantity"] = 2
data["standard_unit"] = StandardizedUnitType.CUP
case "quart":
data["standard_quantity"] = 4
data["standard_unit"] = StandardizedUnitType.CUP
case "gallon":
data["standard_quantity"] = 16
data["standard_unit"] = StandardizedUnitType.CUP
case "milliliter":
data["standard_quantity"] = 1
data["standard_unit"] = StandardizedUnitType.MILLILITER
case "liter":
data["standard_quantity"] = 1
data["standard_unit"] = StandardizedUnitType.LITER
case "pound":
data["standard_quantity"] = 1
data["standard_unit"] = StandardizedUnitType.POUND
case "ounce":
data["standard_quantity"] = 1
data["standard_unit"] = StandardizedUnitType.OUNCE
case "gram":
data["standard_quantity"] = 1
data["standard_unit"] = StandardizedUnitType.GRAM
case "kilogram":
data["standard_quantity"] = 1
data["standard_unit"] = StandardizedUnitType.KILOGRAM
case "milligram":
data["standard_quantity"] = 1 / 1000
data["standard_unit"] = StandardizedUnitType.GRAM
case _:
continue
return data
def create(self, data: IngredientUnit | dict) -> IngredientUnit:
data = self._add_standardized_unit(data)
return super().create(data)
def create_many(self, data: Iterable[IngredientUnit | dict]) -> list[IngredientUnit]:
data = [self._add_standardized_unit(i) for i in data]
return super().create_many(data)
def merge(self, from_unit: UUID4, to_unit: UUID4) -> IngredientUnit | None:
from_model = self._get_unit(from_unit)
to_model = self._get_unit(to_unit)

View File

@@ -1,3 +1,4 @@
import json
from abc import ABC, abstractmethod
from logging import Logger
from pathlib import Path
@@ -11,6 +12,8 @@ class AbstractSeeder(ABC):
Abstract class for seeding data.
"""
resources = Path(__file__).parent / "resources"
def __init__(self, db: AllRepositories, logger: Logger | None = None):
"""
Initialize the abstract seeder.
@@ -19,7 +22,14 @@ class AbstractSeeder(ABC):
"""
self.repos = db
self.logger = logger or get_logger("Data Seeder")
self.resources = Path(__file__).parent / "resources"
@classmethod
@abstractmethod
def get_file(self, locale: str | None = None) -> Path: ...
@classmethod
def load_file(self, file: Path) -> dict[str, dict]:
return json.loads(file.read_text(encoding="utf-8"))
@abstractmethod
def seed(self, locale: str | None = None) -> None: ...

View File

@@ -1,4 +1,3 @@
import json
import pathlib
from collections.abc import Generator
from functools import cached_property
@@ -21,9 +20,10 @@ class MultiPurposeLabelSeeder(AbstractSeeder):
def service(self):
return MultiPurposeLabelService(self.repos)
def get_file(self, locale: str | None = None) -> pathlib.Path:
@classmethod
def get_file(cls, locale: str | None = None) -> pathlib.Path:
# Get the labels from the foods seed file now
locale_path = self.resources / "foods" / "locales" / f"{locale}.json"
locale_path = cls.resources / "foods" / "locales" / f"{locale}.json"
return locale_path if locale_path.exists() else foods.en_US
def get_all_labels(self) -> list[MultiPurposeLabelOut]:
@@ -34,7 +34,7 @@ class MultiPurposeLabelSeeder(AbstractSeeder):
current_label_names = {label.name for label in self.get_all_labels()}
# load from the foods locale file and remove any empty strings
seed_label_names = set(filter(None, json.loads(file.read_text(encoding="utf-8")).keys())) # type: set[str]
seed_label_names = set(filter(None, self.load_file(file).keys())) # type: set[str]
# only seed new labels
to_seed_labels = seed_label_names - current_label_names
for label in to_seed_labels:
@@ -53,8 +53,9 @@ class MultiPurposeLabelSeeder(AbstractSeeder):
class IngredientUnitsSeeder(AbstractSeeder):
def get_file(self, locale: str | None = None) -> pathlib.Path:
locale_path = self.resources / "units" / "locales" / f"{locale}.json"
@classmethod
def get_file(cls, locale: str | None = None) -> pathlib.Path:
locale_path = cls.resources / "units" / "locales" / f"{locale}.json"
return locale_path if locale_path.exists() else units.en_US
def get_all_units(self) -> list[IngredientUnit]:
@@ -64,7 +65,7 @@ class IngredientUnitsSeeder(AbstractSeeder):
file = self.get_file(locale)
seen_unit_names = {unit.name for unit in self.get_all_units()}
for unit in json.loads(file.read_text(encoding="utf-8")).values():
for unit in self.load_file(file).values():
if unit["name"] in seen_unit_names:
continue
@@ -88,8 +89,9 @@ class IngredientUnitsSeeder(AbstractSeeder):
class IngredientFoodsSeeder(AbstractSeeder):
def get_file(self, locale: str | None = None) -> pathlib.Path:
locale_path = self.resources / "foods" / "locales" / f"{locale}.json"
@classmethod
def get_file(cls, locale: str | None = None) -> pathlib.Path:
locale_path = cls.resources / "foods" / "locales" / f"{locale}.json"
return locale_path if locale_path.exists() else foods.en_US
def get_label(self, value: str) -> MultiPurposeLabelOut | None:
@@ -103,7 +105,7 @@ class IngredientFoodsSeeder(AbstractSeeder):
# get all current unique foods
seen_foods_names = {food.name for food in self.get_all_foods()}
for label, values in json.loads(file.read_text(encoding="utf-8")).items():
for label, values in self.load_file(file).items():
label_out = self.get_label(label)
for food_name, attributes in values["foods"].items():

View File

@@ -67,6 +67,7 @@ from .recipe_ingredient import (
RegisteredParser,
SaveIngredientFood,
SaveIngredientUnit,
StandardizedUnitType,
UnitFoodBase,
)
from .recipe_notes import RecipeNote
@@ -159,6 +160,7 @@ __all__ = [
"RegisteredParser",
"SaveIngredientFood",
"SaveIngredientUnit",
"StandardizedUnitType",
"UnitFoodBase",
"RecipeSuggestionQuery",
"RecipeSuggestionResponse",

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import datetime
import enum
from enum import StrEnum
from fractions import Fraction
from typing import ClassVar
from uuid import UUID, uuid4
@@ -34,6 +35,28 @@ def display_fraction(fraction: Fraction):
)
class StandardizedUnitType(StrEnum):
"""
An arbitrary list of standardized units supported by unit conversions.
The backend doesn't really care what standardized unit you use, as long as it's recognized,
but defining them here keeps it consistant with the frontend.
"""
# Imperial
FLUID_OUNCE = "fluid_ounce"
CUP = "cup"
OUNCE = "ounce"
POUND = "pound"
# Metric
MILLILITER = "milliliter"
LITER = "liter"
GRAM = "gram"
KILOGRAM = "kilogram"
class UnitFoodBase(MealieModel):
id: UUID4 | None = None
name: str
@@ -109,9 +132,6 @@ class IngredientFood(CreateIngredientFood):
except AttributeError:
return v
def is_on_hand(self, household_slug: str) -> bool:
return household_slug in self.households_with_tool
class IngredientFoodPagination(PaginationBase):
items: list[IngredientFood]
@@ -130,7 +150,21 @@ class CreateIngredientUnit(UnitFoodBase):
abbreviation: str = ""
plural_abbreviation: str | None = ""
use_abbreviation: bool = False
aliases: list[CreateIngredientUnitAlias] = []
standard_quantity: float | None = None
standard_unit: str | None = None
@model_validator(mode="after")
def validate_standardization_fields(self):
# If one is set, the other must be set.
# If quantity is <= 0, it's considered not set.
if not self.standard_unit:
self.standard_quantity = self.standard_unit = None
elif not ((self.standard_quantity or 0) > 0):
self.standard_quantity = self.standard_unit = None
return self
class SaveIngredientUnit(CreateIngredientUnit):

View File

@@ -32,9 +32,6 @@ class RecipeToolOut(RecipeToolCreate):
except AttributeError:
return v
def is_on_hand(self, household_slug: str) -> bool:
return household_slug in self.households_with_tool
@classmethod
def loader_options(cls) -> list[LoaderOption]:
return [

View File

@@ -28,6 +28,7 @@ from mealie.schema.recipe.recipe_ingredient import (
)
from mealie.schema.response.pagination import OrderDirection, PaginationQuery
from mealie.services.parser_services._base import DataMatcher
from mealie.services.parser_services.parser_utils import UnitConverter, merge_quantity_and_unit
class ShoppingListService:
@@ -41,8 +42,7 @@ class ShoppingListService:
self.list_refs = repos.group_shopping_list_recipe_refs
self.data_matcher = DataMatcher(self.repos, food_fuzzy_match_threshold=self.DEFAULT_FOOD_FUZZY_MATCH_THRESHOLD)
@staticmethod
def can_merge(item1: ShoppingListItemBase, item2: ShoppingListItemBase) -> bool:
def can_merge(self, item1: ShoppingListItemBase, item2: ShoppingListItemBase) -> bool:
"""Check to see if this item can be merged with another item"""
if any(
@@ -50,16 +50,28 @@ class ShoppingListService:
item1.checked,
item2.checked,
item1.food_id != item2.food_id,
item1.unit_id != item2.unit_id,
]
):
return False
# check if units match or if they're compatable
if item1.unit_id != item2.unit_id:
item1_unit = item1.unit or self.data_matcher.units_by_id.get(item1.unit_id)
item2_unit = item2.unit or self.data_matcher.units_by_id.get(item2.unit_id)
if not (item1_unit and item1_unit.standard_unit):
return False
if not (item2_unit and item2_unit.standard_unit):
return False
uc = UnitConverter()
if not uc.can_convert(item1_unit.standard_unit, item2_unit.standard_unit):
return False
# if foods match, we can merge, otherwise compare the notes
return bool(item1.food_id) or item1.note == item2.note
@staticmethod
def merge_items(
self,
from_item: ShoppingListItemCreate | ShoppingListItemUpdateBulk,
to_item: ShoppingListItemCreate | ShoppingListItemUpdateBulk | ShoppingListItemOut,
) -> ShoppingListItemUpdate:
@@ -69,7 +81,20 @@ class ShoppingListService:
Attributes of the `to_item` take priority over the `from_item`, except extras with overlapping keys
"""
to_item.quantity += from_item.quantity
to_item_unit = to_item.unit or self.data_matcher.units_by_id.get(to_item.unit_id)
from_item_unit = from_item.unit or self.data_matcher.units_by_id.get(from_item.unit_id)
if to_item_unit and to_item_unit.standard_unit and from_item_unit and from_item_unit.standard_unit:
merged_qty, merged_unit = merge_quantity_and_unit(
from_item.quantity or 0, from_item_unit, to_item.quantity or 0, to_item_unit
)
to_item.quantity = merged_qty
to_item.unit_id = merged_unit.id
to_item.unit = merged_unit
else:
# No conversion needed, just sum the quantities
to_item.quantity += from_item.quantity
if to_item.note != from_item.note:
to_item.note = " | ".join([note for note in [to_item.note, from_item.note] if note])

View File

@@ -29,18 +29,38 @@ class DataMatcher:
self._food_fuzzy_match_threshold = food_fuzzy_match_threshold
self._unit_fuzzy_match_threshold = unit_fuzzy_match_threshold
self._foods_by_id: dict[UUID4, IngredientFood] | None = None
self._units_by_id: dict[UUID4, IngredientUnit] | None = None
self._foods_by_alias: dict[str, IngredientFood] | None = None
self._units_by_alias: dict[str, IngredientUnit] | None = None
@property
def foods_by_alias(self) -> dict[str, IngredientFood]:
if self._foods_by_alias is None:
def foods_by_id(self) -> dict[UUID4, IngredientFood]:
if self._foods_by_id is None:
foods_repo = self.repos.ingredient_foods
query = PaginationQuery(page=1, per_page=-1)
all_foods = foods_repo.page_all(query).items
self._foods_by_id = {food.id: food for food in all_foods}
return self._foods_by_id
@property
def units_by_id(self) -> dict[UUID4, IngredientUnit]:
if self._units_by_id is None:
units_repo = self.repos.ingredient_units
query = PaginationQuery(page=1, per_page=-1)
all_units = units_repo.page_all(query).items
self._units_by_id = {unit.id: unit for unit in all_units}
return self._units_by_id
@property
def foods_by_alias(self) -> dict[str, IngredientFood]:
if self._foods_by_alias is None:
foods_by_alias: dict[str, IngredientFood] = {}
for food in all_foods:
for food in self.foods_by_id.values():
if food.name:
foods_by_alias[IngredientFoodModel.normalize(food.name)] = food
if food.plural_name:
@@ -57,12 +77,8 @@ class DataMatcher:
@property
def units_by_alias(self) -> dict[str, IngredientUnit]:
if self._units_by_alias is None:
units_repo = self.repos.ingredient_units
query = PaginationQuery(page=1, per_page=-1)
all_units = units_repo.page_all(query).items
units_by_alias: dict[str, IngredientUnit] = {}
for unit in all_units:
for unit in self.units_by_id.values():
if unit.name:
units_by_alias[IngredientUnitModel.normalize(unit.name)] = unit
if unit.plural_name:

View File

@@ -1 +1,2 @@
from .string_utils import *
from .unit_utils import *

View File

@@ -0,0 +1,146 @@
from typing import TYPE_CHECKING, Literal, overload
from pint import Quantity, Unit, UnitRegistry
if TYPE_CHECKING:
from mealie.schema.recipe.recipe_ingredient import CreateIngredientUnit
class UnitNotFound(Exception):
"""Raised when trying to access a unit not found in the unit registry."""
def __init__(self, message: str = "Unit not found in unit registry"):
self.message = message
super().__init__(self.message)
def __str__(self):
return f"{self.message}"
class UnitConverter:
def __init__(self):
self.ureg = UnitRegistry()
def _resolve_ounce(self, unit_1: Unit, unit_2: Unit) -> tuple[Unit, Unit]:
"""
Often times "ounce" is used in place of "fluid ounce" in recipes.
When trying to convert/combine ounces with a volume, we can assume it should have been a fluid ounce.
This function will convert ounces to fluid ounces if the other unit is a volume.
"""
OUNCE = self.ureg("ounce")
FL_OUNCE = self.ureg("fluid_ounce")
VOLUME = "[length] ** 3"
if unit_1 == OUNCE and unit_2.dimensionality == VOLUME:
return FL_OUNCE, unit_2
if unit_2 == OUNCE and unit_1.dimensionality == VOLUME:
return unit_1, FL_OUNCE
return unit_1, unit_2
@overload
def parse(self, unit: str | Unit, strict: Literal[False] = False) -> str | Unit: ...
@overload
def parse(self, unit: str | Unit, strict: Literal[True]) -> Unit: ...
def parse(self, unit: str | Unit, strict: bool = False) -> str | Unit:
"""
Parse a string unit into a pint.Unit.
If strict is False (default), returns a pint.Unit if it exists, otherwise returns the original string.
If strict is True, raises UnitNotFound instead of returning a string.
If the input is already a parsed pint.Unit, returns it as-is.
"""
if isinstance(unit, Unit):
return unit
try:
return self.ureg(unit).units
except Exception as e:
if strict:
raise UnitNotFound(f"Unit '{unit}' not found in unit registry") from e
return unit
def can_convert(self, unit: str | Unit, to_unit: str | Unit) -> bool:
"""Whether or not a given unit can be converted into another unit."""
unit = self.parse(unit)
to_unit = self.parse(to_unit)
if not (isinstance(unit, Unit) and isinstance(to_unit, Unit)):
return False
unit, to_unit = self._resolve_ounce(unit, to_unit)
return unit.is_compatible_with(to_unit)
def convert(self, quantity: float, unit: str | Unit, to_unit: str | Unit) -> tuple[float, Unit]:
"""
Convert a quantity and a unit into another unit.
Returns tuple[quantity, unit]
"""
unit = self.parse(unit, strict=True)
to_unit = self.parse(to_unit, strict=True)
unit, to_unit = self._resolve_ounce(unit, to_unit)
qty = quantity * unit
converted = qty.to(to_unit)
return float(converted.magnitude), converted.units
def merge(self, quantity_1: float, unit_1: str | Unit, quantity_2: float, unit_2: str | Unit) -> tuple[float, Unit]:
"""Merge two quantities together"""
unit_1 = self.parse(unit_1, strict=True)
unit_2 = self.parse(unit_2, strict=True)
unit_1, unit_2 = self._resolve_ounce(unit_1, unit_2)
q1 = quantity_1 * unit_1
q2 = quantity_2 * unit_2
out: Quantity = q1 + q2
return float(out.magnitude), out.units
def merge_quantity_and_unit[T: CreateIngredientUnit](
qty_1: float, unit_1: T, qty_2: float, unit_2: T
) -> tuple[float, T]:
"""
Merge a quantity and unit.
Returns tuple[quantity, unit]
"""
if not (unit_1.standard_quantity and unit_1.standard_unit and unit_2.standard_quantity and unit_2.standard_unit):
raise ValueError("Both units must contain standardized unit data")
PINT_UNIT_1_TXT = "_mealie_unit_1"
PINT_UNIT_2_TXT = "_mealie_unit_2"
uc = UnitConverter()
# pre-process units to account for ounce -> fluid_ounce conversion
unit_1_standard = uc.parse(unit_1.standard_unit, strict=True)
unit_2_standard = uc.parse(unit_2.standard_unit, strict=True)
unit_1_standard, unit_2_standard = uc._resolve_ounce(unit_1_standard, unit_2_standard)
# create custon unit definition so pint can handle them natively
uc.ureg.define(f"{PINT_UNIT_1_TXT} = {unit_1.standard_quantity} * {unit_1_standard}")
uc.ureg.define(f"{PINT_UNIT_2_TXT} = {unit_2.standard_quantity} * {unit_2_standard}")
pint_unit_1 = uc.parse(PINT_UNIT_1_TXT)
pint_unit_2 = uc.parse(PINT_UNIT_2_TXT)
merged_q, merged_u = uc.merge(qty_1, pint_unit_1, qty_2, pint_unit_2)
# Convert to the bigger unit if quantity >= 1, else the smaller unit
merged_q, merged_u = uc.convert(merged_q, merged_u, max(pint_unit_1, pint_unit_2))
if abs(merged_q) < 1:
merged_q, merged_u = uc.convert(merged_q, merged_u, min(pint_unit_1, pint_unit_2))
if str(merged_u) == PINT_UNIT_1_TXT:
return merged_q, unit_1
else:
return merged_q, unit_2