diff --git a/mealie/repos/repository_generic.py b/mealie/repos/repository_generic.py index e0c9e5d8f..76a5b10a4 100644 --- a/mealie/repos/repository_generic.py +++ b/mealie/repos/repository_generic.py @@ -364,14 +364,16 @@ class RepositoryGeneric[Schema: MealieModel, Model: SqlAlchemyBase]: self.logger.error(e) raise HTTPException(status_code=400, detail=str(e)) from e - count_query = select(func.count()).select_from(query.subquery()) + count_query = select(func.count()).select_from(query.order_by(None).distinct().subquery()) count = self.session.scalar(count_query) if not count: count = 0 # interpret -1 as "get_all" + limit: int | None = pagination.per_page if pagination.per_page == -1: pagination.per_page = count + limit = None try: total_pages = ceil(count / pagination.per_page) @@ -387,7 +389,11 @@ class RepositoryGeneric[Schema: MealieModel, Model: SqlAlchemyBase]: pagination.page = 1 query = self.add_order_by_to_query(query, pagination) - return query.limit(pagination.per_page).offset((pagination.page - 1) * pagination.per_page), count, total_pages + + if limit is not None: + query = query.limit(limit) + + return query.offset((pagination.page - 1) * pagination.per_page), count, total_pages def add_order_attr_to_query( self, diff --git a/tests/unit_tests/repository_tests/test_pagination.py b/tests/unit_tests/repository_tests/test_pagination.py index 64c987a10..ecee158a8 100644 --- a/tests/unit_tests/repository_tests/test_pagination.py +++ b/tests/unit_tests/repository_tests/test_pagination.py @@ -131,6 +131,29 @@ def test_pagination_response_and_metadata(unique_user: TestUser): assert last_page_of_results.items[-1] == all_results.items[-1] +def test_pagination_total_calculation(unique_user: TestUser): + db = unique_user.repos + unique_category_1, unused_category = ( + db.categories.create(CategorySave(group_id=unique_user.group_id, name=random_string(10))) for _ in range(2) + ) + recipe_1, recipe_2 = ( + db.recipes.create(Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string())) + for _ in range(2) + ) + recipe_1.recipe_category = [unique_category_1] + + db.recipes.update(recipe_1.slug, recipe_1) + db.recipes.update(recipe_2.slug, recipe_2) + + query = PaginationQuery(page=1, per_page=64, query_filter=f"recipeCategory.name NOT IN [{unique_category_1.name}]") + result = db.recipes.page_all(query) + assert result.total == 1 + + query = PaginationQuery(page=1, per_page=64, query_filter=f"recipeCategory.name NOT IN [{unused_category.name}]") + result = db.recipes.page_all(query) + assert result.total == 2 + + def test_pagination_guides(unique_user: TestUser): database = unique_user.repos group = database.groups.get_one(unique_user.group_id) @@ -239,19 +262,19 @@ def test_pagination_filter_string_case_insensitive( units_repo.delete(upper_unit.id) -def test_pagination_filter_null(unique_user: TestUser): - database = unique_user.repos +def test_pagination_filter_null(unique_user_fn_scoped: TestUser): + database = unique_user_fn_scoped.repos recipe_not_made_1 = database.recipes.create( Recipe( - user_id=unique_user.user_id, - group_id=unique_user.group_id, + user_id=unique_user_fn_scoped.user_id, + group_id=unique_user_fn_scoped.group_id, name=random_string(), ) ) recipe_not_made_2 = database.recipes.create( Recipe( - user_id=unique_user.user_id, - group_id=unique_user.group_id, + user_id=unique_user_fn_scoped.user_id, + group_id=unique_user_fn_scoped.group_id, name=random_string(), ) ) @@ -259,8 +282,8 @@ def test_pagination_filter_null(unique_user: TestUser): # give one recipe a last made date recipe_made = database.recipes.create( Recipe( - user_id=unique_user.user_id, - group_id=unique_user.group_id, + user_id=unique_user_fn_scoped.user_id, + group_id=unique_user_fn_scoped.group_id, name=random_string(), last_made=datetime.now(UTC), ) @@ -370,6 +393,7 @@ def test_pagination_filter_not_in_m2m(unique_user: TestUser): db.recipes.update(recipe_2.slug, recipe_2) query = PaginationQuery(page=1, per_page=-1, query_filter=f"recipeCategory.name NOT IN [{unique_category_1.name}]") + recipe_results = db.recipes.page_all(query).items recipe_results_ids = {recipe.id for recipe in recipe_results} assert recipe_1.id not in recipe_results_ids