mirror of
https://github.com/mealie-recipes/mealie.git
synced 2026-02-10 01:43:11 -05:00
feat: Add OIDC_CLIENT_SECRET and other changes for v2 (#4254)
Co-authored-by: boc-the-git <3479092+boc-the-git@users.noreply.github.com>
This commit is contained in:
@@ -6,6 +6,7 @@ from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.routing import APIRoute
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
|
||||
from mealie.core.config import get_app_settings
|
||||
from mealie.core.root_logger import get_logger
|
||||
@@ -66,12 +67,14 @@ async def lifespan_fn(_: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"LDAP_QUERY_PASSWORD",
|
||||
"OPENAI_API_KEY",
|
||||
"SECRET",
|
||||
"SESSION_SECRET",
|
||||
"SFTP_PASSWORD",
|
||||
"SFTP_USERNAME",
|
||||
"DB_URL", # replace by DB_URL_PUBLIC for logs
|
||||
"DB_PROVIDER",
|
||||
"SMTP_USER",
|
||||
"SMTP_PASSWORD",
|
||||
"OIDC_CLIENT_SECRET",
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -91,6 +94,7 @@ app = FastAPI(
|
||||
)
|
||||
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
app.add_middleware(SessionMiddleware, secret_key=settings.SESSION_SECRET)
|
||||
|
||||
if not settings.PRODUCTION:
|
||||
allowed_origins = ["http://localhost:3000"]
|
||||
|
||||
@@ -49,7 +49,10 @@ class AuthProvider(Generic[T], metaclass=abc.ABCMeta):
|
||||
|
||||
to_encode["exp"] = expire
|
||||
to_encode["iss"] = ISS
|
||||
return (jwt.encode(to_encode, settings.SECRET, algorithm=ALGORITHM), expires_delta)
|
||||
return (
|
||||
jwt.encode(to_encode, settings.SECRET, algorithm=ALGORITHM),
|
||||
expires_delta,
|
||||
)
|
||||
|
||||
def try_get_user(self, username: str) -> PrivateUser | None:
|
||||
"""Try to get a user from the database, first trying username, then trying email"""
|
||||
@@ -66,6 +69,6 @@ class AuthProvider(Generic[T], metaclass=abc.ABCMeta):
|
||||
return user
|
||||
|
||||
@abc.abstractmethod
|
||||
async def authenticate(self) -> tuple[str, timedelta] | None:
|
||||
def authenticate(self) -> tuple[str, timedelta] | None:
|
||||
"""Attempt to authenticate a user"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -20,7 +20,7 @@ class CredentialsProvider(AuthProvider[CredentialsRequest]):
|
||||
def __init__(self, session: Session, data: CredentialsRequest) -> None:
|
||||
super().__init__(session, data)
|
||||
|
||||
async def authenticate(self) -> tuple[str, timedelta] | None:
|
||||
def authenticate(self) -> tuple[str, timedelta] | None:
|
||||
"""Attempt to authenticate a user given a username and password"""
|
||||
settings = get_app_settings()
|
||||
db = get_repositories(self.session, group_id=None, household_id=None)
|
||||
@@ -30,7 +30,8 @@ class CredentialsProvider(AuthProvider[CredentialsRequest]):
|
||||
# To prevent user enumeration we perform the verify_password computation to ensure
|
||||
# server side time is relatively constant and not vulnerable to timing attacks.
|
||||
CredentialsProvider.verify_password(
|
||||
"abc123cba321", "$2b$12$JdHtJOlkPFwyxdjdygEzPOtYmdQF5/R5tHxw5Tq8pxjubyLqdIX5i"
|
||||
"abc123cba321",
|
||||
"$2b$12$JdHtJOlkPFwyxdjdygEzPOtYmdQF5/R5tHxw5Tq8pxjubyLqdIX5i",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ class LDAPProvider(CredentialsProvider):
|
||||
super().__init__(session, data)
|
||||
self.conn = None
|
||||
|
||||
async def authenticate(self) -> tuple[str, timedelta] | None:
|
||||
def authenticate(self) -> tuple[str, timedelta] | None:
|
||||
"""Attempt to authenticate a user given a username and password"""
|
||||
user = self.try_get_user(self.data.username)
|
||||
if not user or user.password == "LDAP" or user.auth_method == AuthMethod.LDAP:
|
||||
@@ -30,7 +30,7 @@ class LDAPProvider(CredentialsProvider):
|
||||
if user:
|
||||
return self.get_access_token(user, self.data.remember_me)
|
||||
|
||||
return await super().authenticate()
|
||||
return super().authenticate()
|
||||
|
||||
def search_user(self, conn: LDAPObject) -> list[tuple[str, dict[str, list[bytes]]]] | None:
|
||||
"""
|
||||
@@ -64,7 +64,11 @@ class LDAPProvider(CredentialsProvider):
|
||||
settings.LDAP_BASE_DN,
|
||||
ldap.SCOPE_SUBTREE,
|
||||
search_filter,
|
||||
[settings.LDAP_ID_ATTRIBUTE, settings.LDAP_NAME_ATTRIBUTE, settings.LDAP_MAIL_ATTRIBUTE],
|
||||
[
|
||||
settings.LDAP_ID_ATTRIBUTE,
|
||||
settings.LDAP_NAME_ATTRIBUTE,
|
||||
settings.LDAP_MAIL_ATTRIBUTE,
|
||||
],
|
||||
)
|
||||
except ldap.FILTER_ERROR:
|
||||
self._logger.error("[LDAP] Bad user search filter")
|
||||
|
||||
@@ -1,55 +1,58 @@
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from functools import lru_cache
|
||||
|
||||
import requests
|
||||
from authlib.jose import JsonWebKey, JsonWebToken, JWTClaims, KeySet
|
||||
from authlib.jose.errors import ExpiredTokenError, UnsupportedAlgorithmError
|
||||
from authlib.oidc.core import CodeIDToken
|
||||
from authlib.oidc.core import UserInfo
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from mealie.core import root_logger
|
||||
from mealie.core.config import get_app_settings
|
||||
from mealie.core.security.providers.auth_provider import AuthProvider
|
||||
from mealie.core.settings.settings import AppSettings
|
||||
from mealie.db.models.users.users import AuthMethod
|
||||
from mealie.repos.all_repositories import get_repositories
|
||||
from mealie.schema.user.auth import OIDCRequest
|
||||
|
||||
|
||||
class OpenIDProvider(AuthProvider[OIDCRequest]):
|
||||
class OpenIDProvider(AuthProvider[UserInfo]):
|
||||
"""Authentication provider that authenticates a user using a token from OIDC ID token"""
|
||||
|
||||
_logger = root_logger.get_logger("openid_provider")
|
||||
|
||||
def __init__(self, session: Session, data: OIDCRequest) -> None:
|
||||
def __init__(self, session: Session, data: UserInfo) -> None:
|
||||
super().__init__(session, data)
|
||||
|
||||
async def authenticate(self) -> tuple[str, timedelta] | None:
|
||||
def authenticate(self) -> tuple[str, timedelta] | None:
|
||||
"""Attempt to authenticate a user given a username and password"""
|
||||
|
||||
settings = get_app_settings()
|
||||
claims = self.get_claims(settings)
|
||||
claims = self.data
|
||||
if not claims:
|
||||
self._logger.error("[OIDC] No claims in the id_token")
|
||||
return None
|
||||
|
||||
if not self.required_claims.issubset(claims.keys()):
|
||||
self._logger.error(
|
||||
"[OIDC] Required claims not present. Expected: %s Actual: %s",
|
||||
self.required_claims,
|
||||
claims.keys(),
|
||||
)
|
||||
return None
|
||||
|
||||
repos = get_repositories(self.session, group_id=None, household_id=None)
|
||||
|
||||
user = self.try_get_user(claims.get(settings.OIDC_USER_CLAIM))
|
||||
is_admin = False
|
||||
if settings.OIDC_USER_GROUP or settings.OIDC_ADMIN_GROUP:
|
||||
if settings.OIDC_REQUIRES_GROUP_CLAIM:
|
||||
group_claim = claims.get(settings.OIDC_GROUPS_CLAIM, []) or []
|
||||
is_admin = settings.OIDC_ADMIN_GROUP in group_claim if settings.OIDC_ADMIN_GROUP else False
|
||||
is_valid_user = settings.OIDC_USER_GROUP in group_claim if settings.OIDC_USER_GROUP else True
|
||||
|
||||
if not is_valid_user:
|
||||
self._logger.debug(
|
||||
"[OIDC] User does not have the required group. Found: %s - Required: %s",
|
||||
if not (is_valid_user or is_admin):
|
||||
self._logger.warning(
|
||||
"[OIDC] Successfully authenticated, but user does not have one of the required group(s). \
|
||||
Found: %s - Required (one of): %s",
|
||||
group_claim,
|
||||
settings.OIDC_USER_GROUP,
|
||||
[settings.OIDC_USER_GROUP, settings.OIDC_ADMIN_GROUP],
|
||||
)
|
||||
return None
|
||||
|
||||
user = self.try_get_user(claims.get(settings.OIDC_USER_CLAIM))
|
||||
if not user:
|
||||
if not settings.OIDC_SIGNUP_ENABLED:
|
||||
self._logger.debug("[OIDC] No user found. Not creating a new user - new user creation is disabled.")
|
||||
@@ -57,22 +60,31 @@ class OpenIDProvider(AuthProvider[OIDCRequest]):
|
||||
|
||||
self._logger.debug("[OIDC] No user found. Creating new OIDC user.")
|
||||
|
||||
user = repos.users.create(
|
||||
{
|
||||
"username": claims.get("preferred_username"),
|
||||
"password": "OIDC",
|
||||
"full_name": claims.get("name"),
|
||||
"email": claims.get("email"),
|
||||
"admin": is_admin,
|
||||
"auth_method": AuthMethod.OIDC,
|
||||
}
|
||||
)
|
||||
self.session.commit()
|
||||
try:
|
||||
# some IdPs don't provide a username (looking at you Google), so if we don't have the claim,
|
||||
# we'll create the user with whatever the USER_CLAIM is (default email)
|
||||
username = claims.get("preferred_username", claims.get(settings.OIDC_USER_CLAIM))
|
||||
user = repos.users.create(
|
||||
{
|
||||
"username": username,
|
||||
"password": "OIDC",
|
||||
"full_name": claims.get("name"),
|
||||
"email": claims.get("email"),
|
||||
"admin": is_admin,
|
||||
"auth_method": AuthMethod.OIDC,
|
||||
}
|
||||
)
|
||||
self.session.commit()
|
||||
|
||||
except Exception as e:
|
||||
self._logger.error("[OIDC] Exception while creating user: %s", e)
|
||||
return None
|
||||
|
||||
return self.get_access_token(user, settings.OIDC_REMEMBER_ME) # type: ignore
|
||||
|
||||
if user:
|
||||
if settings.OIDC_ADMIN_GROUP and user.admin != is_admin:
|
||||
self._logger.debug(f"[OIDC] {'Setting' if is_admin else 'Removing'} user as admin")
|
||||
self._logger.debug("[OIDC] %s user as admin", "Setting" if is_admin else "Removing")
|
||||
user.admin = is_admin
|
||||
repos.users.update(user.id, user)
|
||||
return self.get_access_token(user, settings.OIDC_REMEMBER_ME)
|
||||
@@ -80,78 +92,11 @@ class OpenIDProvider(AuthProvider[OIDCRequest]):
|
||||
self._logger.warning("[OIDC] Found user but their AuthMethod does not match OIDC")
|
||||
return None
|
||||
|
||||
def get_claims(self, settings: AppSettings) -> JWTClaims | None:
|
||||
"""Get the claims from the ID token and check if the required claims are present"""
|
||||
required_claims = {
|
||||
"preferred_username",
|
||||
"name",
|
||||
"email",
|
||||
settings.OIDC_USER_CLAIM,
|
||||
}
|
||||
jwks = OpenIDProvider.get_jwks(self.get_ttl_hash()) # cache the key set for 30 minutes
|
||||
if not jwks:
|
||||
return None
|
||||
|
||||
algorithm = settings.OIDC_SIGNING_ALGORITHM
|
||||
try:
|
||||
claims = JsonWebToken([algorithm]).decode(s=self.data.id_token, key=jwks, claims_cls=CodeIDToken)
|
||||
except UnsupportedAlgorithmError:
|
||||
self._logger.error(
|
||||
f"[OIDC] Unsupported algorithm '{algorithm}'. Unable to decode id token due to mismatched algorithm."
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
claims.validate()
|
||||
except ExpiredTokenError as e:
|
||||
self._logger.error(f"[OIDC] {e.error}: {e.description}")
|
||||
return None
|
||||
except Exception as e:
|
||||
self._logger.error("[OIDC] Exception while validating id_token claims", e)
|
||||
|
||||
if not claims:
|
||||
self._logger.error("[OIDC] Claims not found")
|
||||
return None
|
||||
if not required_claims.issubset(claims.keys()):
|
||||
self._logger.error(
|
||||
f"[OIDC] Required claims not present. Expected: {required_claims} Actual: {claims.keys()}"
|
||||
)
|
||||
return None
|
||||
return claims
|
||||
|
||||
@lru_cache
|
||||
@staticmethod
|
||||
def get_jwks(ttl_hash=None) -> KeySet | None:
|
||||
"""Get the key set from the openid configuration"""
|
||||
del ttl_hash # ttl_hash is used for caching only
|
||||
@property
|
||||
def required_claims(self):
|
||||
settings = get_app_settings()
|
||||
|
||||
if not (settings.OIDC_READY and settings.OIDC_CONFIGURATION_URL):
|
||||
return None
|
||||
|
||||
session = requests.Session()
|
||||
if settings.OIDC_TLS_CACERTFILE:
|
||||
session.verify = settings.OIDC_TLS_CACERTFILE
|
||||
|
||||
config_response = session.get(settings.OIDC_CONFIGURATION_URL, timeout=5)
|
||||
config_response.raise_for_status()
|
||||
configuration = config_response.json()
|
||||
|
||||
if not configuration:
|
||||
OpenIDProvider._logger.warning("[OIDC] Unable to fetch configuration from the OIDC_CONFIGURATION_URL")
|
||||
session.close()
|
||||
return None
|
||||
|
||||
jwks_uri = configuration.get("jwks_uri", None)
|
||||
if not jwks_uri:
|
||||
OpenIDProvider._logger.warning("[OIDC] Unable to find the jwks_uri from the OIDC_CONFIGURATION_URL")
|
||||
session.close()
|
||||
return None
|
||||
|
||||
response = session.get(jwks_uri, timeout=5)
|
||||
response.raise_for_status()
|
||||
session.close()
|
||||
return JsonWebKey.import_key_set(response.json())
|
||||
|
||||
def get_ttl_hash(self, seconds=1800):
|
||||
return time.time() // seconds
|
||||
claims = {"name", "email", settings.OIDC_USER_CLAIM}
|
||||
if settings.OIDC_REQUIRES_GROUP_CLAIM:
|
||||
claims.add(settings.OIDC_GROUPS_CLAIM)
|
||||
return claims
|
||||
|
||||
@@ -3,7 +3,6 @@ from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import jwt
|
||||
from fastapi import Request
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from mealie.core import root_logger
|
||||
@@ -12,20 +11,16 @@ from mealie.core.security.hasher import get_hasher
|
||||
from mealie.core.security.providers.auth_provider import AuthProvider
|
||||
from mealie.core.security.providers.credentials_provider import CredentialsProvider
|
||||
from mealie.core.security.providers.ldap_provider import LDAPProvider
|
||||
from mealie.core.security.providers.openid_provider import OpenIDProvider
|
||||
from mealie.schema.user.auth import CredentialsRequest, CredentialsRequestForm, OIDCRequest
|
||||
from mealie.schema.user.auth import CredentialsRequest, CredentialsRequestForm
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
logger = root_logger.get_logger("security")
|
||||
|
||||
|
||||
def get_auth_provider(session: Session, request: Request, data: CredentialsRequestForm) -> AuthProvider:
|
||||
def get_auth_provider(session: Session, data: CredentialsRequestForm) -> AuthProvider:
|
||||
settings = get_app_settings()
|
||||
|
||||
if request.cookies.get("mealie.auth.strategy") == "oidc":
|
||||
return OpenIDProvider(session, OIDCRequest(id_token=request.cookies.get("mealie.auth._id_token.oidc")))
|
||||
|
||||
credentials_request = CredentialsRequest(**data.__dict__)
|
||||
if settings.LDAP_ENABLED:
|
||||
return LDAPProvider(session, credentials_request)
|
||||
|
||||
@@ -19,11 +19,11 @@ class ScheduleTime(NamedTuple):
|
||||
minute: int
|
||||
|
||||
|
||||
def determine_secrets(data_dir: Path, production: bool) -> str:
|
||||
def determine_secrets(data_dir: Path, secret: str, production: bool) -> str:
|
||||
if not production:
|
||||
return "shh-secret-test-key"
|
||||
|
||||
secrets_file = data_dir.joinpath(".secret")
|
||||
secrets_file = data_dir.joinpath(secret)
|
||||
if secrets_file.is_file():
|
||||
with open(secrets_file) as f:
|
||||
return f.read()
|
||||
@@ -100,6 +100,7 @@ class AppSettings(AppLoggingSettings):
|
||||
"""time in hours"""
|
||||
|
||||
SECRET: str
|
||||
SESSION_SECRET: str
|
||||
|
||||
GIT_COMMIT_HASH: str = "unknown"
|
||||
|
||||
@@ -268,6 +269,7 @@ class AppSettings(AppLoggingSettings):
|
||||
# OIDC Configuration
|
||||
OIDC_AUTH_ENABLED: bool = False
|
||||
OIDC_CLIENT_ID: str | None = None
|
||||
OIDC_CLIENT_SECRET: str | None = None
|
||||
OIDC_CONFIGURATION_URL: str | None = None
|
||||
OIDC_SIGNUP_ENABLED: bool = True
|
||||
OIDC_USER_GROUP: str | None = None
|
||||
@@ -275,23 +277,28 @@ class AppSettings(AppLoggingSettings):
|
||||
OIDC_AUTO_REDIRECT: bool = False
|
||||
OIDC_PROVIDER_NAME: str = "OAuth"
|
||||
OIDC_REMEMBER_ME: bool = False
|
||||
OIDC_SIGNING_ALGORITHM: str = "RS256"
|
||||
OIDC_USER_CLAIM: str = "email"
|
||||
OIDC_GROUPS_CLAIM: str | None = "groups"
|
||||
OIDC_TLS_CACERTFILE: str | None = None
|
||||
|
||||
@property
|
||||
def OIDC_REQUIRES_GROUP_CLAIM(self) -> bool:
|
||||
return self.OIDC_USER_GROUP is not None or self.OIDC_ADMIN_GROUP is not None
|
||||
|
||||
@property
|
||||
def OIDC_READY(self) -> bool:
|
||||
"""Validates OIDC settings are all set"""
|
||||
|
||||
required = {
|
||||
self.OIDC_CLIENT_ID,
|
||||
self.OIDC_CLIENT_SECRET,
|
||||
self.OIDC_CONFIGURATION_URL,
|
||||
self.OIDC_USER_CLAIM,
|
||||
}
|
||||
not_none = None not in required
|
||||
valid_group_claim = True
|
||||
if (not self.OIDC_USER_GROUP or not self.OIDC_ADMIN_GROUP) and not self.OIDC_GROUPS_CLAIM:
|
||||
|
||||
if self.OIDC_REQUIRES_GROUP_CLAIM and self.OIDC_GROUPS_CLAIM is None:
|
||||
valid_group_claim = False
|
||||
|
||||
return self.OIDC_AUTH_ENABLED and not_none and valid_group_claim
|
||||
@@ -353,13 +360,17 @@ def app_settings_constructor(data_dir: Path, production: bool, env_file: Path, e
|
||||
required dependencies into the AppSettings object and nested child objects. AppSettings should not be substantiated
|
||||
directly, but rather through this factory function.
|
||||
"""
|
||||
secret_settings = {
|
||||
"SECRET": determine_secrets(data_dir, ".secret", production),
|
||||
"SESSION_SECRET": determine_secrets(data_dir, ".session_secret", production),
|
||||
}
|
||||
app_settings = AppSettings(
|
||||
_env_file=env_file, # type: ignore
|
||||
_env_file_encoding=env_encoding, # type: ignore
|
||||
# `get_secrets_dir` must be called here rather than within `AppSettings`
|
||||
# to avoid a circular import.
|
||||
_secrets_dir=get_secrets_dir(), # type: ignore
|
||||
**{"SECRET": determine_secrets(data_dir, production)},
|
||||
**secret_settings,
|
||||
)
|
||||
|
||||
app_settings.DB_PROVIDER = db_provider_factory(
|
||||
|
||||
@@ -6,7 +6,7 @@ from mealie.core.settings.static import APP_VERSION
|
||||
from mealie.db.db_setup import generate_session
|
||||
from mealie.db.models.users.users import User
|
||||
from mealie.repos.all_repositories import get_repositories
|
||||
from mealie.schema.admin.about import AppInfo, AppStartupInfo, AppTheme, OIDCInfo
|
||||
from mealie.schema.admin.about import AppInfo, AppStartupInfo, AppTheme
|
||||
|
||||
router = APIRouter(prefix="/about")
|
||||
|
||||
@@ -69,16 +69,3 @@ def get_app_theme(resp: Response):
|
||||
|
||||
resp.headers["Cache-Control"] = "public, max-age=604800"
|
||||
return AppTheme(**settings.theme.model_dump())
|
||||
|
||||
|
||||
@router.get("/oidc", response_model=OIDCInfo)
|
||||
def get_oidc_info(resp: Response):
|
||||
"""Get's the current OIDC configuration needed for the frontend"""
|
||||
settings = get_app_settings()
|
||||
|
||||
resp.headers["Cache-Control"] = "public, max-age=604800"
|
||||
return OIDCInfo(
|
||||
configuration_url=settings.OIDC_CONFIGURATION_URL,
|
||||
client_id=settings.OIDC_CLIENT_ID,
|
||||
groups_claim=settings.OIDC_GROUPS_CLAIM if settings.OIDC_USER_GROUP or settings.OIDC_ADMIN_GROUP else None,
|
||||
)
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
from datetime import timedelta
|
||||
|
||||
from authlib.integrations.starlette_client import OAuth
|
||||
from fastapi import APIRouter, Depends, Request, Response, status
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm.session import Session
|
||||
from starlette.datastructures import URLPath
|
||||
|
||||
from mealie.core import root_logger, security
|
||||
from mealie.core.config import get_app_settings
|
||||
from mealie.core.dependencies import get_current_user
|
||||
from mealie.core.exceptions import UserLockedOut
|
||||
from mealie.core.security.providers.openid_provider import OpenIDProvider
|
||||
from mealie.core.security.security import get_auth_provider
|
||||
from mealie.db.db_setup import generate_session
|
||||
from mealie.routes._base.routers import UserAPIRouter
|
||||
@@ -20,6 +25,20 @@ logger = root_logger.get_logger("auth")
|
||||
|
||||
remember_me_duration = timedelta(days=14)
|
||||
|
||||
settings = get_app_settings()
|
||||
if settings.OIDC_READY:
|
||||
oauth = OAuth()
|
||||
groups_claim = settings.OIDC_GROUPS_CLAIM if settings.OIDC_REQUIRES_GROUP_CLAIM else ""
|
||||
scope = f"openid email profile {groups_claim}"
|
||||
oauth.register(
|
||||
"oidc",
|
||||
client_id=settings.OIDC_CLIENT_ID,
|
||||
client_secret=settings.OIDC_CLIENT_SECRET,
|
||||
server_metadata_url=settings.OIDC_CONFIGURATION_URL,
|
||||
client_kwargs={"scope": scope.rstrip()},
|
||||
code_challenge_method="S256",
|
||||
)
|
||||
|
||||
|
||||
class MealieAuthToken(BaseModel):
|
||||
access_token: str
|
||||
@@ -31,7 +50,7 @@ class MealieAuthToken(BaseModel):
|
||||
|
||||
|
||||
@public_router.post("/token")
|
||||
async def get_token(
|
||||
def get_token(
|
||||
request: Request,
|
||||
response: Response,
|
||||
data: CredentialsRequestForm = Depends(),
|
||||
@@ -46,8 +65,8 @@ async def get_token(
|
||||
ip = request.client.host if request.client else "unknown"
|
||||
|
||||
try:
|
||||
auth_provider = get_auth_provider(session, request, data)
|
||||
auth = await auth_provider.authenticate()
|
||||
auth_provider = get_auth_provider(session, data)
|
||||
auth = auth_provider.authenticate()
|
||||
except UserLockedOut as e:
|
||||
logger.error(f"User is locked out from {ip}")
|
||||
raise HTTPException(status_code=status.HTTP_423_LOCKED, detail="User is locked out") from e
|
||||
@@ -61,7 +80,61 @@ async def get_token(
|
||||
|
||||
expires_in = duration.total_seconds() if duration else None
|
||||
response.set_cookie(
|
||||
key="mealie.access_token", value=access_token, httponly=True, max_age=expires_in, expires=expires_in
|
||||
key="mealie.access_token",
|
||||
value=access_token,
|
||||
httponly=True,
|
||||
max_age=expires_in,
|
||||
secure=settings.PRODUCTION,
|
||||
)
|
||||
|
||||
return MealieAuthToken.respond(access_token)
|
||||
|
||||
|
||||
@public_router.get("/oauth")
|
||||
async def oauth_login(request: Request):
|
||||
if not oauth:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Could not initialize OAuth client",
|
||||
)
|
||||
client = oauth.create_client("oidc")
|
||||
redirect_url = None
|
||||
if not settings.PRODUCTION:
|
||||
# in development, we want to redirect to the frontend
|
||||
redirect_url = "http://localhost:3000/login"
|
||||
else:
|
||||
redirect_url = URLPath("/login").make_absolute_url(request.base_url)
|
||||
|
||||
response: RedirectResponse = await client.authorize_redirect(request, redirect_url)
|
||||
return response
|
||||
|
||||
|
||||
@public_router.get("/oauth/callback")
|
||||
async def oauth_callback(request: Request, response: Response, session: Session = Depends(generate_session)):
|
||||
if not oauth:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Could not initialize OAuth client",
|
||||
)
|
||||
client = oauth.create_client("oidc")
|
||||
token = await client.authorize_access_token(request)
|
||||
auth_provider = OpenIDProvider(session, token["userinfo"])
|
||||
auth = auth_provider.authenticate()
|
||||
|
||||
if not auth:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
access_token, duration = auth
|
||||
|
||||
expires_in = duration.total_seconds() if duration else None
|
||||
|
||||
response.set_cookie(
|
||||
key="mealie.access_token",
|
||||
value=access_token,
|
||||
httponly=True,
|
||||
max_age=expires_in,
|
||||
secure=settings.PRODUCTION,
|
||||
)
|
||||
|
||||
return MealieAuthToken.respond(access_token)
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
# This file is auto-generated by gen_schema_exports.py
|
||||
from .about import AdminAboutInfo, AppInfo, AppStartupInfo, AppStatistics, AppTheme, CheckAppConfig, OIDCInfo
|
||||
from .about import (
|
||||
AdminAboutInfo,
|
||||
AppInfo,
|
||||
AppStartupInfo,
|
||||
AppStatistics,
|
||||
AppTheme,
|
||||
CheckAppConfig,
|
||||
)
|
||||
from .backup import AllBackups, BackupFile, BackupOptions, CreateBackup, ImportJob
|
||||
from .debug import DebugResponse
|
||||
from .email import EmailReady, EmailSuccess, EmailTest
|
||||
@@ -46,7 +53,6 @@ __all__ = [
|
||||
"AppStatistics",
|
||||
"AppTheme",
|
||||
"CheckAppConfig",
|
||||
"OIDCInfo",
|
||||
"EmailReady",
|
||||
"EmailSuccess",
|
||||
"EmailTest",
|
||||
|
||||
@@ -72,9 +72,3 @@ class CheckAppConfig(MealieModel):
|
||||
enable_openai: bool
|
||||
base_url_set: bool
|
||||
is_up_to_date: bool
|
||||
|
||||
|
||||
class OIDCInfo(MealieModel):
|
||||
configuration_url: str | None
|
||||
client_id: str | None
|
||||
groups_claim: str | None
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# This file is auto-generated by gen_schema_exports.py
|
||||
from .auth import CredentialsRequest, CredentialsRequestForm, OIDCRequest, Token, TokenData, UnlockResults
|
||||
from .auth import CredentialsRequest, CredentialsRequestForm, Token, TokenData, UnlockResults
|
||||
from .registration import CreateUserRegistration
|
||||
from .user import (
|
||||
ChangePassword,
|
||||
@@ -37,19 +37,18 @@ from .user_passwords import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CreateUserRegistration",
|
||||
"CredentialsRequest",
|
||||
"CredentialsRequestForm",
|
||||
"Token",
|
||||
"TokenData",
|
||||
"UnlockResults",
|
||||
"ForgotPassword",
|
||||
"PasswordResetToken",
|
||||
"PrivatePasswordResetToken",
|
||||
"ResetPassword",
|
||||
"SavePasswordResetToken",
|
||||
"ValidateResetToken",
|
||||
"CredentialsRequest",
|
||||
"CredentialsRequestForm",
|
||||
"OIDCRequest",
|
||||
"Token",
|
||||
"TokenData",
|
||||
"UnlockResults",
|
||||
"CreateUserRegistration",
|
||||
"ChangePassword",
|
||||
"CreateToken",
|
||||
"DeleteTokenResponse",
|
||||
|
||||
@@ -26,14 +26,15 @@ class CredentialsRequest(BaseModel):
|
||||
remember_me: bool = False
|
||||
|
||||
|
||||
class OIDCRequest(BaseModel):
|
||||
id_token: str
|
||||
|
||||
|
||||
class CredentialsRequestForm:
|
||||
"""Class that represents a user's credentials from the login form"""
|
||||
|
||||
def __init__(self, username: str = Form(""), password: str = Form(""), remember_me: bool = Form(False)):
|
||||
def __init__(
|
||||
self,
|
||||
username: str = Form(""),
|
||||
password: str = Form(""),
|
||||
remember_me: bool = Form(False),
|
||||
):
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.remember_me = remember_me
|
||||
|
||||
Reference in New Issue
Block a user