mirror of
https://github.com/mealie-recipes/mealie.git
synced 2026-05-15 22:37:32 -04:00
Compare commits
14 Commits
mealie-nex
...
fix/preven
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b436ff0dbf | ||
|
|
891508a199 | ||
|
|
000fec4681 | ||
|
|
7de7fc3177 | ||
|
|
3e172dccef | ||
|
|
3e2a60ad14 | ||
|
|
140bd75412 | ||
|
|
b2497295c3 | ||
|
|
37fd0c8510 | ||
|
|
f4dced3623 | ||
|
|
85695c2529 | ||
|
|
7eb8836c14 | ||
|
|
b0acd415af | ||
|
|
81ec849cc6 |
@@ -1,5 +1,6 @@
|
|||||||
import string
|
import string
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import Annotated, ClassVar, get_origin
|
||||||
|
|
||||||
from sqlalchemy import Integer
|
from sqlalchemy import Integer
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, declared_attr, mapped_column, synonym
|
from sqlalchemy.orm import DeclarativeBase, Mapped, declared_attr, mapped_column, synonym
|
||||||
@@ -14,7 +15,28 @@ NORMALIZE_PUNCTUATION = string.punctuation.replace("'", "").replace('"', "")
|
|||||||
_NORMALIZE_PUNCTUATION_TABLE = str.maketrans(NORMALIZE_PUNCTUATION, " " * len(NORMALIZE_PUNCTUATION))
|
_NORMALIZE_PUNCTUATION_TABLE = str.maketrans(NORMALIZE_PUNCTUATION, " " * len(NORMALIZE_PUNCTUATION))
|
||||||
|
|
||||||
|
|
||||||
|
class PrivateColumn[T]:
|
||||||
|
"""
|
||||||
|
Drop-in replacement for `Mapped[]` that marks a column as private.
|
||||||
|
Private columns cannot be used in query filter expressions.
|
||||||
|
|
||||||
|
Only valid on scalar column fields. Using it on a relationship type (e.g. `list[Model]`)
|
||||||
|
will raise a `TypeError` at class definition time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __class_getitem__(cls, item: type) -> type:
|
||||||
|
if get_origin(item) is list or item is list:
|
||||||
|
raise TypeError(
|
||||||
|
f"PrivateColumn cannot be used on relationship fields (got {item!r}). "
|
||||||
|
"Annotate the related model's scalar column directly instead."
|
||||||
|
)
|
||||||
|
return Mapped[Annotated[item, mapped_column(info={"private": True})]]
|
||||||
|
|
||||||
|
|
||||||
class SqlAlchemyBase(DeclarativeBase):
|
class SqlAlchemyBase(DeclarativeBase):
|
||||||
|
__filter_restricted__: ClassVar[bool] = False
|
||||||
|
"""When True, the query filter API will block traversal into this model unless explicitly allowed."""
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||||
created_at: Mapped[datetime | None] = mapped_column(NaiveDateTime, default=get_utc_now, index=True)
|
created_at: Mapped[datetime | None] = mapped_column(NaiveDateTime, default=get_utc_now, index=True)
|
||||||
update_at: Mapped[datetime | None] = mapped_column(NaiveDateTime, default=get_utc_now, onupdate=get_utc_now)
|
update_at: Mapped[datetime | None] = mapped_column(NaiveDateTime, default=get_utc_now, onupdate=get_utc_now)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
from sqlalchemy import ForeignKey, Integer, String, orm
|
from sqlalchemy import ForeignKey, Integer, String, orm
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from .._model_base import BaseMixins, SqlAlchemyBase
|
from .._model_base import BaseMixins, PrivateColumn, SqlAlchemyBase
|
||||||
from .._model_utils import guid
|
from .._model_utils import guid
|
||||||
from .._model_utils.auto_init import auto_init
|
from .._model_utils.auto_init import auto_init
|
||||||
|
|
||||||
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
class GroupInviteToken(SqlAlchemyBase, BaseMixins):
|
class GroupInviteToken(SqlAlchemyBase, BaseMixins):
|
||||||
__tablename__ = "invite_tokens"
|
__tablename__ = "invite_tokens"
|
||||||
token: Mapped[str] = mapped_column(String, index=True, nullable=False, unique=True)
|
token: PrivateColumn[str] = mapped_column(String, index=True, nullable=False, unique=True)
|
||||||
uses_left: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
uses_left: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
||||||
|
|
||||||
group_id: Mapped[guid.GUID | None] = mapped_column(guid.GUID, ForeignKey("groups.id"), index=True)
|
group_id: Mapped[guid.GUID | None] = mapped_column(guid.GUID, ForeignKey("groups.id"), index=True)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
|
|||||||
from sqlalchemy import ForeignKey, String, orm
|
from sqlalchemy import ForeignKey, String, orm
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from .._model_base import BaseMixins, SqlAlchemyBase
|
from .._model_base import BaseMixins, PrivateColumn, SqlAlchemyBase
|
||||||
from .._model_utils.guid import GUID
|
from .._model_utils.guid import GUID
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -15,7 +15,7 @@ class PasswordResetModel(SqlAlchemyBase, BaseMixins):
|
|||||||
|
|
||||||
user_id: Mapped[GUID] = mapped_column(GUID, ForeignKey("users.id"), nullable=False, index=True)
|
user_id: Mapped[GUID] = mapped_column(GUID, ForeignKey("users.id"), nullable=False, index=True)
|
||||||
user: Mapped["User"] = orm.relationship("User", back_populates="password_reset_tokens", uselist=False)
|
user: Mapped["User"] = orm.relationship("User", back_populates="password_reset_tokens", uselist=False)
|
||||||
token: Mapped[str] = mapped_column(String(64), unique=True, nullable=False)
|
token: PrivateColumn[str] = mapped_column(String(64), unique=True, nullable=False)
|
||||||
|
|
||||||
def __init__(self, user_id, token, **_):
|
def __init__(self, user_id, token, **_):
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from mealie.db.models._model_utils.auto_init import auto_init
|
|||||||
from mealie.db.models._model_utils.datetime import NaiveDateTime
|
from mealie.db.models._model_utils.datetime import NaiveDateTime
|
||||||
from mealie.db.models._model_utils.guid import GUID
|
from mealie.db.models._model_utils.guid import GUID
|
||||||
|
|
||||||
from .._model_base import BaseMixins, SqlAlchemyBase
|
from .._model_base import BaseMixins, PrivateColumn, SqlAlchemyBase
|
||||||
from .user_to_recipe import UserToRecipe
|
from .user_to_recipe import UserToRecipe
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
|||||||
class LongLiveToken(SqlAlchemyBase, BaseMixins):
|
class LongLiveToken(SqlAlchemyBase, BaseMixins):
|
||||||
__tablename__ = "long_live_tokens"
|
__tablename__ = "long_live_tokens"
|
||||||
name: Mapped[str] = mapped_column(String, nullable=False)
|
name: Mapped[str] = mapped_column(String, nullable=False)
|
||||||
token: Mapped[str] = mapped_column(String, nullable=False, index=True)
|
token: PrivateColumn[str] = mapped_column(String, nullable=False, index=True)
|
||||||
|
|
||||||
user_id: Mapped[GUID | None] = mapped_column(GUID, ForeignKey("users.id"), index=True)
|
user_id: Mapped[GUID | None] = mapped_column(GUID, ForeignKey("users.id"), index=True)
|
||||||
user: Mapped[Optional["User"]] = orm.relationship("User")
|
user: Mapped[Optional["User"]] = orm.relationship("User")
|
||||||
@@ -50,11 +50,13 @@ class AuthMethod(enum.Enum):
|
|||||||
|
|
||||||
class User(SqlAlchemyBase, BaseMixins):
|
class User(SqlAlchemyBase, BaseMixins):
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
__filter_restricted__ = True
|
||||||
|
|
||||||
id: Mapped[GUID] = mapped_column(GUID, primary_key=True, default=GUID.generate)
|
id: Mapped[GUID] = mapped_column(GUID, primary_key=True, default=GUID.generate)
|
||||||
full_name: Mapped[str | None] = mapped_column(String, index=True)
|
full_name: Mapped[str | None] = mapped_column(String, index=True)
|
||||||
username: Mapped[str | None] = mapped_column(String, index=True, unique=True)
|
username: Mapped[str | None] = mapped_column(String, index=True, unique=True)
|
||||||
email: Mapped[str | None] = mapped_column(String, unique=True, index=True)
|
email: Mapped[str | None] = mapped_column(String, unique=True, index=True)
|
||||||
password: Mapped[str | None] = mapped_column(String)
|
password: PrivateColumn[str | None] = mapped_column(String)
|
||||||
auth_method: Mapped[Enum[AuthMethod]] = mapped_column(Enum(AuthMethod), default=AuthMethod.MEALIE)
|
auth_method: Mapped[Enum[AuthMethod]] = mapped_column(Enum(AuthMethod), default=AuthMethod.MEALIE)
|
||||||
admin: Mapped[bool | None] = mapped_column(Boolean, default=False)
|
admin: Mapped[bool | None] = mapped_column(Boolean, default=False)
|
||||||
advanced: Mapped[bool | None] = mapped_column(Boolean, default=False)
|
advanced: Mapped[bool | None] = mapped_column(Boolean, default=False)
|
||||||
|
|||||||
@@ -27,6 +27,12 @@ from mealie.schema.household.household import HouseholdInDB
|
|||||||
from mealie.schema.user.user import GroupInDB, PrivateUser
|
from mealie.schema.user.user import GroupInDB, PrivateUser
|
||||||
from mealie.services.event_bus_service.event_bus_service import EventBusService
|
from mealie.services.event_bus_service.event_bus_service import EventBusService
|
||||||
from mealie.services.event_bus_service.event_types import EventDocumentDataBase, EventTypes
|
from mealie.services.event_bus_service.event_types import EventDocumentDataBase, EventTypes
|
||||||
|
from mealie.services.query_filter.context import allow_filter_restricted
|
||||||
|
|
||||||
|
|
||||||
|
def _set_no_restricted_filter() -> None:
|
||||||
|
"""FastAPI dependency that disables restricted model traversal for the current request."""
|
||||||
|
allow_filter_restricted.set(False)
|
||||||
|
|
||||||
|
|
||||||
class _BaseController(ABC): # noqa: B024
|
class _BaseController(ABC): # noqa: B024
|
||||||
@@ -94,6 +100,7 @@ class BasePublicGroupExploreController(BasePublicController):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
group: GroupInDB = Depends(get_public_group)
|
group: GroupInDB = Depends(get_public_group)
|
||||||
|
_no_restricted_filter: None = Depends(_set_no_restricted_filter)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def group_id(self) -> UUID4 | None | NotSet:
|
def group_id(self) -> UUID4 | None | NotSet:
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from mealie.db.models._model_utils.datetime import NaiveDateTime
|
|||||||
from mealie.db.models._model_utils.guid import GUID
|
from mealie.db.models._model_utils.guid import GUID
|
||||||
from mealie.schema._mealie.mealie_model import MealieModel
|
from mealie.schema._mealie.mealie_model import MealieModel
|
||||||
|
|
||||||
|
from .context import allow_filter_restricted
|
||||||
from .keywords import PlaceholderKeyword, RelationalKeyword
|
from .keywords import PlaceholderKeyword, RelationalKeyword
|
||||||
from .operators import LogicalOperator, RelationalOperator
|
from .operators import LogicalOperator, RelationalOperator
|
||||||
|
|
||||||
@@ -199,10 +200,18 @@ class QueryFilterBuilder:
|
|||||||
if i == len(group) - 1:
|
if i == len(group) - 1:
|
||||||
return consolidated_group_builder.self_group()
|
return consolidated_group_builder.self_group()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_model_attr(cls, model: type[SqlAlchemyBase], attr_name: str) -> InstrumentedAttribute:
|
||||||
|
model_attr: InstrumentedAttribute = getattr(model, attr_name)
|
||||||
|
if getattr(model_attr, "info", {}).get("private"):
|
||||||
|
raise ValueError(f"cannot filter on private field '{model.__name__}.{attr_name}'")
|
||||||
|
|
||||||
|
return model_attr
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_model_and_model_attr_from_attr_string[Model: SqlAlchemyBase](
|
def get_model_and_model_attr_from_attr_string[Model: SqlAlchemyBase](
|
||||||
cls, attr_string: str, model: type[Model], *, query: sa.Select | None = None
|
cls, attr_string: str, model: type[Model], *, query: sa.Select | None = None
|
||||||
) -> tuple[SqlAlchemyBase, InstrumentedAttribute, sa.Select | None]:
|
) -> tuple[type[SqlAlchemyBase], InstrumentedAttribute, sa.Select | None]:
|
||||||
"""
|
"""
|
||||||
Take an attribute string and traverse a database model and its relationships to get the desired
|
Take an attribute string and traverse a database model and its relationships to get the desired
|
||||||
model and model attribute. Optionally provide a query to apply the necessary table joins.
|
model and model attribute. Optionally provide a query to apply the necessary table joins.
|
||||||
@@ -222,17 +231,18 @@ class QueryFilterBuilder:
|
|||||||
if not attribute_chain:
|
if not attribute_chain:
|
||||||
raise ValueError("invalid query string: attribute name cannot be empty")
|
raise ValueError("invalid query string: attribute name cannot be empty")
|
||||||
|
|
||||||
current_model: SqlAlchemyBase = model # type: ignore
|
current_model: type[SqlAlchemyBase] = model
|
||||||
|
allow_restricted = allow_filter_restricted.get()
|
||||||
for i, attribute_link in enumerate(attribute_chain):
|
for i, attribute_link in enumerate(attribute_chain):
|
||||||
try:
|
try:
|
||||||
model_attr = getattr(current_model, attribute_link)
|
model_attr = cls._get_model_attr(current_model, attribute_link)
|
||||||
|
|
||||||
# proxied attributes can't be joined to the query directly, so we need to inspect the proxy
|
# proxied attributes can't be joined to the query directly, so we need to inspect the proxy
|
||||||
# and get the actual model and its attribute
|
# and get the actual model and its attribute
|
||||||
if isinstance(model_attr, AssociationProxyInstance):
|
if isinstance(model_attr, AssociationProxyInstance):
|
||||||
proxied_attribute_link = model_attr.target_collection
|
proxied_attribute_link = model_attr.target_collection
|
||||||
next_attribute_link = model_attr.value_attr
|
next_attribute_link = model_attr.value_attr
|
||||||
model_attr = getattr(current_model, proxied_attribute_link)
|
model_attr = cls._get_model_attr(current_model, proxied_attribute_link)
|
||||||
|
|
||||||
if query is not None:
|
if query is not None:
|
||||||
query = query.join(model_attr, isouter=True)
|
query = query.join(model_attr, isouter=True)
|
||||||
@@ -240,7 +250,10 @@ class QueryFilterBuilder:
|
|||||||
mapper = sa.inspect(current_model)
|
mapper = sa.inspect(current_model)
|
||||||
relationship = mapper.relationships[proxied_attribute_link]
|
relationship = mapper.relationships[proxied_attribute_link]
|
||||||
current_model = relationship.mapper.class_
|
current_model = relationship.mapper.class_
|
||||||
model_attr = getattr(current_model, next_attribute_link)
|
|
||||||
|
# Association proxies are intentional field exposures defined on the source model,
|
||||||
|
# so we do not apply the __filter_restricted__ check here.
|
||||||
|
model_attr = cls._get_model_attr(current_model, next_attribute_link)
|
||||||
|
|
||||||
# at the end of the chain there are no more relationships to inspect
|
# at the end of the chain there are no more relationships to inspect
|
||||||
if i == len(attribute_chain) - 1:
|
if i == len(attribute_chain) - 1:
|
||||||
@@ -252,6 +265,8 @@ class QueryFilterBuilder:
|
|||||||
mapper = sa.inspect(current_model)
|
mapper = sa.inspect(current_model)
|
||||||
relationship = mapper.relationships[attribute_link]
|
relationship = mapper.relationships[attribute_link]
|
||||||
current_model = relationship.mapper.class_
|
current_model = relationship.mapper.class_
|
||||||
|
if not allow_restricted and current_model.__filter_restricted__:
|
||||||
|
raise ValueError(f"cannot traverse into restricted model '{current_model.__name__}'")
|
||||||
|
|
||||||
except (AttributeError, KeyError) as e:
|
except (AttributeError, KeyError) as e:
|
||||||
raise ValueError(f"invalid attribute string: '{attr_string}' does not exist on this schema") from e
|
raise ValueError(f"invalid attribute string: '{attr_string}' does not exist on this schema") from e
|
||||||
@@ -299,7 +314,9 @@ class QueryFilterBuilder:
|
|||||||
if len(value) == 1:
|
if len(value) == 1:
|
||||||
element = model_attr.in_(value)
|
element = model_attr.in_(value)
|
||||||
else:
|
else:
|
||||||
primary_model_attr: InstrumentedAttribute = getattr(model, component.attribute_name.split(".")[0])
|
primary_model_attr: InstrumentedAttribute = cls._get_model_attr(
|
||||||
|
model, component.attribute_name.split(".")[0]
|
||||||
|
)
|
||||||
element = sa.and_(*(primary_model_attr.any(model_attr == v) for v in value))
|
element = sa.and_(*(primary_model_attr.any(model_attr == v) for v in value))
|
||||||
elif component.relationship is RelationalKeyword.LIKE:
|
elif component.relationship is RelationalKeyword.LIKE:
|
||||||
element = model_attr.ilike(value)
|
element = model_attr.ilike(value)
|
||||||
@@ -368,7 +385,7 @@ class QueryFilterBuilder:
|
|||||||
else:
|
else:
|
||||||
component = cast(QueryFilterBuilderComponent, component)
|
component = cast(QueryFilterBuilderComponent, component)
|
||||||
base_attribute_name = component.attribute_name.split(".")[-1]
|
base_attribute_name = component.attribute_name.split(".")[-1]
|
||||||
model_attr = getattr(attr_model_map[i], base_attribute_name)
|
model_attr = self._get_model_attr(attr_model_map[i], base_attribute_name)
|
||||||
|
|
||||||
if (column_alias := column_aliases.get(base_attribute_name)) is not None:
|
if (column_alias := column_aliases.get(base_attribute_name)) is not None:
|
||||||
model_attr = column_alias
|
model_attr = column_alias
|
||||||
|
|||||||
3
mealie/services/query_filter/context.py
Normal file
3
mealie/services/query_filter/context.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
allow_filter_restricted: ContextVar[bool] = ContextVar("allow_filter_restricted", default=True)
|
||||||
@@ -1,3 +1,9 @@
|
|||||||
|
import pytest
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from mealie.db.models._model_base import PrivateColumn
|
||||||
|
from mealie.db.models.recipe.recipe import RecipeModel
|
||||||
|
from mealie.db.models.users.users import LongLiveToken, User
|
||||||
from mealie.services.query_filter.builder import (
|
from mealie.services.query_filter.builder import (
|
||||||
LogicalOperator,
|
LogicalOperator,
|
||||||
QueryFilterBuilder,
|
QueryFilterBuilder,
|
||||||
@@ -6,6 +12,7 @@ from mealie.services.query_filter.builder import (
|
|||||||
RelationalKeyword,
|
RelationalKeyword,
|
||||||
RelationalOperator,
|
RelationalOperator,
|
||||||
)
|
)
|
||||||
|
from mealie.services.query_filter.context import allow_filter_restricted
|
||||||
|
|
||||||
|
|
||||||
def test_query_filter_builder_json():
|
def test_query_filter_builder_json():
|
||||||
@@ -74,3 +81,118 @@ def test_query_filter_builder_json_uses_raw_value():
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PrivateColumn tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_private_column_rejects_list_type():
|
||||||
|
"""PrivateColumn[list[X]] must raise TypeError at definition time to prevent misuse on relationships."""
|
||||||
|
with pytest.raises(TypeError, match="relationship"):
|
||||||
|
PrivateColumn[list[User]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_private_field_user_password_raises():
|
||||||
|
"""Filtering on User.password (PrivateColumn) should raise ValueError."""
|
||||||
|
with pytest.raises(ValueError, match="private field"):
|
||||||
|
QueryFilterBuilder.get_model_and_model_attr_from_attr_string("password", User)
|
||||||
|
|
||||||
|
|
||||||
|
def test_private_field_long_live_token_raises():
|
||||||
|
"""Filtering on LongLiveToken.token (PrivateColumn) should raise ValueError."""
|
||||||
|
with pytest.raises(ValueError, match="private field"):
|
||||||
|
QueryFilterBuilder.get_model_and_model_attr_from_attr_string("token", LongLiveToken)
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_private_field_does_not_raise():
|
||||||
|
"""Filtering on a normal field should not raise."""
|
||||||
|
model, attr, _ = QueryFilterBuilder.get_model_and_model_attr_from_attr_string("full_name", User)
|
||||||
|
assert model is User
|
||||||
|
assert attr is User.full_name
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# __filter_restricted__ traversal tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_restricted_traversal_blocked_when_disallowed():
|
||||||
|
"""Traversing into User (restricted) via RecipeModel.user should raise when the ContextVar is False."""
|
||||||
|
allow_filter_restricted.set(False)
|
||||||
|
try:
|
||||||
|
with pytest.raises(ValueError, match="restricted model"):
|
||||||
|
QueryFilterBuilder.get_model_and_model_attr_from_attr_string("user.email", RecipeModel)
|
||||||
|
finally:
|
||||||
|
allow_filter_restricted.set(True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_association_proxy_through_restricted_model_allowed():
|
||||||
|
"""Association proxies (e.g. household_id) traverse through User but are intentional
|
||||||
|
exposures on the source model and must NOT be blocked even when the ContextVar is False."""
|
||||||
|
allow_filter_restricted.set(False)
|
||||||
|
try:
|
||||||
|
model, attr, _ = QueryFilterBuilder.get_model_and_model_attr_from_attr_string("household_id", RecipeModel)
|
||||||
|
assert model is User
|
||||||
|
finally:
|
||||||
|
allow_filter_restricted.set(True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_restricted_traversal_allowed_by_default():
|
||||||
|
"""Traversing into User via RecipeModel.user should succeed when the ContextVar is True (default)."""
|
||||||
|
model, attr, _ = QueryFilterBuilder.get_model_and_model_attr_from_attr_string("user.email", RecipeModel)
|
||||||
|
assert model is User
|
||||||
|
assert attr is User.email
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ContextVar tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_allow_filter_restricted_default_is_true():
|
||||||
|
"""The ContextVar default must be True so authenticated requests are unrestricted."""
|
||||||
|
assert allow_filter_restricted.get() is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_query_respects_context_var_false(monkeypatch):
|
||||||
|
"""filter_query should block restricted traversal when the ContextVar is False."""
|
||||||
|
allow_filter_restricted.set(False)
|
||||||
|
try:
|
||||||
|
query = sa.select(RecipeModel)
|
||||||
|
builder = QueryFilterBuilder("user.email = 'test@example.com'")
|
||||||
|
with pytest.raises(ValueError, match="restricted model"):
|
||||||
|
builder.filter_query(query, RecipeModel)
|
||||||
|
finally:
|
||||||
|
allow_filter_restricted.set(True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_query_respects_context_var_true():
|
||||||
|
"""filter_query should allow restricted traversal when the ContextVar is True (default)."""
|
||||||
|
allow_filter_restricted.set(True)
|
||||||
|
query = sa.select(RecipeModel)
|
||||||
|
builder = QueryFilterBuilder("user.email = 'test@example.com'")
|
||||||
|
# Should not raise
|
||||||
|
builder.filter_query(query, RecipeModel)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# orderBy restricted traversal tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_order_by_restricted_traversal_blocked():
|
||||||
|
"""orderBy into a restricted model is blocked when the ContextVar is False."""
|
||||||
|
allow_filter_restricted.set(False)
|
||||||
|
try:
|
||||||
|
with pytest.raises(ValueError, match="restricted model"):
|
||||||
|
QueryFilterBuilder.get_model_and_model_attr_from_attr_string("user.email", RecipeModel)
|
||||||
|
finally:
|
||||||
|
allow_filter_restricted.set(True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_order_by_private_field_blocked():
|
||||||
|
"""Ordering by a PrivateColumn field should always raise regardless of the ContextVar."""
|
||||||
|
with pytest.raises(ValueError, match="private field"):
|
||||||
|
QueryFilterBuilder.get_model_and_model_attr_from_attr_string("password", User)
|
||||||
|
|||||||
Reference in New Issue
Block a user