simplify contextvar handling

This commit is contained in:
Michael Genson
2026-05-14 19:33:57 +00:00
parent 7de7fc3177
commit 000fec4681
3 changed files with 28 additions and 18 deletions

View File

@@ -26,7 +26,6 @@ from mealie.schema.response.pagination import (
) )
from mealie.schema.response.query_search import SearchFilter from mealie.schema.response.query_search import SearchFilter
from mealie.services.query_filter.builder import QueryFilterBuilder from mealie.services.query_filter.builder import QueryFilterBuilder
from mealie.services.query_filter.context import allow_filter_restricted
from ._utils import NOT_SET, NotSet from ._utils import NOT_SET, NotSet
@@ -461,7 +460,7 @@ class RepositoryGeneric[Schema: MealieModel, Model: SqlAlchemyBase]:
order_dir = request_query.order_direction order_dir = request_query.order_direction
_, order_attr, query = QueryFilterBuilder.get_model_and_model_attr_from_attr_string( _, order_attr, query = QueryFilterBuilder.get_model_and_model_attr_from_attr_string(
order_by, self.model, query=query, allow_restricted=allow_filter_restricted.get() order_by, self.model, query=query
) )
query = self.add_order_attr_to_query( query = self.add_order_attr_to_query(

View File

@@ -210,7 +210,7 @@ class QueryFilterBuilder:
@classmethod @classmethod
def get_model_and_model_attr_from_attr_string[Model: SqlAlchemyBase]( def get_model_and_model_attr_from_attr_string[Model: SqlAlchemyBase](
cls, attr_string: str, model: type[Model], *, query: sa.Select | None = None, allow_restricted: bool = True cls, attr_string: str, model: type[Model], *, query: sa.Select | None = None
) -> tuple[type[SqlAlchemyBase], InstrumentedAttribute, sa.Select | None]: ) -> tuple[type[SqlAlchemyBase], InstrumentedAttribute, sa.Select | None]:
""" """
Take an attribute string and traverse a database model and its relationships to get the desired Take an attribute string and traverse a database model and its relationships to get the desired
@@ -232,6 +232,7 @@ class QueryFilterBuilder:
raise ValueError("invalid query string: attribute name cannot be empty") raise ValueError("invalid query string: attribute name cannot be empty")
current_model: type[SqlAlchemyBase] = model current_model: type[SqlAlchemyBase] = model
allow_restricted = allow_filter_restricted.get()
for i, attribute_link in enumerate(attribute_chain): for i, attribute_link in enumerate(attribute_chain):
try: try:
model_attr = cls._get_model_attr(current_model, attribute_link) model_attr = cls._get_model_attr(current_model, attribute_link)
@@ -357,7 +358,7 @@ class QueryFilterBuilder:
continue continue
nested_model, model_attr, query = self.get_model_and_model_attr_from_attr_string( nested_model, model_attr, query = self.get_model_and_model_attr_from_attr_string(
component.attribute_name, model, query=query, allow_restricted=allow_filter_restricted.get() component.attribute_name, model, query=query
) )
attr_model_map[i] = nested_model attr_model_map[i] = nested_model

View File

@@ -112,22 +112,28 @@ def test_non_private_field_does_not_raise():
def test_restricted_traversal_blocked_when_disallowed(): def test_restricted_traversal_blocked_when_disallowed():
"""Traversing into User (restricted) via RecipeModel.user should raise when allow_restricted=False.""" """Traversing into User (restricted) via RecipeModel.user should raise when the ContextVar is False."""
with pytest.raises(ValueError, match="restricted model"): allow_filter_restricted.set(False)
QueryFilterBuilder.get_model_and_model_attr_from_attr_string("user.email", RecipeModel, allow_restricted=False) try:
with pytest.raises(ValueError, match="restricted model"):
QueryFilterBuilder.get_model_and_model_attr_from_attr_string("user.email", RecipeModel)
finally:
allow_filter_restricted.set(True)
def test_association_proxy_through_restricted_model_allowed(): def test_association_proxy_through_restricted_model_allowed():
"""Association proxies (e.g. household_id) traverse through User but are intentional """Association proxies (e.g. household_id) traverse through User but are intentional
exposures on the source model and must NOT be blocked even when allow_restricted=False.""" exposures on the source model and must NOT be blocked even when the ContextVar is False."""
model, attr, _ = QueryFilterBuilder.get_model_and_model_attr_from_attr_string( allow_filter_restricted.set(False)
"household_id", RecipeModel, allow_restricted=False try:
) model, attr, _ = QueryFilterBuilder.get_model_and_model_attr_from_attr_string("household_id", RecipeModel)
assert model is User assert model is User
finally:
allow_filter_restricted.set(True)
def test_restricted_traversal_allowed_by_default(): def test_restricted_traversal_allowed_by_default():
"""Traversing into User via RecipeModel.user should succeed when allow_restricted=True (default).""" """Traversing into User via RecipeModel.user should succeed when the ContextVar is True (default)."""
model, attr, _ = QueryFilterBuilder.get_model_and_model_attr_from_attr_string("user.email", RecipeModel) model, attr, _ = QueryFilterBuilder.get_model_and_model_attr_from_attr_string("user.email", RecipeModel)
assert model is User assert model is User
assert attr is User.email assert attr is User.email
@@ -170,12 +176,16 @@ def test_filter_query_respects_context_var_true():
def test_order_by_restricted_traversal_blocked(): def test_order_by_restricted_traversal_blocked():
"""get_model_and_model_attr_from_attr_string with allow_restricted=False blocks orderBy into User.""" """orderBy into a restricted model is blocked when the ContextVar is False."""
with pytest.raises(ValueError, match="restricted model"): allow_filter_restricted.set(False)
QueryFilterBuilder.get_model_and_model_attr_from_attr_string("user.email", RecipeModel, allow_restricted=False) try:
with pytest.raises(ValueError, match="restricted model"):
QueryFilterBuilder.get_model_and_model_attr_from_attr_string("user.email", RecipeModel)
finally:
allow_filter_restricted.set(True)
def test_order_by_private_field_blocked(): def test_order_by_private_field_blocked():
"""Ordering by a PrivateColumn field should always raise, regardless of allow_restricted.""" """Ordering by a PrivateColumn field should always raise regardless of the ContextVar."""
with pytest.raises(ValueError, match="private field"): with pytest.raises(ValueError, match="private field"):
QueryFilterBuilder.get_model_and_model_attr_from_attr_string("password", User, allow_restricted=True) QueryFilterBuilder.get_model_and_model_attr_from_attr_string("password", User)