mirror of
https://github.com/mealie-recipes/mealie.git
synced 2026-02-09 09:23:12 -05:00
feat: Add OIDC_CLIENT_SECRET and other changes for v2 (#4254)
Co-authored-by: boc-the-git <3479092+boc-the-git@users.noreply.github.com>
This commit is contained in:
@@ -49,7 +49,10 @@ class AuthProvider(Generic[T], metaclass=abc.ABCMeta):
|
||||
|
||||
to_encode["exp"] = expire
|
||||
to_encode["iss"] = ISS
|
||||
return (jwt.encode(to_encode, settings.SECRET, algorithm=ALGORITHM), expires_delta)
|
||||
return (
|
||||
jwt.encode(to_encode, settings.SECRET, algorithm=ALGORITHM),
|
||||
expires_delta,
|
||||
)
|
||||
|
||||
def try_get_user(self, username: str) -> PrivateUser | None:
|
||||
"""Try to get a user from the database, first trying username, then trying email"""
|
||||
@@ -66,6 +69,6 @@ class AuthProvider(Generic[T], metaclass=abc.ABCMeta):
|
||||
return user
|
||||
|
||||
@abc.abstractmethod
|
||||
async def authenticate(self) -> tuple[str, timedelta] | None:
|
||||
def authenticate(self) -> tuple[str, timedelta] | None:
|
||||
"""Attempt to authenticate a user"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -20,7 +20,7 @@ class CredentialsProvider(AuthProvider[CredentialsRequest]):
|
||||
def __init__(self, session: Session, data: CredentialsRequest) -> None:
|
||||
super().__init__(session, data)
|
||||
|
||||
async def authenticate(self) -> tuple[str, timedelta] | None:
|
||||
def authenticate(self) -> tuple[str, timedelta] | None:
|
||||
"""Attempt to authenticate a user given a username and password"""
|
||||
settings = get_app_settings()
|
||||
db = get_repositories(self.session, group_id=None, household_id=None)
|
||||
@@ -30,7 +30,8 @@ class CredentialsProvider(AuthProvider[CredentialsRequest]):
|
||||
# To prevent user enumeration we perform the verify_password computation to ensure
|
||||
# server side time is relatively constant and not vulnerable to timing attacks.
|
||||
CredentialsProvider.verify_password(
|
||||
"abc123cba321", "$2b$12$JdHtJOlkPFwyxdjdygEzPOtYmdQF5/R5tHxw5Tq8pxjubyLqdIX5i"
|
||||
"abc123cba321",
|
||||
"$2b$12$JdHtJOlkPFwyxdjdygEzPOtYmdQF5/R5tHxw5Tq8pxjubyLqdIX5i",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ class LDAPProvider(CredentialsProvider):
|
||||
super().__init__(session, data)
|
||||
self.conn = None
|
||||
|
||||
async def authenticate(self) -> tuple[str, timedelta] | None:
|
||||
def authenticate(self) -> tuple[str, timedelta] | None:
|
||||
"""Attempt to authenticate a user given a username and password"""
|
||||
user = self.try_get_user(self.data.username)
|
||||
if not user or user.password == "LDAP" or user.auth_method == AuthMethod.LDAP:
|
||||
@@ -30,7 +30,7 @@ class LDAPProvider(CredentialsProvider):
|
||||
if user:
|
||||
return self.get_access_token(user, self.data.remember_me)
|
||||
|
||||
return await super().authenticate()
|
||||
return super().authenticate()
|
||||
|
||||
def search_user(self, conn: LDAPObject) -> list[tuple[str, dict[str, list[bytes]]]] | None:
|
||||
"""
|
||||
@@ -64,7 +64,11 @@ class LDAPProvider(CredentialsProvider):
|
||||
settings.LDAP_BASE_DN,
|
||||
ldap.SCOPE_SUBTREE,
|
||||
search_filter,
|
||||
[settings.LDAP_ID_ATTRIBUTE, settings.LDAP_NAME_ATTRIBUTE, settings.LDAP_MAIL_ATTRIBUTE],
|
||||
[
|
||||
settings.LDAP_ID_ATTRIBUTE,
|
||||
settings.LDAP_NAME_ATTRIBUTE,
|
||||
settings.LDAP_MAIL_ATTRIBUTE,
|
||||
],
|
||||
)
|
||||
except ldap.FILTER_ERROR:
|
||||
self._logger.error("[LDAP] Bad user search filter")
|
||||
|
||||
@@ -1,55 +1,58 @@
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from functools import lru_cache
|
||||
|
||||
import requests
|
||||
from authlib.jose import JsonWebKey, JsonWebToken, JWTClaims, KeySet
|
||||
from authlib.jose.errors import ExpiredTokenError, UnsupportedAlgorithmError
|
||||
from authlib.oidc.core import CodeIDToken
|
||||
from authlib.oidc.core import UserInfo
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from mealie.core import root_logger
|
||||
from mealie.core.config import get_app_settings
|
||||
from mealie.core.security.providers.auth_provider import AuthProvider
|
||||
from mealie.core.settings.settings import AppSettings
|
||||
from mealie.db.models.users.users import AuthMethod
|
||||
from mealie.repos.all_repositories import get_repositories
|
||||
from mealie.schema.user.auth import OIDCRequest
|
||||
|
||||
|
||||
class OpenIDProvider(AuthProvider[OIDCRequest]):
|
||||
class OpenIDProvider(AuthProvider[UserInfo]):
|
||||
"""Authentication provider that authenticates a user using a token from OIDC ID token"""
|
||||
|
||||
_logger = root_logger.get_logger("openid_provider")
|
||||
|
||||
def __init__(self, session: Session, data: OIDCRequest) -> None:
|
||||
def __init__(self, session: Session, data: UserInfo) -> None:
|
||||
super().__init__(session, data)
|
||||
|
||||
async def authenticate(self) -> tuple[str, timedelta] | None:
|
||||
def authenticate(self) -> tuple[str, timedelta] | None:
|
||||
"""Attempt to authenticate a user given a username and password"""
|
||||
|
||||
settings = get_app_settings()
|
||||
claims = self.get_claims(settings)
|
||||
claims = self.data
|
||||
if not claims:
|
||||
self._logger.error("[OIDC] No claims in the id_token")
|
||||
return None
|
||||
|
||||
if not self.required_claims.issubset(claims.keys()):
|
||||
self._logger.error(
|
||||
"[OIDC] Required claims not present. Expected: %s Actual: %s",
|
||||
self.required_claims,
|
||||
claims.keys(),
|
||||
)
|
||||
return None
|
||||
|
||||
repos = get_repositories(self.session, group_id=None, household_id=None)
|
||||
|
||||
user = self.try_get_user(claims.get(settings.OIDC_USER_CLAIM))
|
||||
is_admin = False
|
||||
if settings.OIDC_USER_GROUP or settings.OIDC_ADMIN_GROUP:
|
||||
if settings.OIDC_REQUIRES_GROUP_CLAIM:
|
||||
group_claim = claims.get(settings.OIDC_GROUPS_CLAIM, []) or []
|
||||
is_admin = settings.OIDC_ADMIN_GROUP in group_claim if settings.OIDC_ADMIN_GROUP else False
|
||||
is_valid_user = settings.OIDC_USER_GROUP in group_claim if settings.OIDC_USER_GROUP else True
|
||||
|
||||
if not is_valid_user:
|
||||
self._logger.debug(
|
||||
"[OIDC] User does not have the required group. Found: %s - Required: %s",
|
||||
if not (is_valid_user or is_admin):
|
||||
self._logger.warning(
|
||||
"[OIDC] Successfully authenticated, but user does not have one of the required group(s). \
|
||||
Found: %s - Required (one of): %s",
|
||||
group_claim,
|
||||
settings.OIDC_USER_GROUP,
|
||||
[settings.OIDC_USER_GROUP, settings.OIDC_ADMIN_GROUP],
|
||||
)
|
||||
return None
|
||||
|
||||
user = self.try_get_user(claims.get(settings.OIDC_USER_CLAIM))
|
||||
if not user:
|
||||
if not settings.OIDC_SIGNUP_ENABLED:
|
||||
self._logger.debug("[OIDC] No user found. Not creating a new user - new user creation is disabled.")
|
||||
@@ -57,22 +60,31 @@ class OpenIDProvider(AuthProvider[OIDCRequest]):
|
||||
|
||||
self._logger.debug("[OIDC] No user found. Creating new OIDC user.")
|
||||
|
||||
user = repos.users.create(
|
||||
{
|
||||
"username": claims.get("preferred_username"),
|
||||
"password": "OIDC",
|
||||
"full_name": claims.get("name"),
|
||||
"email": claims.get("email"),
|
||||
"admin": is_admin,
|
||||
"auth_method": AuthMethod.OIDC,
|
||||
}
|
||||
)
|
||||
self.session.commit()
|
||||
try:
|
||||
# some IdPs don't provide a username (looking at you Google), so if we don't have the claim,
|
||||
# we'll create the user with whatever the USER_CLAIM is (default email)
|
||||
username = claims.get("preferred_username", claims.get(settings.OIDC_USER_CLAIM))
|
||||
user = repos.users.create(
|
||||
{
|
||||
"username": username,
|
||||
"password": "OIDC",
|
||||
"full_name": claims.get("name"),
|
||||
"email": claims.get("email"),
|
||||
"admin": is_admin,
|
||||
"auth_method": AuthMethod.OIDC,
|
||||
}
|
||||
)
|
||||
self.session.commit()
|
||||
|
||||
except Exception as e:
|
||||
self._logger.error("[OIDC] Exception while creating user: %s", e)
|
||||
return None
|
||||
|
||||
return self.get_access_token(user, settings.OIDC_REMEMBER_ME) # type: ignore
|
||||
|
||||
if user:
|
||||
if settings.OIDC_ADMIN_GROUP and user.admin != is_admin:
|
||||
self._logger.debug(f"[OIDC] {'Setting' if is_admin else 'Removing'} user as admin")
|
||||
self._logger.debug("[OIDC] %s user as admin", "Setting" if is_admin else "Removing")
|
||||
user.admin = is_admin
|
||||
repos.users.update(user.id, user)
|
||||
return self.get_access_token(user, settings.OIDC_REMEMBER_ME)
|
||||
@@ -80,78 +92,11 @@ class OpenIDProvider(AuthProvider[OIDCRequest]):
|
||||
self._logger.warning("[OIDC] Found user but their AuthMethod does not match OIDC")
|
||||
return None
|
||||
|
||||
def get_claims(self, settings: AppSettings) -> JWTClaims | None:
|
||||
"""Get the claims from the ID token and check if the required claims are present"""
|
||||
required_claims = {
|
||||
"preferred_username",
|
||||
"name",
|
||||
"email",
|
||||
settings.OIDC_USER_CLAIM,
|
||||
}
|
||||
jwks = OpenIDProvider.get_jwks(self.get_ttl_hash()) # cache the key set for 30 minutes
|
||||
if not jwks:
|
||||
return None
|
||||
|
||||
algorithm = settings.OIDC_SIGNING_ALGORITHM
|
||||
try:
|
||||
claims = JsonWebToken([algorithm]).decode(s=self.data.id_token, key=jwks, claims_cls=CodeIDToken)
|
||||
except UnsupportedAlgorithmError:
|
||||
self._logger.error(
|
||||
f"[OIDC] Unsupported algorithm '{algorithm}'. Unable to decode id token due to mismatched algorithm."
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
claims.validate()
|
||||
except ExpiredTokenError as e:
|
||||
self._logger.error(f"[OIDC] {e.error}: {e.description}")
|
||||
return None
|
||||
except Exception as e:
|
||||
self._logger.error("[OIDC] Exception while validating id_token claims", e)
|
||||
|
||||
if not claims:
|
||||
self._logger.error("[OIDC] Claims not found")
|
||||
return None
|
||||
if not required_claims.issubset(claims.keys()):
|
||||
self._logger.error(
|
||||
f"[OIDC] Required claims not present. Expected: {required_claims} Actual: {claims.keys()}"
|
||||
)
|
||||
return None
|
||||
return claims
|
||||
|
||||
@lru_cache
|
||||
@staticmethod
|
||||
def get_jwks(ttl_hash=None) -> KeySet | None:
|
||||
"""Get the key set from the openid configuration"""
|
||||
del ttl_hash # ttl_hash is used for caching only
|
||||
@property
|
||||
def required_claims(self):
|
||||
settings = get_app_settings()
|
||||
|
||||
if not (settings.OIDC_READY and settings.OIDC_CONFIGURATION_URL):
|
||||
return None
|
||||
|
||||
session = requests.Session()
|
||||
if settings.OIDC_TLS_CACERTFILE:
|
||||
session.verify = settings.OIDC_TLS_CACERTFILE
|
||||
|
||||
config_response = session.get(settings.OIDC_CONFIGURATION_URL, timeout=5)
|
||||
config_response.raise_for_status()
|
||||
configuration = config_response.json()
|
||||
|
||||
if not configuration:
|
||||
OpenIDProvider._logger.warning("[OIDC] Unable to fetch configuration from the OIDC_CONFIGURATION_URL")
|
||||
session.close()
|
||||
return None
|
||||
|
||||
jwks_uri = configuration.get("jwks_uri", None)
|
||||
if not jwks_uri:
|
||||
OpenIDProvider._logger.warning("[OIDC] Unable to find the jwks_uri from the OIDC_CONFIGURATION_URL")
|
||||
session.close()
|
||||
return None
|
||||
|
||||
response = session.get(jwks_uri, timeout=5)
|
||||
response.raise_for_status()
|
||||
session.close()
|
||||
return JsonWebKey.import_key_set(response.json())
|
||||
|
||||
def get_ttl_hash(self, seconds=1800):
|
||||
return time.time() // seconds
|
||||
claims = {"name", "email", settings.OIDC_USER_CLAIM}
|
||||
if settings.OIDC_REQUIRES_GROUP_CLAIM:
|
||||
claims.add(settings.OIDC_GROUPS_CLAIM)
|
||||
return claims
|
||||
|
||||
@@ -3,7 +3,6 @@ from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import jwt
|
||||
from fastapi import Request
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from mealie.core import root_logger
|
||||
@@ -12,20 +11,16 @@ from mealie.core.security.hasher import get_hasher
|
||||
from mealie.core.security.providers.auth_provider import AuthProvider
|
||||
from mealie.core.security.providers.credentials_provider import CredentialsProvider
|
||||
from mealie.core.security.providers.ldap_provider import LDAPProvider
|
||||
from mealie.core.security.providers.openid_provider import OpenIDProvider
|
||||
from mealie.schema.user.auth import CredentialsRequest, CredentialsRequestForm, OIDCRequest
|
||||
from mealie.schema.user.auth import CredentialsRequest, CredentialsRequestForm
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
logger = root_logger.get_logger("security")
|
||||
|
||||
|
||||
def get_auth_provider(session: Session, request: Request, data: CredentialsRequestForm) -> AuthProvider:
|
||||
def get_auth_provider(session: Session, data: CredentialsRequestForm) -> AuthProvider:
|
||||
settings = get_app_settings()
|
||||
|
||||
if request.cookies.get("mealie.auth.strategy") == "oidc":
|
||||
return OpenIDProvider(session, OIDCRequest(id_token=request.cookies.get("mealie.auth._id_token.oidc")))
|
||||
|
||||
credentials_request = CredentialsRequest(**data.__dict__)
|
||||
if settings.LDAP_ENABLED:
|
||||
return LDAPProvider(session, credentials_request)
|
||||
|
||||
@@ -19,11 +19,11 @@ class ScheduleTime(NamedTuple):
|
||||
minute: int
|
||||
|
||||
|
||||
def determine_secrets(data_dir: Path, production: bool) -> str:
|
||||
def determine_secrets(data_dir: Path, secret: str, production: bool) -> str:
|
||||
if not production:
|
||||
return "shh-secret-test-key"
|
||||
|
||||
secrets_file = data_dir.joinpath(".secret")
|
||||
secrets_file = data_dir.joinpath(secret)
|
||||
if secrets_file.is_file():
|
||||
with open(secrets_file) as f:
|
||||
return f.read()
|
||||
@@ -100,6 +100,7 @@ class AppSettings(AppLoggingSettings):
|
||||
"""time in hours"""
|
||||
|
||||
SECRET: str
|
||||
SESSION_SECRET: str
|
||||
|
||||
GIT_COMMIT_HASH: str = "unknown"
|
||||
|
||||
@@ -268,6 +269,7 @@ class AppSettings(AppLoggingSettings):
|
||||
# OIDC Configuration
|
||||
OIDC_AUTH_ENABLED: bool = False
|
||||
OIDC_CLIENT_ID: str | None = None
|
||||
OIDC_CLIENT_SECRET: str | None = None
|
||||
OIDC_CONFIGURATION_URL: str | None = None
|
||||
OIDC_SIGNUP_ENABLED: bool = True
|
||||
OIDC_USER_GROUP: str | None = None
|
||||
@@ -275,23 +277,28 @@ class AppSettings(AppLoggingSettings):
|
||||
OIDC_AUTO_REDIRECT: bool = False
|
||||
OIDC_PROVIDER_NAME: str = "OAuth"
|
||||
OIDC_REMEMBER_ME: bool = False
|
||||
OIDC_SIGNING_ALGORITHM: str = "RS256"
|
||||
OIDC_USER_CLAIM: str = "email"
|
||||
OIDC_GROUPS_CLAIM: str | None = "groups"
|
||||
OIDC_TLS_CACERTFILE: str | None = None
|
||||
|
||||
@property
|
||||
def OIDC_REQUIRES_GROUP_CLAIM(self) -> bool:
|
||||
return self.OIDC_USER_GROUP is not None or self.OIDC_ADMIN_GROUP is not None
|
||||
|
||||
@property
|
||||
def OIDC_READY(self) -> bool:
|
||||
"""Validates OIDC settings are all set"""
|
||||
|
||||
required = {
|
||||
self.OIDC_CLIENT_ID,
|
||||
self.OIDC_CLIENT_SECRET,
|
||||
self.OIDC_CONFIGURATION_URL,
|
||||
self.OIDC_USER_CLAIM,
|
||||
}
|
||||
not_none = None not in required
|
||||
valid_group_claim = True
|
||||
if (not self.OIDC_USER_GROUP or not self.OIDC_ADMIN_GROUP) and not self.OIDC_GROUPS_CLAIM:
|
||||
|
||||
if self.OIDC_REQUIRES_GROUP_CLAIM and self.OIDC_GROUPS_CLAIM is None:
|
||||
valid_group_claim = False
|
||||
|
||||
return self.OIDC_AUTH_ENABLED and not_none and valid_group_claim
|
||||
@@ -353,13 +360,17 @@ def app_settings_constructor(data_dir: Path, production: bool, env_file: Path, e
|
||||
required dependencies into the AppSettings object and nested child objects. AppSettings should not be substantiated
|
||||
directly, but rather through this factory function.
|
||||
"""
|
||||
secret_settings = {
|
||||
"SECRET": determine_secrets(data_dir, ".secret", production),
|
||||
"SESSION_SECRET": determine_secrets(data_dir, ".session_secret", production),
|
||||
}
|
||||
app_settings = AppSettings(
|
||||
_env_file=env_file, # type: ignore
|
||||
_env_file_encoding=env_encoding, # type: ignore
|
||||
# `get_secrets_dir` must be called here rather than within `AppSettings`
|
||||
# to avoid a circular import.
|
||||
_secrets_dir=get_secrets_dir(), # type: ignore
|
||||
**{"SECRET": determine_secrets(data_dir, production)},
|
||||
**secret_settings,
|
||||
)
|
||||
|
||||
app_settings.DB_PROVIDER = db_provider_factory(
|
||||
|
||||
Reference in New Issue
Block a user