diff --git a/mealie/alembic/versions/2026-02-21-17.59.01_a39c7f1826e3_add_unit_standardization_fields.py b/mealie/alembic/versions/2026-02-21-17.59.01_a39c7f1826e3_add_unit_standardization_fields.py index 9bc93dee9..c0898392a 100644 --- a/mealie/alembic/versions/2026-02-21-17.59.01_a39c7f1826e3_add_unit_standardization_fields.py +++ b/mealie/alembic/versions/2026-02-21-17.59.01_a39c7f1826e3_add_unit_standardization_fields.py @@ -7,8 +7,13 @@ Create Date: 2026-02-21 17:59:01.161812 """ import sqlalchemy as sa +from sqlalchemy import orm from alembic import op - +from mealie.repos.repository_units import RepositoryUnit +from mealie.core.root_logger import get_logger +from mealie.db.models._model_utils.guid import GUID +from mealie.repos.seed.seeders import IngredientUnitsSeeder +from mealie.lang.locale_config import LOCALE_CONFIG # revision identifiers, used by Alembic. revision = "a39c7f1826e3" @@ -16,6 +21,66 @@ down_revision: str | None = "1d9a002d7234" branch_labels: str | tuple[str, ...] | None = None depends_on: str | tuple[str, ...] | None = None +logger = get_logger() + + +class SqlAlchemyBase(orm.DeclarativeBase): ... + + +class IngredientUnitModel(SqlAlchemyBase): + __tablename__ = "ingredient_units" + + id: orm.Mapped[GUID] = orm.mapped_column(GUID, primary_key=True, default=GUID.generate) + name: orm.Mapped[str | None] = orm.mapped_column(sa.String) + plural_name: orm.Mapped[str | None] = orm.mapped_column(sa.String) + abbreviation: orm.Mapped[str | None] = orm.mapped_column(sa.String) + plural_abbreviation: orm.Mapped[str | None] = orm.mapped_column(sa.String) + standard_quantity: orm.Mapped[float | None] = orm.mapped_column(sa.Float) + standard_unit: orm.Mapped[str | None] = orm.mapped_column(sa.String) + + +def populate_standards() -> None: + bind = op.get_bind() + + session = orm.Session(bind) + + # We aren't using most of the functionality of this class, so we pass dummy args + repo = RepositoryUnit(None, None, None, None, group_id=None) # type: ignore + + stmt = sa.select(IngredientUnitModel) + units = session.execute(stmt).scalars().all() + if not units: + return + + # Manually build repo._standardized_unit_map with all locales + repo._standardized_unit_map = {} + for locale in LOCALE_CONFIG: + locale_file = IngredientUnitsSeeder.get_file(locale) + for unit_key, unit in IngredientUnitsSeeder.load_file(locale_file).items(): + for prop in ["name", "plural_name", "abbreviation"]: + val = unit.get(prop) + if val and isinstance(val, str): + repo._standardized_unit_map[val.strip().lower()] = unit_key + + for unit in units: + unit_data = { + "name": unit.name, + "plural_name": unit.plural_name, + "abbreviation": unit.abbreviation, + "plural_abbreviation": unit.plural_abbreviation, + } + + standardized_data = repo._add_standardized_unit(unit_data) + std_q = standardized_data.get("standard_quantity") + std_u = standardized_data.get("standard_unit") + if std_q and std_u: + logger.info(f"Found unit '{unit.name}', which is standardized as '{std_q} * {std_u}'") + unit.standard_quantity = std_q + unit.standard_unit = std_u + + session.commit() + session.close() + def upgrade(): # ### commands auto generated by Alembic - please adjust! ### @@ -25,6 +90,12 @@ def upgrade(): # ### end Alembic commands ### + # Populate standardized units for existing records + try: + populate_standards() + except Exception: + logger.exception("Failed to populate unit standards, skipping...") + def downgrade(): # ### commands auto generated by Alembic - please adjust! ###