feat: Advanced Query Filter Record Ordering (#2530)

* added support for multiple order_by strs

* refactored qf to expose nested attr logic

* added nested attr support to order_by

* added tests

* changed unique user to be function-level

* updated docs

* added support for null handling

* updated docs

* undid fixture changes

* fix leaky tests

* added advanced shopping list item test

---------

Co-authored-by: Hayden <64056131+hay-kot@users.noreply.github.com>
This commit is contained in:
Michael Genson
2023-09-14 09:09:05 -05:00
committed by GitHub
parent 2c5e5a8421
commit aec4cb4f31
6 changed files with 483 additions and 66 deletions

View File

@@ -7,14 +7,14 @@ from typing import Any, Generic, TypeVar
from fastapi import HTTPException
from pydantic import UUID4, BaseModel
from sqlalchemy import Select, case, delete, func, select
from sqlalchemy import Select, case, delete, func, nulls_first, nulls_last, select
from sqlalchemy.orm.session import Session
from sqlalchemy.sql import sqltypes
from mealie.core.root_logger import get_logger
from mealie.db.models._model_base import SqlAlchemyBase
from mealie.schema._mealie import MealieModel
from mealie.schema.response.pagination import OrderDirection, PaginationBase, PaginationQuery
from mealie.schema.response.pagination import OrderByNullPosition, OrderDirection, PaginationBase, PaginationQuery
from mealie.schema.response.query_filter import QueryFilter
from mealie.schema.response.query_search import SearchFilter
@@ -372,32 +372,65 @@ class RepositoryGeneric(Generic[Schema, Model]):
pagination.page = 1
if pagination.order_by:
if order_attr := getattr(self.model, pagination.order_by, None):
# queries handle uppercase and lowercase differently, which is undesirable
if isinstance(order_attr.type, sqltypes.String):
order_attr = func.lower(order_attr)
if pagination.order_direction == OrderDirection.asc:
order_attr = order_attr.asc()
elif pagination.order_direction == OrderDirection.desc:
order_attr = order_attr.desc()
query = query.order_by(order_attr)
elif pagination.order_by == "random":
# randomize outside of database, since not all db's can set random seeds
# this solution is db-independent & stable to paging
temp_query = query.with_only_columns(self.model.id)
allids = self.session.execute(temp_query).scalars().all() # fast because id is indexed
order = list(range(len(allids)))
random.seed(pagination.pagination_seed)
random.shuffle(order)
random_dict = dict(zip(allids, order, strict=True))
case_stmt = case(random_dict, value=self.model.id)
query = query.order_by(case_stmt)
query = self.add_order_by_to_query(query, pagination)
return query.limit(pagination.per_page).offset((pagination.page - 1) * pagination.per_page), count, total_pages
def add_order_by_to_query(self, query: Select, pagination: PaginationQuery) -> Select:
if not pagination.order_by:
return query
if pagination.order_by == "random":
# randomize outside of database, since not all db's can set random seeds
# this solution is db-independent & stable to paging
temp_query = query.with_only_columns(self.model.id)
allids = self.session.execute(temp_query).scalars().all() # fast because id is indexed
order = list(range(len(allids)))
random.seed(pagination.pagination_seed)
random.shuffle(order)
random_dict = dict(zip(allids, order, strict=True))
case_stmt = case(random_dict, value=self.model.id)
return query.order_by(case_stmt)
else:
for order_by_val in pagination.order_by.split(","):
try:
order_by_val = order_by_val.strip()
if ":" in order_by_val:
order_by, order_dir_val = order_by_val.split(":")
order_dir = OrderDirection(order_dir_val)
else:
order_by = order_by_val
order_dir = pagination.order_direction
_, order_attr, query = QueryFilter.get_model_and_model_attr_from_attr_string(
order_by, self.model, query=query
)
if order_dir is OrderDirection.asc:
order_attr = order_attr.asc()
elif order_dir is OrderDirection.desc:
order_attr = order_attr.desc()
# queries handle uppercase and lowercase differently, which is undesirable
if isinstance(order_attr.type, sqltypes.String):
order_attr = func.lower(order_attr)
if pagination.order_by_null_position is OrderByNullPosition.first:
order_attr = nulls_first(order_attr)
elif pagination.order_by_null_position is OrderByNullPosition.last:
order_attr = nulls_last(order_attr)
query = query.order_by(order_attr)
except ValueError as e:
raise HTTPException(
status_code=400,
detail=f'Invalid order_by statement "{pagination.order_by}": "{order_by_val}" is invalid',
) from e
return query
def add_search_to_query(self, query: Select, schema: type[Schema], search: str) -> Select:
search_filter = SearchFilter(self.session, search, schema._normalize_search)
return search_filter.filter_query_by_search(query, schema, self.model)

View File

@@ -16,6 +16,11 @@ class OrderDirection(str, enum.Enum):
desc = "desc"
class OrderByNullPosition(str, enum.Enum):
first = "first"
last = "last"
class RecipeSearchQuery(MealieModel):
cookbook: UUID4 | str | None
require_all_categories: bool = False
@@ -30,6 +35,7 @@ class PaginationQuery(MealieModel):
page: int = 1
per_page: int = 50
order_by: str = "created_at"
order_by_null_position: OrderByNullPosition | None = None
order_direction: OrderDirection = OrderDirection.desc
query_filter: str | None = None
pagination_seed: str | None = None

View File

@@ -13,9 +13,10 @@ from sqlalchemy import ColumnElement, Select, and_, inspect, or_
from sqlalchemy.orm import InstrumentedAttribute, Mapper
from sqlalchemy.sql import sqltypes
from mealie.db.models._model_base import SqlAlchemyBase
from mealie.db.models._model_utils.guid import GUID
Model = TypeVar("Model")
Model = TypeVar("Model", bound=SqlAlchemyBase)
class RelationalKeyword(Enum):
@@ -238,6 +239,53 @@ class QueryFilter:
if i == len(group) - 1:
return consolidated_group_builder.self_group()
@classmethod
def get_model_and_model_attr_from_attr_string(
cls, attr_string: str, model: type[Model], *, query: Select | None = None
) -> tuple[SqlAlchemyBase, InstrumentedAttribute, Select | None]:
"""
Take an attribute string and traverse a database model and its relationships to get the desired
model and model attribute. Optionally provide a query to apply the necessary table joins.
If the attribute string is invalid, raises a `ValueError`.
For instance, the attribute string "user.name" on `RecipeModel`
will return the `User` model's `name` attribute.
Works with shallow attributes (e.g. "slug" from `RecipeModel`)
and arbitrarily deep ones (e.g. "recipe.group.preferences" on `RecipeTimelineEvent`).
"""
model_attr: InstrumentedAttribute | None = None
attribute_chain = attr_string.split(".")
if not attribute_chain:
raise ValueError("invalid query string: attribute name cannot be empty")
current_model: SqlAlchemyBase = model # type: ignore
for i, attribute_link in enumerate(attribute_chain):
try:
model_attr = getattr(current_model, attribute_link)
# at the end of the chain there are no more relationships to inspect
if i == len(attribute_chain) - 1:
break
if query is not None:
query = query.join(
model_attr, isouter=True
) # we use outer joins to not unintentionally filter out values
mapper: Mapper = inspect(current_model)
relationship = mapper.relationships[attribute_link]
current_model = relationship.mapper.class_
except (AttributeError, KeyError) as e:
raise ValueError(f"invalid attribute string: '{attr_string}' does not exist on this schema") from e
if model_attr is None:
raise ValueError(f"invalid attribute string: '{attr_string}'")
return current_model, model_attr, query
def filter_query(self, query: Select, model: type[Model]) -> Select:
# join tables and build model chain
attr_model_map: dict[int, Any] = {}
@@ -246,29 +294,10 @@ class QueryFilter:
if not isinstance(component, QueryFilterComponent):
continue
attribute_chain = component.attribute_name.split(".")
if not attribute_chain:
raise ValueError("invalid query string: attribute name cannot be empty")
current_model = model
for j, attribute_link in enumerate(attribute_chain):
try:
model_attr = getattr(current_model, attribute_link)
# at the end of the chain there are no more relationships to inspect
if j == len(attribute_chain) - 1:
break
query = query.join(model_attr)
mapper: Mapper = inspect(current_model)
relationship = mapper.relationships[attribute_link]
current_model = relationship.mapper.class_
except (AttributeError, KeyError) as e:
raise ValueError(
f"invalid query string: '{component.attribute_name}' does not exist on this schema"
) from e
attr_model_map[i] = current_model
nested_model, model_attr, query = self.get_model_and_model_attr_from_attr_string(
component.attribute_name, model, query=query
)
attr_model_map[i] = nested_model
# build query filter
partial_group: list[ColumnElement] = []