prs-fleshgolem-2070: feat: sqlalchemy 2.0 (#2096)

* upgrade sqlalchemy to 2.0

* rewrite all db models to sqla 2.0 mapping api

* fix some importing and typing weirdness

* fix types of a lot of nullable columns

* remove get_ref methods

* fix issues found by tests

* rewrite all queries in repository_recipe to 2.0 style

* rewrite all repository queries to 2.0 api

* rewrite all remaining queries to 2.0 api

* remove now-unneeded __allow_unmapped__ flag

* remove and fix some unneeded cases of "# type: ignore"

* fix formatting

* bump black version

* run black

* can this please be the last one. okay. just. okay.

* fix repository errors

* remove return

* drop open API validator

---------

Co-authored-by: Sören Busch <fleshgolem@gmx.net>
This commit is contained in:
Hayden
2023-02-06 18:43:12 -09:00
committed by GitHub
parent 91cd00976a
commit 9e77a9f367
86 changed files with 1776 additions and 1572 deletions

View File

@@ -6,21 +6,19 @@ from typing import Any, Generic, TypeVar
from fastapi import HTTPException
from pydantic import UUID4, BaseModel
from sqlalchemy import func
from sqlalchemy.orm import Query
from sqlalchemy import Select, delete, func, select
from sqlalchemy.orm.session import Session
from sqlalchemy.sql import sqltypes
from mealie.core.root_logger import get_logger
from mealie.schema.response.pagination import (
OrderDirection,
PaginationBase,
PaginationQuery,
)
from mealie.db.models._model_base import SqlAlchemyBase
from mealie.schema.response.pagination import OrderDirection, PaginationBase, PaginationQuery
from mealie.schema.response.query_filter import QueryFilter
Schema = TypeVar("Schema", bound=BaseModel)
Model = TypeVar("Model")
Model = TypeVar("Model", bound=SqlAlchemyBase)
T = TypeVar("T", bound="RepositoryGeneric")
class RepositoryGeneric(Generic[Schema, Model]):
@@ -33,6 +31,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
user_id: UUID4 | None = None
group_id: UUID4 | None = None
session: Session
def __init__(self, session: Session, primary_key: str, sql_model: type[Model], schema: type[Schema]) -> None:
self.session = session
@@ -42,11 +41,11 @@ class RepositoryGeneric(Generic[Schema, Model]):
self.logger = get_logger()
def by_user(self, user_id: UUID4) -> RepositoryGeneric[Schema, Model]:
def by_user(self: T, user_id: UUID4) -> T:
self.user_id = user_id
return self
def by_group(self, group_id: UUID4) -> RepositoryGeneric[Schema, Model]:
def by_group(self: T, group_id: UUID4) -> T:
self.group_id = group_id
return self
@@ -55,7 +54,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
self.logger.error(e)
def _query(self):
return self.session.query(self.model)
return select(self.model)
def _filter_builder(self, **kwargs) -> dict[str, Any]:
dct = {}
@@ -98,8 +97,8 @@ class RepositoryGeneric(Generic[Schema, Model]):
except AttributeError:
self.logger.info(f'Attempted to sort by unknown sort property "{order_by}"; ignoring')
return [eff_schema.from_orm(x) for x in q.offset(start).limit(limit).all()]
result = self.session.execute(q.offset(start).limit(limit)).scalars().all()
return [eff_schema.from_orm(x) for x in result]
def multi_query(
self,
@@ -120,7 +119,9 @@ class RepositoryGeneric(Generic[Schema, Model]):
order_attr = order_attr.desc()
q = q.order_by(order_attr)
return [eff_schema.from_orm(x) for x in q.offset(start).limit(limit).all()]
q = q.offset(start).limit(limit)
result = self.session.execute(q).scalars().all()
return [eff_schema.from_orm(x) for x in result]
def _query_one(self, match_value: str | int | UUID4, match_key: str | None = None) -> Model:
"""
@@ -131,14 +132,14 @@ class RepositoryGeneric(Generic[Schema, Model]):
match_key = self.primary_key
fltr = self._filter_builder(**{match_key: match_value})
return self._query().filter_by(**fltr).one()
return self.session.execute(self._query().filter_by(**fltr)).scalars().one()
def get_one(
self, value: str | int | UUID4, key: str | None = None, any_case=False, override_schema=None
) -> Schema | None:
key = key or self.primary_key
q = self.session.query(self.model)
q = self._query()
if any_case:
search_attr = getattr(self.model, key)
@@ -146,7 +147,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
else:
q = q.filter_by(**self._filter_builder(**{key: value}))
result = q.one_or_none()
result = self.session.execute(q).scalars().one_or_none()
if not result:
return None
@@ -156,7 +157,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
def create(self, data: Schema | BaseModel | dict) -> Schema:
data = data if isinstance(data, dict) else data.dict()
new_document = self.model(session=self.session, **data) # type: ignore
new_document = self.model(session=self.session, **data)
self.session.add(new_document)
self.session.commit()
self.session.refresh(new_document)
@@ -167,7 +168,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
new_documents = []
for document in data:
document = document if isinstance(document, dict) else document.dict()
new_document = self.model(session=self.session, **document) # type: ignore
new_document = self.model(session=self.session, **document)
new_documents.append(new_document)
self.session.add_all(new_documents)
@@ -191,7 +192,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
new_data = new_data if isinstance(new_data, dict) else new_data.dict()
entry = self._query_one(match_value=match_value)
entry.update(session=self.session, **new_data) # type: ignore
entry.update(session=self.session, **new_data)
self.session.commit()
return self.schema.from_orm(entry)
@@ -202,7 +203,8 @@ class RepositoryGeneric(Generic[Schema, Model]):
document_data = document if isinstance(document, dict) else document.dict()
document_data_by_id[document_data["id"]] = document_data
documents_to_update = self._query().filter(self.model.id.in_(list(document_data_by_id.keys()))) # type: ignore
documents_to_update_query = self._query().filter(self.model.id.in_(list(document_data_by_id.keys())))
documents_to_update = self.session.execute(documents_to_update_query).scalars().all()
updated_documents = []
for document_to_update in documents_to_update:
@@ -226,7 +228,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
def delete(self, value, match_key: str | None = None) -> Schema:
match_key = match_key or self.primary_key
result = self._query().filter_by(**{match_key: value}).one()
result = self.session.execute(self._query().filter_by(**{match_key: value})).scalars().one()
results_as_model = self.schema.from_orm(result)
try:
@@ -239,7 +241,8 @@ class RepositoryGeneric(Generic[Schema, Model]):
return results_as_model
def delete_many(self, values: Iterable) -> Schema:
results = self._query().filter(self.model.id.in_(values)) # type: ignore
query = self._query().filter(self.model.id.in_(values)) # type: ignore
results = self.session.execute(query).scalars().all()
results_as_model = [self.schema.from_orm(result) for result in results]
try:
@@ -256,14 +259,14 @@ class RepositoryGeneric(Generic[Schema, Model]):
return results_as_model # type: ignore
def delete_all(self) -> None:
self._query().delete()
delete(self.model)
self.session.commit()
def count_all(self, match_key=None, match_value=None) -> int:
if None in [match_key, match_value]:
return self._query().count()
else:
return self._query().filter_by(**{match_key: match_value}).count()
q = select(func.count(self.model.id))
if None not in [match_key, match_value]:
q = q.filter_by(**{match_key: match_value})
return self.session.scalar(q)
def _count_attribute(
self,
@@ -274,12 +277,12 @@ class RepositoryGeneric(Generic[Schema, Model]):
) -> int | list[Schema]: # sourcery skip: assign-if-exp
eff_schema = override_schema or self.schema
q = self._query().filter(attribute_name == attr_match)
if count:
return q.count()
q = select(func.count(self.model.id)).filter(attribute_name == attr_match)
return self.session.scalar(q)
else:
return [eff_schema.from_orm(x) for x in q.all()]
q = self._query().filter(attribute_name == attr_match)
return [eff_schema.from_orm(x) for x in self.session.execute(q).scalars().all()]
def page_all(self, pagination: PaginationQuery, override=None) -> PaginationBase[Schema]:
"""
@@ -293,14 +296,14 @@ class RepositoryGeneric(Generic[Schema, Model]):
"""
eff_schema = override or self.schema
q = self.session.query(self.model)
q = self._query()
fltr = self._filter_builder()
q = q.filter_by(**fltr)
q, count, total_pages = self.add_pagination_to_query(q, pagination)
try:
data = q.all()
data = self.session.execute(q).scalars().all()
except Exception as e:
self._log_exception(e)
self.session.rollback()
@@ -314,7 +317,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
items=[eff_schema.from_orm(s) for s in data],
)
def add_pagination_to_query(self, query: Query, pagination: PaginationQuery) -> tuple[Query, int, int]:
def add_pagination_to_query(self, query: Select, pagination: PaginationQuery) -> tuple[Select, int, int]:
"""
Adds pagination data to an existing query.
@@ -333,7 +336,8 @@ class RepositoryGeneric(Generic[Schema, Model]):
self.logger.error(e)
raise HTTPException(status_code=400, detail=str(e)) from e
count = query.count()
count_query = select(func.count()).select_from(query)
count = self.session.scalar(count_query)
# interpret -1 as "get_all"
if pagination.per_page == -1: