mirror of
https://github.com/mealie-recipes/mealie.git
synced 2026-01-05 08:31:25 -05:00
feat: User-specific Recipe Ratings (#3345)
This commit is contained in:
@@ -31,6 +31,7 @@ from mealie.db.models.recipe.tool import Tool
|
||||
from mealie.db.models.server.task import ServerTaskModel
|
||||
from mealie.db.models.users import LongLiveToken, User
|
||||
from mealie.db.models.users.password_reset import PasswordResetModel
|
||||
from mealie.db.models.users.user_to_recipe import UserToRecipe
|
||||
from mealie.repos.repository_foods import RepositoryFood
|
||||
from mealie.repos.repository_meal_plan_rules import RepositoryMealPlanRules
|
||||
from mealie.repos.repository_units import RepositoryUnit
|
||||
@@ -58,6 +59,7 @@ from mealie.schema.recipe.recipe_timeline_events import RecipeTimelineEventOut
|
||||
from mealie.schema.reports.reports import ReportEntryOut, ReportOut
|
||||
from mealie.schema.server import ServerTask
|
||||
from mealie.schema.user import GroupInDB, LongLiveTokenInDB, PrivateUser
|
||||
from mealie.schema.user.user import UserRatingOut
|
||||
from mealie.schema.user.user_passwords import PrivatePasswordResetToken
|
||||
|
||||
from .repository_generic import RepositoryGeneric
|
||||
@@ -65,7 +67,7 @@ from .repository_group import RepositoryGroup
|
||||
from .repository_meals import RepositoryMeals
|
||||
from .repository_recipes import RepositoryRecipes
|
||||
from .repository_shopping_list import RepositoryShoppingList
|
||||
from .repository_users import RepositoryUsers
|
||||
from .repository_users import RepositoryUserRatings, RepositoryUsers
|
||||
|
||||
PK_ID = "id"
|
||||
PK_SLUG = "slug"
|
||||
@@ -143,6 +145,10 @@ class AllRepositories:
|
||||
def users(self) -> RepositoryUsers:
|
||||
return RepositoryUsers(self.session, PK_ID, User, PrivateUser)
|
||||
|
||||
@cached_property
|
||||
def user_ratings(self) -> RepositoryUserRatings:
|
||||
return RepositoryUserRatings(self.session, PK_ID, UserToRecipe, UserRatingOut)
|
||||
|
||||
@cached_property
|
||||
def api_tokens(self) -> RepositoryGeneric[LongLiveTokenInDB, LongLiveToken]:
|
||||
return RepositoryGeneric(self.session, PK_ID, LongLiveToken, LongLiveTokenInDB)
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any, Generic, TypeVar
|
||||
from fastapi import HTTPException
|
||||
from pydantic import UUID4, BaseModel
|
||||
from sqlalchemy import Select, case, delete, func, nulls_first, nulls_last, select
|
||||
from sqlalchemy.orm import InstrumentedAttribute
|
||||
from sqlalchemy.orm.session import Session
|
||||
from sqlalchemy.sql import sqltypes
|
||||
|
||||
@@ -67,9 +68,6 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
def _filter_builder(self, **kwargs) -> dict[str, Any]:
|
||||
dct = {}
|
||||
|
||||
if self.user_id:
|
||||
dct["user_id"] = self.user_id
|
||||
|
||||
if self.group_id:
|
||||
dct["group_id"] = self.group_id
|
||||
|
||||
@@ -287,7 +285,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
pagination is a method to interact with the filtered database table and return a paginated result
|
||||
using the PaginationBase that provides several data points that are needed to manage pagination
|
||||
on the client side. This method does utilize the _filter_build method to ensure that the results
|
||||
are filtered by the user and group id when applicable.
|
||||
are filtered by the group id when applicable.
|
||||
|
||||
NOTE: When you provide an override you'll need to manually type the result of this method
|
||||
as the override, as the type system is not able to infer the result of this method.
|
||||
@@ -368,6 +366,29 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
query = self.add_order_by_to_query(query, pagination)
|
||||
return query.limit(pagination.per_page).offset((pagination.page - 1) * pagination.per_page), count, total_pages
|
||||
|
||||
def add_order_attr_to_query(
|
||||
self,
|
||||
query: Select,
|
||||
order_attr: InstrumentedAttribute,
|
||||
order_dir: OrderDirection,
|
||||
order_by_null: OrderByNullPosition | None,
|
||||
) -> Select:
|
||||
if order_dir is OrderDirection.asc:
|
||||
order_attr = order_attr.asc()
|
||||
elif order_dir is OrderDirection.desc:
|
||||
order_attr = order_attr.desc()
|
||||
|
||||
# queries handle uppercase and lowercase differently, which is undesirable
|
||||
if isinstance(order_attr.type, sqltypes.String):
|
||||
order_attr = func.lower(order_attr)
|
||||
|
||||
if order_by_null is OrderByNullPosition.first:
|
||||
order_attr = nulls_first(order_attr)
|
||||
elif order_by_null is OrderByNullPosition.last:
|
||||
order_attr = nulls_last(order_attr)
|
||||
|
||||
return query.order_by(order_attr)
|
||||
|
||||
def add_order_by_to_query(self, query: Select, pagination: PaginationQuery) -> Select:
|
||||
if not pagination.order_by:
|
||||
return query
|
||||
@@ -399,21 +420,9 @@ class RepositoryGeneric(Generic[Schema, Model]):
|
||||
order_by, self.model, query=query
|
||||
)
|
||||
|
||||
if order_dir is OrderDirection.asc:
|
||||
order_attr = order_attr.asc()
|
||||
elif order_dir is OrderDirection.desc:
|
||||
order_attr = order_attr.desc()
|
||||
|
||||
# queries handle uppercase and lowercase differently, which is undesirable
|
||||
if isinstance(order_attr.type, sqltypes.String):
|
||||
order_attr = func.lower(order_attr)
|
||||
|
||||
if pagination.order_by_null_position is OrderByNullPosition.first:
|
||||
order_attr = nulls_first(order_attr)
|
||||
elif pagination.order_by_null_position is OrderByNullPosition.last:
|
||||
order_attr = nulls_last(order_attr)
|
||||
|
||||
query = query.order_by(order_attr)
|
||||
query = self.add_order_attr_to_query(
|
||||
query, order_attr, order_dir, pagination.order_by_null_position
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -3,11 +3,11 @@ from collections.abc import Sequence
|
||||
from random import randint
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import UUID4
|
||||
from slugify import slugify
|
||||
from sqlalchemy import and_, func, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import InstrumentedAttribute, joinedload
|
||||
|
||||
from mealie.db.models.recipe.category import Category
|
||||
from mealie.db.models.recipe.ingredient import RecipeIngredientModel
|
||||
@@ -15,11 +15,12 @@ from mealie.db.models.recipe.recipe import RecipeModel
|
||||
from mealie.db.models.recipe.settings import RecipeSettings
|
||||
from mealie.db.models.recipe.tag import Tag
|
||||
from mealie.db.models.recipe.tool import Tool
|
||||
from mealie.db.models.users.user_to_recipe import UserToRecipe
|
||||
from mealie.schema.cookbook.cookbook import ReadCookBook
|
||||
from mealie.schema.recipe import Recipe
|
||||
from mealie.schema.recipe.recipe import RecipeCategory, RecipePagination, RecipeSummary, RecipeTag, RecipeTool
|
||||
from mealie.schema.recipe.recipe_category import CategoryBase, TagBase
|
||||
from mealie.schema.response.pagination import PaginationQuery
|
||||
from mealie.schema.response.pagination import OrderByNullPosition, OrderDirection, PaginationQuery
|
||||
|
||||
from ..db.models._model_base import SqlAlchemyBase
|
||||
from ..schema._mealie.mealie_model import extract_uuids
|
||||
@@ -51,7 +52,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
|
||||
if order_by:
|
||||
order_attr = getattr(self.model, str(order_by))
|
||||
stmt = (
|
||||
select(self.model)
|
||||
sa.select(self.model)
|
||||
.join(RecipeSettings)
|
||||
.filter(RecipeSettings.public == True) # noqa: E712
|
||||
.order_by(order_attr.desc())
|
||||
@@ -61,7 +62,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
|
||||
return [eff_schema.model_validate(x) for x in self.session.execute(stmt).scalars().all()]
|
||||
|
||||
stmt = (
|
||||
select(self.model)
|
||||
sa.select(self.model)
|
||||
.join(RecipeSettings)
|
||||
.filter(RecipeSettings.public == True) # noqa: E712
|
||||
.offset(start)
|
||||
@@ -121,7 +122,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
|
||||
order_attr = order_attr.asc()
|
||||
|
||||
stmt = (
|
||||
select(RecipeModel)
|
||||
sa.select(RecipeModel)
|
||||
.options(*args)
|
||||
.filter(RecipeModel.group_id == group_id)
|
||||
.order_by(order_attr)
|
||||
@@ -145,9 +146,54 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
|
||||
ids.append(i_as_uuid)
|
||||
except ValueError:
|
||||
slugs.append(i)
|
||||
additional_ids = self.session.execute(select(model.id).filter(model.slug.in_(slugs))).scalars().all()
|
||||
additional_ids = self.session.execute(sa.select(model.id).filter(model.slug.in_(slugs))).scalars().all()
|
||||
return ids + additional_ids
|
||||
|
||||
def add_order_attr_to_query(
|
||||
self,
|
||||
query: sa.Select,
|
||||
order_attr: InstrumentedAttribute,
|
||||
order_dir: OrderDirection,
|
||||
order_by_null: OrderByNullPosition | None,
|
||||
) -> sa.Select:
|
||||
"""Special handling for ordering recipes by rating"""
|
||||
column_name = order_attr.key
|
||||
if column_name != "rating" or not self.user_id:
|
||||
return super().add_order_attr_to_query(query, order_attr, order_dir, order_by_null)
|
||||
|
||||
# calculate the effictive rating for the user by using the user's rating if it exists,
|
||||
# falling back to the recipe's rating if it doesn't
|
||||
effective_rating_column_name = "_effective_rating"
|
||||
query = query.add_columns(
|
||||
sa.case(
|
||||
(
|
||||
sa.exists().where(
|
||||
UserToRecipe.recipe_id == self.model.id,
|
||||
UserToRecipe.user_id == self.user_id,
|
||||
UserToRecipe.rating is not None,
|
||||
UserToRecipe.rating > 0,
|
||||
),
|
||||
sa.select(UserToRecipe.rating)
|
||||
.where(UserToRecipe.recipe_id == self.model.id, UserToRecipe.user_id == self.user_id)
|
||||
.scalar_subquery(),
|
||||
),
|
||||
else_=self.model.rating,
|
||||
).label(effective_rating_column_name)
|
||||
)
|
||||
|
||||
order_attr = effective_rating_column_name
|
||||
if order_dir is OrderDirection.asc:
|
||||
order_attr = sa.asc(order_attr)
|
||||
elif order_dir is OrderDirection.desc:
|
||||
order_attr = sa.desc(order_attr)
|
||||
|
||||
if order_by_null is OrderByNullPosition.first:
|
||||
order_attr = sa.nulls_first(order_attr)
|
||||
elif order_by_null is OrderByNullPosition.last:
|
||||
order_attr = sa.nulls_last(order_attr)
|
||||
|
||||
return query.order_by(order_attr)
|
||||
|
||||
def page_all( # type: ignore
|
||||
self,
|
||||
pagination: PaginationQuery,
|
||||
@@ -165,7 +211,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
|
||||
) -> RecipePagination:
|
||||
# Copy this, because calling methods (e.g. tests) might rely on it not getting mutated
|
||||
pagination_result = pagination.model_copy()
|
||||
q = select(self.model)
|
||||
q = sa.select(self.model)
|
||||
|
||||
args = [
|
||||
joinedload(RecipeModel.recipe_category),
|
||||
@@ -236,7 +282,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
|
||||
|
||||
ids = [x.id for x in categories]
|
||||
stmt = (
|
||||
select(RecipeModel)
|
||||
sa.select(RecipeModel)
|
||||
.join(RecipeModel.recipe_category)
|
||||
.filter(RecipeModel.recipe_category.any(Category.id.in_(ids)))
|
||||
)
|
||||
@@ -301,7 +347,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
|
||||
require_all_tags=require_all_tags,
|
||||
require_all_tools=require_all_tools,
|
||||
)
|
||||
stmt = select(RecipeModel).filter(*fltr)
|
||||
stmt = sa.select(RecipeModel).filter(*fltr)
|
||||
return [self.schema.model_validate(x) for x in self.session.execute(stmt).scalars().all()]
|
||||
|
||||
def get_random_by_categories_and_tags(
|
||||
@@ -318,26 +364,29 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
|
||||
|
||||
filters = self._build_recipe_filter(extract_uuids(categories), extract_uuids(tags)) # type: ignore
|
||||
stmt = (
|
||||
select(RecipeModel).filter(and_(*filters)).order_by(func.random()).limit(1) # Postgres and SQLite specific
|
||||
sa.select(RecipeModel)
|
||||
.filter(sa.and_(*filters))
|
||||
.order_by(sa.func.random())
|
||||
.limit(1) # Postgres and SQLite specific
|
||||
)
|
||||
return [self.schema.model_validate(x) for x in self.session.execute(stmt).scalars().all()]
|
||||
|
||||
def get_random(self, limit=1) -> list[Recipe]:
|
||||
stmt = (
|
||||
select(RecipeModel)
|
||||
sa.select(RecipeModel)
|
||||
.filter(RecipeModel.group_id == self.group_id)
|
||||
.order_by(func.random()) # Postgres and SQLite specific
|
||||
.order_by(sa.func.random()) # Postgres and SQLite specific
|
||||
.limit(limit)
|
||||
)
|
||||
return [self.schema.model_validate(x) for x in self.session.execute(stmt).scalars().all()]
|
||||
|
||||
def get_by_slug(self, group_id: UUID4, slug: str, limit=1) -> Recipe | None:
|
||||
stmt = select(RecipeModel).filter(RecipeModel.group_id == group_id, RecipeModel.slug == slug)
|
||||
def get_by_slug(self, group_id: UUID4, slug: str) -> Recipe | None:
|
||||
stmt = sa.select(RecipeModel).filter(RecipeModel.group_id == group_id, RecipeModel.slug == slug)
|
||||
dbrecipe = self.session.execute(stmt).scalars().one_or_none()
|
||||
if dbrecipe is None:
|
||||
return None
|
||||
return self.schema.model_validate(dbrecipe)
|
||||
|
||||
def all_ids(self, group_id: UUID4) -> Sequence[UUID4]:
|
||||
stmt = select(RecipeModel.id).filter(RecipeModel.group_id == group_id)
|
||||
stmt = sa.select(RecipeModel.id).filter(RecipeModel.group_id == group_id)
|
||||
return self.session.execute(stmt).scalars().all()
|
||||
|
||||
@@ -6,7 +6,8 @@ from sqlalchemy import select
|
||||
|
||||
from mealie.assets import users as users_assets
|
||||
from mealie.core.config import get_app_settings
|
||||
from mealie.schema.user.user import PrivateUser
|
||||
from mealie.db.models.users.user_to_recipe import UserToRecipe
|
||||
from mealie.schema.user.user import PrivateUser, UserRatingOut
|
||||
|
||||
from ..db.models.users import User
|
||||
from .repository_generic import RepositoryGeneric
|
||||
@@ -72,3 +73,26 @@ class RepositoryUsers(RepositoryGeneric[PrivateUser, User]):
|
||||
stmt = select(User).filter(User.locked_at != None) # noqa E711
|
||||
results = self.session.execute(stmt).scalars().all()
|
||||
return [self.schema.model_validate(x) for x in results]
|
||||
|
||||
|
||||
class RepositoryUserRatings(RepositoryGeneric[UserRatingOut, UserToRecipe]):
|
||||
def get_by_user(self, user_id: UUID4, favorites_only=False) -> list[UserRatingOut]:
|
||||
stmt = select(UserToRecipe).filter(UserToRecipe.user_id == user_id)
|
||||
if favorites_only:
|
||||
stmt = stmt.filter(UserToRecipe.is_favorite)
|
||||
|
||||
results = self.session.execute(stmt).scalars().all()
|
||||
return [self.schema.model_validate(x) for x in results]
|
||||
|
||||
def get_by_recipe(self, recipe_id: UUID4, favorites_only=False) -> list[UserRatingOut]:
|
||||
stmt = select(UserToRecipe).filter(UserToRecipe.recipe_id == recipe_id)
|
||||
if favorites_only:
|
||||
stmt = stmt.filter(UserToRecipe.is_favorite)
|
||||
|
||||
results = self.session.execute(stmt).scalars().all()
|
||||
return [self.schema.model_validate(x) for x in results]
|
||||
|
||||
def get_by_user_and_recipe(self, user_id: UUID4, recipe_id: UUID4) -> UserRatingOut | None:
|
||||
stmt = select(UserToRecipe).filter(UserToRecipe.user_id == user_id, UserToRecipe.recipe_id == recipe_id)
|
||||
result = self.session.execute(stmt).scalars().one_or_none()
|
||||
return None if result is None else self.schema.model_validate(result)
|
||||
|
||||
Reference in New Issue
Block a user