feat: Upgrade to Pydantic V2 (#3134)

* bumped pydantic
This commit is contained in:
Michael Genson
2024-02-11 10:47:37 -06:00
committed by GitHub
parent 248459671e
commit 7a107584c7
129 changed files with 1138 additions and 833 deletions

View File

@@ -106,7 +106,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
except AttributeError:
self.logger.info(f'Attempted to sort by unknown sort property "{order_by}"; ignoring')
result = self.session.execute(q.offset(start).limit(limit)).unique().scalars().all()
return [eff_schema.from_orm(x) for x in result]
return [eff_schema.model_validate(x) for x in result]
def multi_query(
self,
@@ -129,7 +129,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
q = q.offset(start).limit(limit)
result = self.session.execute(q).unique().scalars().all()
return [eff_schema.from_orm(x) for x in result]
return [eff_schema.model_validate(x) for x in result]
def _query_one(self, match_value: str | int | UUID4, match_key: str | None = None) -> Model:
"""
@@ -161,11 +161,11 @@ class RepositoryGeneric(Generic[Schema, Model]):
if not result:
return None
return eff_schema.from_orm(result)
return eff_schema.model_validate(result)
def create(self, data: Schema | BaseModel | dict) -> Schema:
try:
data = data if isinstance(data, dict) else data.dict()
data = data if isinstance(data, dict) else data.model_dump()
new_document = self.model(session=self.session, **data)
self.session.add(new_document)
self.session.commit()
@@ -175,12 +175,12 @@ class RepositoryGeneric(Generic[Schema, Model]):
self.session.refresh(new_document)
return self.schema.from_orm(new_document)
return self.schema.model_validate(new_document)
def create_many(self, data: Iterable[Schema | dict]) -> list[Schema]:
new_documents = []
for document in data:
document = document if isinstance(document, dict) else document.dict()
document = document if isinstance(document, dict) else document.model_dump()
new_document = self.model(session=self.session, **document)
new_documents.append(new_document)
@@ -190,7 +190,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
for created_document in new_documents:
self.session.refresh(created_document)
return [self.schema.from_orm(x) for x in new_documents]
return [self.schema.model_validate(x) for x in new_documents]
def update(self, match_value: str | int | UUID4, new_data: dict | BaseModel) -> Schema:
"""Update a database entry.
@@ -202,18 +202,18 @@ class RepositoryGeneric(Generic[Schema, Model]):
Returns:
dict: Returns a dictionary representation of the database entry
"""
new_data = new_data if isinstance(new_data, dict) else new_data.dict()
new_data = new_data if isinstance(new_data, dict) else new_data.model_dump()
entry = self._query_one(match_value=match_value)
entry.update(session=self.session, **new_data)
self.session.commit()
return self.schema.from_orm(entry)
return self.schema.model_validate(entry)
def update_many(self, data: Iterable[Schema | dict]) -> list[Schema]:
document_data_by_id: dict[str, dict] = {}
for document in data:
document_data = document if isinstance(document, dict) else document.dict()
document_data = document if isinstance(document, dict) else document.model_dump()
document_data_by_id[document_data["id"]] = document_data
documents_to_update_query = self._query().filter(self.model.id.in_(list(document_data_by_id.keys())))
@@ -226,14 +226,14 @@ class RepositoryGeneric(Generic[Schema, Model]):
updated_documents.append(document_to_update)
self.session.commit()
return [self.schema.from_orm(x) for x in updated_documents]
return [self.schema.model_validate(x) for x in updated_documents]
def patch(self, match_value: str | int | UUID4, new_data: dict | BaseModel) -> Schema:
new_data = new_data if isinstance(new_data, dict) else new_data.dict()
new_data = new_data if isinstance(new_data, dict) else new_data.model_dump()
entry = self._query_one(match_value=match_value)
entry_as_dict = self.schema.from_orm(entry).dict()
entry_as_dict = self.schema.model_validate(entry).model_dump()
entry_as_dict.update(new_data)
return self.update(match_value, entry_as_dict)
@@ -242,7 +242,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
match_key = match_key or self.primary_key
result = self._query_one(value, match_key)
results_as_model = self.schema.from_orm(result)
results_as_model = self.schema.model_validate(result)
try:
self.session.delete(result)
@@ -256,7 +256,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
def delete_many(self, values: Iterable) -> Schema:
query = self._query().filter(self.model.id.in_(values)) # type: ignore
results = self.session.execute(query).unique().scalars().all()
results_as_model = [self.schema.from_orm(result) for result in results]
results_as_model = [self.schema.model_validate(result) for result in results]
try:
# we create a delete statement for each row
@@ -295,7 +295,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
return self.session.scalar(q)
else:
q = self._query(override_schema=eff_schema).filter(attribute_name == attr_match)
return [eff_schema.from_orm(x) for x in self.session.execute(q).scalars().all()]
return [eff_schema.model_validate(x) for x in self.session.execute(q).scalars().all()]
def page_all(self, pagination: PaginationQuery, override=None, search: str | None = None) -> PaginationBase[Schema]:
"""
@@ -309,7 +309,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
"""
eff_schema = override or self.schema
# Copy this, because calling methods (e.g. tests) might rely on it not getting mutated
pagination_result = pagination.copy()
pagination_result = pagination.model_copy()
q = self._query(override_schema=eff_schema, with_options=False)
fltr = self._filter_builder()
@@ -336,7 +336,7 @@ class RepositoryGeneric(Generic[Schema, Model]):
per_page=pagination_result.per_page,
total=count,
total_pages=total_pages,
items=[eff_schema.from_orm(s) for s in data],
items=[eff_schema.model_validate(s) for s in data],
)
def add_pagination_to_query(self, query: Select, pagination: PaginationQuery) -> tuple[Select, int, int]:

View File

@@ -23,7 +23,7 @@ from .repository_generic import RepositoryGeneric
class RepositoryGroup(RepositoryGeneric[GroupInDB, Group]):
def create(self, data: GroupBase | dict) -> GroupInDB:
if isinstance(data, GroupBase):
data = data.dict()
data = data.model_dump()
max_attempts = 10
original_name = cast(str, data["name"])
@@ -61,7 +61,7 @@ class RepositoryGroup(RepositoryGeneric[GroupInDB, Group]):
dbgroup = self.session.execute(select(self.model).filter_by(name=name)).scalars().one_or_none()
if dbgroup is None:
return None
return self.schema.from_orm(dbgroup)
return self.schema.model_validate(dbgroup)
def get_by_slug_or_id(self, slug_or_id: str | UUID) -> GroupInDB | None:
if isinstance(slug_or_id, str):

View File

@@ -28,4 +28,4 @@ class RepositoryMealPlanRules(RepositoryGeneric[PlanRulesOut, GroupMealPlanRules
rules = self.session.execute(stmt).scalars().all()
return [self.schema.from_orm(x) for x in rules]
return [self.schema.model_validate(x) for x in rules]

View File

@@ -17,4 +17,4 @@ class RepositoryMeals(RepositoryGeneric[ReadPlanEntry, GroupMealPlan]):
today = date.today()
stmt = select(GroupMealPlan).filter(GroupMealPlan.date == today, GroupMealPlan.group_id == group_id)
plans = self.session.execute(stmt).scalars().all()
return [self.schema.from_orm(x) for x in plans]
return [self.schema.model_validate(x) for x in plans]

View File

@@ -58,7 +58,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
.offset(start)
.limit(limit)
)
return [eff_schema.from_orm(x) for x in self.session.execute(stmt).scalars().all()]
return [eff_schema.model_validate(x) for x in self.session.execute(stmt).scalars().all()]
stmt = (
select(self.model)
@@ -67,7 +67,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
.offset(start)
.limit(limit)
)
return [eff_schema.from_orm(x) for x in self.session.execute(stmt).scalars().all()]
return [eff_schema.model_validate(x) for x in self.session.execute(stmt).scalars().all()]
def update_image(self, slug: str, _: str | None = None) -> int:
entry: RecipeModel = self._query_one(match_value=slug)
@@ -160,7 +160,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
search: str | None = None,
) -> RecipePagination:
# Copy this, because calling methods (e.g. tests) might rely on it not getting mutated
pagination_result = pagination.copy()
pagination_result = pagination.model_copy()
q = select(self.model)
args = [
@@ -216,7 +216,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
self.session.rollback()
raise e
items = [RecipeSummary.from_orm(item) for item in data]
items = [RecipeSummary.model_validate(item) for item in data]
return RecipePagination(
page=pagination_result.page,
per_page=pagination_result.per_page,
@@ -236,7 +236,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
.join(RecipeModel.recipe_category)
.filter(RecipeModel.recipe_category.any(Category.id.in_(ids)))
)
return [RecipeSummary.from_orm(x) for x in self.session.execute(stmt).unique().scalars().all()]
return [RecipeSummary.model_validate(x) for x in self.session.execute(stmt).unique().scalars().all()]
def _build_recipe_filter(
self,
@@ -298,7 +298,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
require_all_tools=require_all_tools,
)
stmt = select(RecipeModel).filter(*fltr)
return [self.schema.from_orm(x) for x in self.session.execute(stmt).scalars().all()]
return [self.schema.model_validate(x) for x in self.session.execute(stmt).scalars().all()]
def get_random_by_categories_and_tags(
self, categories: list[RecipeCategory], tags: list[RecipeTag]
@@ -316,7 +316,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
stmt = (
select(RecipeModel).filter(and_(*filters)).order_by(func.random()).limit(1) # Postgres and SQLite specific
)
return [self.schema.from_orm(x) for x in self.session.execute(stmt).scalars().all()]
return [self.schema.model_validate(x) for x in self.session.execute(stmt).scalars().all()]
def get_random(self, limit=1) -> list[Recipe]:
stmt = (
@@ -325,14 +325,14 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
.order_by(func.random()) # Postgres and SQLite specific
.limit(limit)
)
return [self.schema.from_orm(x) for x in self.session.execute(stmt).scalars().all()]
return [self.schema.model_validate(x) for x in self.session.execute(stmt).scalars().all()]
def get_by_slug(self, group_id: UUID4, slug: str, limit=1) -> Recipe | None:
stmt = select(RecipeModel).filter(RecipeModel.group_id == group_id, RecipeModel.slug == slug)
dbrecipe = self.session.execute(stmt).scalars().one_or_none()
if dbrecipe is None:
return None
return self.schema.from_orm(dbrecipe)
return self.schema.model_validate(dbrecipe)
def all_ids(self, group_id: UUID4) -> Sequence[UUID4]:
stmt = select(RecipeModel.id).filter(RecipeModel.group_id == group_id)

View File

@@ -18,7 +18,7 @@ class RepositoryUsers(RepositoryGeneric[PrivateUser, User]):
def update_password(self, id, password: str):
entry = self._query_one(match_value=id)
if settings.IS_DEMO:
user_to_update = self.schema.from_orm(entry)
user_to_update = self.schema.model_validate(entry)
if user_to_update.is_default_user:
# do not update the default user in demo mode
return user_to_update
@@ -26,7 +26,7 @@ class RepositoryUsers(RepositoryGeneric[PrivateUser, User]):
entry.update_password(password)
self.session.commit()
return self.schema.from_orm(entry)
return self.schema.model_validate(entry)
def create(self, user: PrivateUser | dict): # type: ignore
new_user = super().create(user)
@@ -66,9 +66,9 @@ class RepositoryUsers(RepositoryGeneric[PrivateUser, User]):
def get_by_username(self, username: str) -> PrivateUser | None:
stmt = select(User).filter(User.username == username)
dbuser = self.session.execute(stmt).scalars().one_or_none()
return None if dbuser is None else self.schema.from_orm(dbuser)
return None if dbuser is None else self.schema.model_validate(dbuser)
def get_locked_users(self) -> list[PrivateUser]:
stmt = select(User).filter(User.locked_at != None) # noqa E711
results = self.session.execute(stmt).scalars().all()
return [self.schema.from_orm(x) for x in results]
return [self.schema.model_validate(x) for x in results]