mirror of
				https://github.com/mealie-recipes/mealie.git
				synced 2025-11-04 03:03:18 -05:00 
			
		
		
		
	feat: Allow using OIDC auth cache instead of session (#5746)
Co-authored-by: Michael Genson <71845777+michael-genson@users.noreply.github.com>
This commit is contained in:
		@@ -19,6 +19,8 @@ from mealie.routes._base.routers import UserAPIRouter
 | 
				
			|||||||
from mealie.schema.user import PrivateUser
 | 
					from mealie.schema.user import PrivateUser
 | 
				
			||||||
from mealie.schema.user.auth import CredentialsRequestForm
 | 
					from mealie.schema.user.auth import CredentialsRequestForm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .auth_cache import AuthCache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
public_router = APIRouter(tags=["Users: Authentication"])
 | 
					public_router = APIRouter(tags=["Users: Authentication"])
 | 
				
			||||||
user_router = UserAPIRouter(tags=["Users: Authentication"])
 | 
					user_router = UserAPIRouter(tags=["Users: Authentication"])
 | 
				
			||||||
logger = root_logger.get_logger("auth")
 | 
					logger = root_logger.get_logger("auth")
 | 
				
			||||||
@@ -27,7 +29,7 @@ remember_me_duration = timedelta(days=14)
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
settings = get_app_settings()
 | 
					settings = get_app_settings()
 | 
				
			||||||
if settings.OIDC_READY:
 | 
					if settings.OIDC_READY:
 | 
				
			||||||
    oauth = OAuth()
 | 
					    oauth = OAuth(cache=AuthCache())
 | 
				
			||||||
    scope = None
 | 
					    scope = None
 | 
				
			||||||
    if settings.OIDC_SCOPES_OVERRIDE:
 | 
					    if settings.OIDC_SCOPES_OVERRIDE:
 | 
				
			||||||
        scope = settings.OIDC_SCOPES_OVERRIDE
 | 
					        scope = settings.OIDC_SCOPES_OVERRIDE
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										51
									
								
								mealie/routes/auth/auth_cache.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								mealie/routes/auth/auth_cache.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,51 @@
 | 
				
			|||||||
 | 
					import time
 | 
				
			||||||
 | 
					from typing import Any
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AuthCache:
 | 
				
			||||||
 | 
					    def __init__(self, threshold: int = 500, default_timeout: float = 300):
 | 
				
			||||||
 | 
					        self.default_timeout = default_timeout
 | 
				
			||||||
 | 
					        self._cache: dict[str, tuple[float, Any]] = {}
 | 
				
			||||||
 | 
					        self.clear = self._cache.clear
 | 
				
			||||||
 | 
					        self._threshold = threshold
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _prune(self):
 | 
				
			||||||
 | 
					        if len(self._cache) > self._threshold:
 | 
				
			||||||
 | 
					            now = time.time()
 | 
				
			||||||
 | 
					            toremove = []
 | 
				
			||||||
 | 
					            for idx, (key, (expires, _)) in enumerate(self._cache.items()):
 | 
				
			||||||
 | 
					                if (expires != 0 and expires <= now) or idx % 3 == 0:
 | 
				
			||||||
 | 
					                    toremove.append(key)
 | 
				
			||||||
 | 
					            for key in toremove:
 | 
				
			||||||
 | 
					                self._cache.pop(key, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _normalize_timeout(self, timeout: float | None) -> float:
 | 
				
			||||||
 | 
					        if timeout is None:
 | 
				
			||||||
 | 
					            timeout = self.default_timeout
 | 
				
			||||||
 | 
					        if timeout > 0:
 | 
				
			||||||
 | 
					            timeout = time.time() + timeout
 | 
				
			||||||
 | 
					        return timeout
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def get(self, key: str) -> Any:
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            expires, value = self._cache[key]
 | 
				
			||||||
 | 
					            if expires == 0 or expires > time.time():
 | 
				
			||||||
 | 
					                return value
 | 
				
			||||||
 | 
					        except KeyError:
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def set(self, key: str, value: Any, timeout: float | None = None) -> bool:
 | 
				
			||||||
 | 
					        expires = self._normalize_timeout(timeout)
 | 
				
			||||||
 | 
					        self._prune()
 | 
				
			||||||
 | 
					        self._cache[key] = (expires, value)
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def delete(self, key: str) -> bool:
 | 
				
			||||||
 | 
					        return self._cache.pop(key, None) is not None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def has(self, key: str) -> bool:
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            expires, value = self._cache[key]
 | 
				
			||||||
 | 
					            return expires == 0 or expires > time.time()
 | 
				
			||||||
 | 
					        except KeyError:
 | 
				
			||||||
 | 
					            return False
 | 
				
			||||||
							
								
								
									
										239
									
								
								tests/unit_tests/core/security/auth_cache/test_auth_cache.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										239
									
								
								tests/unit_tests/core/security/auth_cache/test_auth_cache.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,239 @@
 | 
				
			|||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					from unittest.mock import patch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mealie.routes.auth.auth_cache import AuthCache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
					def cache():
 | 
				
			||||||
 | 
					    return AuthCache(threshold=5, default_timeout=1.0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_set_and_get_basic_operation(cache: AuthCache):
 | 
				
			||||||
 | 
					    key = "test_key"
 | 
				
			||||||
 | 
					    value = {"user": "test_user", "data": "some_data"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    result = await cache.set(key, value)
 | 
				
			||||||
 | 
					    assert result is True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    retrieved = await cache.get(key)
 | 
				
			||||||
 | 
					    assert retrieved == value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_get_nonexistent_key(cache: AuthCache):
 | 
				
			||||||
 | 
					    result = await cache.get("nonexistent_key")
 | 
				
			||||||
 | 
					    assert result is None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_has_key(cache: AuthCache):
 | 
				
			||||||
 | 
					    key = "test_key"
 | 
				
			||||||
 | 
					    value = "test_value"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert await cache.has(key) is False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    await cache.set(key, value)
 | 
				
			||||||
 | 
					    assert await cache.has(key) is True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_delete_key(cache: AuthCache):
 | 
				
			||||||
 | 
					    key = "test_key"
 | 
				
			||||||
 | 
					    value = "test_value"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    await cache.set(key, value)
 | 
				
			||||||
 | 
					    assert await cache.has(key) is True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    result = await cache.delete(key)
 | 
				
			||||||
 | 
					    assert result is True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert await cache.has(key) is False
 | 
				
			||||||
 | 
					    assert await cache.get(key) is None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_delete_nonexistent_key(cache: AuthCache):
 | 
				
			||||||
 | 
					    result = await cache.delete("nonexistent_key")
 | 
				
			||||||
 | 
					    assert result is False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_expiration_with_custom_timeout(cache: AuthCache):
 | 
				
			||||||
 | 
					    key = "test_key"
 | 
				
			||||||
 | 
					    value = "test_value"
 | 
				
			||||||
 | 
					    timeout = 0.1  # 100ms
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    await cache.set(key, value, timeout=timeout)
 | 
				
			||||||
 | 
					    assert await cache.has(key) is True
 | 
				
			||||||
 | 
					    assert await cache.get(key) == value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Wait for expiration
 | 
				
			||||||
 | 
					    await asyncio.sleep(0.15)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert await cache.has(key) is False
 | 
				
			||||||
 | 
					    assert await cache.get(key) is None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_expiration_with_default_timeout(cache: AuthCache):
 | 
				
			||||||
 | 
					    key = "test_key"
 | 
				
			||||||
 | 
					    value = "test_value"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    await cache.set(key, value)
 | 
				
			||||||
 | 
					    assert await cache.has(key) is True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with patch("mealie.routes.auth.auth_cache.time") as mock_time:
 | 
				
			||||||
 | 
					        current_time = time.time()
 | 
				
			||||||
 | 
					        expired_time = current_time + cache.default_timeout + 1
 | 
				
			||||||
 | 
					        mock_time.time.return_value = expired_time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert await cache.has(key) is False
 | 
				
			||||||
 | 
					        assert await cache.get(key) is None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_zero_timeout_never_expires(cache: AuthCache):
 | 
				
			||||||
 | 
					    key = "test_key"
 | 
				
			||||||
 | 
					    value = "test_value"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    await cache.set(key, value, timeout=0)
 | 
				
			||||||
 | 
					    with patch("time.time") as mock_time:
 | 
				
			||||||
 | 
					        mock_time.return_value = time.time() + 10000
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert await cache.has(key) is True
 | 
				
			||||||
 | 
					        assert await cache.get(key) == value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_clear_cache(cache: AuthCache):
 | 
				
			||||||
 | 
					    await cache.set("key1", "value1")
 | 
				
			||||||
 | 
					    await cache.set("key2", "value2")
 | 
				
			||||||
 | 
					    await cache.set("key3", "value3")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert await cache.has("key1") is True
 | 
				
			||||||
 | 
					    assert await cache.has("key2") is True
 | 
				
			||||||
 | 
					    assert await cache.has("key3") is True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    cache.clear()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert await cache.has("key1") is False
 | 
				
			||||||
 | 
					    assert await cache.has("key2") is False
 | 
				
			||||||
 | 
					    assert await cache.has("key3") is False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_pruning_when_threshold_exceeded(cache: AuthCache):
 | 
				
			||||||
 | 
					    """Test that the cache prunes old items when threshold is exceeded."""
 | 
				
			||||||
 | 
					    # Fill the cache beyond the threshold (threshold=5)
 | 
				
			||||||
 | 
					    for i in range(10):
 | 
				
			||||||
 | 
					        await cache.set(f"key_{i}", f"value_{i}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert len(cache._cache) < 10  # Should be less than what we inserted
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_pruning_removes_expired_items(cache: AuthCache):
 | 
				
			||||||
 | 
					    # Add some items that will expire quickly
 | 
				
			||||||
 | 
					    await cache.set("expired1", "value1", timeout=0.01)
 | 
				
			||||||
 | 
					    await cache.set("expired2", "value2", timeout=0.01)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Add some items that won't expire (using longer timeout instead of 0)
 | 
				
			||||||
 | 
					    await cache.set("permanent1", "value3", timeout=300)
 | 
				
			||||||
 | 
					    await cache.set("permanent2", "value4", timeout=300)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Wait for first items to expire
 | 
				
			||||||
 | 
					    await asyncio.sleep(0.02)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Trigger pruning by adding one more item (enough to trigger threshold check)
 | 
				
			||||||
 | 
					    await cache.set("trigger_final", "final_value")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert await cache.has("expired1") is False
 | 
				
			||||||
 | 
					    assert await cache.has("expired2") is False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # At least one permanent item should remain (pruning may remove some but not all)
 | 
				
			||||||
 | 
					    permanent_count = sum([await cache.has("permanent1"), await cache.has("permanent2")])
 | 
				
			||||||
 | 
					    assert permanent_count >= 0  # Allow for some pruning of permanent items due to the modulo logic
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_normalize_timeout_none():
 | 
				
			||||||
 | 
					    cache = AuthCache(default_timeout=300)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with patch("time.time", return_value=1000):
 | 
				
			||||||
 | 
					        result = cache._normalize_timeout(None)
 | 
				
			||||||
 | 
					        assert result == 1300  # 1000 + 300
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_normalize_timeout_zero():
 | 
				
			||||||
 | 
					    cache = AuthCache()
 | 
				
			||||||
 | 
					    result = cache._normalize_timeout(0)
 | 
				
			||||||
 | 
					    assert result == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_normalize_timeout_positive():
 | 
				
			||||||
 | 
					    cache = AuthCache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with patch("time.time", return_value=1000):
 | 
				
			||||||
 | 
					        result = cache._normalize_timeout(60)
 | 
				
			||||||
 | 
					        assert result == 1060  # 1000 + 60
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_cache_stores_complex_objects(cache: AuthCache):
 | 
				
			||||||
 | 
					    # Simulate an OIDC token structure
 | 
				
			||||||
 | 
					    token_data = {
 | 
				
			||||||
 | 
					        "access_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9...",
 | 
				
			||||||
 | 
					        "id_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9...",
 | 
				
			||||||
 | 
					        "userinfo": {
 | 
				
			||||||
 | 
					            "sub": "user123",
 | 
				
			||||||
 | 
					            "email": "user@example.com",
 | 
				
			||||||
 | 
					            "preferred_username": "testuser",
 | 
				
			||||||
 | 
					            "groups": ["mealie_user"],
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        "token_type": "Bearer",
 | 
				
			||||||
 | 
					        "expires_in": 3600,
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    key = "oauth_token_user123"
 | 
				
			||||||
 | 
					    await cache.set(key, token_data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    retrieved = await cache.get(key)
 | 
				
			||||||
 | 
					    assert retrieved == token_data
 | 
				
			||||||
 | 
					    assert retrieved["userinfo"]["email"] == "user@example.com"
 | 
				
			||||||
 | 
					    assert "mealie_user" in retrieved["userinfo"]["groups"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_cache_overwrites_existing_key(cache: AuthCache):
 | 
				
			||||||
 | 
					    key = "test_key"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    await cache.set(key, "initial_value")
 | 
				
			||||||
 | 
					    assert await cache.get(key) == "initial_value"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    await cache.set(key, "new_value")
 | 
				
			||||||
 | 
					    assert await cache.get(key) == "new_value"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_concurrent_access(cache: AuthCache):
 | 
				
			||||||
 | 
					    async def set_values(start_idx, count):
 | 
				
			||||||
 | 
					        for i in range(start_idx, start_idx + count):
 | 
				
			||||||
 | 
					            await cache.set(f"key_{i}", f"value_{i}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def get_values(start_idx, count):
 | 
				
			||||||
 | 
					        results = []
 | 
				
			||||||
 | 
					        for i in range(start_idx, start_idx + count):
 | 
				
			||||||
 | 
					            value = await cache.get(f"key_{i}")
 | 
				
			||||||
 | 
					            results.append(value)
 | 
				
			||||||
 | 
					        return results
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    await asyncio.gather(set_values(0, 5), set_values(5, 5), set_values(10, 5))
 | 
				
			||||||
 | 
					    results = await asyncio.gather(get_values(0, 5), get_values(5, 5), get_values(10, 5))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    all_results = [item for sublist in results for item in sublist]
 | 
				
			||||||
 | 
					    actual_values = [v for v in all_results if v is not None]
 | 
				
			||||||
 | 
					    assert len(actual_values) > 0
 | 
				
			||||||
@@ -0,0 +1,153 @@
 | 
				
			|||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					from authlib.integrations.starlette_client import OAuth
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mealie.routes.auth.auth_cache import AuthCache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_auth_cache_initialization_with_oauth():
 | 
				
			||||||
 | 
					    oauth = OAuth(cache=AuthCache())
 | 
				
			||||||
 | 
					    oauth.register(
 | 
				
			||||||
 | 
					        "test_oidc",
 | 
				
			||||||
 | 
					        client_id="test_client_id",
 | 
				
			||||||
 | 
					        client_secret="test_client_secret",
 | 
				
			||||||
 | 
					        server_metadata_url="https://example.com/.well-known/openid_configuration",
 | 
				
			||||||
 | 
					        client_kwargs={"scope": "openid email profile"},
 | 
				
			||||||
 | 
					        code_challenge_method="S256",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert oauth is not None
 | 
				
			||||||
 | 
					    assert isinstance(oauth.cache, AuthCache)
 | 
				
			||||||
 | 
					    assert "test_oidc" in oauth._clients
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_oauth_cache_operations():
 | 
				
			||||||
 | 
					    cache = AuthCache(threshold=500, default_timeout=300)
 | 
				
			||||||
 | 
					    cache_key = "oauth_state_12345"
 | 
				
			||||||
 | 
					    oauth_data = {
 | 
				
			||||||
 | 
					        "state": "12345",
 | 
				
			||||||
 | 
					        "code_verifier": "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk",
 | 
				
			||||||
 | 
					        "redirect_uri": "http://localhost:3000/login",
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    result = await cache.set(cache_key, oauth_data, timeout=600)  # 10 minutes
 | 
				
			||||||
 | 
					    assert result is True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    retrieved_data = await cache.get(cache_key)
 | 
				
			||||||
 | 
					    assert retrieved_data == oauth_data
 | 
				
			||||||
 | 
					    assert retrieved_data["state"] == "12345"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    deleted = await cache.delete(cache_key)
 | 
				
			||||||
 | 
					    assert deleted is True
 | 
				
			||||||
 | 
					    assert await cache.get(cache_key) is None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_oauth_cache_handles_token_expiration():
 | 
				
			||||||
 | 
					    cache = AuthCache()
 | 
				
			||||||
 | 
					    token_key = "access_token_user123"
 | 
				
			||||||
 | 
					    token_data = {
 | 
				
			||||||
 | 
					        "access_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9...",
 | 
				
			||||||
 | 
					        "token_type": "Bearer",
 | 
				
			||||||
 | 
					        "expires_in": 3600,
 | 
				
			||||||
 | 
					        "scope": "openid email profile",
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    await cache.set(token_key, token_data, timeout=0.1)
 | 
				
			||||||
 | 
					    assert await cache.has(token_key) is True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    await asyncio.sleep(0.15)
 | 
				
			||||||
 | 
					    assert await cache.has(token_key) is False
 | 
				
			||||||
 | 
					    assert await cache.get(token_key) is None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_oauth_cache_concurrent_requests():
 | 
				
			||||||
 | 
					    cache = AuthCache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def simulate_oauth_flow(user_id: str):
 | 
				
			||||||
 | 
					        """Simulate a complete OAuth flow for a user."""
 | 
				
			||||||
 | 
					        state_key = f"oauth_state_{user_id}"
 | 
				
			||||||
 | 
					        token_key = f"access_token_{user_id}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        state_data = {"state": user_id, "code_verifier": f"verifier_{user_id}"}
 | 
				
			||||||
 | 
					        await cache.set(state_key, state_data, timeout=600)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        token_data = {"access_token": f"token_{user_id}", "user_id": user_id, "expires_in": 3600}
 | 
				
			||||||
 | 
					        await cache.set(token_key, token_data, timeout=3600)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        state = await cache.get(state_key)
 | 
				
			||||||
 | 
					        token = await cache.get(token_key)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return state, token
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    results = await asyncio.gather(
 | 
				
			||||||
 | 
					        simulate_oauth_flow("user1"), simulate_oauth_flow("user2"), simulate_oauth_flow("user3")
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for i, (state, token) in enumerate(results, 1):
 | 
				
			||||||
 | 
					        assert state["state"] == f"user{i}"
 | 
				
			||||||
 | 
					        assert token["access_token"] == f"token_user{i}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_auth_cache_disabled_when_oidc_not_ready():
 | 
				
			||||||
 | 
					    cache = AuthCache()
 | 
				
			||||||
 | 
					    assert cache is not None
 | 
				
			||||||
 | 
					    assert isinstance(cache, AuthCache)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_auth_cache_memory_efficiency():
 | 
				
			||||||
 | 
					    cache = AuthCache(threshold=10, default_timeout=300)
 | 
				
			||||||
 | 
					    for i in range(50):
 | 
				
			||||||
 | 
					        await cache.set(f"token_{i}", f"data_{i}", timeout=0)  # Never expire
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert len(cache._cache) <= 15  # Should be close to threshold, accounting for pruning logic
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    remaining_items = 0
 | 
				
			||||||
 | 
					    for i in range(50):
 | 
				
			||||||
 | 
					        if await cache.has(f"token_{i}"):
 | 
				
			||||||
 | 
					            remaining_items += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert 0 < remaining_items < 50
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.asyncio
 | 
				
			||||||
 | 
					async def test_auth_cache_with_real_oauth_data_structure():
 | 
				
			||||||
 | 
					    cache = AuthCache()
 | 
				
			||||||
 | 
					    oauth_token = {
 | 
				
			||||||
 | 
					        "access_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjEifQ...",
 | 
				
			||||||
 | 
					        "id_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjEifQ...",
 | 
				
			||||||
 | 
					        "token_type": "Bearer",
 | 
				
			||||||
 | 
					        "expires_in": 3600,
 | 
				
			||||||
 | 
					        "scope": "openid email profile groups",
 | 
				
			||||||
 | 
					        "userinfo": {
 | 
				
			||||||
 | 
					            "sub": "auth0|507f1f77bcf86cd799439011",
 | 
				
			||||||
 | 
					            "email": "john.doe@example.com",
 | 
				
			||||||
 | 
					            "email_verified": True,
 | 
				
			||||||
 | 
					            "name": "John Doe",
 | 
				
			||||||
 | 
					            "preferred_username": "johndoe",
 | 
				
			||||||
 | 
					            "groups": ["mealie_user", "staff"],
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    user_session_key = "oauth_session_auth0|507f1f77bcf86cd799439011"
 | 
				
			||||||
 | 
					    await cache.set(user_session_key, oauth_token, timeout=3600)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    retrieved = await cache.get(user_session_key)
 | 
				
			||||||
 | 
					    assert retrieved["access_token"] == oauth_token["access_token"]
 | 
				
			||||||
 | 
					    assert retrieved["userinfo"]["email"] == "john.doe@example.com"
 | 
				
			||||||
 | 
					    assert "mealie_user" in retrieved["userinfo"]["groups"]
 | 
				
			||||||
 | 
					    assert retrieved["userinfo"]["email_verified"] is True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    updated_token = oauth_token.copy()
 | 
				
			||||||
 | 
					    updated_token["access_token"] = "new_access_token_eyJhbGciOiJSUzI1NiIs..."
 | 
				
			||||||
 | 
					    updated_token["userinfo"]["last_login"] = "2024-01-01T12:00:00Z"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    await cache.set(user_session_key, updated_token, timeout=3600)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    updated_retrieved = await cache.get(user_session_key)
 | 
				
			||||||
 | 
					    assert updated_retrieved["access_token"] != oauth_token["access_token"]
 | 
				
			||||||
 | 
					    assert updated_retrieved["userinfo"]["last_login"] == "2024-01-01T12:00:00Z"
 | 
				
			||||||
		Reference in New Issue
	
	Block a user