fix: strict optional errors (#1759)

* fix strict optional errors

* fix typing in repository

* fix backup db files location

* update workspace settings
This commit is contained in:
Hayden
2022-10-23 13:04:04 -08:00
committed by GitHub
parent 97d9e2a109
commit 84c23765cd
31 changed files with 253 additions and 139 deletions

View File

@@ -21,9 +21,9 @@ class AlchemyExporter(BaseService):
look_for_time = {"scheduled_time"}
class DateTimeParser(BaseModel):
date: datetime.date = None
dt: datetime.datetime = None
time: datetime.time = None
date: datetime.date | None = None
dt: datetime.datetime | None = None
time: datetime.time | None = None
def __init__(self, connection_str: str) -> None:
super().__init__()

View File

@@ -5,7 +5,7 @@ from pathlib import Path
class BackupContents:
_tables: dict = None
_tables: dict | None = None
def __init__(self, file: Path) -> None:
self.base = file

View File

@@ -17,15 +17,17 @@ class BackupV2(BaseService):
def __init__(self, db_url: str = None) -> None:
super().__init__()
self.db_url = db_url or self.settings.DB_URL
# type - one of these has to be a string
self.db_url: str = db_url or self.settings.DB_URL # type: ignore
self.db_exporter = AlchemyExporter(self.db_url)
def _sqlite(self) -> None:
db_file = self.settings.DB_URL.removeprefix("sqlite:///")
db_file = self.settings.DB_URL.removeprefix("sqlite:///") # type: ignore
# Create a backup of the SQLite database
timestamp = datetime.datetime.now().strftime("%Y.%m.%d")
shutil.copy(db_file, f"mealie_{timestamp}.bak.db")
shutil.copy(db_file, self.directories.DATA_DIR.joinpath(f"mealie_{timestamp}.bak.db"))
def _postgres(self) -> None:
pass

View File

@@ -11,8 +11,8 @@ from mealie.services._base_service import BaseService
class EmailOptions:
host: str
port: int
username: str = None
password: str = None
username: str | None = None
password: str | None = None
tls: bool = False
ssl: bool = False
@@ -39,7 +39,9 @@ class Message:
if smtp.ssl:
with smtplib.SMTP_SSL(smtp.host, smtp.port) as server:
server.login(smtp.username, smtp.password)
if smtp.username and smtp.password:
server.login(smtp.username, smtp.password)
errors = server.send_message(msg)
else:
with smtplib.SMTP(smtp.host, smtp.port) as server:
@@ -66,17 +68,24 @@ class DefaultEmailSender(ABCEmailSender, BaseService):
"""
def send(self, email_to: str, subject: str, html: str) -> bool:
if self.settings.SMTP_FROM_EMAIL is None or self.settings.SMTP_FROM_NAME is None:
raise ValueError("SMTP_FROM_EMAIL and SMTP_FROM_NAME must be set in the config file.")
message = Message(
subject=subject,
html=html,
mail_from=(self.settings.SMTP_FROM_NAME, self.settings.SMTP_FROM_EMAIL),
)
if self.settings.SMTP_HOST is None or self.settings.SMTP_PORT is None:
raise ValueError("SMTP_HOST, SMTP_PORT must be set in the config file.")
smtp_options = EmailOptions(
self.settings.SMTP_HOST,
int(self.settings.SMTP_PORT),
tls=self.settings.SMTP_AUTH_STRATEGY.upper() == "TLS",
ssl=self.settings.SMTP_AUTH_STRATEGY.upper() == "SSL",
tls=self.settings.SMTP_AUTH_STRATEGY.upper() == "TLS" if self.settings.SMTP_AUTH_STRATEGY else False,
ssl=self.settings.SMTP_AUTH_STRATEGY.upper() == "SSL" if self.settings.SMTP_AUTH_STRATEGY else False,
)
if self.settings.SMTP_USER:

View File

@@ -41,7 +41,7 @@ class EventListenerBase(ABC):
...
@contextlib.contextmanager
def ensure_session(self) -> Generator[None, None, None]:
def ensure_session(self) -> Generator[Session, None, None]:
"""
ensure_session ensures that a session is available for the caller by checking if a session
was provided during construction, and if not, creating a new session with the `with_session`
@@ -54,10 +54,9 @@ class EventListenerBase(ABC):
if self.session is None:
with session_context() as session:
self.session = session
yield
yield self.session
else:
yield
yield self.session
class AppriseEventListener(EventListenerBase):
@@ -87,7 +86,7 @@ class AppriseEventListener(EventListenerBase):
"integration_id": event.integration_id,
"document_data": json.dumps(jsonable_encoder(event.document_data)),
"event_id": str(event.event_id),
"timestamp": event.timestamp.isoformat(),
"timestamp": event.timestamp.isoformat() if event.timestamp else None,
}
return [
@@ -148,9 +147,9 @@ class WebhookEventListener(EventListenerBase):
def get_scheduled_webhooks(self, start_dt: datetime, end_dt: datetime) -> list[ReadWebhook]:
"""Fetches all scheduled webhooks from the database"""
with self.ensure_session():
with self.ensure_session() as session:
return (
self.session.query(GroupWebhooksModel)
session.query(GroupWebhooksModel)
.where(
GroupWebhooksModel.enabled == True, # noqa: E712 - required for SQLAlchemy comparison
GroupWebhooksModel.scheduled_time > start_dt.astimezone(timezone.utc).time(),

View File

@@ -1,29 +0,0 @@
import random
from pydantic import UUID4
from mealie.repos.repository_factory import AllRepositories
from mealie.schema.recipe.recipe import Recipe, RecipeCategory
from mealie.services._base_service import BaseService
class MealPlanService(BaseService):
def __init__(self, group_id: UUID4, repos: AllRepositories):
self.group_id = group_id
self.repos = repos
def get_random_recipe(self, categories: list[RecipeCategory] = None) -> Recipe:
"""get_random_recipe returns a single recipe matching a specific criteria of
categories. if no categories are provided, a single recipe is returned from the
entire recipe database.
Note that the recipe must contain ALL categories in the list provided.
Args:
categories (list[RecipeCategory], optional): [description]. Defaults to None.
Returns:
Recipe: [description]
"""
recipes = self.repos.recipes.by_group(self.group_id).get_by_categories(categories)
return random.choice(recipes)

View File

@@ -1,5 +1,6 @@
from pydantic import UUID4
from mealie.core.exceptions import UnexpectedNone
from mealie.repos.repository_factory import AllRepositories
from mealie.schema.group import ShoppingListItemCreate, ShoppingListOut
from mealie.schema.group.group_shopping_list import (
@@ -120,8 +121,10 @@ class ShoppingListService:
- deleted_shopping_list_items
"""
recipe = self.repos.recipes.get_one(recipe_id, "id")
to_create = []
if not recipe:
raise UnexpectedNone("Recipe not found")
to_create = []
for ingredient in recipe.recipe_ingredient:
food_id = None
try:
@@ -144,7 +147,7 @@ class ShoppingListService:
to_create.append(
ShoppingListItemCreate(
shopping_list_id=list_id,
is_food=not recipe.settings.disable_amount,
is_food=not recipe.settings.disable_amount if recipe.settings else False,
food_id=food_id,
unit_id=unit_id,
quantity=ingredient.quantity,
@@ -163,6 +166,9 @@ class ShoppingListService:
new_shopping_list_items = [self.repos.group_shopping_list_item.create(item) for item in to_create]
updated_shopping_list = self.shopping_lists.get_one(list_id)
if not updated_shopping_list:
raise UnexpectedNone("Shopping List not found")
updated_shopping_list_items, deleted_shopping_list_items = self.consolidate_and_save(updated_shopping_list.list_items) # type: ignore
updated_shopping_list.list_items = updated_shopping_list_items
@@ -219,13 +225,16 @@ class ShoppingListService:
"""
shopping_list = self.shopping_lists.get_one(list_id)
if shopping_list is None:
raise UnexpectedNone("Shopping list not found, cannot remove recipe ingredients")
updated_shopping_list_items = []
deleted_shopping_list_items = []
for item in shopping_list.list_items:
found = False
for ref in item.recipe_references:
remove_qty = 0.0
remove_qty: None | float = 0.0
if ref.recipe_id == recipe_id:
self.list_item_refs.delete(ref.id) # type: ignore
@@ -236,7 +245,9 @@ class ShoppingListService:
# If the item was found decrement the quantity by the remove_qty
if found:
item.quantity = item.quantity - remove_qty
if remove_qty is not None:
item.quantity = item.quantity - remove_qty
if item.quantity <= 0:
self.list_items.delete(item.id)
@@ -246,16 +257,16 @@ class ShoppingListService:
updated_shopping_list_items.append(item)
# Decrement the list recipe reference count
for ref in shopping_list.recipe_references: # type: ignore
if ref.recipe_id == recipe_id:
ref.recipe_quantity -= 1
for recipe_ref in shopping_list.recipe_references:
if recipe_ref.recipe_id == recipe_id and recipe_ref.recipe_quantity is not None:
recipe_ref.recipe_quantity -= 1.0
if ref.recipe_quantity <= 0:
self.list_refs.delete(ref.id) # type: ignore
if recipe_ref.recipe_quantity <= 0.0:
self.list_refs.delete(recipe_ref.id)
else:
self.list_refs.update(ref.id, ref) # type: ignore
self.list_refs.update(recipe_ref.id, ref)
break
# Save Changes
return self.shopping_lists.get_one(shopping_list.id), updated_shopping_list_items, deleted_shopping_list_items
return self.shopping_lists.get_one(shopping_list.id), updated_shopping_list_items, deleted_shopping_list_items # type: ignore

View File

@@ -96,7 +96,12 @@ class BaseMigrator(BaseService):
self._migrate()
self._save_all_entries()
return self.db.group_reports.get_one(self.report_id)
result = self.db.group_reports.get_one(self.report_id)
if not result:
raise ValueError("Report not found")
return result
def import_recipes_to_database(self, validated_recipes: list[Recipe]) -> list[tuple[str, UUID4, bool]]:
"""
@@ -111,10 +116,13 @@ class BaseMigrator(BaseService):
if self.add_migration_tag:
migration_tag = self.helpers.get_or_set_tags([self.name])[0]
return_vars = []
return_vars: list[tuple[str, UUID4, bool]] = []
group = self.db.groups.get_one(self.group_id)
if not group or not group.preferences:
raise ValueError("Group preferences not found")
default_settings = RecipeSettings(
public=group.preferences.recipe_public,
show_nutrition=group.preferences.recipe_show_nutrition,
@@ -132,6 +140,8 @@ class BaseMigrator(BaseService):
if recipe.tags:
recipe.tags = self.helpers.get_or_set_tags(x.name for x in recipe.tags)
else:
recipe.tags = []
if recipe.recipe_category:
recipe.recipe_category = self.helpers.get_or_set_category(x.name for x in recipe.recipe_category)
@@ -155,7 +165,7 @@ class BaseMigrator(BaseService):
else:
message = f"Failed to import {recipe.name}"
return_vars.append((recipe.slug, recipe.id, status))
return_vars.append((recipe.slug, recipe.id, status)) # type: ignore
self.report_entries.append(
ReportEntryCreate(

View File

@@ -42,8 +42,13 @@ class ChowdownMigrator(BaseMigrator):
for slug, recipe_id, status in results:
if status:
try:
original_image = recipe_lookup.get(slug).image
cd_image = image_dir.joinpath(original_image)
r = recipe_lookup.get(slug)
if not r:
continue
if r.image:
cd_image = image_dir.joinpath(r.image)
except StopIteration:
continue
if cd_image:

View File

@@ -1,5 +1,6 @@
from pathlib import Path
from mealie.core.exceptions import UnexpectedNone
from mealie.repos.repository_factory import AllRepositories
from mealie.schema.group.group_exports import GroupDataExport
from mealie.schema.recipe import CategoryBase
@@ -41,6 +42,9 @@ class RecipeBulkActionsService(BaseService):
group = self.repos.groups.get_one(self.group.id)
if group is None:
raise UnexpectedNone("Failed to purge exports for group, no group found")
for match in group.directory.glob("**/export/*zip"):
if match.is_file():
match.unlink()
@@ -52,8 +56,8 @@ class RecipeBulkActionsService(BaseService):
for slug in recipes:
recipe = self.repos.recipes.get_one(slug)
if recipe is None:
self.logger.error(f"Failed to set settings for recipe {slug}, no recipe found")
if recipe is None or recipe.settings is None:
raise UnexpectedNone(f"Failed to set settings for recipe {slug}, no recipe found")
settings.locked = recipe.settings.locked
recipe.settings = settings
@@ -69,9 +73,12 @@ class RecipeBulkActionsService(BaseService):
recipe = self.repos.recipes.get_one(slug)
if recipe is None:
self.logger.error(f"Failed to tag recipe {slug}, no recipe found")
raise UnexpectedNone(f"Failed to tag recipe {slug}, no recipe found")
recipe.tags += tags
if recipe.tags is None:
recipe.tags = []
recipe.tags += tags # type: ignore
try:
self.repos.recipes.update(slug, recipe)
@@ -84,9 +91,12 @@ class RecipeBulkActionsService(BaseService):
recipe = self.repos.recipes.get_one(slug)
if recipe is None:
self.logger.error(f"Failed to categorize recipe {slug}, no recipe found")
raise UnexpectedNone(f"Failed to categorize recipe {slug}, no recipe found")
recipe.recipe_category += categories
if recipe.recipe_category is None:
recipe.recipe_category = []
recipe.recipe_category += categories # type: ignore
try:
self.repos.recipes.update(slug, recipe)

View File

@@ -50,6 +50,8 @@ class RecipeService(BaseService):
return recipe
def can_update(self, recipe: Recipe) -> bool:
if recipe.settings is None:
raise exceptions.UnexpectedNone("Recipe Settings is None")
return recipe.settings.locked is False or self.user.id == recipe.user_id
def can_lock_unlock(self, recipe: Recipe) -> bool:
@@ -66,6 +68,9 @@ class RecipeService(BaseService):
except FileNotFoundError:
self.logger.error(f"Recipe Directory not Found: {original_slug}")
if recipe.assets is None:
recipe.assets = []
all_asset_files = [x.file_name for x in recipe.assets]
for file in recipe.asset_dir.iterdir():
@@ -92,7 +97,7 @@ class RecipeService(BaseService):
additional_attrs["group_id"] = user.group_id
if additional_attrs.get("tags"):
for i in range(len(additional_attrs.get("tags"))):
for i in range(len(additional_attrs.get("tags", []))):
additional_attrs["tags"][i]["group_id"] = user.group_id
if not additional_attrs.get("recipe_ingredient"):
@@ -105,6 +110,9 @@ class RecipeService(BaseService):
def create_one(self, create_data: Union[Recipe, CreateRecipe]) -> Recipe:
if create_data.name is None:
create_data.name = "New Recipe"
data: Recipe = self._recipe_creation_factory(
self.user,
name=create_data.name,
@@ -134,8 +142,8 @@ class RecipeService(BaseService):
with temp_path.open("wb") as buffer:
shutil.copyfileobj(archive.file, buffer)
recipe_dict = None
recipe_image = None
recipe_dict: dict | None = None
recipe_image: bytes | None = None
with ZipFile(temp_path) as myzip:
for file in myzip.namelist():
@@ -146,10 +154,15 @@ class RecipeService(BaseService):
with myzip.open(file) as myfile:
recipe_image = myfile.read()
if recipe_dict is None:
raise exceptions.UnexpectedNone("No json data found in Zip")
recipe = self.create_one(Recipe(**recipe_dict))
if recipe:
if recipe and recipe.id:
data_service = RecipeDataService(recipe.id)
if recipe_image:
data_service.write_image(recipe_image, "webp")
return recipe
@@ -172,6 +185,10 @@ class RecipeService(BaseService):
"""
recipe = self._get_recipe(slug)
if recipe is None or recipe.settings is None:
raise exceptions.NoEntryFound("Recipe not found.")
if not self.can_update(recipe):
raise exceptions.PermissionDenied("You do not have permission to edit this recipe.")
@@ -189,9 +206,12 @@ class RecipeService(BaseService):
return new_data
def patch_one(self, slug: str, patch_data: Recipe) -> Recipe:
recipe = self._pre_update_check(slug, patch_data)
recipe: Recipe | None = self._pre_update_check(slug, patch_data)
recipe = self.repos.recipes.by_group(self.group.id).get_one(slug)
if recipe is None:
raise exceptions.NoEntryFound("Recipe not found.")
new_data = self.repos.recipes.by_group(self.group.id).patch(recipe.slug, patch_data.dict(exclude_unset=True))
self.check_assets(new_data, recipe.slug)
@@ -210,6 +230,6 @@ class RecipeService(BaseService):
# =================================================================
# Recipe Template Methods
def render_template(self, recipe: Recipe, temp_dir: Path, template: str = None) -> Path:
def render_template(self, recipe: Recipe, temp_dir: Path, template: str) -> Path:
t_service = TemplateService(temp_dir)
return t_service.render(recipe, template)

View File

@@ -16,7 +16,7 @@ class TemplateType(str, enum.Enum):
class TemplateService(BaseService):
def __init__(self, temp: Path = None) -> None:
def __init__(self, temp: Path | None = None) -> None:
"""Creates a template service that can be used for multiple template generations
A temporary directory must be provided as a place holder for where to render all templates
Args:
@@ -58,7 +58,7 @@ class TemplateService(BaseService):
return TemplateType(t_type)
def render(self, recipe: Recipe, template: str = None) -> Path:
def render(self, recipe: Recipe, template: str) -> Path:
"""
Renders a TemplateType in a temporary directory and returns the path to the file.
@@ -87,6 +87,9 @@ class TemplateService(BaseService):
"""
self.__check_temp(self._render_json)
if self.temp is None:
raise ValueError("Temporary directory must be provided for method _render_json")
save_path = self.temp.joinpath(f"{recipe.slug}.json")
with open(save_path, "w") as f:
f.write(recipe.json(indent=4, by_alias=True))
@@ -100,6 +103,9 @@ class TemplateService(BaseService):
"""
self.__check_temp(self._render_jinja2)
if j2_template is None:
raise ValueError("Template must be provided for method _render_jinja2")
j2_path: Path = self.directories.TEMPLATE_DIR / j2_template
if not j2_path.is_file():
@@ -113,6 +119,9 @@ class TemplateService(BaseService):
save_name = f"{recipe.slug}{j2_path.suffix}"
if self.temp is None:
raise ValueError("Temporary directory must be provided for method _render_jinja2")
save_path = self.temp.joinpath(save_name)
with open(save_path, "w") as f:
@@ -124,6 +133,10 @@ class TemplateService(BaseService):
self.__check_temp(self._render_jinja2)
image_asset = recipe.image_dir.joinpath(RecipeImageTypes.original.value)
if self.temp is None:
raise ValueError("Temporary directory must be provided for method _render_zip")
zip_temp = self.temp.joinpath(f"{recipe.slug}.zip")
with ZipFile(zip_temp, "w") as myzip:

View File

@@ -19,7 +19,7 @@ class ParserErrors(str, Enum):
CONNECTION_ERROR = "CONNECTION_ERROR"
def create_from_url(url: str) -> tuple[Recipe, ScrapedExtras]:
def create_from_url(url: str) -> tuple[Recipe, ScrapedExtras | None]:
"""Main entry point for generating a recipe from a URL. Pass in a URL and
a Recipe object will be returned if successful.
@@ -43,6 +43,10 @@ def create_from_url(url: str) -> tuple[Recipe, ScrapedExtras]:
try:
recipe_data_service.scrape_image(new_recipe.image)
if new_recipe.name is None:
new_recipe.name = "Untitled"
new_recipe.slug = slugify(new_recipe.name)
new_recipe.image = cache.new_key(4)
except Exception as e:

View File

@@ -176,7 +176,7 @@ class RecipeScraperPackage(ABCScraperStrategy):
ingredients = []
try:
instruct = scraped_schema.instructions()
instruct: list | str = scraped_schema.instructions()
except Exception:
instruct = []
@@ -212,7 +212,7 @@ class RecipeScraperOpenGraph(ABCScraperStrategy):
"""
def og_field(properties: dict, field_name: str) -> str:
return next((val for name, val in properties if name == field_name), None)
return next((val for name, val in properties if name == field_name), "")
def og_fields(properties: list[tuple[str, str]], field_name: str) -> list[str]:
return list({val for name, val in properties if name == field_name})

View File

@@ -31,6 +31,9 @@ class PasswordResetService(BaseService):
def send_reset_email(self, email: str):
token_entry = self.generate_reset_token(email)
if token_entry is None:
return None
# Send Email
email_servive = EmailService()
reset_url = f"{self.settings.BASE_URL}/reset-password?token={token_entry.token}"

View File

@@ -35,7 +35,8 @@ class RegistrationService:
can_organize=new_group,
)
return self.repos.users.create(new_user)
# TODO: problem with repository type, not type here
return self.repos.users.create(new_user) # type: ignore
def _register_new_group(self) -> GroupInDB:
group_data = GroupBase(name=self.registration.group)
@@ -74,7 +75,13 @@ class RegistrationService:
token_entry = self.repos.group_invite_tokens.get_one(registration.group_token)
if not token_entry:
raise HTTPException(status.HTTP_400_BAD_REQUEST, {"message": "Invalid group token"})
group = self.repos.groups.get_one(token_entry.group_id)
maybe_none_group = self.repos.groups.get_one(token_entry.group_id)
if maybe_none_group is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, {"message": "Invalid group token"})
group = maybe_none_group
else:
raise HTTPException(status.HTTP_400_BAD_REQUEST, {"message": "Missing group"})