From bdbef1ab9e5f3a4d4b25512a0c41d89643fb246a Mon Sep 17 00:00:00 2001 From: Michael Genson <71845777+michael-genson@users.noreply.github.com> Date: Sat, 13 Dec 2025 14:21:54 -0600 Subject: [PATCH] fix: More lenient postgres override parsing (#6712) --- mealie/core/settings/db_providers.py | 26 ++++++------ tests/unit_tests/test_config.py | 63 ++++++++++++++++++++++++++-- 2 files changed, 73 insertions(+), 16 deletions(-) diff --git a/mealie/core/settings/db_providers.py b/mealie/core/settings/db_providers.py index c1fd72613..2bfc398e1 100644 --- a/mealie/core/settings/db_providers.py +++ b/mealie/core/settings/db_providers.py @@ -43,22 +43,22 @@ class PostgresProvider(AbstractDBProvider, BaseSettings): model_config = SettingsConfigDict(arbitrary_types_allowed=True, extra="allow") + def _parse_override_url(self, url: str) -> str: + if not url.startswith("postgresql://"): + raise ValueError("POSTGRES_URL_OVERRIDE scheme must be postgresql") + + scheme, remainder = url.split("://", 1) + if "@" in remainder and ":" in remainder.split("@")[0]: + credentials, host_part = remainder.rsplit("@", 1) + user, password = credentials.split(":", 1) + return f"{scheme}://{user}:{urlparse.quote(password, safe='')}@{host_part}" + + return url + @property def db_url(self) -> str: if self.POSTGRES_URL_OVERRIDE: - url = self.POSTGRES_URL_OVERRIDE - - scheme, remainder = url.split("://", 1) - if scheme != "postgresql": - raise ValueError("POSTGRES_URL_OVERRIDE scheme must be postgresql") - - remainder = remainder.split(":", 1)[1] - password = remainder[: remainder.rfind("@")] - quoted_password = urlparse.quote(password) - - safe_url = url.replace(password, quoted_password) - - return safe_url + return self._parse_override_url(self.POSTGRES_URL_OVERRIDE) return str( PostgresDsn.build( diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index 06d80fde5..15014542a 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -1,6 +1,7 @@ import json import re from dataclasses import dataclass +from typing import Any import pytest @@ -58,6 +59,14 @@ psql_validation_cases = [ "postgresql://mealie:P%40ssword%21%40%23%24%25%25%5E%5E%26%26%2A%2A%28%29%2B%3B%27%22%27%3C%3E%3F%7B%7D%5B%5D@postgres:5432/mealie", ], ), + ( + "unencoded_to_encoded_no_port_url", + [ + "POSTGRES_URL_OVERRIDE", + "postgresql://mealie:P@ssword!@#$%%^^&&**()+;'\"'<>?{}[]@postgres/mealie", + "postgresql://mealie:P%40ssword%21%40%23%24%25%25%5E%5E%26%26%2A%2A%28%29%2B%3B%27%22%27%3C%3E%3F%7B%7D%5B%5D@postgres/mealie", + ], + ), ( "no_encode_needed_password", [ @@ -74,6 +83,54 @@ psql_validation_cases = [ "postgresql://mealie:MyPassword@postgres:5432/mealie", ], ), + ( + "no_password_url", + [ + "POSTGRES_URL_OVERRIDE", + "postgresql://mealie@postgres:5432/mealie", + "postgresql://mealie@postgres:5432/mealie", + ], + ), + ( + "no_password_no_port_url", + [ + "POSTGRES_URL_OVERRIDE", + "postgresql://mealie@postgres/mealie", + "postgresql://mealie@postgres/mealie", + ], + ), + ( + "unix_socket_with_empty_password", + [ + "POSTGRES_URL_OVERRIDE", + "postgresql://mealie:@/mealie?host=/run/postgresql", + "postgresql://mealie:@/mealie?host=/run/postgresql", + ], + ), + ( + "unix_socket_no_password", + [ + "POSTGRES_URL_OVERRIDE", + "postgresql://mealie@/mealie?host=/run/postgresql", + "postgresql://mealie@/mealie?host=/run/postgresql", + ], + ), + ( + "no_credentials_at_all", + [ + "POSTGRES_URL_OVERRIDE", + "postgresql:///mealie?host=/run/postgresql", + "postgresql:///mealie?host=/run/postgresql", + ], + ), + ( + "query_params_with_colon", + [ + "POSTGRES_URL_OVERRIDE", + "postgresql://user@host/db?sslmode=require&connect_timeout=10", + "postgresql://user@host/db?sslmode=require&connect_timeout=10", + ], + ), ] psql_cases = [x[1] for x in psql_validation_cases] @@ -174,11 +231,11 @@ def test_smtp_enable_with_bad_data_tls(data: SMTPValidationCase): @dataclass(slots=True) class EnvVar: name: str - value: any + value: Any class LDAPValidationCase: - settings = list[EnvVar] + settings: list[EnvVar] is_valid: bool def __init__( @@ -222,7 +279,7 @@ def test_ldap_settings_validation(data: LDAPValidationCase, monkeypatch: pytest. class OIDCValidationCase: - settings = list[EnvVar] + settings: list[EnvVar] is_valid: bool def __init__(