mirror of
				https://github.com/mealie-recipes/mealie.git
				synced 2025-10-31 10:13:32 -04:00 
			
		
		
		
	fix: group creation (#1126)
* fix: unify group creation - closes #1100 * tests: disable password hashing during testing * tests: fix email config tests
This commit is contained in:
		
							
								
								
									
										1
									
								
								mealie/core/security/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								mealie/core/security/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | |||||||
|  | from .security import * | ||||||
							
								
								
									
										43
									
								
								mealie/core/security/hasher.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								mealie/core/security/hasher.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,43 @@ | |||||||
|  | from functools import lru_cache | ||||||
|  | from typing import Protocol | ||||||
|  |  | ||||||
|  | from passlib.context import CryptContext | ||||||
|  |  | ||||||
|  | from mealie.core.config import get_app_settings | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Hasher(Protocol): | ||||||
|  |     def hash(self, password: str) -> str: | ||||||
|  |         ... | ||||||
|  |  | ||||||
|  |     def verify(self, password: str, hashed: str) -> bool: | ||||||
|  |         ... | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class FakeHasher: | ||||||
|  |     def hash(self, password: str) -> str: | ||||||
|  |         return password | ||||||
|  |  | ||||||
|  |     def verify(self, password: str, hashed: str) -> bool: | ||||||
|  |         return password == hashed | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class PasslibHasher: | ||||||
|  |     def __init__(self) -> None: | ||||||
|  |         self.ctx = CryptContext(schemes=["bcrypt"], deprecated="auto") | ||||||
|  |  | ||||||
|  |     def hash(self, password: str) -> str: | ||||||
|  |         return self.ctx.hash(password) | ||||||
|  |  | ||||||
|  |     def verify(self, password: str, hashed: str) -> bool: | ||||||
|  |         return self.ctx.verify(password, hashed) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @lru_cache(maxsize=1) | ||||||
|  | def get_hasher() -> Hasher: | ||||||
|  |     settings = get_app_settings() | ||||||
|  |  | ||||||
|  |     if settings.TESTING: | ||||||
|  |         return FakeHasher() | ||||||
|  |  | ||||||
|  |     return PasslibHasher() | ||||||
| @@ -1,16 +1,15 @@ | |||||||
| import secrets | import secrets | ||||||
| from datetime import datetime, timedelta | from datetime import datetime, timedelta, timezone | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| 
 | 
 | ||||||
| from jose import jwt | from jose import jwt | ||||||
| from passlib.context import CryptContext |  | ||||||
| 
 | 
 | ||||||
| from mealie.core.config import get_app_settings | from mealie.core.config import get_app_settings | ||||||
|  | from mealie.core.security.hasher import get_hasher | ||||||
| from mealie.repos.all_repositories import get_repositories | from mealie.repos.all_repositories import get_repositories | ||||||
| from mealie.repos.repository_factory import AllRepositories | from mealie.repos.repository_factory import AllRepositories | ||||||
| from mealie.schema.user import PrivateUser | from mealie.schema.user import PrivateUser | ||||||
| 
 | 
 | ||||||
| pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |  | ||||||
| ALGORITHM = "HS256" | ALGORITHM = "HS256" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @@ -20,7 +19,7 @@ def create_access_token(data: dict, expires_delta: timedelta = None) -> str: | |||||||
|     to_encode = data.copy() |     to_encode = data.copy() | ||||||
|     expires_delta = expires_delta or timedelta(hours=settings.TOKEN_TIME) |     expires_delta = expires_delta or timedelta(hours=settings.TOKEN_TIME) | ||||||
| 
 | 
 | ||||||
|     expire = datetime.utcnow() + expires_delta |     expire = datetime.now(timezone.utc) + expires_delta | ||||||
| 
 | 
 | ||||||
|     to_encode["exp"] = expire |     to_encode["exp"] = expire | ||||||
|     return jwt.encode(to_encode, settings.SECRET, algorithm=ALGORITHM) |     return jwt.encode(to_encode, settings.SECRET, algorithm=ALGORITHM) | ||||||
| @@ -31,7 +30,7 @@ def create_file_token(file_path: Path) -> str: | |||||||
|     return create_access_token(token_data, expires_delta=timedelta(minutes=30)) |     return create_access_token(token_data, expires_delta=timedelta(minutes=30)) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def create_recipe_slug_token(file_path: str) -> str: | def create_recipe_slug_token(file_path: str | Path) -> str: | ||||||
|     token_data = {"slug": str(file_path)} |     token_data = {"slug": str(file_path)} | ||||||
|     return create_access_token(token_data, expires_delta=timedelta(minutes=30)) |     return create_access_token(token_data, expires_delta=timedelta(minutes=30)) | ||||||
| 
 | 
 | ||||||
| @@ -96,12 +95,12 @@ def authenticate_user(session, email: str, password: str) -> PrivateUser | bool: | |||||||
| 
 | 
 | ||||||
| def verify_password(plain_password: str, hashed_password: str) -> bool: | def verify_password(plain_password: str, hashed_password: str) -> bool: | ||||||
|     """Compares a plain string to a hashed password""" |     """Compares a plain string to a hashed password""" | ||||||
|     return pwd_context.verify(plain_password, hashed_password) |     return get_hasher().verify(plain_password, hashed_password) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def hash_password(password: str) -> str: | def hash_password(password: str) -> str: | ||||||
|     """Takes in a raw password and hashes it. Used prior to saving a new password to the database.""" |     """Takes in a raw password and hashes it. Used prior to saving a new password to the database.""" | ||||||
|     return pwd_context.hash(password) |     return get_hasher().hash(password) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def url_safe_token() -> str: | def url_safe_token() -> str: | ||||||
| @@ -16,7 +16,7 @@ from mealie.repos.repository_factory import AllRepositories | |||||||
| from mealie.repos.seed.init_users import default_user_init | from mealie.repos.seed.init_users import default_user_init | ||||||
| from mealie.repos.seed.seeders import IngredientFoodsSeeder, IngredientUnitsSeeder, MultiPurposeLabelSeeder | from mealie.repos.seed.seeders import IngredientFoodsSeeder, IngredientUnitsSeeder, MultiPurposeLabelSeeder | ||||||
| from mealie.schema.user.user import GroupBase | from mealie.schema.user.user import GroupBase | ||||||
| from mealie.services.group_services.group_utils import create_new_group | from mealie.services.group_services.group_service import GroupService | ||||||
|  |  | ||||||
| PROJECT_DIR = Path(__file__).parent.parent.parent | PROJECT_DIR = Path(__file__).parent.parent.parent | ||||||
|  |  | ||||||
| @@ -44,7 +44,8 @@ def default_group_init(db: AllRepositories): | |||||||
|     settings = get_app_settings() |     settings = get_app_settings() | ||||||
|  |  | ||||||
|     logger.info("Generating Default Group") |     logger.info("Generating Default Group") | ||||||
|     create_new_group(db, GroupBase(name=settings.DEFAULT_GROUP)) |  | ||||||
|  |     GroupService.create_group(db, GroupBase(name=settings.DEFAULT_GROUP)) | ||||||
|  |  | ||||||
|  |  | ||||||
| # Adapted from https://alembic.sqlalchemy.org/en/latest/cookbook.html#test-current-database-revision-is-at-head-s | # Adapted from https://alembic.sqlalchemy.org/en/latest/cookbook.html#test-current-database-revision-is-at-head-s | ||||||
|   | |||||||
| @@ -8,6 +8,7 @@ from mealie.schema.mapper import mapper | |||||||
| from mealie.schema.query import GetAll | from mealie.schema.query import GetAll | ||||||
| from mealie.schema.response.responses import ErrorResponse | from mealie.schema.response.responses import ErrorResponse | ||||||
| from mealie.schema.user.user import GroupBase, GroupInDB | from mealie.schema.user.user import GroupBase, GroupInDB | ||||||
|  | from mealie.services.group_services.group_service import GroupService | ||||||
|  |  | ||||||
| from .._base import BaseAdminController, controller | from .._base import BaseAdminController, controller | ||||||
| from .._base.dependencies import SharedDependencies | from .._base.dependencies import SharedDependencies | ||||||
| @@ -44,7 +45,7 @@ class AdminUserManagementRoutes(BaseAdminController): | |||||||
|  |  | ||||||
|     @router.post("", response_model=GroupInDB, status_code=status.HTTP_201_CREATED) |     @router.post("", response_model=GroupInDB, status_code=status.HTTP_201_CREATED) | ||||||
|     def create_one(self, data: GroupBase): |     def create_one(self, data: GroupBase): | ||||||
|         return self.mixins.create_one(data) |         return GroupService.create_group(self.deps.repos, data) | ||||||
|  |  | ||||||
|     @router.get("/{item_id}", response_model=GroupInDB) |     @router.get("/{item_id}", response_model=GroupInDB) | ||||||
|     def get_one(self, item_id: UUID4): |     def get_one(self, item_id: UUID4): | ||||||
| @@ -69,7 +70,7 @@ class AdminUserManagementRoutes(BaseAdminController): | |||||||
|     def delete_one(self, item_id: UUID4): |     def delete_one(self, item_id: UUID4): | ||||||
|         item = self.repo.get_one(item_id) |         item = self.repo.get_one(item_id) | ||||||
|  |  | ||||||
|         if len(item.users) > 0: |         if item and len(item.users) > 0: | ||||||
|             raise HTTPException( |             raise HTTPException( | ||||||
|                 status_code=status.HTTP_400_BAD_REQUEST, |                 status_code=status.HTTP_400_BAD_REQUEST, | ||||||
|                 detail=ErrorResponse.respond(message="Cannot delete group with users"), |                 detail=ErrorResponse.respond(message="Cannot delete group with users"), | ||||||
|   | |||||||
| @@ -2,7 +2,9 @@ from pydantic import UUID4 | |||||||
|  |  | ||||||
| from mealie.pkgs.stats import fs_stats | from mealie.pkgs.stats import fs_stats | ||||||
| from mealie.repos.repository_factory import AllRepositories | from mealie.repos.repository_factory import AllRepositories | ||||||
|  | from mealie.schema.group.group_preferences import CreateGroupPreferences | ||||||
| from mealie.schema.group.group_statistics import GroupStatistics, GroupStorage | from mealie.schema.group.group_statistics import GroupStatistics, GroupStorage | ||||||
|  | from mealie.schema.user.user import GroupBase | ||||||
| from mealie.services._base_service import BaseService | from mealie.services._base_service import BaseService | ||||||
|  |  | ||||||
| ALLOWED_SIZE = 500 * fs_stats.megabyte | ALLOWED_SIZE = 500 * fs_stats.megabyte | ||||||
| @@ -14,6 +16,23 @@ class GroupService(BaseService): | |||||||
|         self.repos = repos |         self.repos = repos | ||||||
|         super().__init__() |         super().__init__() | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def create_group(repos: AllRepositories, g_base: GroupBase, prefs: CreateGroupPreferences | None = None): | ||||||
|  |         """ | ||||||
|  |         Creates a new group in the database with the required associated table references to ensure | ||||||
|  |         the group includes required preferences. | ||||||
|  |         """ | ||||||
|  |         new_group = repos.groups.create(g_base) | ||||||
|  |  | ||||||
|  |         if prefs is None: | ||||||
|  |             prefs = CreateGroupPreferences(group_id=new_group.id) | ||||||
|  |         else: | ||||||
|  |             prefs.group_id = new_group.id | ||||||
|  |  | ||||||
|  |         repos.group_preferences.create(prefs) | ||||||
|  |  | ||||||
|  |         return new_group | ||||||
|  |  | ||||||
|     def calculate_statistics(self, group_id: None | UUID4 = None) -> GroupStatistics: |     def calculate_statistics(self, group_id: None | UUID4 = None) -> GroupStatistics: | ||||||
|         """ |         """ | ||||||
|         calculate_statistics calculates the statistics for the group and returns |         calculate_statistics calculates the statistics for the group and returns | ||||||
|   | |||||||
| @@ -1,18 +0,0 @@ | |||||||
| from uuid import uuid4 |  | ||||||
|  |  | ||||||
| from mealie.repos.repository_factory import AllRepositories |  | ||||||
| from mealie.schema.group.group_preferences import CreateGroupPreferences |  | ||||||
| from mealie.schema.user.user import GroupBase, GroupInDB |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def create_new_group(db: AllRepositories, g_base: GroupBase, g_preferences: CreateGroupPreferences = None) -> GroupInDB: |  | ||||||
|     created_group = db.groups.create(g_base) |  | ||||||
|  |  | ||||||
|     # Assign Temporary ID before group is created |  | ||||||
|     g_preferences = g_preferences or CreateGroupPreferences(group_id=uuid4()) |  | ||||||
|  |  | ||||||
|     g_preferences.group_id = created_group.id |  | ||||||
|  |  | ||||||
|     db.group_preferences.create(g_preferences) |  | ||||||
|  |  | ||||||
|     return created_group |  | ||||||
| @@ -8,7 +8,7 @@ from mealie.repos.repository_factory import AllRepositories | |||||||
| from mealie.schema.group.group_preferences import CreateGroupPreferences | from mealie.schema.group.group_preferences import CreateGroupPreferences | ||||||
| from mealie.schema.user.registration import CreateUserRegistration | from mealie.schema.user.registration import CreateUserRegistration | ||||||
| from mealie.schema.user.user import GroupBase, GroupInDB, PrivateUser, UserIn | from mealie.schema.user.user import GroupBase, GroupInDB, PrivateUser, UserIn | ||||||
| from mealie.services.group_services.group_utils import create_new_group | from mealie.services.group_services.group_service import GroupService | ||||||
|  |  | ||||||
|  |  | ||||||
| class RegistrationService: | class RegistrationService: | ||||||
| @@ -19,7 +19,7 @@ class RegistrationService: | |||||||
|         self.logger = logger |         self.logger = logger | ||||||
|         self.repos = db |         self.repos = db | ||||||
|  |  | ||||||
|     def _create_new_user(self, group: GroupInDB, new_group=bool) -> PrivateUser: |     def _create_new_user(self, group: GroupInDB, new_group: bool) -> PrivateUser: | ||||||
|         new_user = UserIn( |         new_user = UserIn( | ||||||
|             email=self.registration.email, |             email=self.registration.email, | ||||||
|             username=self.registration.username, |             username=self.registration.username, | ||||||
| @@ -49,7 +49,7 @@ class RegistrationService: | |||||||
|             recipe_disable_amount=self.registration.advanced, |             recipe_disable_amount=self.registration.advanced, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         return create_new_group(self.repos, group_data, group_preferences) |         return GroupService.create_group(self.repos, group_data, group_preferences) | ||||||
|  |  | ||||||
|     def register_user(self, registration: CreateUserRegistration) -> PrivateUser: |     def register_user(self, registration: CreateUserRegistration) -> PrivateUser: | ||||||
|         self.registration = registration |         self.registration = registration | ||||||
|   | |||||||
							
								
								
									
										22
									
								
								tests/unit_tests/core/test_security.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								tests/unit_tests/core/test_security.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | |||||||
|  | from pytest import MonkeyPatch | ||||||
|  |  | ||||||
|  | from mealie.core.config import get_app_settings | ||||||
|  | from mealie.core.security.hasher import FakeHasher, PasslibHasher, get_hasher | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_get_hasher(monkeypatch: MonkeyPatch): | ||||||
|  |     hasher = get_hasher() | ||||||
|  |  | ||||||
|  |     assert isinstance(hasher, FakeHasher) | ||||||
|  |  | ||||||
|  |     monkeypatch.setenv("TESTING", "0") | ||||||
|  |  | ||||||
|  |     get_hasher.cache_clear() | ||||||
|  |     get_app_settings.cache_clear() | ||||||
|  |  | ||||||
|  |     hasher = get_hasher() | ||||||
|  |  | ||||||
|  |     assert isinstance(hasher, PasslibHasher) | ||||||
|  |  | ||||||
|  |     get_app_settings.cache_clear() | ||||||
|  |     get_hasher.cache_clear() | ||||||
| @@ -44,8 +44,11 @@ def email_service(monkeypatch) -> EmailService: | |||||||
|     return email_service |     return email_service | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_email_disabled(): | def test_email_disabled(monkeypatch): | ||||||
|     email_service = EmailService(TestEmailSender()) |     email_service = EmailService(TestEmailSender()) | ||||||
|  |  | ||||||
|  |     monkeypatch.setenv("SMTP_HOST", "")  # disable email | ||||||
|  |  | ||||||
|     get_app_settings.cache_clear() |     get_app_settings.cache_clear() | ||||||
|     email_service.settings = get_app_settings() |     email_service.settings = get_app_settings() | ||||||
|     success = email_service.send_test_email(FAKE_ADDRESS) |     success = email_service.send_test_email(FAKE_ADDRESS) | ||||||
|   | |||||||
| @@ -60,8 +60,17 @@ def test_pg_connection_args(monkeypatch): | |||||||
|  |  | ||||||
|  |  | ||||||
| def test_smtp_enable(monkeypatch): | def test_smtp_enable(monkeypatch): | ||||||
|  |     monkeypatch.setenv("SMTP_HOST", "") | ||||||
|  |     monkeypatch.setenv("SMTP_PORT", "") | ||||||
|  |     monkeypatch.setenv("SMTP_TLS", "true") | ||||||
|  |     monkeypatch.setenv("SMTP_FROM_NAME", "") | ||||||
|  |     monkeypatch.setenv("SMTP_FROM_EMAIL", "") | ||||||
|  |     monkeypatch.setenv("SMTP_USER", "") | ||||||
|  |     monkeypatch.setenv("SMTP_PASSWORD", "") | ||||||
|  |  | ||||||
|     get_app_settings.cache_clear() |     get_app_settings.cache_clear() | ||||||
|     app_settings = get_app_settings() |     app_settings = get_app_settings() | ||||||
|  |  | ||||||
|     assert app_settings.SMTP_ENABLE is False |     assert app_settings.SMTP_ENABLE is False | ||||||
|  |  | ||||||
|     monkeypatch.setenv("SMTP_HOST", "email.mealie.io") |     monkeypatch.setenv("SMTP_HOST", "email.mealie.io") | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user