improve developer tooling (backend) (#1051)

* add basic pre-commit file

* add flake8

* add isort

* add pep585-upgrade (typing upgrades)

* use namespace for import

* add mypy

* update ci for backend

* flake8 scope

* fix version format

* update makefile

* disable strict option (temporary)

* fix mypy issues

* upgrade type hints (pre-commit)

* add vscode typing check

* add types to dev deps

* remote container draft

* update setup script

* update compose version

* run setup on create

* dev containers update

* remove unused pages

* update setup tips

* expose ports

* Update pre-commit to include flask8-print (#1053)

* Add in flake8-print to pre-commit

* pin version of flake8-print

* formatting

* update getting strated docs

* add mypy to pre-commit

* purge .mypy_cache on clean

* drop mypy

Co-authored-by: zackbcom <zackbcom@users.noreply.github.com>
This commit is contained in:
Hayden
2022-03-15 15:01:56 -08:00
committed by GitHub
parent e109391e9a
commit 3c2744a3da
105 changed files with 723 additions and 437 deletions

View File

@@ -1,5 +1,6 @@
import shutil
import tempfile
from collections.abc import AsyncGenerator, Callable, Generator
from pathlib import Path
from typing import Optional
from uuid import uuid4
@@ -94,10 +95,11 @@ def validate_long_live_token(session: Session, client_token: str, id: int) -> Pr
tokens: list[LongLiveTokenInDB] = repos.api_tokens.get(id, "user_id", limit=9999)
for token in tokens:
token: LongLiveTokenInDB
if token.token == client_token:
return token.user
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Token")
def validate_file_token(token: Optional[str] = None) -> Path:
credentials_exception = HTTPException(
@@ -133,7 +135,7 @@ def validate_recipe_token(token: Optional[str] = None) -> str:
return slug
async def temporary_zip_path() -> Path:
async def temporary_zip_path() -> AsyncGenerator[Path, None]:
app_dirs.TEMP_DIR.mkdir(exist_ok=True, parents=True)
temp_path = app_dirs.TEMP_DIR.joinpath("my_zip_archive.zip")
@@ -143,7 +145,7 @@ async def temporary_zip_path() -> Path:
temp_path.unlink(missing_ok=True)
async def temporary_dir() -> Path:
async def temporary_dir() -> AsyncGenerator[Path, None]:
temp_path = app_dirs.TEMP_DIR.joinpath(uuid4().hex)
temp_path.mkdir(exist_ok=True, parents=True)
@@ -153,12 +155,12 @@ async def temporary_dir() -> Path:
shutil.rmtree(temp_path)
def temporary_file(ext: str = "") -> Path:
def temporary_file(ext: str = "") -> Callable[[], Generator[tempfile._TemporaryFileWrapper, None, None]]:
"""
Returns a temporary file with the specified extension
"""
def func() -> Path:
def func():
temp_path = app_dirs.TEMP_DIR.joinpath(uuid4().hex + ext)
temp_path.touch()

View File

@@ -20,7 +20,7 @@ class LoggerConfig:
format: str
date_format: str
logger_file: str
level: str = logging.INFO
level: int = logging.INFO
@lru_cache

View File

@@ -36,7 +36,7 @@ def create_recipe_slug_token(file_path: str) -> str:
return create_access_token(token_data, expires_delta=timedelta(minutes=30))
def user_from_ldap(db: AllRepositories, session, username: str, password: str) -> PrivateUser:
def user_from_ldap(db: AllRepositories, session, username: str, password: str) -> PrivateUser | bool:
"""Given a username and password, tries to authenticate by BINDing to an
LDAP server

View File

@@ -35,7 +35,7 @@ class PostgresProvider(AbstractDBProvider, BaseSettings):
POSTGRES_USER: str = "mealie"
POSTGRES_PASSWORD: str = "mealie"
POSTGRES_SERVER: str = "postgres"
POSTGRES_PORT: str = 5432
POSTGRES_PORT: str = "5432"
POSTGRES_DB: str = "mealie"
@property

View File

@@ -2,7 +2,7 @@ import secrets
from pathlib import Path
from typing import Optional
from pydantic import BaseSettings
from pydantic import BaseSettings, NoneStr
from .db_providers import AbstractDBProvider, db_provider_factory
@@ -33,26 +33,26 @@ class AppSettings(BaseSettings):
SECRET: str
@property
def DOCS_URL(self) -> str:
def DOCS_URL(self) -> str | None:
return "/docs" if self.API_DOCS else None
@property
def REDOC_URL(self) -> str:
def REDOC_URL(self) -> str | None:
return "/redoc" if self.API_DOCS else None
# ===============================================
# Database Configuration
DB_ENGINE: str = "sqlite" # Options: 'sqlite', 'postgres'
DB_PROVIDER: AbstractDBProvider = None
DB_PROVIDER: Optional[AbstractDBProvider] = None
@property
def DB_URL(self) -> str:
return self.DB_PROVIDER.db_url
def DB_URL(self) -> str | None:
return self.DB_PROVIDER.db_url if self.DB_PROVIDER else None
@property
def DB_URL_PUBLIC(self) -> str:
return self.DB_PROVIDER.db_url_public
def DB_URL_PUBLIC(self) -> str | None:
return self.DB_PROVIDER.db_url_public if self.DB_PROVIDER else None
DEFAULT_GROUP: str = "Home"
DEFAULT_EMAIL: str = "changeme@email.com"
@@ -88,9 +88,9 @@ class AppSettings(BaseSettings):
# LDAP Configuration
LDAP_AUTH_ENABLED: bool = False
LDAP_SERVER_URL: str = None
LDAP_BIND_TEMPLATE: str = None
LDAP_ADMIN_FILTER: str = None
LDAP_SERVER_URL: NoneStr = None
LDAP_BIND_TEMPLATE: NoneStr = None
LDAP_ADMIN_FILTER: NoneStr = None
@property
def LDAP_ENABLED(self) -> bool:

View File

@@ -24,7 +24,7 @@ def sql_global_init(db_url: str):
return SessionLocal, engine
SessionLocal, engine = sql_global_init(settings.DB_URL)
SessionLocal, engine = sql_global_init(settings.DB_URL) # type: ignore
def create_session() -> Session:

View File

@@ -1,5 +1,5 @@
from collections.abc import Callable
from pathlib import Path
from typing import Callable
from sqlalchemy import engine

View File

@@ -1,5 +1,5 @@
from .group import *
from .labels import *
from .recipe.recipe import *
from .recipe.recipe import * # type: ignore
from .server import *
from .users import *

View File

@@ -24,7 +24,7 @@ class BaseMixins:
@classmethod
def get_ref(cls, match_value: str, match_attr: str = None, session: Session = None):
match_attr = match_attr = cls.Config.get_attr
match_attr = match_attr or cls.Config.get_attr # type: ignore
if match_value is None or session is None:
return None

View File

@@ -1,7 +1,7 @@
from functools import wraps
from uuid import UUID
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, NoneStr
from sqlalchemy.orm import MANYTOMANY, MANYTOONE, ONETOMANY, Session
from sqlalchemy.orm.decl_api import DeclarativeMeta
from sqlalchemy.orm.mapper import Mapper
@@ -21,7 +21,7 @@ class AutoInitConfig(BaseModel):
Config class for `auto_init` decorator.
"""
get_attr: str = None
get_attr: NoneStr = None
exclude: set = Field(default_factory=_default_exclusion)
# auto_create: bool = False
@@ -83,12 +83,14 @@ def handle_one_to_many_list(session: Session, get_attr, relation_cls, all_elemen
elem_id = elem.get(get_attr, None) if isinstance(elem, dict) else elem
existing_elem = session.query(relation_cls).filter_by(**{get_attr: elem_id}).one_or_none()
if existing_elem is None:
elems_to_create.append(elem)
is_dict = isinstance(elem, dict)
if existing_elem is None and is_dict:
elems_to_create.append(elem) # type: ignore
continue
elif isinstance(elem, dict):
for key, value in elem.items():
elif is_dict:
for key, value in elem.items(): # type: ignore
if key not in cfg.exclude:
setattr(existing_elem, key, value)

View File

@@ -1,5 +1,6 @@
import inspect
from typing import Any, Callable
from collections.abc import Callable
from typing import Any
def get_valid_call(func: Callable, args_dict) -> dict:
@@ -8,7 +9,7 @@ def get_valid_call(func: Callable, args_dict) -> dict:
the original dictionary will be returned.
"""
def get_valid_args(func: Callable) -> tuple:
def get_valid_args(func: Callable) -> list[str]:
"""
Returns a tuple of valid arguemnts for the supplied function.
"""

View File

@@ -78,8 +78,8 @@ class Group(SqlAlchemyBase, BaseMixins):
def __init__(self, **_) -> None:
pass
@staticmethod
def get_ref(session: Session, name: str):
@staticmethod # TODO: Remove this
def get_ref(session: Session, name: str): # type: ignore
settings = get_app_settings()
item = session.query(Group).filter(Group.name == name).one_or_none()

View File

@@ -63,8 +63,8 @@ class Category(SqlAlchemyBase, BaseMixins):
self.name = name.strip()
self.slug = slugify(name)
@classmethod
def get_ref(cls, match_value: str, session=None):
@classmethod # TODO: Remove this
def get_ref(cls, match_value: str, session=None): # type: ignore
if not session or not match_value:
return None
@@ -76,4 +76,4 @@ class Category(SqlAlchemyBase, BaseMixins):
return result
else:
logger.debug("Category doesn't exists, creating Category")
return Category(name=match_value)
return Category(name=match_value) # type: ignore

View File

@@ -22,5 +22,5 @@ class RecipeComment(SqlAlchemyBase, BaseMixins):
def __init__(self, **_) -> None:
pass
def update(self, text, **_) -> None:
def update(self, text, **_) -> None: # type: ignore
self.text = text

View File

@@ -1,5 +1,4 @@
import datetime
from datetime import date
import sqlalchemy as sa
import sqlalchemy.orm as orm
@@ -107,7 +106,7 @@ class RecipeModel(SqlAlchemyBase, BaseMixins):
extras: list[ApiExtras] = orm.relationship("ApiExtras", cascade="all, delete-orphan")
# Time Stamp Properties
date_added = sa.Column(sa.Date, default=date.today)
date_added = sa.Column(sa.Date, default=datetime.date.today)
date_updated = sa.Column(sa.DateTime)
# Shopping List Refs

View File

@@ -50,8 +50,8 @@ class Tag(SqlAlchemyBase, BaseMixins):
self.name = name.strip()
self.slug = slugify(self.name)
@classmethod
def get_ref(cls, match_value: str, session=None):
@classmethod # TODO: Remove this
def get_ref(cls, match_value: str, session=None): # type: ignore
if not session or not match_value:
return None
@@ -62,4 +62,4 @@ class Tag(SqlAlchemyBase, BaseMixins):
return result
else:
logger.debug("Category doesn't exists, creating Category")
return Tag(name=match_value)
return Tag(name=match_value) # type: ignore

View File

@@ -124,6 +124,6 @@ class User(SqlAlchemyBase, BaseMixins):
self.can_invite = can_invite
self.can_organize = can_organize
@staticmethod
def get_ref(session, id: str):
@staticmethod # TODO: Remove This
def get_ref(session, id: str): # type: ignore
return session.query(User).filter(User.id == id).one()

View File

@@ -19,11 +19,11 @@ def get_format(image: Path) -> str:
def sizeof_fmt(file_path: Path, decimal_places=2):
if not file_path.exists():
return "(File Not Found)"
size = file_path.stat().st_size
size: int | float = file_path.stat().st_size
for unit in ["B", "kB", "MB", "GB", "TB", "PB"]:
if size < 1024.0 or unit == "PiB":
if size < 1024 or unit == "PiB":
break
size /= 1024.0
size /= 1024
return f"{size:.{decimal_places}f} {unit}"

View File

@@ -1,5 +1,5 @@
from typing import Any, Callable, Generic, TypeVar, Union
from uuid import UUID
from collections.abc import Callable
from typing import Any, Generic, TypeVar, Union
from pydantic import UUID4, BaseModel
from sqlalchemy import func
@@ -18,7 +18,7 @@ class RepositoryGeneric(Generic[T, D]):
Generic ([D]): Represents the SqlAlchemyModel Model
"""
def __init__(self, session: Session, primary_key: Union[str, int], sql_model: D, schema: T) -> None:
def __init__(self, session: Session, primary_key: str, sql_model: type[D], schema: type[T]) -> None:
self.session = session
self.primary_key = primary_key
self.sql_model = sql_model
@@ -26,10 +26,10 @@ class RepositoryGeneric(Generic[T, D]):
self.observers: list = []
self.limit_by_group = False
self.user_id = None
self.user_id: UUID4 = None
self.limit_by_user = False
self.group_id = None
self.group_id: UUID4 = None
def subscribe(self, func: Callable) -> None:
self.observers.append(func)
@@ -39,7 +39,7 @@ class RepositoryGeneric(Generic[T, D]):
self.user_id = user_id
return self
def by_group(self, group_id: UUID) -> "RepositoryGeneric[T, D]":
def by_group(self, group_id: UUID4) -> "RepositoryGeneric[T, D]":
self.limit_by_group = True
self.group_id = group_id
return self
@@ -88,7 +88,7 @@ class RepositoryGeneric(Generic[T, D]):
def multi_query(
self,
query_by: dict[str, str],
query_by: dict[str, str | bool | int | UUID4],
start=0,
limit: int = None,
override_schema=None,
@@ -152,7 +152,7 @@ class RepositoryGeneric(Generic[T, D]):
filter = self._filter_builder(**{match_key: match_value})
return self.session.query(self.sql_model).filter_by(**filter).one()
def get_one(self, value: str | int | UUID4, key: str = None, any_case=False, override_schema=None) -> T:
def get_one(self, value: str | int | UUID4, key: str = None, any_case=False, override_schema=None) -> T | None:
key = key or self.primary_key
q = self.session.query(self.sql_model)
@@ -166,14 +166,14 @@ class RepositoryGeneric(Generic[T, D]):
result = q.one_or_none()
if not result:
return
return None
eff_schema = override_schema or self.schema
return eff_schema.from_orm(result)
def get(
self, match_value: str | int | UUID4, match_key: str = None, limit=1, any_case=False, override_schema=None
) -> T | list[T]:
) -> T | list[T] | None:
"""Retrieves an entry from the database by matching a key/value pair. If no
key is provided the class objects primary key will be used to match against.
@@ -193,7 +193,7 @@ class RepositoryGeneric(Generic[T, D]):
search_attr = getattr(self.sql_model, match_key)
result = (
self.session.query(self.sql_model)
.filter(func.lower(search_attr) == match_value.lower())
.filter(func.lower(search_attr) == match_value.lower()) # type: ignore
.limit(limit)
.all()
)
@@ -210,7 +210,7 @@ class RepositoryGeneric(Generic[T, D]):
return [eff_schema.from_orm(x) for x in result]
def create(self, document: T) -> T:
def create(self, document: T | BaseModel) -> T:
"""Creates a new database entry for the given SQL Alchemy Model.
Args:
@@ -221,7 +221,7 @@ class RepositoryGeneric(Generic[T, D]):
dict: A dictionary representation of the database entry
"""
document = document if isinstance(document, dict) else document.dict()
new_document = self.sql_model(session=self.session, **document)
new_document = self.sql_model(session=self.session, **document) # type: ignore
self.session.add(new_document)
self.session.commit()
self.session.refresh(new_document)
@@ -231,7 +231,7 @@ class RepositoryGeneric(Generic[T, D]):
return self.schema.from_orm(new_document)
def update(self, match_value: str | int | UUID4, new_data: dict) -> T:
def update(self, match_value: str | int | UUID4, new_data: dict | BaseModel) -> T:
"""Update a database entry.
Args:
session (Session): Database Session
@@ -244,7 +244,7 @@ class RepositoryGeneric(Generic[T, D]):
new_data = new_data if isinstance(new_data, dict) else new_data.dict()
entry = self._query_one(match_value=match_value)
entry.update(session=self.session, **new_data)
entry.update(session=self.session, **new_data) # type: ignore
if self.observers:
self.update_observers()
@@ -252,13 +252,14 @@ class RepositoryGeneric(Generic[T, D]):
self.session.commit()
return self.schema.from_orm(entry)
def patch(self, match_value: str | int | UUID4, new_data: dict) -> T:
def patch(self, match_value: str | int | UUID4, new_data: dict | BaseModel) -> T | None:
new_data = new_data if isinstance(new_data, dict) else new_data.dict()
entry = self._query_one(match_value=match_value)
if not entry:
return
# TODO: Should raise exception
return None
entry_as_dict = self.schema.from_orm(entry).dict()
entry_as_dict.update(new_data)
@@ -300,7 +301,7 @@ class RepositoryGeneric(Generic[T, D]):
attr_match: str = None,
count=True,
override_schema=None,
) -> Union[int, T]:
) -> Union[int, list[T]]:
eff_schema = override_schema or self.schema
# attr_filter = getattr(self.sql_model, attribute_name)
@@ -316,7 +317,7 @@ class RepositoryGeneric(Generic[T, D]):
new_documents = []
for document in documents:
document = document if isinstance(document, dict) else document.dict()
new_document = self.sql_model(session=self.session, **document)
new_document = self.sql_model(session=self.session, **document) # type: ignore
new_documents.append(new_document)
self.session.add_all(new_documents)

View File

@@ -10,7 +10,7 @@ from .repository_generic import RepositoryGeneric
class RepositoryMealPlanRules(RepositoryGeneric[PlanRulesOut, GroupMealPlanRules]):
def by_group(self, group_id: UUID) -> "RepositoryMealPlanRules":
return super().by_group(group_id)
return super().by_group(group_id) # type: ignore
def get_rules(self, day: PlanRulesDay, entry_type: PlanRulesType) -> list[PlanRulesOut]:
qry = self.session.query(GroupMealPlanRules).filter(

View File

@@ -9,10 +9,10 @@ from .repository_generic import RepositoryGeneric
class RepositoryMeals(RepositoryGeneric[ReadPlanEntry, GroupMealPlan]):
def get_slice(self, start: date, end: date, group_id: UUID) -> list[ReadPlanEntry]:
start = start.strftime("%Y-%m-%d")
end = end.strftime("%Y-%m-%d")
start_str = start.strftime("%Y-%m-%d")
end_str = end.strftime("%Y-%m-%d")
qry = self.session.query(GroupMealPlan).filter(
GroupMealPlan.date.between(start, end),
GroupMealPlan.date.between(start_str, end_str),
GroupMealPlan.group_id == group_id,
)

View File

@@ -18,7 +18,7 @@ from .repository_generic import RepositoryGeneric
class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
def by_group(self, group_id: UUID) -> "RepositoryRecipes":
return super().by_group(group_id)
return super().by_group(group_id) # type: ignore
def get_all_public(self, limit: int = None, order_by: str = None, start=0, override_schema=None):
eff_schema = override_schema or self.schema
@@ -47,14 +47,14 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
.all()
]
def update_image(self, slug: str, _: str = None) -> str:
def update_image(self, slug: str, _: str = None) -> int:
entry: RecipeModel = self._query_one(match_value=slug)
entry.image = randint(0, 255)
self.session.commit()
return entry.image
def count_uncategorized(self, count=True, override_schema=None) -> int:
def count_uncategorized(self, count=True, override_schema=None):
return self._count_attribute(
attribute_name=RecipeModel.recipe_category,
attr_match=None,
@@ -62,7 +62,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
override_schema=override_schema,
)
def count_untagged(self, count=True, override_schema=None) -> int:
def count_untagged(self, count=True, override_schema=None):
return self._count_attribute(
attribute_name=RecipeModel.tags,
attr_match=None,
@@ -105,7 +105,9 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]):
.all()
]
def get_random_by_categories_and_tags(self, categories: list[RecipeCategory], tags: list[RecipeTag]) -> Recipe:
def get_random_by_categories_and_tags(
self, categories: list[RecipeCategory], tags: list[RecipeTag]
) -> list[Recipe]:
"""
get_random_by_categories returns a single random Recipe that contains every category provided
in the list. This uses a function built in to Postgres and SQLite to get a random row limited

View File

@@ -7,5 +7,5 @@ from .repository_generic import RepositoryGeneric
class RepositoryShoppingList(RepositoryGeneric[ShoppingListOut, ShoppingList]):
def update(self, item_id: UUID4, data: ShoppingListUpdate) -> ShoppingListOut:
def update(self, item_id: UUID4, data: ShoppingListUpdate) -> ShoppingListOut: # type: ignore
return super().update(item_id, data)

View File

@@ -16,7 +16,7 @@ class RepositoryUsers(RepositoryGeneric[PrivateUser, User]):
return self.schema.from_orm(entry)
def create(self, user: PrivateUser):
def create(self, user: PrivateUser | dict):
new_user = super().create(user)
# Select Random Image

View File

@@ -1,5 +1,5 @@
import json
from typing import Generator
from collections.abc import Generator
from mealie.schema.labels import MultiPurposeLabelSave
from mealie.schema.recipe.recipe_ingredient import SaveIngredientFood, SaveIngredientUnit

View File

@@ -1,6 +1,5 @@
from abc import ABC
from functools import cached_property
from typing import Type
from fastapi import Depends
@@ -29,7 +28,7 @@ class BaseUserController(ABC):
deps: SharedDependencies = Depends(SharedDependencies.user)
def registered_exceptions(self, ex: Type[Exception]) -> str:
def registered_exceptions(self, ex: type[Exception]) -> str:
registered = {
**mealie_registered_exceptions(self.deps.t),
}

View File

@@ -4,7 +4,8 @@ This file contains code taken from fastapi-utils project. The code is licensed u
See their repository for details -> https://github.com/dmontagu/fastapi-utils
"""
import inspect
from typing import Any, Callable, List, Tuple, Type, TypeVar, Union, cast, get_type_hints
from collections.abc import Callable
from typing import Any, TypeVar, Union, cast, get_type_hints
from fastapi import APIRouter, Depends
from fastapi.routing import APIRoute
@@ -18,7 +19,7 @@ INCLUDE_INIT_PARAMS_KEY = "__include_init_params__"
RETURN_TYPES_FUNC_KEY = "__return_types_func__"
def controller(router: APIRouter, *urls: str) -> Callable[[Type[T]], Type[T]]:
def controller(router: APIRouter, *urls: str) -> Callable[[type[T]], type[T]]:
"""
This function returns a decorator that converts the decorated into a class-based view for the provided router.
Any methods of the decorated class that are decorated as endpoints using the router provided to this function
@@ -28,14 +29,14 @@ def controller(router: APIRouter, *urls: str) -> Callable[[Type[T]], Type[T]]:
https://fastapi-utils.davidmontague.xyz/user-guide/class-based-views/#the-cbv-decorator
"""
def decorator(cls: Type[T]) -> Type[T]:
def decorator(cls: type[T]) -> type[T]:
# Define cls as cbv class exclusively when using the decorator
return _cbv(router, cls, *urls)
return decorator
def _cbv(router: APIRouter, cls: Type[T], *urls: str, instance: Any = None) -> Type[T]:
def _cbv(router: APIRouter, cls: type[T], *urls: str, instance: Any = None) -> type[T]:
"""
Replaces any methods of the provided class `cls` that are endpoints of routes in `router` with updated
function calls that will properly inject an instance of `cls`.
@@ -45,7 +46,7 @@ def _cbv(router: APIRouter, cls: Type[T], *urls: str, instance: Any = None) -> T
return cls
def _init_cbv(cls: Type[Any], instance: Any = None) -> None:
def _init_cbv(cls: type[Any], instance: Any = None) -> None:
"""
Idempotently modifies the provided `cls`, performing the following modifications:
* The `__init__` function is updated to set any class-annotated dependencies as instance attributes
@@ -60,7 +61,7 @@ def _init_cbv(cls: Type[Any], instance: Any = None) -> None:
x for x in old_parameters if x.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
]
dependency_names: List[str] = []
dependency_names: list[str] = []
for name, hint in get_type_hints(cls).items():
if is_classvar(hint):
continue
@@ -88,7 +89,7 @@ def _init_cbv(cls: Type[Any], instance: Any = None) -> None:
setattr(cls, CBV_CLASS_KEY, True)
def _register_endpoints(router: APIRouter, cls: Type[Any], *urls: str) -> None:
def _register_endpoints(router: APIRouter, cls: type[Any], *urls: str) -> None:
cbv_router = APIRouter()
function_members = inspect.getmembers(cls, inspect.isfunction)
for url in urls:
@@ -97,7 +98,7 @@ def _register_endpoints(router: APIRouter, cls: Type[Any], *urls: str) -> None:
for route in router.routes:
assert isinstance(route, APIRoute)
route_methods: Any = route.methods
cast(Tuple[Any], route_methods)
cast(tuple[Any], route_methods)
router_roles.append((route.path, tuple(route_methods)))
if len(set(router_roles)) != len(router_roles):
@@ -110,7 +111,7 @@ def _register_endpoints(router: APIRouter, cls: Type[Any], *urls: str) -> None:
}
prefix_length = len(router.prefix)
routes_to_append: List[Tuple[int, Union[Route, WebSocketRoute]]] = []
routes_to_append: list[tuple[int, Union[Route, WebSocketRoute]]] = []
for _, func in function_members:
index_route = numbered_routes_by_endpoint.get(func)
@@ -138,9 +139,9 @@ def _register_endpoints(router: APIRouter, cls: Type[Any], *urls: str) -> None:
router.include_router(cbv_router, prefix=cbv_prefix)
def _allocate_routes_by_method_name(router: APIRouter, url: str, function_members: List[Tuple[str, Any]]) -> None:
def _allocate_routes_by_method_name(router: APIRouter, url: str, function_members: list[tuple[str, Any]]) -> None:
# sourcery skip: merge-nested-ifs
existing_routes_endpoints: List[Tuple[Any, str]] = [
existing_routes_endpoints: list[tuple[Any, str]] = [
(route.endpoint, route.path) for route in router.routes if isinstance(route, APIRoute)
]
for name, func in function_members:
@@ -165,13 +166,13 @@ def _allocate_routes_by_method_name(router: APIRouter, url: str, function_member
api_resource(func)
def _update_cbv_route_endpoint_signature(cls: Type[Any], route: Union[Route, WebSocketRoute]) -> None:
def _update_cbv_route_endpoint_signature(cls: type[Any], route: Union[Route, WebSocketRoute]) -> None:
"""
Fixes the endpoint signature for a cbv route to ensure FastAPI performs dependency injection properly.
"""
old_endpoint = route.endpoint
old_signature = inspect.signature(old_endpoint)
old_parameters: List[inspect.Parameter] = list(old_signature.parameters.values())
old_parameters: list[inspect.Parameter] = list(old_signature.parameters.values())
old_first_parameter = old_parameters[0]
new_first_parameter = old_first_parameter.replace(default=Depends(cls))
new_parameters = [new_first_parameter] + [

View File

@@ -1,5 +1,6 @@
from collections.abc import Callable
from logging import Logger
from typing import Callable, Generic, Type, TypeVar
from typing import Generic, TypeVar
from fastapi import HTTPException, status
from pydantic import UUID4, BaseModel
@@ -26,14 +27,14 @@ class CrudMixins(Generic[C, R, U]):
"""
repo: RepositoryGeneric
exception_msgs: Callable[[Type[Exception]], str] | None
exception_msgs: Callable[[type[Exception]], str] | None
default_message: str = "An unexpected error occurred."
def __init__(
self,
repo: RepositoryGeneric,
logger: Logger,
exception_msgs: Callable[[Type[Exception]], str] = None,
exception_msgs: Callable[[type[Exception]], str] = None,
default_message: str = None,
) -> None:
@@ -83,7 +84,7 @@ class CrudMixins(Generic[C, R, U]):
return item
def update_one(self, data: U, item_id: int | str | UUID4) -> R:
item: R = self.repo.get_one(item_id)
item = self.repo.get_one(item_id)
if not item:
raise HTTPException(

View File

@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Optional
from fastapi import APIRouter, Depends
@@ -10,7 +10,7 @@ class AdminAPIRouter(APIRouter):
def __init__(
self,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
prefix: str = "",
):
super().__init__(tags=tags, prefix=prefix, dependencies=[Depends(get_admin_user)])
@@ -21,7 +21,7 @@ class UserAPIRouter(APIRouter):
def __init__(
self,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
prefix: str = "",
):
super().__init__(tags=tags, prefix=prefix, dependencies=[Depends(get_current_user)])

View File

@@ -52,16 +52,15 @@ def get_token(data: CustomOAuth2Form = Depends(), session: Session = Depends(gen
email = data.username
password = data.password
user: PrivateUser = authenticate_user(session, email, password)
user = authenticate_user(session, email, password) # type: ignore
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
headers={"WWW-Authenticate": "Bearer"},
)
duration = timedelta(days=14) if data.remember_me else None
access_token = security.create_access_token(dict(sub=str(user.id)), duration)
access_token = security.create_access_token(dict(sub=str(user.id)), duration) # type: ignore
return MealieAuthToken.respond(access_token)

View File

@@ -1,5 +1,4 @@
from functools import cached_property
from typing import Type
from fastapi import APIRouter, HTTPException
from pydantic import UUID4
@@ -24,7 +23,7 @@ class GroupCookbookController(BaseUserController):
def repo(self):
return self.deps.repos.cookbooks.by_group(self.group_id)
def registered_exceptions(self, ex: Type[Exception]) -> str:
def registered_exceptions(self, ex: type[Exception]) -> str:
registered = {
**mealie_registered_exceptions(self.deps.t),
}

View File

@@ -1,5 +1,4 @@
from functools import cached_property
from typing import Type
from fastapi import APIRouter
from pydantic import UUID4
@@ -19,7 +18,7 @@ class GroupReportsController(BaseUserController):
def repo(self):
return self.deps.repos.group_reports.by_group(self.deps.acting_user.group_id)
def registered_exceptions(self, ex: Type[Exception]) -> str:
def registered_exceptions(self, ex: type[Exception]) -> str:
return {
**mealie_registered_exceptions(self.deps.t),
}.get(ex, "An unexpected error occurred.")

View File

@@ -1,6 +1,5 @@
from datetime import date, timedelta
from functools import cached_property
from typing import Type
from fastapi import APIRouter, HTTPException
@@ -24,7 +23,7 @@ class GroupMealplanController(BaseUserController):
def repo(self) -> RepositoryMeals:
return self.repos.meals.by_group(self.group_id)
def registered_exceptions(self, ex: Type[Exception]) -> str:
def registered_exceptions(self, ex: type[Exception]) -> str:
registered = {
**mealie_registered_exceptions(self.deps.t),
}
@@ -58,7 +57,7 @@ class GroupMealplanController(BaseUserController):
)
recipe_repo = self.repos.recipes.by_group(self.group_id)
random_recipes: Recipe = []
random_recipes: list[Recipe] = []
if not rules: # If no rules are set, return any random recipe from the group
random_recipes = recipe_repo.get_random()

View File

@@ -1,4 +1,5 @@
import shutil
from pathlib import Path
from fastapi import Depends, File, Form
from fastapi.datastructures import UploadFile
@@ -8,7 +9,13 @@ from mealie.routes._base import BaseUserController, controller
from mealie.routes._base.routers import UserAPIRouter
from mealie.schema.group.group_migration import SupportedMigrations
from mealie.schema.reports.reports import ReportSummary
from mealie.services.migrations import ChowdownMigrator, MealieAlphaMigrator, NextcloudMigrator, PaprikaMigrator
from mealie.services.migrations import (
BaseMigrator,
ChowdownMigrator,
MealieAlphaMigrator,
NextcloudMigrator,
PaprikaMigrator,
)
router = UserAPIRouter(prefix="/groups/migrations", tags=["Group: Migrations"])
@@ -21,7 +28,7 @@ class GroupMigrationController(BaseUserController):
add_migration_tag: bool = Form(False),
migration_type: SupportedMigrations = Form(...),
archive: UploadFile = File(...),
temp_path: str = Depends(temporary_zip_path),
temp_path: Path = Depends(temporary_zip_path),
):
# Save archive to temp_path
with temp_path.open("wb") as buffer:
@@ -36,6 +43,8 @@ class GroupMigrationController(BaseUserController):
"add_migration_tag": add_migration_tag,
}
migrator: BaseMigrator
match migration_type:
case SupportedMigrations.chowdown:
migrator = ChowdownMigrator(**args)

View File

@@ -23,7 +23,6 @@ def register_debug_handler(app: FastAPI):
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
exc_str = f"{exc}".replace("\n", " ").replace(" ", " ")
log_wrapper(request, exc)
content = {"status_code": status.HTTP_422_UNPROCESSABLE_ENTITY, "message": exc_str, "data": None}

View File

@@ -173,7 +173,7 @@ class RecipeController(BaseRecipeController):
task.append_log(f"Error: Failed to create recipe from url: {b.url}")
task.append_log(f"Error: {e}")
self.deps.logger.error(f"Failed to create recipe from url: {b.url}")
self.deps.error(e)
self.deps.logger.error(e)
database.server_tasks.update(task.id, task)
task.set_finished()
@@ -225,12 +225,13 @@ class RecipeController(BaseRecipeController):
return self.mixins.get_one(slug)
@router.post("", status_code=201, response_model=str)
def create_one(self, data: CreateRecipe) -> str:
def create_one(self, data: CreateRecipe) -> str | None:
"""Takes in a JSON string and loads data into the database as a new entry"""
try:
return self.service.create_one(data).slug
except Exception as e:
self.handle_exceptions(e)
return None
@router.put("/{slug}")
def update_one(self, slug: str, data: Recipe):
@@ -263,7 +264,7 @@ class RecipeController(BaseRecipeController):
# Image and Assets
@router.post("/{slug}/image", tags=["Recipe: Images and Assets"])
def scrape_image_url(self, slug: str, url: CreateRecipeByUrl) -> str:
def scrape_image_url(self, slug: str, url: CreateRecipeByUrl):
recipe = self.mixins.get_one(slug)
data_service = RecipeDataService(recipe.id)
data_service.scrape_image(url.url)
@@ -303,7 +304,7 @@ class RecipeController(BaseRecipeController):
if not dest.is_file():
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR)
recipe: Recipe = self.mixins.get_one(slug)
recipe = self.mixins.get_one(slug)
recipe.assets.append(asset_in)
self.mixins.update_one(recipe, slug)

View File

View File

@@ -0,0 +1,3 @@
from typing import Optional
NoneFloat = Optional[float]

View File

@@ -1,5 +1,5 @@
from datetime import datetime
from typing import List, Optional
from typing import Optional
from pydantic import BaseModel
@@ -22,7 +22,7 @@ class ImportJob(BackupOptions):
class CreateBackup(BaseModel):
tag: Optional[str]
options: BackupOptions
templates: Optional[List[str]]
templates: Optional[list[str]]
class BackupFile(BaseModel):
@@ -32,5 +32,5 @@ class BackupFile(BaseModel):
class AllBackups(BaseModel):
imports: List[BackupFile]
templates: List[str]
imports: list[BackupFile]
templates: list[str]

View File

@@ -1,5 +1,4 @@
from datetime import datetime
from typing import List
from pydantic.main import BaseModel
@@ -17,7 +16,7 @@ class MigrationFile(BaseModel):
class Migrations(BaseModel):
type: str
files: List[MigrationFile] = []
files: list[MigrationFile] = []
class MigrationImport(RecipeImport):

View File

@@ -1,5 +1,5 @@
from fastapi_camelcase import CamelModel
from pydantic import UUID4
from pydantic import UUID4, NoneStr
# =============================================================================
# Group Events Notifier Options
@@ -68,7 +68,7 @@ class GroupEventNotifierSave(GroupEventNotifierCreate):
class GroupEventNotifierUpdate(GroupEventNotifierSave):
id: UUID4
apprise_url: str = None
apprise_url: NoneStr = None
class GroupEventNotifierOut(CamelModel):

View File

@@ -1,6 +1,7 @@
from uuid import UUID
from fastapi_camelcase import CamelModel
from pydantic import NoneStr
class CreateInviteToken(CamelModel):
@@ -29,4 +30,4 @@ class EmailInvitation(CamelModel):
class EmailInitationResponse(CamelModel):
success: bool
error: str = None
error: NoneStr = None

View File

@@ -18,7 +18,7 @@ def mapper(source: U, dest: T, **_) -> T:
return dest
def cast(source: U, dest: T, **kwargs) -> T:
def cast(source: U, dest: type[T], **kwargs) -> T:
create_data = {field: getattr(source, field) for field in source.__fields__ if field in dest.__fields__}
create_data.update(kwargs or {})
return dest(**create_data)

View File

@@ -3,13 +3,13 @@ from .recipe import *
from .recipe_asset import *
from .recipe_bulk_actions import *
from .recipe_category import *
from .recipe_comments import *
from .recipe_comments import * # type: ignore
from .recipe_image_types import *
from .recipe_ingredient import *
from .recipe_notes import *
from .recipe_nutrition import *
from .recipe_settings import *
from .recipe_share_token import *
from .recipe_share_token import * # type: ignore
from .recipe_step import *
from .recipe_tool import *
from .request_helpers import *

View File

@@ -7,6 +7,8 @@ from uuid import UUID, uuid4
from fastapi_camelcase import CamelModel
from pydantic import UUID4, Field
from mealie.schema._mealie.types import NoneFloat
class UnitFoodBase(CamelModel):
name: str
@@ -23,7 +25,7 @@ class SaveIngredientFood(CreateIngredientFood):
class IngredientFood(CreateIngredientFood):
id: UUID4
label: MultiPurposeLabelSummary = None
label: Optional[MultiPurposeLabelSummary] = None
class Config:
orm_mode = True
@@ -63,12 +65,12 @@ class RecipeIngredient(CamelModel):
class IngredientConfidence(CamelModel):
average: float = None
comment: float = None
name: float = None
unit: float = None
quantity: float = None
food: float = None
average: NoneFloat = None
comment: NoneFloat = None
name: NoneFloat = None
unit: NoneFloat = None
quantity: NoneFloat = None
food: NoneFloat = None
class ParsedIngredient(CamelModel):

View File

@@ -1,4 +1,4 @@
from typing import List
import typing
from fastapi_camelcase import CamelModel
from pydantic import UUID4
@@ -22,7 +22,7 @@ class RecipeTool(RecipeToolCreate):
class RecipeToolResponse(RecipeTool):
recipes: List["Recipe"] = []
recipes: typing.List["Recipe"] = []
class Config:
orm_mode = True

View File

@@ -11,4 +11,4 @@ class Token(BaseModel):
class TokenData(BaseModel):
user_id: Optional[UUID4]
username: Optional[constr(to_lower=True, strip_whitespace=True)] = None
username: Optional[constr(to_lower=True, strip_whitespace=True)] = None # type: ignore

View File

@@ -1,13 +1,13 @@
from fastapi_camelcase import CamelModel
from pydantic import validator
from pydantic.types import constr
from pydantic.types import NoneStr, constr
class CreateUserRegistration(CamelModel):
group: str = None
group_token: str = None
email: constr(to_lower=True, strip_whitespace=True)
username: constr(to_lower=True, strip_whitespace=True)
group: NoneStr = None
group_token: NoneStr = None
email: constr(to_lower=True, strip_whitespace=True) # type: ignore
username: constr(to_lower=True, strip_whitespace=True) # type: ignore
password: str
password_confirm: str
advanced: bool = False

View File

@@ -53,7 +53,7 @@ class GroupBase(CamelModel):
class UserBase(CamelModel):
username: Optional[str]
full_name: Optional[str] = None
email: constr(to_lower=True, strip_whitespace=True)
email: constr(to_lower=True, strip_whitespace=True) # type: ignore
admin: bool = False
group: Optional[str]
advanced: bool = False
@@ -107,7 +107,7 @@ class UserOut(UserBase):
class UserFavorites(UserBase):
favorite_recipes: list[RecipeSummary] = []
favorite_recipes: list[RecipeSummary] = [] # type: ignore
class Config:
orm_mode = True

View File

@@ -39,7 +39,7 @@ class ExportDatabase:
try:
self.templates = [app_dirs.TEMPLATE_DIR.joinpath(x) for x in templates]
except Exception:
self.templates = False
self.templates = []
logger.info("No Jinja2 Templates Registered for Export")
required_dirs = [

View File

@@ -1,8 +1,8 @@
import json
import shutil
import zipfile
from collections.abc import Callable
from pathlib import Path
from typing import Callable
from pydantic.main import BaseModel
from sqlalchemy.orm.session import Session
@@ -140,7 +140,7 @@ class ImportDatabase:
if image_dir.exists(): # Migrate from before v0.5.0
for image in image_dir.iterdir():
item: Recipe = successful_imports.get(image.stem)
item: Recipe = successful_imports.get(image.stem) # type: ignore
if item:
dest_dir = item.image_dir
@@ -294,7 +294,7 @@ def import_database(
settings_report = import_session.import_settings() if import_settings else []
group_report = import_session.import_groups() if import_groups else []
user_report = import_session.import_users() if import_users else []
notification_report = []
notification_report: list = []
import_session.clean_up()

View File

@@ -6,7 +6,7 @@ from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy import MetaData, create_engine
from sqlalchemy.engine import base
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import sessionmaker
from mealie.services._base_service import BaseService
@@ -122,8 +122,6 @@ class AlchemyExporter(BaseService):
"""Drops all data from the database"""
self.meta.reflect(bind=self.engine)
with self.session_maker() as session:
session: Session
is_postgres = self.settings.DB_ENGINE == "postgres"
try:

View File

@@ -23,7 +23,7 @@ class DefaultEmailSender(ABCEmailSender, BaseService):
mail_from=(self.settings.SMTP_FROM_NAME, self.settings.SMTP_FROM_EMAIL),
)
smtp_options = {"host": self.settings.SMTP_HOST, "port": self.settings.SMTP_PORT}
smtp_options: dict[str, str | bool] = {"host": self.settings.SMTP_HOST, "port": self.settings.SMTP_PORT}
if self.settings.SMTP_TLS:
smtp_options["tls"] = True
if self.settings.SMTP_USER:

View File

@@ -1,8 +1,9 @@
import zipfile
from abc import abstractmethod, abstractproperty
from collections.abc import Iterator
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Iterator, Optional
from typing import Callable, Optional
from uuid import UUID
from pydantic import BaseModel
@@ -27,7 +28,7 @@ class ExportedItem:
class ABCExporter(BaseService):
write_dir_to_zip: Callable[[Path, str, Optional[list[str]]], None]
write_dir_to_zip: Callable[[Path, str, Optional[set[str]]], None] | None
def __init__(self, db: AllRepositories, group_id: UUID) -> None:
self.logger = get_logger()
@@ -47,8 +48,7 @@ class ABCExporter(BaseService):
def _post_export_hook(self, _: BaseModel) -> None:
pass
@abstractmethod
def export(self, zip: zipfile.ZipFile) -> list[ReportEntryCreate]:
def export(self, zip: zipfile.ZipFile) -> list[ReportEntryCreate]: # type: ignore
"""
Export takes in a zip file and exports the recipes to it. Note that the zip
file open/close is NOT handled by this method. You must handle it yourself.
@@ -57,7 +57,7 @@ class ABCExporter(BaseService):
zip (zipfile.ZipFile): Zip file destination
Returns:
list[ReportEntryCreate]: [description] ???!?!
list[ReportEntryCreate]:
"""
self.write_dir_to_zip = self.write_dir_to_zip_func(zip)

View File

@@ -1,4 +1,4 @@
from typing import Iterator
from collections.abc import Iterator
from uuid import UUID
from mealie.repos.all_repositories import AllRepositories
@@ -37,5 +37,5 @@ class RecipeExporter(ABCExporter):
"""Copy recipe directory contents into the zip folder"""
recipe_dir = item.directory
if recipe_dir.exists():
if recipe_dir.exists() and self.write_dir_to_zip:
self.write_dir_to_zip(recipe_dir, f"{self.destination_dir}/{item.slug}", {".json"})

View File

@@ -168,7 +168,7 @@ class ShoppingListService:
found = False
for ref in item.recipe_references:
remove_qty = 0
remove_qty = 0.0
if ref.recipe_id == recipe_id:
self.list_item_refs.delete(ref.id)
@@ -199,4 +199,4 @@ class ShoppingListService:
break
# Save Changes
return self.shopping_lists.get(shopping_list.id)
return self.shopping_lists.get_one(shopping_list.id)

View File

@@ -1,5 +1,4 @@
from pathlib import Path
from typing import Tuple
from uuid import UUID
from pydantic import UUID4
@@ -94,9 +93,10 @@ class BaseMigrator(BaseService):
self._create_report(report_name)
self._migrate()
self._save_all_entries()
return self.db.group_reports.get(self.report_id)
def import_recipes_to_database(self, validated_recipes: list[Recipe]) -> list[Tuple[str, UUID4, bool]]:
return self.db.group_reports.get_one(self.report_id)
def import_recipes_to_database(self, validated_recipes: list[Recipe]) -> list[tuple[str, UUID4, bool]]:
"""
Used as a single access point to process a list of Recipe objects into the
database in a predictable way. If an error occurs the session is rolled back

View File

@@ -67,6 +67,6 @@ class NextcloudMigrator(BaseMigrator):
for slug, recipe_id, status in all_statuses:
if status:
nc_dir: NextcloudDir = nextcloud_dirs[slug]
nc_dir = nextcloud_dirs[slug]
if nc_dir.image:
import_image(nc_dir.image, recipe_id)

View File

@@ -1,3 +1,4 @@
from collections.abc import Iterable
from typing import TypeVar
from pydantic import UUID4, BaseModel
@@ -14,14 +15,14 @@ T = TypeVar("T", bound=BaseModel)
class DatabaseMigrationHelpers:
def __init__(self, db: AllRepositories, session: Session, group_id: int, user_id: UUID4) -> None:
def __init__(self, db: AllRepositories, session: Session, group_id: UUID4, user_id: UUID4) -> None:
self.group_id = group_id
self.user_id = user_id
self.session = session
self.db = db
def _get_or_set_generic(
self, accessor: RepositoryGeneric, items: list[str], create_model: T, out_model: T
self, accessor: RepositoryGeneric, items: Iterable[str], create_model: type[T], out_model: type[T]
) -> list[T]:
"""
Utility model for getting or setting categories or tags. This will only work for those two cases.
@@ -47,7 +48,7 @@ class DatabaseMigrationHelpers:
items_out.append(item_model.dict())
return items_out
def get_or_set_category(self, categories: list[str]) -> list[RecipeCategory]:
def get_or_set_category(self, categories: Iterable[str]) -> list[RecipeCategory]:
return self._get_or_set_generic(
self.db.categories.by_group(self.group_id),
categories,
@@ -55,7 +56,7 @@ class DatabaseMigrationHelpers:
CategoryOut,
)
def get_or_set_tags(self, tags: list[str]) -> list[RecipeTag]:
def get_or_set_tags(self, tags: Iterable[str]) -> list[RecipeTag]:
return self._get_or_set_generic(
self.db.tags.by_group(self.group_id),
tags,

View File

@@ -1,4 +1,5 @@
from typing import Callable, Optional
from collections.abc import Callable
from typing import Optional
from pydantic import BaseModel

View File

@@ -10,10 +10,10 @@ def move_parens_to_end(ing_str) -> str:
If no parentheses are found, the string is returned unchanged.
"""
if re.match(compiled_match, ing_str):
match = re.search(compiled_search, ing_str)
start = match.start()
end = match.end()
ing_str = ing_str[:start] + ing_str[end:] + " " + ing_str[start:end]
if match := re.search(compiled_search, ing_str):
start = match.start()
end = match.end()
ing_str = ing_str[:start] + ing_str[end:] + " " + ing_str[start:end]
return ing_str

View File

@@ -1,6 +1,5 @@
import string
import unicodedata
from typing import Tuple
from pydantic import BaseModel
@@ -10,7 +9,7 @@ from .._helpers import check_char, move_parens_to_end
class BruteParsedIngredient(BaseModel):
food: str = ""
note: str = ""
amount: float = ""
amount: float = 1.0
unit: str = ""
class Config:
@@ -31,7 +30,7 @@ def parse_fraction(x):
raise ValueError
def parse_amount(ing_str) -> Tuple[float, str, str]:
def parse_amount(ing_str) -> tuple[float, str, str]:
def keep_looping(ing_str, end) -> bool:
"""
Checks if:
@@ -48,7 +47,9 @@ def parse_amount(ing_str) -> Tuple[float, str, str]:
if check_char(ing_str[end], ".", ",", "/") and end + 1 < len(ing_str) and ing_str[end + 1] in string.digits:
return True
amount = 0
return False
amount = 0.0
unit = ""
note = ""
@@ -87,7 +88,7 @@ def parse_amount(ing_str) -> Tuple[float, str, str]:
return amount, unit, note
def parse_ingredient_with_comma(tokens) -> Tuple[str, str]:
def parse_ingredient_with_comma(tokens) -> tuple[str, str]:
ingredient = ""
note = ""
start = 0
@@ -105,7 +106,7 @@ def parse_ingredient_with_comma(tokens) -> Tuple[str, str]:
return ingredient, note
def parse_ingredient(tokens) -> Tuple[str, str]:
def parse_ingredient(tokens) -> tuple[str, str]:
ingredient = ""
note = ""
if tokens[-1].endswith(")"):
@@ -132,7 +133,7 @@ def parse_ingredient(tokens) -> Tuple[str, str]:
def parse(ing_str) -> BruteParsedIngredient:
amount = 0
amount = 0.0
unit = ""
ingredient = ""
note = ""

View File

@@ -5,6 +5,8 @@ from pathlib import Path
from pydantic import BaseModel, validator
from mealie.schema._mealie.types import NoneFloat
from . import utils
from .pre_processor import pre_process_string
@@ -14,10 +16,10 @@ MODEL_PATH = CWD / "model.crfmodel"
class CRFConfidence(BaseModel):
average: float = 0.0
comment: float = None
name: float = None
unit: float = None
qty: float = None
comment: NoneFloat = None
name: NoneFloat = None
unit: NoneFloat = None
qty: NoneFloat = None
class CRFIngredient(BaseModel):

View File

@@ -99,7 +99,7 @@ class NLPParser(ABCIngredientParser):
return [self._crf_to_ingredient(crf_model) for crf_model in crf_models]
def parse_one(self, ingredient: str) -> ParsedIngredient:
items = self.parse_one([ingredient])
items = self.parse([ingredient])
return items[0]

View File

@@ -38,7 +38,7 @@ class RecipeDataService(BaseService):
except Exception as e:
self.logger.exception(f"Failed to delete recipe data: {e}")
def write_image(self, file_data: bytes, extension: str) -> Path:
def write_image(self, file_data: bytes | Path, extension: str) -> Path:
extension = extension.replace(".", "")
image_path = self.dir_image.joinpath(f"original.{extension}")
image_path.unlink(missing_ok=True)
@@ -91,8 +91,8 @@ class RecipeDataService(BaseService):
if ext not in img.IMAGE_EXTENSIONS:
ext = "jpg" # Guess the extension
filename = str(self.recipe_id) + "." + ext
filename = Recipe.directory_from_id(self.recipe_id).joinpath("images", filename)
file_name = f"{str(self.recipe_id)}.{ext}"
file_path = Recipe.directory_from_id(self.recipe_id).joinpath("images", file_name)
try:
r = requests.get(image_url, stream=True, headers={"User-Agent": _FIREFOX_UA})
@@ -102,7 +102,7 @@ class RecipeDataService(BaseService):
if r.status_code == 200:
r.raw.decode_content = True
self.logger.info(f"File Name Suffix {filename.suffix}")
self.write_image(r.raw, filename.suffix)
self.logger.info(f"File Name Suffix {file_path.suffix}")
self.write_image(r.raw, file_path.suffix)
filename.unlink(missing_ok=True)
file_path.unlink(missing_ok=True)

View File

@@ -69,7 +69,6 @@ class RecipeService(BaseService):
all_asset_files = [x.file_name for x in recipe.assets]
for file in recipe.asset_dir.iterdir():
file: Path
if file.is_dir():
continue
if file.name not in all_asset_files:
@@ -102,13 +101,13 @@ class RecipeService(BaseService):
def create_one(self, create_data: Union[Recipe, CreateRecipe]) -> Recipe:
create_data: Recipe = self._recipe_creation_factory(
data: Recipe = self._recipe_creation_factory(
self.user,
name=create_data.name,
additional_attrs=create_data.dict(),
)
create_data.settings = RecipeSettings(
data.settings = RecipeSettings(
public=self.group.preferences.recipe_public,
show_nutrition=self.group.preferences.recipe_show_nutrition,
show_assets=self.group.preferences.recipe_show_assets,
@@ -117,7 +116,7 @@ class RecipeService(BaseService):
disable_amount=self.group.preferences.recipe_disable_amount,
)
return self.repos.recipes.create(create_data)
return self.repos.recipes.create(data)
def create_from_zip(self, archive: UploadFile, temp_path: Path) -> Recipe:
"""

View File

@@ -27,7 +27,7 @@ class TemplateService(BaseService):
super().__init__()
@property
def templates(self) -> list:
def templates(self) -> dict[str, list[str]]:
"""
Returns a list of all templates available to render.
"""
@@ -78,6 +78,8 @@ class TemplateService(BaseService):
if t_type == TemplateType.zip:
return self._render_zip(recipe)
raise ValueError(f"Template Type '{t_type}' not found.")
def _render_json(self, recipe: Recipe) -> Path:
"""
Renders a JSON file in a temporary directory and returns
@@ -98,18 +100,18 @@ class TemplateService(BaseService):
"""
self.__check_temp(self._render_jinja2)
j2_template: Path = self.directories.TEMPLATE_DIR / j2_template
j2_path: Path = self.directories.TEMPLATE_DIR / j2_template
if not j2_template.is_file():
raise FileNotFoundError(f"Template '{j2_template}' not found.")
if not j2_path.is_file():
raise FileNotFoundError(f"Template '{j2_path}' not found.")
with open(j2_template, "r") as f:
with open(j2_path, "r") as f:
template_text = f.read()
template = Template(template_text)
rendered_text = template.render(recipe=recipe.dict(by_alias=True))
save_name = f"{recipe.slug}{j2_template.suffix}"
save_name = f"{recipe.slug}{j2_path.suffix}"
save_path = self.temp.joinpath(save_name)

View File

@@ -1,5 +1,5 @@
from collections.abc import Callable
from dataclasses import dataclass
from typing import Callable, Tuple
from pydantic import BaseModel
@@ -17,7 +17,7 @@ class Cron:
@dataclass
class ScheduledFunc(BaseModel):
id: Tuple[str, int]
id: tuple[str, int]
name: str
hour: int
minutes: int

View File

@@ -1,4 +1,4 @@
from typing import Callable, Iterable
from collections.abc import Callable, Iterable
from mealie.core import root_logger

View File

@@ -49,30 +49,26 @@ class SchedulerService:
@staticmethod
def add_cron_job(job_func: ScheduledFunc):
SchedulerService.scheduler.add_job(
SchedulerService.scheduler.add_job( # type: ignore
job_func.callback,
trigger="cron",
name=job_func.id,
hour=job_func.hour,
minute=job_func.minutes,
max_instances=job_func.max_instances,
max_instances=job_func.max_instances, # type: ignore
replace_existing=job_func.replace_existing,
args=job_func.args,
)
# SchedulerService._job_store[job_func.id] = job_func
@staticmethod
def update_cron_job(job_func: ScheduledFunc):
SchedulerService.scheduler.reschedule_job(
SchedulerService.scheduler.reschedule_job( # type: ignore
job_func.id,
trigger="cron",
hour=job_func.hour,
minute=job_func.minutes,
)
# SchedulerService._job_store[job_func.id] = job_func
def _scheduled_task_wrapper(callable):
try:

View File

@@ -39,7 +39,8 @@ def purge_excess_files() -> None:
limit = datetime.datetime.now() - datetime.timedelta(minutes=ONE_DAY_AS_MINUTES * 2)
for file in directories.GROUPS_DIR.glob("**/export/*.zip"):
if file.stat().st_mtime < limit:
# TODO: fix comparison types
if file.stat().st_mtime < limit: # type: ignore
file.unlink()
logger.info(f"excess group file removed '{file}'")

View File

@@ -28,7 +28,7 @@ def post_webhooks(webhook_id: int, session: Session = None):
if not todays_recipe:
return
payload = json.loads([x.json(by_alias=True) for x in todays_recipe])
payload = json.loads([x.json(by_alias=True) for x in todays_recipe]) # type: ignore
response = requests.post(webhook.url, json=payload)
if response.status_code != 200:

View File

@@ -2,7 +2,7 @@ import html
import json
import re
from datetime import datetime, timedelta
from typing import List, Optional
from typing import Optional
from slugify import slugify
@@ -33,7 +33,7 @@ def clean(recipe_data: dict, url=None) -> dict:
recipe_data["recipeIngredient"] = ingredient(recipe_data.get("recipeIngredient"))
recipe_data["recipeInstructions"] = instructions(recipe_data.get("recipeInstructions"))
recipe_data["image"] = image(recipe_data.get("image"))
recipe_data["slug"] = slugify(recipe_data.get("name"))
recipe_data["slug"] = slugify(recipe_data.get("name")) # type: ignore
recipe_data["orgURL"] = url
return recipe_data
@@ -127,7 +127,7 @@ def image(image=None) -> str:
raise Exception(f"Unrecognised image URL format: {image}")
def instructions(instructions) -> List[dict]:
def instructions(instructions) -> list[dict]:
try:
instructions = json.loads(instructions)
except Exception:
@@ -162,7 +162,8 @@ def instructions(instructions) -> List[dict]:
sectionSteps = []
for step in instructions:
if step["@type"] == "HowToSection":
[sectionSteps.append(item) for item in step["itemListElement"]]
for sectionStep in step["itemListElement"]:
sectionSteps.append(sectionStep)
if len(sectionSteps) > 0:
return [{"text": _instruction(step["text"])} for step in sectionSteps if step["@type"] == "HowToStep"]
@@ -183,6 +184,8 @@ def instructions(instructions) -> List[dict]:
else:
raise Exception(f"Unrecognised instruction format: {instructions}")
return []
def _instruction(line) -> str:
if isinstance(line, dict):
@@ -199,7 +202,7 @@ def _instruction(line) -> str:
return clean_line
def ingredient(ingredients: list) -> str:
def ingredient(ingredients: list | None) -> list[str]:
if ingredients:
return [clean_string(ing) for ing in ingredients]
else:

View File

@@ -1,5 +1,3 @@
from typing import Type
from mealie.schema.recipe.recipe import Recipe
from .scraper_strategies import ABCScraperStrategy, RecipeScraperOpenGraph, RecipeScraperPackage
@@ -11,9 +9,9 @@ class RecipeScraper:
"""
# List of recipe scrapers. Note that order matters
scrapers: list[Type[ABCScraperStrategy]]
scrapers: list[type[ABCScraperStrategy]]
def __init__(self, scrapers: list[Type[ABCScraperStrategy]] = None) -> None:
def __init__(self, scrapers: list[type[ABCScraperStrategy]] = None) -> None:
if scrapers is None:
scrapers = [
RecipeScraperPackage,
@@ -27,8 +25,8 @@ class RecipeScraper:
Scrapes a recipe from the web.
"""
for scraper in self.scrapers:
scraper = scraper(url)
for scraper_type in self.scrapers:
scraper = scraper_type(url)
recipe = scraper.parse()
if recipe is not None:

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Tuple
from typing import Any, Callable
import extruct
import requests
@@ -26,7 +26,7 @@ class ABCScraperStrategy(ABC):
self.url = url
@abstractmethod
def parse(self, recipe_url: str) -> Recipe | None:
def parse(self) -> Recipe | None:
"""Parse a recipe from a web URL.
Args:
@@ -40,7 +40,7 @@ class ABCScraperStrategy(ABC):
class RecipeScraperPackage(ABCScraperStrategy):
def clean_scraper(self, scraped_data: SchemaScraperFactory.SchemaScraper, url: str) -> Recipe:
def try_get_default(func_call: Callable, get_attr: str, default: Any, clean_func=None):
def try_get_default(func_call: Callable | None, get_attr: str, default: Any, clean_func=None):
value = default
try:
value = func_call()
@@ -143,7 +143,7 @@ class RecipeScraperOpenGraph(ABCScraperStrategy):
def get_html(self) -> str:
return requests.get(self.url).text
def get_recipe_fields(self, html) -> dict:
def get_recipe_fields(self, html) -> dict | None:
"""
Get the recipe fields from the Open Graph data.
"""
@@ -151,7 +151,7 @@ class RecipeScraperOpenGraph(ABCScraperStrategy):
def og_field(properties: dict, field_name: str) -> str:
return next((val for name, val in properties if name == field_name), None)
def og_fields(properties: list[Tuple[str, str]], field_name: str) -> list[str]:
def og_fields(properties: list[tuple[str, str]], field_name: str) -> list[str]:
return list({val for name, val in properties if name == field_name})
base_url = get_base_url(html, self.url)
@@ -159,7 +159,7 @@ class RecipeScraperOpenGraph(ABCScraperStrategy):
try:
properties = data["opengraph"][0]["properties"]
except Exception:
return
return None
return {
"name": og_field(properties, "og:title"),

View File

@@ -1,6 +1,7 @@
from collections.abc import Callable
from random import getrandbits
from time import sleep
from typing import Any, Callable
from typing import Any
from fastapi import BackgroundTasks
from pydantic import UUID4

View File

@@ -16,13 +16,13 @@ class PasswordResetService(BaseService):
self.db = get_repositories(session)
super().__init__()
def generate_reset_token(self, email: str) -> SavePasswordResetToken:
def generate_reset_token(self, email: str) -> SavePasswordResetToken | None:
user = self.db.users.get_one(email, "email")
if user is None:
logger.error(f"failed to create password reset for {email=}: user doesn't exists")
# Do not raise exception here as we don't want to confirm to the client that the Email doens't exists
return
return None
# Create Reset Token
token = url_safe_token()

View File

@@ -66,7 +66,7 @@ class RegistrationService:
token_entry = self.repos.group_invite_tokens.get_one(registration.group_token)
if not token_entry:
raise HTTPException(status.HTTP_400_BAD_REQUEST, {"message": "Invalid group token"})
group = self.repos.groups.get(token_entry.group_id)
group = self.repos.groups.get_one(token_entry.group_id)
else:
raise HTTPException(status.HTTP_400_BAD_REQUEST, {"message": "Missing group"})