mirror of
				https://github.com/mealie-recipes/mealie.git
				synced 2025-10-31 02:03:35 -04:00 
			
		
		
		
	fix: unclosed sessions (#1734)
* resolve session leak * cleanup session management functions
This commit is contained in:
		| @@ -31,22 +31,31 @@ SessionLocal, engine = sql_global_init(settings.DB_URL)  # type: ignore | ||||
|  | ||||
|  | ||||
| @contextmanager | ||||
| def with_session() -> Session: | ||||
| def session_context() -> Session: | ||||
|     """ | ||||
|     session_context() provides a managed session to the database that is automatically | ||||
|     closed when the context is exited. This is the preferred method of accessing the | ||||
|     database. | ||||
|  | ||||
|     Note: use `generate_session` when using the `Depends` function from FastAPI | ||||
|     """ | ||||
|     global SessionLocal | ||||
|     sess = SessionLocal() | ||||
|  | ||||
|     try: | ||||
|         yield sess | ||||
|     finally: | ||||
|         sess.close() | ||||
|  | ||||
|  | ||||
| def create_session() -> Session: | ||||
|     global SessionLocal | ||||
|     return SessionLocal() | ||||
|  | ||||
|  | ||||
| def generate_session() -> Generator[Session, None, None]: | ||||
|     """ | ||||
|     WARNING: This function should _only_ be called when used with | ||||
|     using the `Depends` function from FastAPI. This function will leak | ||||
|     sessions if used outside of the context of a request. | ||||
|  | ||||
|     Use `with_session` instead. That function will allow you to use the | ||||
|     session within a context manager | ||||
|     """ | ||||
|     global SessionLocal | ||||
|     db = SessionLocal() | ||||
|     try: | ||||
|   | ||||
| @@ -9,7 +9,7 @@ from alembic.config import Config | ||||
| from alembic.runtime import migration | ||||
| from mealie.core import root_logger | ||||
| from mealie.core.config import get_app_settings | ||||
| from mealie.db.db_setup import create_session | ||||
| from mealie.db.db_setup import session_context | ||||
| from mealie.db.fixes.fix_slug_foods import fix_slug_food_names | ||||
| from mealie.repos.all_repositories import get_repositories | ||||
| from mealie.repos.repository_factory import AllRepositories | ||||
| @@ -67,41 +67,40 @@ def connect(session: orm.Session) -> bool: | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     session = create_session() | ||||
|  | ||||
|     # Wait for database to connect | ||||
|     max_retry = 10 | ||||
|     wait_seconds = 1 | ||||
|  | ||||
|     while True: | ||||
|         if connect(session): | ||||
|             logger.info("Database connection established.") | ||||
|             break | ||||
|     with session_context() as session: | ||||
|         while True: | ||||
|             if connect(session): | ||||
|                 logger.info("Database connection established.") | ||||
|                 break | ||||
|  | ||||
|         logger.error(f"Database connection failed. Retrying in {wait_seconds} seconds...") | ||||
|         max_retry -= 1 | ||||
|             logger.error(f"Database connection failed. Retrying in {wait_seconds} seconds...") | ||||
|             max_retry -= 1 | ||||
|  | ||||
|         sleep(wait_seconds) | ||||
|             sleep(wait_seconds) | ||||
|  | ||||
|         if max_retry == 0: | ||||
|             raise ConnectionError("Database connection failed - exiting application.") | ||||
|             if max_retry == 0: | ||||
|                 raise ConnectionError("Database connection failed - exiting application.") | ||||
|  | ||||
|     alembic_cfg = Config(str(PROJECT_DIR / "alembic.ini")) | ||||
|     if db_is_at_head(alembic_cfg): | ||||
|         logger.info("Migration not needed.") | ||||
|     else: | ||||
|         logger.info("Migration needed. Performing migration...") | ||||
|         command.upgrade(alembic_cfg, "head") | ||||
|         alembic_cfg = Config(str(PROJECT_DIR / "alembic.ini")) | ||||
|         if db_is_at_head(alembic_cfg): | ||||
|             logger.info("Migration not needed.") | ||||
|         else: | ||||
|             logger.info("Migration needed. Performing migration...") | ||||
|             command.upgrade(alembic_cfg, "head") | ||||
|  | ||||
|     db = get_repositories(session) | ||||
|         db = get_repositories(session) | ||||
|  | ||||
|     if db.users.get_all(): | ||||
|         logger.info("Database exists") | ||||
|     else: | ||||
|         logger.info("Database contains no users, initializing...") | ||||
|         init_db(db) | ||||
|         if db.users.get_all(): | ||||
|             logger.info("Database exists") | ||||
|         else: | ||||
|             logger.info("Database contains no users, initializing...") | ||||
|             init_db(db) | ||||
|  | ||||
|     safe_try(lambda: fix_slug_food_names(db)) | ||||
|         safe_try(lambda: fix_slug_food_names(db)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| from mealie.core import root_logger | ||||
| from mealie.db.db_setup import with_session | ||||
| from mealie.db.db_setup import session_context | ||||
| from mealie.repos.repository_factory import AllRepositories | ||||
| from mealie.services.user_services.user_service import UserService | ||||
|  | ||||
| @@ -13,7 +13,7 @@ def main(): | ||||
|  | ||||
|     logger = root_logger.get_logger() | ||||
|  | ||||
|     with with_session() as session: | ||||
|     with session_context() as session: | ||||
|         repos = AllRepositories(session) | ||||
|         user_service = UserService(repos) | ||||
|  | ||||
|   | ||||
| @@ -1,4 +1,7 @@ | ||||
| import contextlib | ||||
| import json | ||||
| from abc import ABC, abstractmethod | ||||
| from collections.abc import Generator | ||||
| from datetime import datetime, timezone | ||||
| from typing import cast | ||||
| from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit | ||||
| @@ -7,6 +10,7 @@ from fastapi.encoders import jsonable_encoder | ||||
| from pydantic import UUID4 | ||||
| from sqlalchemy.orm.session import Session | ||||
|  | ||||
| from mealie.db.db_setup import session_context | ||||
| from mealie.db.models.group.webhooks import GroupWebhooksModel | ||||
| from mealie.repos.all_repositories import get_repositories | ||||
| from mealie.repos.repository_factory import AllRepositories | ||||
| @@ -18,34 +22,58 @@ from .event_types import Event, EventDocumentType, EventTypes, EventWebhookData | ||||
| from .publisher import ApprisePublisher, PublisherLike, WebhookPublisher | ||||
|  | ||||
|  | ||||
| class EventListenerBase: | ||||
| class EventListenerBase(ABC): | ||||
|     session: Session | None | ||||
|  | ||||
|     def __init__(self, session: Session, group_id: UUID4, publisher: PublisherLike) -> None: | ||||
|         self.session = session | ||||
|         self.group_id = group_id | ||||
|         self.publisher = publisher | ||||
|  | ||||
|     @abstractmethod | ||||
|     def get_subscribers(self, event: Event) -> list: | ||||
|         """Get a list of all subscribers to this event""" | ||||
|         ... | ||||
|  | ||||
|     @abstractmethod | ||||
|     def publish_to_subscribers(self, event: Event, subscribers: list) -> None: | ||||
|         """Publishes the event to all subscribers""" | ||||
|         ... | ||||
|  | ||||
|     @contextlib.contextmanager | ||||
|     def ensure_session(self) -> Generator[None, None, None]: | ||||
|         """ | ||||
|         ensure_session ensures that a session is available for the caller by checking if a session | ||||
|         was provided during construction, and if not, creating a new session with the `with_session` | ||||
|         function and closing it when the context manager exits. | ||||
|  | ||||
|         This is _required_ when working with sessions inside an event bus listener where the listener | ||||
|         may be constructed during a request where the session is provided by the request, but the when | ||||
|         run as a scheduled task, the session is not provided and must be created. | ||||
|         """ | ||||
|         if self.session is None: | ||||
|             with session_context() as session: | ||||
|                 self.session = session | ||||
|                 yield | ||||
|  | ||||
|         else: | ||||
|             yield | ||||
|  | ||||
|  | ||||
| class AppriseEventListener(EventListenerBase): | ||||
|     def __init__(self, session: Session, group_id: UUID4) -> None: | ||||
|         super().__init__(session, group_id, ApprisePublisher()) | ||||
|  | ||||
|     def get_subscribers(self, event: Event) -> list[str]: | ||||
|         repos = AllRepositories(self.session) | ||||
|         with self.ensure_session(): | ||||
|             repos = AllRepositories(self.session) | ||||
|  | ||||
|         notifiers: list[GroupEventNotifierPrivate] = repos.group_event_notifier.by_group(  # type: ignore | ||||
|             self.group_id | ||||
|         ).multi_query({"enabled": True}, override_schema=GroupEventNotifierPrivate) | ||||
|             notifiers: list[GroupEventNotifierPrivate] = repos.group_event_notifier.by_group(  # type: ignore | ||||
|                 self.group_id | ||||
|             ).multi_query({"enabled": True}, override_schema=GroupEventNotifierPrivate) | ||||
|  | ||||
|         urls = [notifier.apprise_url for notifier in notifiers if getattr(notifier.options, event.event_type.name)] | ||||
|         urls = AppriseEventListener.update_urls_with_event_data(urls, event) | ||||
|             urls = [notifier.apprise_url for notifier in notifiers if getattr(notifier.options, event.event_type.name)] | ||||
|             urls = AppriseEventListener.update_urls_with_event_data(urls, event) | ||||
|  | ||||
|         return urls | ||||
|  | ||||
| @@ -120,12 +148,13 @@ class WebhookEventListener(EventListenerBase): | ||||
|  | ||||
|     def get_scheduled_webhooks(self, start_dt: datetime, end_dt: datetime) -> list[ReadWebhook]: | ||||
|         """Fetches all scheduled webhooks from the database""" | ||||
|         return ( | ||||
|             self.session.query(GroupWebhooksModel) | ||||
|             .where( | ||||
|                 GroupWebhooksModel.enabled == True,  # noqa: E712 - required for SQLAlchemy comparison | ||||
|                 GroupWebhooksModel.scheduled_time > start_dt.astimezone(timezone.utc).time(), | ||||
|                 GroupWebhooksModel.scheduled_time <= end_dt.astimezone(timezone.utc).time(), | ||||
|         with self.ensure_session(): | ||||
|             return ( | ||||
|                 self.session.query(GroupWebhooksModel) | ||||
|                 .where( | ||||
|                     GroupWebhooksModel.enabled == True,  # noqa: E712 - required for SQLAlchemy comparison | ||||
|                     GroupWebhooksModel.scheduled_time > start_dt.astimezone(timezone.utc).time(), | ||||
|                     GroupWebhooksModel.scheduled_time <= end_dt.astimezone(timezone.utc).time(), | ||||
|                 ) | ||||
|                 .all() | ||||
|             ) | ||||
|             .all() | ||||
|         ) | ||||
|   | ||||
| @@ -40,12 +40,13 @@ class EventSource: | ||||
|  | ||||
|  | ||||
| class EventBusService: | ||||
|     bg: BackgroundTasks | None | ||||
|     session: Session | None | ||||
|     group_id: UUID4 | None | ||||
|  | ||||
|     def __init__( | ||||
|         self, bg: Optional[BackgroundTasks] = None, session: Optional[Session] = None, group_id: UUID4 | None = None | ||||
|     ) -> None: | ||||
|         if not session: | ||||
|             session = next(generate_session()) | ||||
|  | ||||
|         self.bg = bg | ||||
|         self.session = session | ||||
|         self.group_id = group_id | ||||
|   | ||||
| @@ -3,7 +3,7 @@ from typing import Optional | ||||
|  | ||||
| from pydantic import UUID4 | ||||
|  | ||||
| from mealie.db.db_setup import create_session | ||||
| from mealie.db.db_setup import session_context | ||||
| from mealie.repos.all_repositories import get_repositories | ||||
| from mealie.schema.response.pagination import PaginationQuery | ||||
| from mealie.services.event_bus_service.event_bus_service import EventBusService | ||||
| @@ -31,10 +31,11 @@ def post_group_webhooks(start_dt: Optional[datetime] = None, group_id: Optional[ | ||||
|  | ||||
|     if group_id is None: | ||||
|         # publish the webhook event to each group's event bus | ||||
|         session = create_session() | ||||
|         repos = get_repositories(session) | ||||
|         groups_data = repos.groups.page_all(PaginationQuery(page=1, per_page=-1)) | ||||
|         group_ids = [group.id for group in groups_data.items] | ||||
|  | ||||
|         with session_context() as session: | ||||
|             repos = get_repositories(session) | ||||
|             groups_data = repos.groups.page_all(PaginationQuery(page=1, per_page=-1)) | ||||
|             group_ids = [group.id for group in groups_data.items] | ||||
|  | ||||
|     else: | ||||
|         group_ids = [group_id] | ||||
|   | ||||
| @@ -3,7 +3,7 @@ from pathlib import Path | ||||
|  | ||||
| from mealie.core import root_logger | ||||
| from mealie.core.config import get_app_dirs | ||||
| from mealie.db.db_setup import create_session | ||||
| from mealie.db.db_setup import session_context | ||||
| from mealie.db.models.group.exports import GroupDataExportsModel | ||||
|  | ||||
| ONE_DAY_AS_MINUTES = 1440 | ||||
| @@ -15,20 +15,19 @@ def purge_group_data_exports(max_minutes_old=ONE_DAY_AS_MINUTES): | ||||
|  | ||||
|     logger.info("purging group data exports") | ||||
|     limit = datetime.datetime.now() - datetime.timedelta(minutes=max_minutes_old) | ||||
|     session = create_session() | ||||
|  | ||||
|     results = session.query(GroupDataExportsModel).filter(GroupDataExportsModel.expires <= limit) | ||||
|     with session_context() as session: | ||||
|         results = session.query(GroupDataExportsModel).filter(GroupDataExportsModel.expires <= limit) | ||||
|  | ||||
|     total_removed = 0 | ||||
|     for result in results: | ||||
|         session.delete(result) | ||||
|         Path(result.path).unlink(missing_ok=True) | ||||
|         total_removed += 1 | ||||
|         total_removed = 0 | ||||
|         for result in results: | ||||
|             session.delete(result) | ||||
|             Path(result.path).unlink(missing_ok=True) | ||||
|             total_removed += 1 | ||||
|  | ||||
|     session.commit() | ||||
|     session.close() | ||||
|         session.commit() | ||||
|  | ||||
|     logger.info(f"finished purging group data exports. {total_removed} exports removed from group data") | ||||
|         logger.info(f"finished purging group data exports. {total_removed} exports removed from group data") | ||||
|  | ||||
|  | ||||
| def purge_excess_files() -> None: | ||||
|   | ||||
| @@ -1,7 +1,7 @@ | ||||
| import datetime | ||||
|  | ||||
| from mealie.core import root_logger | ||||
| from mealie.db.db_setup import create_session | ||||
| from mealie.db.db_setup import session_context | ||||
| from mealie.db.models.users.password_reset import PasswordResetModel | ||||
|  | ||||
| logger = root_logger.get_logger() | ||||
| @@ -13,8 +13,9 @@ def purge_password_reset_tokens(): | ||||
|     """Purges all events after x days""" | ||||
|     logger.info("purging password reset tokens") | ||||
|     limit = datetime.datetime.now() - datetime.timedelta(days=MAX_DAYS_OLD) | ||||
|     session = create_session() | ||||
|     session.query(PasswordResetModel).filter(PasswordResetModel.created_at <= limit).delete() | ||||
|     session.commit() | ||||
|     session.close() | ||||
|     logger.info("password reset tokens purges") | ||||
|  | ||||
|     with session_context() as session: | ||||
|         session.query(PasswordResetModel).filter(PasswordResetModel.created_at <= limit).delete() | ||||
|         session.commit() | ||||
|         session.close() | ||||
|         logger.info("password reset tokens purges") | ||||
|   | ||||
| @@ -1,7 +1,7 @@ | ||||
| import datetime | ||||
|  | ||||
| from mealie.core import root_logger | ||||
| from mealie.db.db_setup import create_session | ||||
| from mealie.db.db_setup import session_context | ||||
| from mealie.db.models.group import GroupInviteToken | ||||
|  | ||||
| logger = root_logger.get_logger() | ||||
| @@ -13,8 +13,10 @@ def purge_group_registration(): | ||||
|     """Purges all events after x days""" | ||||
|     logger.info("purging expired registration tokens") | ||||
|     limit = datetime.datetime.now() - datetime.timedelta(days=MAX_DAYS_OLD) | ||||
|     session = create_session() | ||||
|     session.query(GroupInviteToken).filter(GroupInviteToken.created_at <= limit).delete() | ||||
|     session.commit() | ||||
|     session.close() | ||||
|  | ||||
|     with session_context() as session: | ||||
|         session.query(GroupInviteToken).filter(GroupInviteToken.created_at <= limit).delete() | ||||
|         session.commit() | ||||
|         session.close() | ||||
|  | ||||
|     logger.info("registration token purged") | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| from mealie.core import root_logger | ||||
| from mealie.db.db_setup import with_session | ||||
| from mealie.db.db_setup import session_context | ||||
| from mealie.repos.repository_factory import AllRepositories | ||||
| from mealie.services.user_services.user_service import UserService | ||||
|  | ||||
| @@ -8,7 +8,7 @@ def locked_user_reset(): | ||||
|     logger = root_logger.get_logger() | ||||
|     logger.info("resetting locked users") | ||||
|  | ||||
|     with with_session() as session: | ||||
|     with session_context() as session: | ||||
|         repos = AllRepositories(session) | ||||
|         user_service = UserService(repos) | ||||
|  | ||||
|   | ||||
| @@ -3,7 +3,7 @@ import json | ||||
| import pytest | ||||
| from fastapi.testclient import TestClient | ||||
|  | ||||
| from mealie.db.db_setup import create_session | ||||
| from mealie.db.db_setup import session_context | ||||
| from mealie.services.user_services.password_reset_service import PasswordResetService | ||||
| from tests.utils.factories import random_string | ||||
| from tests.utils.fixture_schemas import TestUser | ||||
| @@ -31,10 +31,10 @@ def test_password_reset(api_client: TestClient, unique_user: TestUser, casing: s | ||||
|                 cased_email += l.lower() | ||||
|         cased_email | ||||
|  | ||||
|     session = create_session() | ||||
|     service = PasswordResetService(session) | ||||
|     token = service.generate_reset_token(cased_email) | ||||
|     assert token is not None | ||||
|     with session_context() as session: | ||||
|         service = PasswordResetService(session) | ||||
|         token = service.generate_reset_token(cased_email) | ||||
|         assert token is not None | ||||
|  | ||||
|     new_password = random_string(15) | ||||
|  | ||||
| @@ -59,8 +59,6 @@ def test_password_reset(api_client: TestClient, unique_user: TestUser, casing: s | ||||
|     response = api_client.get(Routes.self, headers={"Authorization": f"Bearer {new_token}"}) | ||||
|     assert response.status_code == 200 | ||||
|  | ||||
|     session.close() | ||||
|  | ||||
|     # Test successful password reset | ||||
|     response = api_client.post(Routes.base, json=payload) | ||||
|     assert response.status_code == 400 | ||||
|   | ||||
| @@ -5,7 +5,7 @@ from pytest import MonkeyPatch | ||||
| from mealie.core import security | ||||
| from mealie.core.config import get_app_settings | ||||
| from mealie.core.dependencies import validate_file_token | ||||
| from mealie.db.db_setup import create_session | ||||
| from mealie.db.db_setup import session_context | ||||
| from tests.utils.factories import random_string | ||||
|  | ||||
|  | ||||
| @@ -47,5 +47,8 @@ def test_ldap_authentication_mocked(monkeypatch: MonkeyPatch): | ||||
|     monkeypatch.setattr(ldap, "initialize", ldap_initialize_mock) | ||||
|  | ||||
|     get_app_settings.cache_clear() | ||||
|     result = security.authenticate_user(create_session(), user, password) | ||||
|  | ||||
|     with session_context() as session: | ||||
|         result = security.authenticate_user(session, user, password) | ||||
|  | ||||
|     assert result is False | ||||
|   | ||||
		Reference in New Issue
	
	Block a user