mirror of
https://github.com/mealie-recipes/mealie.git
synced 2025-11-01 02:33:22 -04:00
wip: pagination-repository (#1316)
* bump mypy * add pagination + refactor generic repo * add pagination test * remove all query object
This commit is contained in:
@@ -1,90 +1,75 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Generic, TypeVar, Union
|
||||
|
||||
from pydantic import UUID4, BaseModel
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import load_only
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
D = TypeVar("D")
|
||||
from mealie.core.root_logger import get_logger
|
||||
from mealie.schema.response.pagination import OrderDirection, PaginationBase, PaginationQuery
|
||||
|
||||
Schema = TypeVar("Schema", bound=BaseModel)
|
||||
Model = TypeVar("Model")
|
||||
|
||||
|
||||
class RepositoryGeneric(Generic[T, D]):
|
||||
class RepositoryGeneric(Generic[Schema, Model]):
|
||||
"""A Generic BaseAccess Model method to perform common operations on the database
|
||||
|
||||
Args:
|
||||
Generic ([T]): Represents the Pydantic Model
|
||||
Generic ([D]): Represents the SqlAlchemyModel Model
|
||||
Generic ([Schema]): Represents the Pydantic Model
|
||||
Generic ([Model]): Represents the SqlAlchemyModel Model
|
||||
"""
|
||||
|
||||
def __init__(self, session: Session, primary_key: str, sql_model: type[D], schema: type[T]) -> None:
|
||||
user_id: UUID4 = None
|
||||
group_id: UUID4 = None
|
||||
|
||||
def __init__(self, session: Session, primary_key: str, sql_model: type[Model], schema: type[Schema]) -> None:
|
||||
self.session = session
|
||||
self.primary_key = primary_key
|
||||
self.sql_model = sql_model
|
||||
self.model = sql_model
|
||||
self.schema = schema
|
||||
self.observers: list = []
|
||||
|
||||
self.limit_by_group = False
|
||||
self.user_id: UUID4 = None
|
||||
self.logger = get_logger()
|
||||
|
||||
self.limit_by_user = False
|
||||
self.group_id: UUID4 = None
|
||||
|
||||
def subscribe(self, func: Callable) -> None:
|
||||
self.observers.append(func)
|
||||
|
||||
def by_user(self, user_id: UUID4) -> "RepositoryGeneric[T, D]":
|
||||
self.limit_by_user = True
|
||||
def by_user(self, user_id: UUID4) -> "RepositoryGeneric[Schema, Model]":
|
||||
self.user_id = user_id
|
||||
return self
|
||||
|
||||
def by_group(self, group_id: UUID4) -> "RepositoryGeneric[T, D]":
|
||||
self.limit_by_group = True
|
||||
def by_group(self, group_id: UUID4) -> "RepositoryGeneric[Schema, Model]":
|
||||
self.group_id = group_id
|
||||
return self
|
||||
|
||||
def _log_exception(self, e: Exception) -> None:
|
||||
self.logger.error(f"Error processing query for Repo model={self.model.__name__} schema={self.schema.__name__}")
|
||||
self.logger.error(e)
|
||||
|
||||
def _query(self):
|
||||
return self.session.query(self.model)
|
||||
|
||||
def _filter_builder(self, **kwargs) -> dict[str, Any]:
|
||||
dct = {}
|
||||
|
||||
if self.limit_by_user:
|
||||
if self.user_id:
|
||||
dct["user_id"] = self.user_id
|
||||
|
||||
if self.limit_by_group:
|
||||
if self.group_id:
|
||||
dct["group_id"] = self.group_id
|
||||
|
||||
return {**dct, **kwargs}
|
||||
|
||||
# TODO: Run Observer in Async Background Task
|
||||
def update_observers(self) -> None:
|
||||
if self.observers:
|
||||
for observer in self.observers:
|
||||
observer()
|
||||
def get_all(self, limit: int = None, order_by: str = None, start=0, override=None) -> list[Schema]:
|
||||
# sourcery skip: remove-unnecessary-cast
|
||||
eff_schema = override or self.schema
|
||||
|
||||
def get_all(self, limit: int = None, order_by: str = None, start=0, override_schema=None) -> list[T]:
|
||||
eff_schema = override_schema or self.schema
|
||||
fltr = self._filter_builder()
|
||||
|
||||
filter = self._filter_builder()
|
||||
q = self._query().filter_by(**fltr)
|
||||
|
||||
order_attr = None
|
||||
if order_by:
|
||||
order_attr = getattr(self.sql_model, str(order_by))
|
||||
order_attr = order_attr.desc()
|
||||
if order_attr := getattr(self.model, str(order_by)):
|
||||
order_attr = order_attr.desc()
|
||||
q = q.order_by(order_attr)
|
||||
|
||||
return [
|
||||
eff_schema.from_orm(x)
|
||||
for x in self.session.query(self.sql_model)
|
||||
.order_by(order_attr)
|
||||
.filter_by(**filter)
|
||||
.offset(start)
|
||||
.limit(limit)
|
||||
.all()
|
||||
]
|
||||
|
||||
return [
|
||||
eff_schema.from_orm(x)
|
||||
for x in self.session.query(self.sql_model).filter_by(**filter).offset(start).limit(limit).all()
|
||||
]
|
||||
return [eff_schema.from_orm(x) for x in q.offset(start).limit(limit).all()]
|
||||
|
||||
def multi_query(
|
||||
self,
|
||||
@@ -93,55 +78,21 @@ class RepositoryGeneric(Generic[T, D]):
|
||||
limit: int = None,
|
||||
override_schema=None,
|
||||
order_by: str = None,
|
||||
) -> list[T]:
|
||||
) -> list[Schema]:
|
||||
# sourcery skip: remove-unnecessary-cast
|
||||
eff_schema = override_schema or self.schema
|
||||
|
||||
filer = self._filter_builder(**query_by)
|
||||
fltr = self._filter_builder(**query_by)
|
||||
q = self._query().filter_by(**fltr)
|
||||
|
||||
order_attr = None
|
||||
if order_by:
|
||||
order_attr = getattr(self.sql_model, str(order_by))
|
||||
order_attr = order_attr.desc()
|
||||
if order_attr := getattr(self.model, str(order_by)):
|
||||
order_attr = order_attr.desc()
|
||||
q = q.order_by(order_attr)
|
||||
|
||||
return [
|
||||
eff_schema.from_orm(x)
|
||||
for x in self.session.query(self.sql_model)
|
||||
.filter_by(**filer)
|
||||
.order_by(order_attr)
|
||||
.offset(start)
|
||||
.limit(limit)
|
||||
.all()
|
||||
]
|
||||
return [eff_schema.from_orm(x) for x in q.offset(start).limit(limit).all()]
|
||||
|
||||
def get_all_limit_columns(self, fields: list[str], limit: int = None) -> list[D]:
|
||||
"""Queries the database for the selected model. Restricts return responses to the
|
||||
keys specified under "fields"
|
||||
|
||||
Args:
|
||||
session (Session): Database Session Object
|
||||
fields (list[str]): list of column names to query
|
||||
limit (int): A limit of values to return
|
||||
|
||||
Returns:
|
||||
list[SqlAlchemyBase]: Returns a list of ORM objects
|
||||
"""
|
||||
return self.session.query(self.sql_model).options(load_only(*fields)).limit(limit).all()
|
||||
|
||||
def get_all_primary_keys(self) -> list[str]:
|
||||
"""Queries the database of the selected model and returns a list
|
||||
of all primary_key values
|
||||
|
||||
Args:
|
||||
session (Session): Database Session object
|
||||
|
||||
Returns:
|
||||
list[str]:
|
||||
"""
|
||||
results = self.session.query(self.sql_model).options(load_only(str(self.primary_key)))
|
||||
results_as_dict = [x.dict() for x in results]
|
||||
return [x.get(self.primary_key) for x in results_as_dict]
|
||||
|
||||
def _query_one(self, match_value: str | int | UUID4, match_key: str = None) -> D:
|
||||
def _query_one(self, match_value: str | int | UUID4, match_key: str = None) -> Model:
|
||||
"""
|
||||
Query the sql database for one item an return the sql alchemy model
|
||||
object. If no match key is provided the primary_key attribute will be used.
|
||||
@@ -150,18 +101,18 @@ class RepositoryGeneric(Generic[T, D]):
|
||||
match_key = self.primary_key
|
||||
|
||||
fltr = self._filter_builder(**{match_key: match_value})
|
||||
return self.session.query(self.sql_model).filter_by(**fltr).one()
|
||||
return self._query().filter_by(**fltr).one()
|
||||
|
||||
def get_one(self, value: str | int | UUID4, key: str = None, any_case=False, override_schema=None) -> T | None:
|
||||
def get_one(self, value: str | int | UUID4, key: str = None, any_case=False, override_schema=None) -> Schema | None:
|
||||
key = key or self.primary_key
|
||||
|
||||
q = self.session.query(self.sql_model)
|
||||
q = self.session.query(self.model)
|
||||
|
||||
if any_case:
|
||||
search_attr = getattr(self.sql_model, key)
|
||||
q = q.filter(func.lower(search_attr) == str(value).lower()).filter_by(**self._filter_builder())
|
||||
search_attr = getattr(self.model, key)
|
||||
q = q.where(func.lower(search_attr) == str(value).lower()).filter_by(**self._filter_builder())
|
||||
else:
|
||||
q = self.session.query(self.sql_model).filter_by(**self._filter_builder(**{key: value}))
|
||||
q = q.filter_by(**self._filter_builder(**{key: value}))
|
||||
|
||||
result = q.one_or_none()
|
||||
|
||||
@@ -173,32 +124,20 @@ class RepositoryGeneric(Generic[T, D]):
|
||||
|
||||
def get(
|
||||
self, match_value: str | int | UUID4, match_key: str = None, limit=1, any_case=False, override_schema=None
|
||||
) -> T | list[T] | None:
|
||||
"""Retrieves an entry from the database by matching a key/value pair. If no
|
||||
key is provided the class objects primary key will be used to match against.
|
||||
|
||||
|
||||
Args:
|
||||
match_value (str): A value used to match against the key/value in the database
|
||||
match_key (str, optional): They key to match the value against. Defaults to None.
|
||||
limit (int, optional): A limit to returned responses. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
dict or list[dict]:
|
||||
|
||||
"""
|
||||
) -> Schema | list[Schema] | None:
|
||||
self.logger.info("DEPRECATED: use get_one or get_all instead")
|
||||
match_key = match_key or self.primary_key
|
||||
|
||||
if any_case:
|
||||
search_attr = getattr(self.sql_model, match_key)
|
||||
search_attr = getattr(self.model, match_key)
|
||||
result = (
|
||||
self.session.query(self.sql_model)
|
||||
self.session.query(self.model)
|
||||
.filter(func.lower(search_attr) == match_value.lower()) # type: ignore
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
result = self.session.query(self.sql_model).filter_by(**{match_key: match_value}).limit(limit).all()
|
||||
result = self.session.query(self.model).filter_by(**{match_key: match_value}).limit(limit).all()
|
||||
|
||||
eff_schema = override_schema or self.schema
|
||||
|
||||
@@ -210,28 +149,29 @@ class RepositoryGeneric(Generic[T, D]):
|
||||
|
||||
return [eff_schema.from_orm(x) for x in result]
|
||||
|
||||
def create(self, document: T | BaseModel | dict) -> T:
|
||||
"""Creates a new database entry for the given SQL Alchemy Model.
|
||||
|
||||
Args:
|
||||
session (Session): A Database Session
|
||||
document (dict): A python dictionary representing the data structure
|
||||
|
||||
Returns:
|
||||
dict: A dictionary representation of the database entry
|
||||
"""
|
||||
document = document if isinstance(document, dict) else document.dict()
|
||||
new_document = self.sql_model(session=self.session, **document) # type: ignore
|
||||
def create(self, data: Schema | BaseModel | dict) -> Schema:
|
||||
data = data if isinstance(data, dict) else data.dict()
|
||||
new_document = self.model(session=self.session, **data) # type: ignore
|
||||
self.session.add(new_document)
|
||||
self.session.commit()
|
||||
self.session.refresh(new_document)
|
||||
|
||||
if self.observers:
|
||||
self.update_observers()
|
||||
|
||||
return self.schema.from_orm(new_document)
|
||||
|
||||
def update(self, match_value: str | int | UUID4, new_data: dict | BaseModel) -> T:
|
||||
def create_many(self, data: list[Schema | dict]) -> list[Schema]:
|
||||
new_documents = []
|
||||
for document in data:
|
||||
document = document if isinstance(document, dict) else document.dict()
|
||||
new_document = self.model(session=self.session, **document) # type: ignore
|
||||
new_documents.append(new_document)
|
||||
|
||||
self.session.add_all(new_documents)
|
||||
self.session.commit()
|
||||
self.session.refresh(new_documents)
|
||||
|
||||
return [self.schema.from_orm(x) for x in new_documents]
|
||||
|
||||
def update(self, match_value: str | int | UUID4, new_data: dict | BaseModel) -> Schema:
|
||||
"""Update a database entry.
|
||||
Args:
|
||||
session (Session): Database Session
|
||||
@@ -246,30 +186,23 @@ class RepositoryGeneric(Generic[T, D]):
|
||||
entry = self._query_one(match_value=match_value)
|
||||
entry.update(session=self.session, **new_data) # type: ignore
|
||||
|
||||
if self.observers:
|
||||
self.update_observers()
|
||||
|
||||
self.session.commit()
|
||||
return self.schema.from_orm(entry)
|
||||
|
||||
def patch(self, match_value: str | int | UUID4, new_data: dict | BaseModel) -> T | None:
|
||||
def patch(self, match_value: str | int | UUID4, new_data: dict | BaseModel) -> Schema:
|
||||
new_data = new_data if isinstance(new_data, dict) else new_data.dict()
|
||||
|
||||
entry = self._query_one(match_value=match_value)
|
||||
|
||||
if not entry:
|
||||
# TODO: Should raise exception
|
||||
return None
|
||||
|
||||
entry_as_dict = self.schema.from_orm(entry).dict()
|
||||
entry_as_dict.update(new_data)
|
||||
|
||||
return self.update(match_value, entry_as_dict)
|
||||
|
||||
def delete(self, value, match_key: str | None = None) -> T:
|
||||
def delete(self, value, match_key: str | None = None) -> Schema:
|
||||
match_key = match_key or self.primary_key
|
||||
|
||||
result = self.session.query(self.sql_model).filter_by(**{match_key: value}).one()
|
||||
result = self._query().filter_by(**{match_key: value}).one()
|
||||
results_as_model = self.schema.from_orm(result)
|
||||
|
||||
try:
|
||||
@@ -279,23 +212,17 @@ class RepositoryGeneric(Generic[T, D]):
|
||||
self.session.rollback()
|
||||
raise e
|
||||
|
||||
if self.observers:
|
||||
self.update_observers()
|
||||
|
||||
return results_as_model
|
||||
|
||||
def delete_all(self) -> None:
|
||||
self.session.query(self.sql_model).delete()
|
||||
self._query().delete()
|
||||
self.session.commit()
|
||||
|
||||
if self.observers:
|
||||
self.update_observers()
|
||||
|
||||
def count_all(self, match_key=None, match_value=None) -> int:
|
||||
if None in [match_key, match_value]:
|
||||
return self.session.query(self.sql_model).count()
|
||||
return self._query().count()
|
||||
else:
|
||||
return self.session.query(self.sql_model).filter_by(**{match_key: match_value}).count()
|
||||
return self._query().filter_by(**{match_key: match_value}).count()
|
||||
|
||||
def _count_attribute(
|
||||
self,
|
||||
@@ -303,27 +230,57 @@ class RepositoryGeneric(Generic[T, D]):
|
||||
attr_match: str = None,
|
||||
count=True,
|
||||
override_schema=None,
|
||||
) -> Union[int, list[T]]:
|
||||
) -> Union[int, list[Schema]]: # sourcery skip: assign-if-exp
|
||||
eff_schema = override_schema or self.schema
|
||||
# attr_filter = getattr(self.sql_model, attribute_name)
|
||||
|
||||
q = self._query().filter(attribute_name == attr_match)
|
||||
|
||||
if count:
|
||||
return self.session.query(self.sql_model).filter(attribute_name == attr_match).count() # noqa: 711
|
||||
return q.count()
|
||||
else:
|
||||
return [
|
||||
eff_schema.from_orm(x)
|
||||
for x in self.session.query(self.sql_model).filter(attribute_name == attr_match).all() # noqa: 711
|
||||
]
|
||||
return [eff_schema.from_orm(x) for x in q.all()]
|
||||
|
||||
def create_many(self, documents: list[T | dict]) -> list[T]:
|
||||
new_documents = []
|
||||
for document in documents:
|
||||
document = document if isinstance(document, dict) else document.dict()
|
||||
new_document = self.sql_model(session=self.session, **document) # type: ignore
|
||||
new_documents.append(new_document)
|
||||
def pagination(self, pagination: PaginationQuery, override=None) -> PaginationBase[Schema]:
|
||||
"""
|
||||
pagination is a method to interact with the filtered database table and return a paginated result
|
||||
using the PaginationBase that provides several data points that are needed to manage pagination
|
||||
on the client side. This method does utilize the _filter_build method to ensure that the results
|
||||
are filtered by the user and group id when applicable.
|
||||
|
||||
self.session.add_all(new_documents)
|
||||
self.session.commit()
|
||||
self.session.refresh(new_documents)
|
||||
NOTE: When you provide an override you'll need to manually type the result of this method
|
||||
as the override, as the type system, is not able to infer the result of this method.
|
||||
"""
|
||||
eff_schema = override or self.schema
|
||||
|
||||
return [self.schema.from_orm(x) for x in new_documents]
|
||||
q = self.session.query(self.model)
|
||||
|
||||
fltr = self._filter_builder()
|
||||
q = q.filter_by(**fltr)
|
||||
|
||||
count = q.count()
|
||||
|
||||
if pagination.order_by:
|
||||
if order_attr := getattr(self.model, pagination.order_by, None):
|
||||
if pagination.order_direction == OrderDirection.asc:
|
||||
order_attr = order_attr.asc()
|
||||
elif pagination.order_direction == OrderDirection.desc:
|
||||
order_attr = order_attr.desc()
|
||||
|
||||
q = q.order_by(order_attr)
|
||||
|
||||
q = q.limit(pagination.per_page).offset((pagination.page - 1) * pagination.per_page)
|
||||
|
||||
try:
|
||||
data = q.all()
|
||||
except Exception as e:
|
||||
self._log_exception(e)
|
||||
self.session.rollback()
|
||||
raise e
|
||||
|
||||
return PaginationBase(
|
||||
page=pagination.page,
|
||||
per_page=pagination.per_page,
|
||||
total=count,
|
||||
total_pages=int(count / pagination.per_page) + 1,
|
||||
data=[eff_schema.from_orm(s) for s in data],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user