Remove all sqlalchemy lazy-loading from app (#2260)

* Remove some implicit lazy-loads from user serialization

* implement full backup restore across different database versions

* rework all custom getter dicts to not leak lazy loads

* remove some occurances of lazy-loading

* remove a lot of lazy loading from recipes

* add more eager loading
remove loading options from repository
remove raiseload for checking

* fix failing test

* do not apply loader options for paging counts

* try using selectinload a bit more instead of joinedload

* linter fixes
This commit is contained in:
Sören
2023-03-24 17:27:26 +01:00
committed by GitHub
parent fae62ecb19
commit 4b426ddf2f
23 changed files with 351 additions and 142 deletions

View File

@@ -7,16 +7,16 @@ from typing import Any, Generic, TypeVar
from fastapi import HTTPException
from pydantic import UUID4, BaseModel
from sqlalchemy import Select, delete, func, select
from sqlalchemy.orm.interfaces import LoaderOption
from sqlalchemy.orm.session import Session
from sqlalchemy.sql import sqltypes
from mealie.core.root_logger import get_logger
from mealie.db.models._model_base import SqlAlchemyBase
from mealie.schema._mealie import MealieModel
from mealie.schema.response.pagination import OrderDirection, PaginationBase, PaginationQuery
from mealie.schema.response.query_filter import QueryFilter
Schema = TypeVar("Schema", bound=BaseModel)
Schema = TypeVar("Schema", bound=MealieModel)
Model = TypeVar("Model", bound=SqlAlchemyBase)
T = TypeVar("T", bound="RepositoryGeneric")
@@ -54,8 +54,13 @@ class RepositoryGeneric(Generic[Schema, Model]):
self.logger.error(f"Error processing query for Repo model={self.model.__name__} schema={self.schema.__name__}")
self.logger.error(e)
def _query(self):
return select(self.model)
def _query(self, override_schema: type[MealieModel] | None = None, with_options=True):
q = select(self.model)
if with_options:
schema = override_schema or self.schema
return q.options(*schema.loader_options())
else:
return q
def _filter_builder(self, **kwargs) -> dict[str, Any]:
dct = {}
@@ -83,7 +88,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
fltr = self._filter_builder()
q = self._query().filter_by(**fltr)
q = self._query(override_schema=eff_schema).filter_by(**fltr)
if order_by:
try:
@@ -98,7 +103,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
except AttributeError:
self.logger.info(f'Attempted to sort by unknown sort property "{order_by}"; ignoring')
result = self.session.execute(q.offset(start).limit(limit)).scalars().all()
result = self.session.execute(q.offset(start).limit(limit)).unique().scalars().all()
return [eff_schema.from_orm(x) for x in result]
def multi_query(
@@ -113,7 +118,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
eff_schema = override_schema or self.schema
fltr = self._filter_builder(**query_by)
q = self._query().filter_by(**fltr)
q = self._query(override_schema=eff_schema).filter_by(**fltr)
if order_by:
if order_attr := getattr(self.model, str(order_by)):
@@ -121,7 +126,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
q = q.order_by(order_attr)
q = q.offset(start).limit(limit)
result = self.session.execute(q).scalars().all()
result = self.session.execute(q).unique().scalars().all()
return [eff_schema.from_orm(x) for x in result]
def _query_one(self, match_value: str | int | UUID4, match_key: str | None = None) -> Model:
@@ -133,14 +138,15 @@ class RepositoryGeneric(Generic[Schema, Model]):
match_key = self.primary_key
fltr = self._filter_builder(**{match_key: match_value})
return self.session.execute(self._query().filter_by(**fltr)).scalars().one()
return self.session.execute(self._query().filter_by(**fltr)).unique().scalars().one()
def get_one(
self, value: str | int | UUID4, key: str | None = None, any_case=False, override_schema=None
) -> Schema | None:
key = key or self.primary_key
eff_schema = override_schema or self.schema
q = self._query()
q = self._query(override_schema=eff_schema)
if any_case:
search_attr = getattr(self.model, key)
@@ -148,12 +154,11 @@ class RepositoryGeneric(Generic[Schema, Model]):
else:
q = q.filter_by(**self._filter_builder(**{key: value}))
result = self.session.execute(q).scalars().one_or_none()
result = self.session.execute(q).unique().scalars().one_or_none()
if not result:
return None
eff_schema = override_schema or self.schema
return eff_schema.from_orm(result)
def create(self, data: Schema | BaseModel | dict) -> Schema:
@@ -205,7 +210,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
document_data_by_id[document_data["id"]] = document_data
documents_to_update_query = self._query().filter(self.model.id.in_(list(document_data_by_id.keys())))
documents_to_update = self.session.execute(documents_to_update_query).scalars().all()
documents_to_update = self.session.execute(documents_to_update_query).unique().scalars().all()
updated_documents = []
for document_to_update in documents_to_update:
@@ -229,7 +234,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
def delete(self, value, match_key: str | None = None) -> Schema:
match_key = match_key or self.primary_key
result = self.session.execute(self._query().filter_by(**{match_key: value})).scalars().one()
result = self._query_one(value, match_key)
results_as_model = self.schema.from_orm(result)
try:
@@ -243,7 +248,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
def delete_many(self, values: Iterable) -> Schema:
query = self._query().filter(self.model.id.in_(values)) # type: ignore
results = self.session.execute(query).scalars().all()
results = self.session.execute(query).unique().scalars().all()
results_as_model = [self.schema.from_orm(result) for result in results]
try:
@@ -282,13 +287,9 @@ class RepositoryGeneric(Generic[Schema, Model]):
q = select(func.count(self.model.id)).filter(attribute_name == attr_match)
return self.session.scalar(q)
else:
q = self._query().filter(attribute_name == attr_match)
q = self._query(override_schema=eff_schema).filter(attribute_name == attr_match)
return [eff_schema.from_orm(x) for x in self.session.execute(q).scalars().all()]
def paging_query_options(self) -> list[LoaderOption]:
# Override this in subclasses to specify joinedloads or similar for page_all
return []
def page_all(self, pagination: PaginationQuery, override=None) -> PaginationBase[Schema]:
"""
pagination is a method to interact with the filtered database table and return a paginated result
@@ -301,12 +302,14 @@ class RepositoryGeneric(Generic[Schema, Model]):
"""
eff_schema = override or self.schema
q = self._query().options(*self.paging_query_options())
q = self._query(override_schema=eff_schema, with_options=False)
fltr = self._filter_builder()
q = q.filter_by(**fltr)
q, count, total_pages = self.add_pagination_to_query(q, pagination)
# Apply options late, so they do not get used for counting
q = q.options(*eff_schema.loader_options())
try:
data = self.session.execute(q).unique().scalars().all()
except Exception as e: