mirror of
https://github.com/mealie-recipes/mealie.git
synced 2025-11-01 10:43:25 -04:00
feat: Add Households to Mealie (#3970)
This commit is contained in:
@@ -19,6 +19,8 @@ from mealie.schema.response.pagination import OrderByNullPosition, OrderDirectio
|
||||
from mealie.schema.response.query_filter import QueryFilter
|
||||
from mealie.schema.response.query_search import SearchFilter
|
||||
|
||||
from ._utils import NOT_SET, NotSet
|
||||
|
||||
Schema = TypeVar("Schema", bound=MealieModel)
|
||||
Model = TypeVar("Model", bound=SqlAlchemyBase)
|
||||
|
||||
@@ -33,11 +35,18 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
Generic ([Model]): Represents the SqlAlchemyModel Model
|
||||
"""
|
||||
|
||||
user_id: UUID4 | None = None
|
||||
group_id: UUID4 | None = None
|
||||
session: Session
|
||||
|
||||
def __init__(self, session: Session, primary_key: str, sql_model: type[Model], schema: type[Schema]) -> None:
|
||||
_group_id: UUID4 | None = None
|
||||
_household_id: UUID4 | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: Session,
|
||||
primary_key: str,
|
||||
sql_model: type[Model],
|
||||
schema: type[Schema],
|
||||
) -> None:
|
||||
self.session = session
|
||||
self.primary_key = primary_key
|
||||
self.model = sql_model
|
||||
@@ -45,13 +54,13 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
|
||||
self.logger = get_logger()
|
||||
|
||||
def by_user(self: T, user_id: UUID4) -> T:
|
||||
self.user_id = user_id
|
||||
return self
|
||||
@property
|
||||
def group_id(self) -> UUID4 | None:
|
||||
return self._group_id
|
||||
|
||||
def by_group(self: T, group_id: UUID4) -> T:
|
||||
self.group_id = group_id
|
||||
return self
|
||||
@property
|
||||
def household_id(self) -> UUID4 | None:
|
||||
return self._household_id
|
||||
|
||||
def _log_exception(self, e: Exception) -> None:
|
||||
self.logger.error(f"Error processing query for Repo model={self.model.__name__} schema={self.schema.__name__}")
|
||||
@@ -70,6 +79,8 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
|
||||
if self.group_id:
|
||||
dct["group_id"] = self.group_id
|
||||
if self.household_id:
|
||||
dct["household_id"] = self.household_id
|
||||
|
||||
return {**dct, **kwargs}
|
||||
|
||||
@@ -341,7 +352,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
self.logger.error(e)
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
count_query = select(func.count()).select_from(query)
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
count = self.session.scalar(count_query)
|
||||
if not count:
|
||||
count = 0
|
||||
@@ -373,15 +384,15 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
order_dir: OrderDirection,
|
||||
order_by_null: OrderByNullPosition | None,
|
||||
) -> Select:
|
||||
# queries handle uppercase and lowercase differently, which is undesirable
|
||||
if isinstance(order_attr.type, sqltypes.String):
|
||||
order_attr = func.lower(order_attr)
|
||||
|
||||
if order_dir is OrderDirection.asc:
|
||||
order_attr = order_attr.asc()
|
||||
elif order_dir is OrderDirection.desc:
|
||||
order_attr = order_attr.desc()
|
||||
|
||||
# queries handle uppercase and lowercase differently, which is undesirable
|
||||
if isinstance(order_attr.type, sqltypes.String):
|
||||
order_attr = func.lower(order_attr)
|
||||
|
||||
if order_by_null is OrderByNullPosition.first:
|
||||
order_attr = nulls_first(order_attr)
|
||||
elif order_by_null is OrderByNullPosition.last:
|
||||
@@ -435,3 +446,40 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
def add_search_to_query(self, query: Select, schema: type[Schema], search: str) -> Select:
|
||||
search_filter = SearchFilter(self.session, search, schema._normalize_search)
|
||||
return search_filter.filter_query_by_search(query, schema, self.model)
|
||||
|
||||
|
||||
class GroupRepositoryGeneric(RepositoryGeneric[Schema, Model]):
|
||||
def __init__(
|
||||
self,
|
||||
session: Session,
|
||||
primary_key: str,
|
||||
sql_model: type[Model],
|
||||
schema: type[Schema],
|
||||
*,
|
||||
group_id: UUID4 | None | NotSet,
|
||||
) -> None:
|
||||
super().__init__(session, primary_key, sql_model, schema)
|
||||
if group_id is NOT_SET:
|
||||
raise ValueError("group_id must be set")
|
||||
self._group_id = group_id if group_id else None
|
||||
|
||||
|
||||
class HouseholdRepositoryGeneric(RepositoryGeneric[Schema, Model]):
|
||||
def __init__(
|
||||
self,
|
||||
session: Session,
|
||||
primary_key: str,
|
||||
sql_model: type[Model],
|
||||
schema: type[Schema],
|
||||
*,
|
||||
group_id: UUID4 | None | NotSet,
|
||||
household_id: UUID4 | None | NotSet,
|
||||
) -> None:
|
||||
super().__init__(session, primary_key, sql_model, schema)
|
||||
if group_id is NOT_SET:
|
||||
raise ValueError("group_id must be set")
|
||||
self._group_id = group_id if group_id else None
|
||||
|
||||
if household_id is NOT_SET:
|
||||
raise ValueError("household_id must be set")
|
||||
self._household_id = household_id if household_id else None
|
||||
|
||||
Reference in New Issue
Block a user