feat: Add OIDC_CLIENT_SECRET and other changes for v2 (#4254)

Co-authored-by: boc-the-git <3479092+boc-the-git@users.noreply.github.com>
This commit is contained in:
Carter
2024-10-05 16:12:11 -05:00
committed by GitHub
parent 4f1abcf4a3
commit 5ed0ec029b
31 changed files with 530 additions and 349 deletions

View File

@@ -6,6 +6,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.routing import APIRoute
from starlette.middleware.sessions import SessionMiddleware
from mealie.core.config import get_app_settings
from mealie.core.root_logger import get_logger
@@ -66,12 +67,14 @@ async def lifespan_fn(_: FastAPI) -> AsyncGenerator[None, None]:
"LDAP_QUERY_PASSWORD",
"OPENAI_API_KEY",
"SECRET",
"SESSION_SECRET",
"SFTP_PASSWORD",
"SFTP_USERNAME",
"DB_URL", # replace by DB_URL_PUBLIC for logs
"DB_PROVIDER",
"SMTP_USER",
"SMTP_PASSWORD",
"OIDC_CLIENT_SECRET",
},
)
)
@@ -91,6 +94,7 @@ app = FastAPI(
)
app.add_middleware(GZipMiddleware, minimum_size=1000)
app.add_middleware(SessionMiddleware, secret_key=settings.SESSION_SECRET)
if not settings.PRODUCTION:
allowed_origins = ["http://localhost:3000"]

View File

@@ -49,7 +49,10 @@ class AuthProvider(Generic[T], metaclass=abc.ABCMeta):
to_encode["exp"] = expire
to_encode["iss"] = ISS
return (jwt.encode(to_encode, settings.SECRET, algorithm=ALGORITHM), expires_delta)
return (
jwt.encode(to_encode, settings.SECRET, algorithm=ALGORITHM),
expires_delta,
)
def try_get_user(self, username: str) -> PrivateUser | None:
"""Try to get a user from the database, first trying username, then trying email"""
@@ -66,6 +69,6 @@ class AuthProvider(Generic[T], metaclass=abc.ABCMeta):
return user
@abc.abstractmethod
async def authenticate(self) -> tuple[str, timedelta] | None:
def authenticate(self) -> tuple[str, timedelta] | None:
"""Attempt to authenticate a user"""
raise NotImplementedError

View File

@@ -20,7 +20,7 @@ class CredentialsProvider(AuthProvider[CredentialsRequest]):
def __init__(self, session: Session, data: CredentialsRequest) -> None:
super().__init__(session, data)
async def authenticate(self) -> tuple[str, timedelta] | None:
def authenticate(self) -> tuple[str, timedelta] | None:
"""Attempt to authenticate a user given a username and password"""
settings = get_app_settings()
db = get_repositories(self.session, group_id=None, household_id=None)
@@ -30,7 +30,8 @@ class CredentialsProvider(AuthProvider[CredentialsRequest]):
# To prevent user enumeration we perform the verify_password computation to ensure
# server side time is relatively constant and not vulnerable to timing attacks.
CredentialsProvider.verify_password(
"abc123cba321", "$2b$12$JdHtJOlkPFwyxdjdygEzPOtYmdQF5/R5tHxw5Tq8pxjubyLqdIX5i"
"abc123cba321",
"$2b$12$JdHtJOlkPFwyxdjdygEzPOtYmdQF5/R5tHxw5Tq8pxjubyLqdIX5i",
)
return None

View File

@@ -22,7 +22,7 @@ class LDAPProvider(CredentialsProvider):
super().__init__(session, data)
self.conn = None
async def authenticate(self) -> tuple[str, timedelta] | None:
def authenticate(self) -> tuple[str, timedelta] | None:
"""Attempt to authenticate a user given a username and password"""
user = self.try_get_user(self.data.username)
if not user or user.password == "LDAP" or user.auth_method == AuthMethod.LDAP:
@@ -30,7 +30,7 @@ class LDAPProvider(CredentialsProvider):
if user:
return self.get_access_token(user, self.data.remember_me)
return await super().authenticate()
return super().authenticate()
def search_user(self, conn: LDAPObject) -> list[tuple[str, dict[str, list[bytes]]]] | None:
"""
@@ -64,7 +64,11 @@ class LDAPProvider(CredentialsProvider):
settings.LDAP_BASE_DN,
ldap.SCOPE_SUBTREE,
search_filter,
[settings.LDAP_ID_ATTRIBUTE, settings.LDAP_NAME_ATTRIBUTE, settings.LDAP_MAIL_ATTRIBUTE],
[
settings.LDAP_ID_ATTRIBUTE,
settings.LDAP_NAME_ATTRIBUTE,
settings.LDAP_MAIL_ATTRIBUTE,
],
)
except ldap.FILTER_ERROR:
self._logger.error("[LDAP] Bad user search filter")

View File

@@ -1,55 +1,58 @@
import time
from datetime import timedelta
from functools import lru_cache
import requests
from authlib.jose import JsonWebKey, JsonWebToken, JWTClaims, KeySet
from authlib.jose.errors import ExpiredTokenError, UnsupportedAlgorithmError
from authlib.oidc.core import CodeIDToken
from authlib.oidc.core import UserInfo
from sqlalchemy.orm.session import Session
from mealie.core import root_logger
from mealie.core.config import get_app_settings
from mealie.core.security.providers.auth_provider import AuthProvider
from mealie.core.settings.settings import AppSettings
from mealie.db.models.users.users import AuthMethod
from mealie.repos.all_repositories import get_repositories
from mealie.schema.user.auth import OIDCRequest
class OpenIDProvider(AuthProvider[OIDCRequest]):
class OpenIDProvider(AuthProvider[UserInfo]):
"""Authentication provider that authenticates a user using a token from OIDC ID token"""
_logger = root_logger.get_logger("openid_provider")
def __init__(self, session: Session, data: OIDCRequest) -> None:
def __init__(self, session: Session, data: UserInfo) -> None:
super().__init__(session, data)
async def authenticate(self) -> tuple[str, timedelta] | None:
def authenticate(self) -> tuple[str, timedelta] | None:
"""Attempt to authenticate a user given a username and password"""
settings = get_app_settings()
claims = self.get_claims(settings)
claims = self.data
if not claims:
self._logger.error("[OIDC] No claims in the id_token")
return None
if not self.required_claims.issubset(claims.keys()):
self._logger.error(
"[OIDC] Required claims not present. Expected: %s Actual: %s",
self.required_claims,
claims.keys(),
)
return None
repos = get_repositories(self.session, group_id=None, household_id=None)
user = self.try_get_user(claims.get(settings.OIDC_USER_CLAIM))
is_admin = False
if settings.OIDC_USER_GROUP or settings.OIDC_ADMIN_GROUP:
if settings.OIDC_REQUIRES_GROUP_CLAIM:
group_claim = claims.get(settings.OIDC_GROUPS_CLAIM, []) or []
is_admin = settings.OIDC_ADMIN_GROUP in group_claim if settings.OIDC_ADMIN_GROUP else False
is_valid_user = settings.OIDC_USER_GROUP in group_claim if settings.OIDC_USER_GROUP else True
if not is_valid_user:
self._logger.debug(
"[OIDC] User does not have the required group. Found: %s - Required: %s",
if not (is_valid_user or is_admin):
self._logger.warning(
"[OIDC] Successfully authenticated, but user does not have one of the required group(s). \
Found: %s - Required (one of): %s",
group_claim,
settings.OIDC_USER_GROUP,
[settings.OIDC_USER_GROUP, settings.OIDC_ADMIN_GROUP],
)
return None
user = self.try_get_user(claims.get(settings.OIDC_USER_CLAIM))
if not user:
if not settings.OIDC_SIGNUP_ENABLED:
self._logger.debug("[OIDC] No user found. Not creating a new user - new user creation is disabled.")
@@ -57,22 +60,31 @@ class OpenIDProvider(AuthProvider[OIDCRequest]):
self._logger.debug("[OIDC] No user found. Creating new OIDC user.")
user = repos.users.create(
{
"username": claims.get("preferred_username"),
"password": "OIDC",
"full_name": claims.get("name"),
"email": claims.get("email"),
"admin": is_admin,
"auth_method": AuthMethod.OIDC,
}
)
self.session.commit()
try:
# some IdPs don't provide a username (looking at you Google), so if we don't have the claim,
# we'll create the user with whatever the USER_CLAIM is (default email)
username = claims.get("preferred_username", claims.get(settings.OIDC_USER_CLAIM))
user = repos.users.create(
{
"username": username,
"password": "OIDC",
"full_name": claims.get("name"),
"email": claims.get("email"),
"admin": is_admin,
"auth_method": AuthMethod.OIDC,
}
)
self.session.commit()
except Exception as e:
self._logger.error("[OIDC] Exception while creating user: %s", e)
return None
return self.get_access_token(user, settings.OIDC_REMEMBER_ME) # type: ignore
if user:
if settings.OIDC_ADMIN_GROUP and user.admin != is_admin:
self._logger.debug(f"[OIDC] {'Setting' if is_admin else 'Removing'} user as admin")
self._logger.debug("[OIDC] %s user as admin", "Setting" if is_admin else "Removing")
user.admin = is_admin
repos.users.update(user.id, user)
return self.get_access_token(user, settings.OIDC_REMEMBER_ME)
@@ -80,78 +92,11 @@ class OpenIDProvider(AuthProvider[OIDCRequest]):
self._logger.warning("[OIDC] Found user but their AuthMethod does not match OIDC")
return None
def get_claims(self, settings: AppSettings) -> JWTClaims | None:
"""Get the claims from the ID token and check if the required claims are present"""
required_claims = {
"preferred_username",
"name",
"email",
settings.OIDC_USER_CLAIM,
}
jwks = OpenIDProvider.get_jwks(self.get_ttl_hash()) # cache the key set for 30 minutes
if not jwks:
return None
algorithm = settings.OIDC_SIGNING_ALGORITHM
try:
claims = JsonWebToken([algorithm]).decode(s=self.data.id_token, key=jwks, claims_cls=CodeIDToken)
except UnsupportedAlgorithmError:
self._logger.error(
f"[OIDC] Unsupported algorithm '{algorithm}'. Unable to decode id token due to mismatched algorithm."
)
return None
try:
claims.validate()
except ExpiredTokenError as e:
self._logger.error(f"[OIDC] {e.error}: {e.description}")
return None
except Exception as e:
self._logger.error("[OIDC] Exception while validating id_token claims", e)
if not claims:
self._logger.error("[OIDC] Claims not found")
return None
if not required_claims.issubset(claims.keys()):
self._logger.error(
f"[OIDC] Required claims not present. Expected: {required_claims} Actual: {claims.keys()}"
)
return None
return claims
@lru_cache
@staticmethod
def get_jwks(ttl_hash=None) -> KeySet | None:
"""Get the key set from the openid configuration"""
del ttl_hash # ttl_hash is used for caching only
@property
def required_claims(self):
settings = get_app_settings()
if not (settings.OIDC_READY and settings.OIDC_CONFIGURATION_URL):
return None
session = requests.Session()
if settings.OIDC_TLS_CACERTFILE:
session.verify = settings.OIDC_TLS_CACERTFILE
config_response = session.get(settings.OIDC_CONFIGURATION_URL, timeout=5)
config_response.raise_for_status()
configuration = config_response.json()
if not configuration:
OpenIDProvider._logger.warning("[OIDC] Unable to fetch configuration from the OIDC_CONFIGURATION_URL")
session.close()
return None
jwks_uri = configuration.get("jwks_uri", None)
if not jwks_uri:
OpenIDProvider._logger.warning("[OIDC] Unable to find the jwks_uri from the OIDC_CONFIGURATION_URL")
session.close()
return None
response = session.get(jwks_uri, timeout=5)
response.raise_for_status()
session.close()
return JsonWebKey.import_key_set(response.json())
def get_ttl_hash(self, seconds=1800):
return time.time() // seconds
claims = {"name", "email", settings.OIDC_USER_CLAIM}
if settings.OIDC_REQUIRES_GROUP_CLAIM:
claims.add(settings.OIDC_GROUPS_CLAIM)
return claims

View File

@@ -3,7 +3,6 @@ from datetime import datetime, timedelta, timezone
from pathlib import Path
import jwt
from fastapi import Request
from sqlalchemy.orm.session import Session
from mealie.core import root_logger
@@ -12,20 +11,16 @@ from mealie.core.security.hasher import get_hasher
from mealie.core.security.providers.auth_provider import AuthProvider
from mealie.core.security.providers.credentials_provider import CredentialsProvider
from mealie.core.security.providers.ldap_provider import LDAPProvider
from mealie.core.security.providers.openid_provider import OpenIDProvider
from mealie.schema.user.auth import CredentialsRequest, CredentialsRequestForm, OIDCRequest
from mealie.schema.user.auth import CredentialsRequest, CredentialsRequestForm
ALGORITHM = "HS256"
logger = root_logger.get_logger("security")
def get_auth_provider(session: Session, request: Request, data: CredentialsRequestForm) -> AuthProvider:
def get_auth_provider(session: Session, data: CredentialsRequestForm) -> AuthProvider:
settings = get_app_settings()
if request.cookies.get("mealie.auth.strategy") == "oidc":
return OpenIDProvider(session, OIDCRequest(id_token=request.cookies.get("mealie.auth._id_token.oidc")))
credentials_request = CredentialsRequest(**data.__dict__)
if settings.LDAP_ENABLED:
return LDAPProvider(session, credentials_request)

View File

@@ -19,11 +19,11 @@ class ScheduleTime(NamedTuple):
minute: int
def determine_secrets(data_dir: Path, production: bool) -> str:
def determine_secrets(data_dir: Path, secret: str, production: bool) -> str:
if not production:
return "shh-secret-test-key"
secrets_file = data_dir.joinpath(".secret")
secrets_file = data_dir.joinpath(secret)
if secrets_file.is_file():
with open(secrets_file) as f:
return f.read()
@@ -100,6 +100,7 @@ class AppSettings(AppLoggingSettings):
"""time in hours"""
SECRET: str
SESSION_SECRET: str
GIT_COMMIT_HASH: str = "unknown"
@@ -268,6 +269,7 @@ class AppSettings(AppLoggingSettings):
# OIDC Configuration
OIDC_AUTH_ENABLED: bool = False
OIDC_CLIENT_ID: str | None = None
OIDC_CLIENT_SECRET: str | None = None
OIDC_CONFIGURATION_URL: str | None = None
OIDC_SIGNUP_ENABLED: bool = True
OIDC_USER_GROUP: str | None = None
@@ -275,23 +277,28 @@ class AppSettings(AppLoggingSettings):
OIDC_AUTO_REDIRECT: bool = False
OIDC_PROVIDER_NAME: str = "OAuth"
OIDC_REMEMBER_ME: bool = False
OIDC_SIGNING_ALGORITHM: str = "RS256"
OIDC_USER_CLAIM: str = "email"
OIDC_GROUPS_CLAIM: str | None = "groups"
OIDC_TLS_CACERTFILE: str | None = None
@property
def OIDC_REQUIRES_GROUP_CLAIM(self) -> bool:
return self.OIDC_USER_GROUP is not None or self.OIDC_ADMIN_GROUP is not None
@property
def OIDC_READY(self) -> bool:
"""Validates OIDC settings are all set"""
required = {
self.OIDC_CLIENT_ID,
self.OIDC_CLIENT_SECRET,
self.OIDC_CONFIGURATION_URL,
self.OIDC_USER_CLAIM,
}
not_none = None not in required
valid_group_claim = True
if (not self.OIDC_USER_GROUP or not self.OIDC_ADMIN_GROUP) and not self.OIDC_GROUPS_CLAIM:
if self.OIDC_REQUIRES_GROUP_CLAIM and self.OIDC_GROUPS_CLAIM is None:
valid_group_claim = False
return self.OIDC_AUTH_ENABLED and not_none and valid_group_claim
@@ -353,13 +360,17 @@ def app_settings_constructor(data_dir: Path, production: bool, env_file: Path, e
required dependencies into the AppSettings object and nested child objects. AppSettings should not be substantiated
directly, but rather through this factory function.
"""
secret_settings = {
"SECRET": determine_secrets(data_dir, ".secret", production),
"SESSION_SECRET": determine_secrets(data_dir, ".session_secret", production),
}
app_settings = AppSettings(
_env_file=env_file, # type: ignore
_env_file_encoding=env_encoding, # type: ignore
# `get_secrets_dir` must be called here rather than within `AppSettings`
# to avoid a circular import.
_secrets_dir=get_secrets_dir(), # type: ignore
**{"SECRET": determine_secrets(data_dir, production)},
**secret_settings,
)
app_settings.DB_PROVIDER = db_provider_factory(

View File

@@ -6,7 +6,7 @@ from mealie.core.settings.static import APP_VERSION
from mealie.db.db_setup import generate_session
from mealie.db.models.users.users import User
from mealie.repos.all_repositories import get_repositories
from mealie.schema.admin.about import AppInfo, AppStartupInfo, AppTheme, OIDCInfo
from mealie.schema.admin.about import AppInfo, AppStartupInfo, AppTheme
router = APIRouter(prefix="/about")
@@ -69,16 +69,3 @@ def get_app_theme(resp: Response):
resp.headers["Cache-Control"] = "public, max-age=604800"
return AppTheme(**settings.theme.model_dump())
@router.get("/oidc", response_model=OIDCInfo)
def get_oidc_info(resp: Response):
"""Get's the current OIDC configuration needed for the frontend"""
settings = get_app_settings()
resp.headers["Cache-Control"] = "public, max-age=604800"
return OIDCInfo(
configuration_url=settings.OIDC_CONFIGURATION_URL,
client_id=settings.OIDC_CLIENT_ID,
groups_claim=settings.OIDC_GROUPS_CLAIM if settings.OIDC_USER_GROUP or settings.OIDC_ADMIN_GROUP else None,
)

View File

@@ -1,13 +1,18 @@
from datetime import timedelta
from authlib.integrations.starlette_client import OAuth
from fastapi import APIRouter, Depends, Request, Response, status
from fastapi.exceptions import HTTPException
from fastapi.responses import RedirectResponse
from pydantic import BaseModel
from sqlalchemy.orm.session import Session
from starlette.datastructures import URLPath
from mealie.core import root_logger, security
from mealie.core.config import get_app_settings
from mealie.core.dependencies import get_current_user
from mealie.core.exceptions import UserLockedOut
from mealie.core.security.providers.openid_provider import OpenIDProvider
from mealie.core.security.security import get_auth_provider
from mealie.db.db_setup import generate_session
from mealie.routes._base.routers import UserAPIRouter
@@ -20,6 +25,20 @@ logger = root_logger.get_logger("auth")
remember_me_duration = timedelta(days=14)
settings = get_app_settings()
if settings.OIDC_READY:
oauth = OAuth()
groups_claim = settings.OIDC_GROUPS_CLAIM if settings.OIDC_REQUIRES_GROUP_CLAIM else ""
scope = f"openid email profile {groups_claim}"
oauth.register(
"oidc",
client_id=settings.OIDC_CLIENT_ID,
client_secret=settings.OIDC_CLIENT_SECRET,
server_metadata_url=settings.OIDC_CONFIGURATION_URL,
client_kwargs={"scope": scope.rstrip()},
code_challenge_method="S256",
)
class MealieAuthToken(BaseModel):
access_token: str
@@ -31,7 +50,7 @@ class MealieAuthToken(BaseModel):
@public_router.post("/token")
async def get_token(
def get_token(
request: Request,
response: Response,
data: CredentialsRequestForm = Depends(),
@@ -46,8 +65,8 @@ async def get_token(
ip = request.client.host if request.client else "unknown"
try:
auth_provider = get_auth_provider(session, request, data)
auth = await auth_provider.authenticate()
auth_provider = get_auth_provider(session, data)
auth = auth_provider.authenticate()
except UserLockedOut as e:
logger.error(f"User is locked out from {ip}")
raise HTTPException(status_code=status.HTTP_423_LOCKED, detail="User is locked out") from e
@@ -61,7 +80,61 @@ async def get_token(
expires_in = duration.total_seconds() if duration else None
response.set_cookie(
key="mealie.access_token", value=access_token, httponly=True, max_age=expires_in, expires=expires_in
key="mealie.access_token",
value=access_token,
httponly=True,
max_age=expires_in,
secure=settings.PRODUCTION,
)
return MealieAuthToken.respond(access_token)
@public_router.get("/oauth")
async def oauth_login(request: Request):
if not oauth:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Could not initialize OAuth client",
)
client = oauth.create_client("oidc")
redirect_url = None
if not settings.PRODUCTION:
# in development, we want to redirect to the frontend
redirect_url = "http://localhost:3000/login"
else:
redirect_url = URLPath("/login").make_absolute_url(request.base_url)
response: RedirectResponse = await client.authorize_redirect(request, redirect_url)
return response
@public_router.get("/oauth/callback")
async def oauth_callback(request: Request, response: Response, session: Session = Depends(generate_session)):
if not oauth:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Could not initialize OAuth client",
)
client = oauth.create_client("oidc")
token = await client.authorize_access_token(request)
auth_provider = OpenIDProvider(session, token["userinfo"])
auth = auth_provider.authenticate()
if not auth:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
)
access_token, duration = auth
expires_in = duration.total_seconds() if duration else None
response.set_cookie(
key="mealie.access_token",
value=access_token,
httponly=True,
max_age=expires_in,
secure=settings.PRODUCTION,
)
return MealieAuthToken.respond(access_token)

View File

@@ -1,5 +1,12 @@
# This file is auto-generated by gen_schema_exports.py
from .about import AdminAboutInfo, AppInfo, AppStartupInfo, AppStatistics, AppTheme, CheckAppConfig, OIDCInfo
from .about import (
AdminAboutInfo,
AppInfo,
AppStartupInfo,
AppStatistics,
AppTheme,
CheckAppConfig,
)
from .backup import AllBackups, BackupFile, BackupOptions, CreateBackup, ImportJob
from .debug import DebugResponse
from .email import EmailReady, EmailSuccess, EmailTest
@@ -46,7 +53,6 @@ __all__ = [
"AppStatistics",
"AppTheme",
"CheckAppConfig",
"OIDCInfo",
"EmailReady",
"EmailSuccess",
"EmailTest",

View File

@@ -72,9 +72,3 @@ class CheckAppConfig(MealieModel):
enable_openai: bool
base_url_set: bool
is_up_to_date: bool
class OIDCInfo(MealieModel):
configuration_url: str | None
client_id: str | None
groups_claim: str | None

View File

@@ -1,5 +1,5 @@
# This file is auto-generated by gen_schema_exports.py
from .auth import CredentialsRequest, CredentialsRequestForm, OIDCRequest, Token, TokenData, UnlockResults
from .auth import CredentialsRequest, CredentialsRequestForm, Token, TokenData, UnlockResults
from .registration import CreateUserRegistration
from .user import (
ChangePassword,
@@ -37,19 +37,18 @@ from .user_passwords import (
)
__all__ = [
"CreateUserRegistration",
"CredentialsRequest",
"CredentialsRequestForm",
"Token",
"TokenData",
"UnlockResults",
"ForgotPassword",
"PasswordResetToken",
"PrivatePasswordResetToken",
"ResetPassword",
"SavePasswordResetToken",
"ValidateResetToken",
"CredentialsRequest",
"CredentialsRequestForm",
"OIDCRequest",
"Token",
"TokenData",
"UnlockResults",
"CreateUserRegistration",
"ChangePassword",
"CreateToken",
"DeleteTokenResponse",

View File

@@ -26,14 +26,15 @@ class CredentialsRequest(BaseModel):
remember_me: bool = False
class OIDCRequest(BaseModel):
id_token: str
class CredentialsRequestForm:
"""Class that represents a user's credentials from the login form"""
def __init__(self, username: str = Form(""), password: str = Form(""), remember_me: bool = Form(False)):
def __init__(
self,
username: str = Form(""),
password: str = Form(""),
remember_me: bool = Form(False),
):
self.username = username
self.password = password
self.remember_me = remember_me