add auto-standardization to migration

This commit is contained in:
Michael Genson
2026-02-22 02:56:48 +00:00
parent fe9dadefea
commit 3af9b05bd8

View File

@@ -7,8 +7,13 @@ Create Date: 2026-02-21 17:59:01.161812
"""
import sqlalchemy as sa
from sqlalchemy import orm
from alembic import op
from mealie.repos.repository_units import RepositoryUnit
from mealie.core.root_logger import get_logger
from mealie.db.models._model_utils.guid import GUID
from mealie.repos.seed.seeders import IngredientUnitsSeeder
from mealie.lang.locale_config import LOCALE_CONFIG
# revision identifiers, used by Alembic.
revision = "a39c7f1826e3"
@@ -16,6 +21,66 @@ down_revision: str | None = "1d9a002d7234"
branch_labels: str | tuple[str, ...] | None = None
depends_on: str | tuple[str, ...] | None = None
logger = get_logger()
class SqlAlchemyBase(orm.DeclarativeBase): ...
class IngredientUnitModel(SqlAlchemyBase):
__tablename__ = "ingredient_units"
id: orm.Mapped[GUID] = orm.mapped_column(GUID, primary_key=True, default=GUID.generate)
name: orm.Mapped[str | None] = orm.mapped_column(sa.String)
plural_name: orm.Mapped[str | None] = orm.mapped_column(sa.String)
abbreviation: orm.Mapped[str | None] = orm.mapped_column(sa.String)
plural_abbreviation: orm.Mapped[str | None] = orm.mapped_column(sa.String)
standard_quantity: orm.Mapped[float | None] = orm.mapped_column(sa.Float)
standard_unit: orm.Mapped[str | None] = orm.mapped_column(sa.String)
def populate_standards() -> None:
bind = op.get_bind()
session = orm.Session(bind)
# We aren't using most of the functionality of this class, so we pass dummy args
repo = RepositoryUnit(None, None, None, None, group_id=None) # type: ignore
stmt = sa.select(IngredientUnitModel)
units = session.execute(stmt).scalars().all()
if not units:
return
# Manually build repo._standardized_unit_map with all locales
repo._standardized_unit_map = {}
for locale in LOCALE_CONFIG:
locale_file = IngredientUnitsSeeder.get_file(locale)
for unit_key, unit in IngredientUnitsSeeder.load_file(locale_file).items():
for prop in ["name", "plural_name", "abbreviation"]:
val = unit.get(prop)
if val and isinstance(val, str):
repo._standardized_unit_map[val.strip().lower()] = unit_key
for unit in units:
unit_data = {
"name": unit.name,
"plural_name": unit.plural_name,
"abbreviation": unit.abbreviation,
"plural_abbreviation": unit.plural_abbreviation,
}
standardized_data = repo._add_standardized_unit(unit_data)
std_q = standardized_data.get("standard_quantity")
std_u = standardized_data.get("standard_unit")
if std_q and std_u:
logger.info(f"Found unit '{unit.name}', which is standardized as '{std_q} * {std_u}'")
unit.standard_quantity = std_q
unit.standard_unit = std_u
session.commit()
session.close()
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
@@ -25,6 +90,12 @@ def upgrade():
# ### end Alembic commands ###
# Populate standardized units for existing records
try:
populate_standards()
except Exception:
logger.exception("Failed to populate unit standards, skipping...")
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###