mirror of
https://github.com/mealie-recipes/mealie.git
synced 2026-05-15 14:27:31 -04:00
Compare commits
14 Commits
mealie-nex
...
fix/preven
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b436ff0dbf | ||
|
|
891508a199 | ||
|
|
000fec4681 | ||
|
|
7de7fc3177 | ||
|
|
3e172dccef | ||
|
|
3e2a60ad14 | ||
|
|
140bd75412 | ||
|
|
b2497295c3 | ||
|
|
37fd0c8510 | ||
|
|
f4dced3623 | ||
|
|
85695c2529 | ||
|
|
7eb8836c14 | ||
|
|
b0acd415af | ||
|
|
81ec849cc6 |
@@ -1,5 +1,6 @@
|
||||
import string
|
||||
from datetime import datetime
|
||||
from typing import Annotated, ClassVar, get_origin
|
||||
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, declared_attr, mapped_column, synonym
|
||||
@@ -14,7 +15,28 @@ NORMALIZE_PUNCTUATION = string.punctuation.replace("'", "").replace('"', "")
|
||||
_NORMALIZE_PUNCTUATION_TABLE = str.maketrans(NORMALIZE_PUNCTUATION, " " * len(NORMALIZE_PUNCTUATION))
|
||||
|
||||
|
||||
class PrivateColumn[T]:
|
||||
"""
|
||||
Drop-in replacement for `Mapped[]` that marks a column as private.
|
||||
Private columns cannot be used in query filter expressions.
|
||||
|
||||
Only valid on scalar column fields. Using it on a relationship type (e.g. `list[Model]`)
|
||||
will raise a `TypeError` at class definition time.
|
||||
"""
|
||||
|
||||
def __class_getitem__(cls, item: type) -> type:
|
||||
if get_origin(item) is list or item is list:
|
||||
raise TypeError(
|
||||
f"PrivateColumn cannot be used on relationship fields (got {item!r}). "
|
||||
"Annotate the related model's scalar column directly instead."
|
||||
)
|
||||
return Mapped[Annotated[item, mapped_column(info={"private": True})]]
|
||||
|
||||
|
||||
class SqlAlchemyBase(DeclarativeBase):
|
||||
__filter_restricted__: ClassVar[bool] = False
|
||||
"""When True, the query filter API will block traversal into this model unless explicitly allowed."""
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
created_at: Mapped[datetime | None] = mapped_column(NaiveDateTime, default=get_utc_now, index=True)
|
||||
update_at: Mapped[datetime | None] = mapped_column(NaiveDateTime, default=get_utc_now, onupdate=get_utc_now)
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
from sqlalchemy import ForeignKey, Integer, String, orm
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .._model_base import BaseMixins, SqlAlchemyBase
|
||||
from .._model_base import BaseMixins, PrivateColumn, SqlAlchemyBase
|
||||
from .._model_utils import guid
|
||||
from .._model_utils.auto_init import auto_init
|
||||
|
||||
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
||||
|
||||
class GroupInviteToken(SqlAlchemyBase, BaseMixins):
|
||||
__tablename__ = "invite_tokens"
|
||||
token: Mapped[str] = mapped_column(String, index=True, nullable=False, unique=True)
|
||||
token: PrivateColumn[str] = mapped_column(String, index=True, nullable=False, unique=True)
|
||||
uses_left: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
||||
|
||||
group_id: Mapped[guid.GUID | None] = mapped_column(guid.GUID, ForeignKey("groups.id"), index=True)
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
|
||||
from sqlalchemy import ForeignKey, String, orm
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .._model_base import BaseMixins, SqlAlchemyBase
|
||||
from .._model_base import BaseMixins, PrivateColumn, SqlAlchemyBase
|
||||
from .._model_utils.guid import GUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -15,7 +15,7 @@ class PasswordResetModel(SqlAlchemyBase, BaseMixins):
|
||||
|
||||
user_id: Mapped[GUID] = mapped_column(GUID, ForeignKey("users.id"), nullable=False, index=True)
|
||||
user: Mapped["User"] = orm.relationship("User", back_populates="password_reset_tokens", uselist=False)
|
||||
token: Mapped[str] = mapped_column(String(64), unique=True, nullable=False)
|
||||
token: PrivateColumn[str] = mapped_column(String(64), unique=True, nullable=False)
|
||||
|
||||
def __init__(self, user_id, token, **_):
|
||||
self.user_id = user_id
|
||||
|
||||
@@ -13,7 +13,7 @@ from mealie.db.models._model_utils.auto_init import auto_init
|
||||
from mealie.db.models._model_utils.datetime import NaiveDateTime
|
||||
from mealie.db.models._model_utils.guid import GUID
|
||||
|
||||
from .._model_base import BaseMixins, SqlAlchemyBase
|
||||
from .._model_base import BaseMixins, PrivateColumn, SqlAlchemyBase
|
||||
from .user_to_recipe import UserToRecipe
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
||||
class LongLiveToken(SqlAlchemyBase, BaseMixins):
|
||||
__tablename__ = "long_live_tokens"
|
||||
name: Mapped[str] = mapped_column(String, nullable=False)
|
||||
token: Mapped[str] = mapped_column(String, nullable=False, index=True)
|
||||
token: PrivateColumn[str] = mapped_column(String, nullable=False, index=True)
|
||||
|
||||
user_id: Mapped[GUID | None] = mapped_column(GUID, ForeignKey("users.id"), index=True)
|
||||
user: Mapped[Optional["User"]] = orm.relationship("User")
|
||||
@@ -50,11 +50,13 @@ class AuthMethod(enum.Enum):
|
||||
|
||||
class User(SqlAlchemyBase, BaseMixins):
|
||||
__tablename__ = "users"
|
||||
__filter_restricted__ = True
|
||||
|
||||
id: Mapped[GUID] = mapped_column(GUID, primary_key=True, default=GUID.generate)
|
||||
full_name: Mapped[str | None] = mapped_column(String, index=True)
|
||||
username: Mapped[str | None] = mapped_column(String, index=True, unique=True)
|
||||
email: Mapped[str | None] = mapped_column(String, unique=True, index=True)
|
||||
password: Mapped[str | None] = mapped_column(String)
|
||||
password: PrivateColumn[str | None] = mapped_column(String)
|
||||
auth_method: Mapped[Enum[AuthMethod]] = mapped_column(Enum(AuthMethod), default=AuthMethod.MEALIE)
|
||||
admin: Mapped[bool | None] = mapped_column(Boolean, default=False)
|
||||
advanced: Mapped[bool | None] = mapped_column(Boolean, default=False)
|
||||
|
||||
@@ -27,6 +27,12 @@ from mealie.schema.household.household import HouseholdInDB
|
||||
from mealie.schema.user.user import GroupInDB, PrivateUser
|
||||
from mealie.services.event_bus_service.event_bus_service import EventBusService
|
||||
from mealie.services.event_bus_service.event_types import EventDocumentDataBase, EventTypes
|
||||
from mealie.services.query_filter.context import allow_filter_restricted
|
||||
|
||||
|
||||
def _set_no_restricted_filter() -> None:
|
||||
"""FastAPI dependency that disables restricted model traversal for the current request."""
|
||||
allow_filter_restricted.set(False)
|
||||
|
||||
|
||||
class _BaseController(ABC): # noqa: B024
|
||||
@@ -94,6 +100,7 @@ class BasePublicGroupExploreController(BasePublicController):
|
||||
"""
|
||||
|
||||
group: GroupInDB = Depends(get_public_group)
|
||||
_no_restricted_filter: None = Depends(_set_no_restricted_filter)
|
||||
|
||||
@property
|
||||
def group_id(self) -> UUID4 | None | NotSet:
|
||||
|
||||
@@ -18,6 +18,7 @@ from mealie.db.models._model_utils.datetime import NaiveDateTime
|
||||
from mealie.db.models._model_utils.guid import GUID
|
||||
from mealie.schema._mealie.mealie_model import MealieModel
|
||||
|
||||
from .context import allow_filter_restricted
|
||||
from .keywords import PlaceholderKeyword, RelationalKeyword
|
||||
from .operators import LogicalOperator, RelationalOperator
|
||||
|
||||
@@ -199,10 +200,18 @@ class QueryFilterBuilder:
|
||||
if i == len(group) - 1:
|
||||
return consolidated_group_builder.self_group()
|
||||
|
||||
@classmethod
|
||||
def _get_model_attr(cls, model: type[SqlAlchemyBase], attr_name: str) -> InstrumentedAttribute:
|
||||
model_attr: InstrumentedAttribute = getattr(model, attr_name)
|
||||
if getattr(model_attr, "info", {}).get("private"):
|
||||
raise ValueError(f"cannot filter on private field '{model.__name__}.{attr_name}'")
|
||||
|
||||
return model_attr
|
||||
|
||||
@classmethod
|
||||
def get_model_and_model_attr_from_attr_string[Model: SqlAlchemyBase](
|
||||
cls, attr_string: str, model: type[Model], *, query: sa.Select | None = None
|
||||
) -> tuple[SqlAlchemyBase, InstrumentedAttribute, sa.Select | None]:
|
||||
) -> tuple[type[SqlAlchemyBase], InstrumentedAttribute, sa.Select | None]:
|
||||
"""
|
||||
Take an attribute string and traverse a database model and its relationships to get the desired
|
||||
model and model attribute. Optionally provide a query to apply the necessary table joins.
|
||||
@@ -222,17 +231,18 @@ class QueryFilterBuilder:
|
||||
if not attribute_chain:
|
||||
raise ValueError("invalid query string: attribute name cannot be empty")
|
||||
|
||||
current_model: SqlAlchemyBase = model # type: ignore
|
||||
current_model: type[SqlAlchemyBase] = model
|
||||
allow_restricted = allow_filter_restricted.get()
|
||||
for i, attribute_link in enumerate(attribute_chain):
|
||||
try:
|
||||
model_attr = getattr(current_model, attribute_link)
|
||||
model_attr = cls._get_model_attr(current_model, attribute_link)
|
||||
|
||||
# proxied attributes can't be joined to the query directly, so we need to inspect the proxy
|
||||
# and get the actual model and its attribute
|
||||
if isinstance(model_attr, AssociationProxyInstance):
|
||||
proxied_attribute_link = model_attr.target_collection
|
||||
next_attribute_link = model_attr.value_attr
|
||||
model_attr = getattr(current_model, proxied_attribute_link)
|
||||
model_attr = cls._get_model_attr(current_model, proxied_attribute_link)
|
||||
|
||||
if query is not None:
|
||||
query = query.join(model_attr, isouter=True)
|
||||
@@ -240,7 +250,10 @@ class QueryFilterBuilder:
|
||||
mapper = sa.inspect(current_model)
|
||||
relationship = mapper.relationships[proxied_attribute_link]
|
||||
current_model = relationship.mapper.class_
|
||||
model_attr = getattr(current_model, next_attribute_link)
|
||||
|
||||
# Association proxies are intentional field exposures defined on the source model,
|
||||
# so we do not apply the __filter_restricted__ check here.
|
||||
model_attr = cls._get_model_attr(current_model, next_attribute_link)
|
||||
|
||||
# at the end of the chain there are no more relationships to inspect
|
||||
if i == len(attribute_chain) - 1:
|
||||
@@ -252,6 +265,8 @@ class QueryFilterBuilder:
|
||||
mapper = sa.inspect(current_model)
|
||||
relationship = mapper.relationships[attribute_link]
|
||||
current_model = relationship.mapper.class_
|
||||
if not allow_restricted and current_model.__filter_restricted__:
|
||||
raise ValueError(f"cannot traverse into restricted model '{current_model.__name__}'")
|
||||
|
||||
except (AttributeError, KeyError) as e:
|
||||
raise ValueError(f"invalid attribute string: '{attr_string}' does not exist on this schema") from e
|
||||
@@ -299,7 +314,9 @@ class QueryFilterBuilder:
|
||||
if len(value) == 1:
|
||||
element = model_attr.in_(value)
|
||||
else:
|
||||
primary_model_attr: InstrumentedAttribute = getattr(model, component.attribute_name.split(".")[0])
|
||||
primary_model_attr: InstrumentedAttribute = cls._get_model_attr(
|
||||
model, component.attribute_name.split(".")[0]
|
||||
)
|
||||
element = sa.and_(*(primary_model_attr.any(model_attr == v) for v in value))
|
||||
elif component.relationship is RelationalKeyword.LIKE:
|
||||
element = model_attr.ilike(value)
|
||||
@@ -368,7 +385,7 @@ class QueryFilterBuilder:
|
||||
else:
|
||||
component = cast(QueryFilterBuilderComponent, component)
|
||||
base_attribute_name = component.attribute_name.split(".")[-1]
|
||||
model_attr = getattr(attr_model_map[i], base_attribute_name)
|
||||
model_attr = self._get_model_attr(attr_model_map[i], base_attribute_name)
|
||||
|
||||
if (column_alias := column_aliases.get(base_attribute_name)) is not None:
|
||||
model_attr = column_alias
|
||||
|
||||
3
mealie/services/query_filter/context.py
Normal file
3
mealie/services/query_filter/context.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from contextvars import ContextVar
|
||||
|
||||
allow_filter_restricted: ContextVar[bool] = ContextVar("allow_filter_restricted", default=True)
|
||||
@@ -1,3 +1,9 @@
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from mealie.db.models._model_base import PrivateColumn
|
||||
from mealie.db.models.recipe.recipe import RecipeModel
|
||||
from mealie.db.models.users.users import LongLiveToken, User
|
||||
from mealie.services.query_filter.builder import (
|
||||
LogicalOperator,
|
||||
QueryFilterBuilder,
|
||||
@@ -6,6 +12,7 @@ from mealie.services.query_filter.builder import (
|
||||
RelationalKeyword,
|
||||
RelationalOperator,
|
||||
)
|
||||
from mealie.services.query_filter.context import allow_filter_restricted
|
||||
|
||||
|
||||
def test_query_filter_builder_json():
|
||||
@@ -74,3 +81,118 @@ def test_query_filter_builder_json_uses_raw_value():
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PrivateColumn tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_private_column_rejects_list_type():
|
||||
"""PrivateColumn[list[X]] must raise TypeError at definition time to prevent misuse on relationships."""
|
||||
with pytest.raises(TypeError, match="relationship"):
|
||||
PrivateColumn[list[User]]
|
||||
|
||||
|
||||
def test_private_field_user_password_raises():
|
||||
"""Filtering on User.password (PrivateColumn) should raise ValueError."""
|
||||
with pytest.raises(ValueError, match="private field"):
|
||||
QueryFilterBuilder.get_model_and_model_attr_from_attr_string("password", User)
|
||||
|
||||
|
||||
def test_private_field_long_live_token_raises():
|
||||
"""Filtering on LongLiveToken.token (PrivateColumn) should raise ValueError."""
|
||||
with pytest.raises(ValueError, match="private field"):
|
||||
QueryFilterBuilder.get_model_and_model_attr_from_attr_string("token", LongLiveToken)
|
||||
|
||||
|
||||
def test_non_private_field_does_not_raise():
|
||||
"""Filtering on a normal field should not raise."""
|
||||
model, attr, _ = QueryFilterBuilder.get_model_and_model_attr_from_attr_string("full_name", User)
|
||||
assert model is User
|
||||
assert attr is User.full_name
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# __filter_restricted__ traversal tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_restricted_traversal_blocked_when_disallowed():
|
||||
"""Traversing into User (restricted) via RecipeModel.user should raise when the ContextVar is False."""
|
||||
allow_filter_restricted.set(False)
|
||||
try:
|
||||
with pytest.raises(ValueError, match="restricted model"):
|
||||
QueryFilterBuilder.get_model_and_model_attr_from_attr_string("user.email", RecipeModel)
|
||||
finally:
|
||||
allow_filter_restricted.set(True)
|
||||
|
||||
|
||||
def test_association_proxy_through_restricted_model_allowed():
|
||||
"""Association proxies (e.g. household_id) traverse through User but are intentional
|
||||
exposures on the source model and must NOT be blocked even when the ContextVar is False."""
|
||||
allow_filter_restricted.set(False)
|
||||
try:
|
||||
model, attr, _ = QueryFilterBuilder.get_model_and_model_attr_from_attr_string("household_id", RecipeModel)
|
||||
assert model is User
|
||||
finally:
|
||||
allow_filter_restricted.set(True)
|
||||
|
||||
|
||||
def test_restricted_traversal_allowed_by_default():
|
||||
"""Traversing into User via RecipeModel.user should succeed when the ContextVar is True (default)."""
|
||||
model, attr, _ = QueryFilterBuilder.get_model_and_model_attr_from_attr_string("user.email", RecipeModel)
|
||||
assert model is User
|
||||
assert attr is User.email
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ContextVar tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_allow_filter_restricted_default_is_true():
|
||||
"""The ContextVar default must be True so authenticated requests are unrestricted."""
|
||||
assert allow_filter_restricted.get() is True
|
||||
|
||||
|
||||
def test_filter_query_respects_context_var_false(monkeypatch):
|
||||
"""filter_query should block restricted traversal when the ContextVar is False."""
|
||||
allow_filter_restricted.set(False)
|
||||
try:
|
||||
query = sa.select(RecipeModel)
|
||||
builder = QueryFilterBuilder("user.email = 'test@example.com'")
|
||||
with pytest.raises(ValueError, match="restricted model"):
|
||||
builder.filter_query(query, RecipeModel)
|
||||
finally:
|
||||
allow_filter_restricted.set(True)
|
||||
|
||||
|
||||
def test_filter_query_respects_context_var_true():
|
||||
"""filter_query should allow restricted traversal when the ContextVar is True (default)."""
|
||||
allow_filter_restricted.set(True)
|
||||
query = sa.select(RecipeModel)
|
||||
builder = QueryFilterBuilder("user.email = 'test@example.com'")
|
||||
# Should not raise
|
||||
builder.filter_query(query, RecipeModel)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# orderBy restricted traversal tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_order_by_restricted_traversal_blocked():
|
||||
"""orderBy into a restricted model is blocked when the ContextVar is False."""
|
||||
allow_filter_restricted.set(False)
|
||||
try:
|
||||
with pytest.raises(ValueError, match="restricted model"):
|
||||
QueryFilterBuilder.get_model_and_model_attr_from_attr_string("user.email", RecipeModel)
|
||||
finally:
|
||||
allow_filter_restricted.set(True)
|
||||
|
||||
|
||||
def test_order_by_private_field_blocked():
|
||||
"""Ordering by a PrivateColumn field should always raise regardless of the ContextVar."""
|
||||
with pytest.raises(ValueError, match="private field"):
|
||||
QueryFilterBuilder.get_model_and_model_attr_from_attr_string("password", User)
|
||||
|
||||
Reference in New Issue
Block a user