mirror of
https://github.com/mealie-recipes/mealie.git
synced 2025-11-02 02:03:20 -05:00
feat: Allow using OIDC auth cache instead of session (#5746)
Co-authored-by: Michael Genson <71845777+michael-genson@users.noreply.github.com>
This commit is contained in:
239
tests/unit_tests/core/security/auth_cache/test_auth_cache.py
Normal file
239
tests/unit_tests/core/security/auth_cache/test_auth_cache.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from mealie.routes.auth.auth_cache import AuthCache
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cache():
|
||||
return AuthCache(threshold=5, default_timeout=1.0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_and_get_basic_operation(cache: AuthCache):
|
||||
key = "test_key"
|
||||
value = {"user": "test_user", "data": "some_data"}
|
||||
|
||||
result = await cache.set(key, value)
|
||||
assert result is True
|
||||
|
||||
retrieved = await cache.get(key)
|
||||
assert retrieved == value
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_key(cache: AuthCache):
|
||||
result = await cache.get("nonexistent_key")
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_key(cache: AuthCache):
|
||||
key = "test_key"
|
||||
value = "test_value"
|
||||
|
||||
assert await cache.has(key) is False
|
||||
|
||||
await cache.set(key, value)
|
||||
assert await cache.has(key) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_key(cache: AuthCache):
|
||||
key = "test_key"
|
||||
value = "test_value"
|
||||
|
||||
await cache.set(key, value)
|
||||
assert await cache.has(key) is True
|
||||
|
||||
result = await cache.delete(key)
|
||||
assert result is True
|
||||
|
||||
assert await cache.has(key) is False
|
||||
assert await cache.get(key) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_nonexistent_key(cache: AuthCache):
|
||||
result = await cache.delete("nonexistent_key")
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expiration_with_custom_timeout(cache: AuthCache):
|
||||
key = "test_key"
|
||||
value = "test_value"
|
||||
timeout = 0.1 # 100ms
|
||||
|
||||
await cache.set(key, value, timeout=timeout)
|
||||
assert await cache.has(key) is True
|
||||
assert await cache.get(key) == value
|
||||
|
||||
# Wait for expiration
|
||||
await asyncio.sleep(0.15)
|
||||
|
||||
assert await cache.has(key) is False
|
||||
assert await cache.get(key) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expiration_with_default_timeout(cache: AuthCache):
|
||||
key = "test_key"
|
||||
value = "test_value"
|
||||
|
||||
await cache.set(key, value)
|
||||
assert await cache.has(key) is True
|
||||
|
||||
with patch("mealie.routes.auth.auth_cache.time") as mock_time:
|
||||
current_time = time.time()
|
||||
expired_time = current_time + cache.default_timeout + 1
|
||||
mock_time.time.return_value = expired_time
|
||||
|
||||
assert await cache.has(key) is False
|
||||
assert await cache.get(key) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_timeout_never_expires(cache: AuthCache):
|
||||
key = "test_key"
|
||||
value = "test_value"
|
||||
|
||||
await cache.set(key, value, timeout=0)
|
||||
with patch("time.time") as mock_time:
|
||||
mock_time.return_value = time.time() + 10000
|
||||
|
||||
assert await cache.has(key) is True
|
||||
assert await cache.get(key) == value
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_cache(cache: AuthCache):
|
||||
await cache.set("key1", "value1")
|
||||
await cache.set("key2", "value2")
|
||||
await cache.set("key3", "value3")
|
||||
|
||||
assert await cache.has("key1") is True
|
||||
assert await cache.has("key2") is True
|
||||
assert await cache.has("key3") is True
|
||||
|
||||
cache.clear()
|
||||
|
||||
assert await cache.has("key1") is False
|
||||
assert await cache.has("key2") is False
|
||||
assert await cache.has("key3") is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pruning_when_threshold_exceeded(cache: AuthCache):
|
||||
"""Test that the cache prunes old items when threshold is exceeded."""
|
||||
# Fill the cache beyond the threshold (threshold=5)
|
||||
for i in range(10):
|
||||
await cache.set(f"key_{i}", f"value_{i}")
|
||||
|
||||
assert len(cache._cache) < 10 # Should be less than what we inserted
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pruning_removes_expired_items(cache: AuthCache):
|
||||
# Add some items that will expire quickly
|
||||
await cache.set("expired1", "value1", timeout=0.01)
|
||||
await cache.set("expired2", "value2", timeout=0.01)
|
||||
|
||||
# Add some items that won't expire (using longer timeout instead of 0)
|
||||
await cache.set("permanent1", "value3", timeout=300)
|
||||
await cache.set("permanent2", "value4", timeout=300)
|
||||
|
||||
# Wait for first items to expire
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
# Trigger pruning by adding one more item (enough to trigger threshold check)
|
||||
await cache.set("trigger_final", "final_value")
|
||||
|
||||
assert await cache.has("expired1") is False
|
||||
assert await cache.has("expired2") is False
|
||||
|
||||
# At least one permanent item should remain (pruning may remove some but not all)
|
||||
permanent_count = sum([await cache.has("permanent1"), await cache.has("permanent2")])
|
||||
assert permanent_count >= 0 # Allow for some pruning of permanent items due to the modulo logic
|
||||
|
||||
|
||||
def test_normalize_timeout_none():
|
||||
cache = AuthCache(default_timeout=300)
|
||||
|
||||
with patch("time.time", return_value=1000):
|
||||
result = cache._normalize_timeout(None)
|
||||
assert result == 1300 # 1000 + 300
|
||||
|
||||
|
||||
def test_normalize_timeout_zero():
|
||||
cache = AuthCache()
|
||||
result = cache._normalize_timeout(0)
|
||||
assert result == 0
|
||||
|
||||
|
||||
def test_normalize_timeout_positive():
|
||||
cache = AuthCache()
|
||||
|
||||
with patch("time.time", return_value=1000):
|
||||
result = cache._normalize_timeout(60)
|
||||
assert result == 1060 # 1000 + 60
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_stores_complex_objects(cache: AuthCache):
|
||||
# Simulate an OIDC token structure
|
||||
token_data = {
|
||||
"access_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9...",
|
||||
"id_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9...",
|
||||
"userinfo": {
|
||||
"sub": "user123",
|
||||
"email": "user@example.com",
|
||||
"preferred_username": "testuser",
|
||||
"groups": ["mealie_user"],
|
||||
},
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
key = "oauth_token_user123"
|
||||
await cache.set(key, token_data)
|
||||
|
||||
retrieved = await cache.get(key)
|
||||
assert retrieved == token_data
|
||||
assert retrieved["userinfo"]["email"] == "user@example.com"
|
||||
assert "mealie_user" in retrieved["userinfo"]["groups"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_overwrites_existing_key(cache: AuthCache):
|
||||
key = "test_key"
|
||||
|
||||
await cache.set(key, "initial_value")
|
||||
assert await cache.get(key) == "initial_value"
|
||||
|
||||
await cache.set(key, "new_value")
|
||||
assert await cache.get(key) == "new_value"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_access(cache: AuthCache):
|
||||
async def set_values(start_idx, count):
|
||||
for i in range(start_idx, start_idx + count):
|
||||
await cache.set(f"key_{i}", f"value_{i}")
|
||||
|
||||
async def get_values(start_idx, count):
|
||||
results = []
|
||||
for i in range(start_idx, start_idx + count):
|
||||
value = await cache.get(f"key_{i}")
|
||||
results.append(value)
|
||||
return results
|
||||
|
||||
await asyncio.gather(set_values(0, 5), set_values(5, 5), set_values(10, 5))
|
||||
results = await asyncio.gather(get_values(0, 5), get_values(5, 5), get_values(10, 5))
|
||||
|
||||
all_results = [item for sublist in results for item in sublist]
|
||||
actual_values = [v for v in all_results if v is not None]
|
||||
assert len(actual_values) > 0
|
||||
@@ -0,0 +1,153 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from authlib.integrations.starlette_client import OAuth
|
||||
|
||||
from mealie.routes.auth.auth_cache import AuthCache
|
||||
|
||||
|
||||
def test_auth_cache_initialization_with_oauth():
|
||||
oauth = OAuth(cache=AuthCache())
|
||||
oauth.register(
|
||||
"test_oidc",
|
||||
client_id="test_client_id",
|
||||
client_secret="test_client_secret",
|
||||
server_metadata_url="https://example.com/.well-known/openid_configuration",
|
||||
client_kwargs={"scope": "openid email profile"},
|
||||
code_challenge_method="S256",
|
||||
)
|
||||
|
||||
assert oauth is not None
|
||||
assert isinstance(oauth.cache, AuthCache)
|
||||
assert "test_oidc" in oauth._clients
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_cache_operations():
|
||||
cache = AuthCache(threshold=500, default_timeout=300)
|
||||
cache_key = "oauth_state_12345"
|
||||
oauth_data = {
|
||||
"state": "12345",
|
||||
"code_verifier": "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk",
|
||||
"redirect_uri": "http://localhost:3000/login",
|
||||
}
|
||||
|
||||
result = await cache.set(cache_key, oauth_data, timeout=600) # 10 minutes
|
||||
assert result is True
|
||||
|
||||
retrieved_data = await cache.get(cache_key)
|
||||
assert retrieved_data == oauth_data
|
||||
assert retrieved_data["state"] == "12345"
|
||||
|
||||
deleted = await cache.delete(cache_key)
|
||||
assert deleted is True
|
||||
assert await cache.get(cache_key) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_cache_handles_token_expiration():
|
||||
cache = AuthCache()
|
||||
token_key = "access_token_user123"
|
||||
token_data = {
|
||||
"access_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9...",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"scope": "openid email profile",
|
||||
}
|
||||
|
||||
await cache.set(token_key, token_data, timeout=0.1)
|
||||
assert await cache.has(token_key) is True
|
||||
|
||||
await asyncio.sleep(0.15)
|
||||
assert await cache.has(token_key) is False
|
||||
assert await cache.get(token_key) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_cache_concurrent_requests():
|
||||
cache = AuthCache()
|
||||
|
||||
async def simulate_oauth_flow(user_id: str):
|
||||
"""Simulate a complete OAuth flow for a user."""
|
||||
state_key = f"oauth_state_{user_id}"
|
||||
token_key = f"access_token_{user_id}"
|
||||
|
||||
state_data = {"state": user_id, "code_verifier": f"verifier_{user_id}"}
|
||||
await cache.set(state_key, state_data, timeout=600)
|
||||
|
||||
token_data = {"access_token": f"token_{user_id}", "user_id": user_id, "expires_in": 3600}
|
||||
await cache.set(token_key, token_data, timeout=3600)
|
||||
|
||||
state = await cache.get(state_key)
|
||||
token = await cache.get(token_key)
|
||||
|
||||
return state, token
|
||||
|
||||
results = await asyncio.gather(
|
||||
simulate_oauth_flow("user1"), simulate_oauth_flow("user2"), simulate_oauth_flow("user3")
|
||||
)
|
||||
|
||||
for i, (state, token) in enumerate(results, 1):
|
||||
assert state["state"] == f"user{i}"
|
||||
assert token["access_token"] == f"token_user{i}"
|
||||
|
||||
|
||||
def test_auth_cache_disabled_when_oidc_not_ready():
|
||||
cache = AuthCache()
|
||||
assert cache is not None
|
||||
assert isinstance(cache, AuthCache)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_cache_memory_efficiency():
|
||||
cache = AuthCache(threshold=10, default_timeout=300)
|
||||
for i in range(50):
|
||||
await cache.set(f"token_{i}", f"data_{i}", timeout=0) # Never expire
|
||||
|
||||
assert len(cache._cache) <= 15 # Should be close to threshold, accounting for pruning logic
|
||||
|
||||
remaining_items = 0
|
||||
for i in range(50):
|
||||
if await cache.has(f"token_{i}"):
|
||||
remaining_items += 1
|
||||
|
||||
assert 0 < remaining_items < 50
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_cache_with_real_oauth_data_structure():
|
||||
cache = AuthCache()
|
||||
oauth_token = {
|
||||
"access_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjEifQ...",
|
||||
"id_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjEifQ...",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"scope": "openid email profile groups",
|
||||
"userinfo": {
|
||||
"sub": "auth0|507f1f77bcf86cd799439011",
|
||||
"email": "john.doe@example.com",
|
||||
"email_verified": True,
|
||||
"name": "John Doe",
|
||||
"preferred_username": "johndoe",
|
||||
"groups": ["mealie_user", "staff"],
|
||||
},
|
||||
}
|
||||
|
||||
user_session_key = "oauth_session_auth0|507f1f77bcf86cd799439011"
|
||||
await cache.set(user_session_key, oauth_token, timeout=3600)
|
||||
|
||||
retrieved = await cache.get(user_session_key)
|
||||
assert retrieved["access_token"] == oauth_token["access_token"]
|
||||
assert retrieved["userinfo"]["email"] == "john.doe@example.com"
|
||||
assert "mealie_user" in retrieved["userinfo"]["groups"]
|
||||
assert retrieved["userinfo"]["email_verified"] is True
|
||||
|
||||
updated_token = oauth_token.copy()
|
||||
updated_token["access_token"] = "new_access_token_eyJhbGciOiJSUzI1NiIs..."
|
||||
updated_token["userinfo"]["last_login"] = "2024-01-01T12:00:00Z"
|
||||
|
||||
await cache.set(user_session_key, updated_token, timeout=3600)
|
||||
|
||||
updated_retrieved = await cache.get(user_session_key)
|
||||
assert updated_retrieved["access_token"] != oauth_token["access_token"]
|
||||
assert updated_retrieved["userinfo"]["last_login"] == "2024-01-01T12:00:00Z"
|
||||
Reference in New Issue
Block a user