Compare commits

..

14 Commits

Author SHA1 Message Date
Michael Genson
b436ff0dbf added runtime check for this 2026-05-14 21:12:25 +00:00
Michael Genson
891508a199 remove over-zealous PrivateColumn on relationships 2026-05-14 21:09:48 +00:00
Michael Genson
000fec4681 simplify contextvar handling 2026-05-14 19:33:57 +00:00
Michael Genson
7de7fc3177 Merge branch 'mealie-next' into fix/prevent-querying-sensitive-fields 2026-05-14 14:31:02 -05:00
Michael Genson
3e172dccef allow association proxies to pass the restricted filter 2026-05-14 19:29:10 +00:00
Michael Genson
3e2a60ad14 protect order_by too 2026-05-14 19:22:31 +00:00
Michael Genson
140bd75412 add tests 2026-05-14 19:15:36 +00:00
Michael Genson
b2497295c3 re-implement PrivateColumn to make sqlalchemy happy 2026-05-14 19:13:45 +00:00
Michael Genson
37fd0c8510 context var for disabling restricted models 2026-05-14 19:06:18 +00:00
Michael Genson
f4dced3623 raise ValueError on restricted models 2026-05-14 18:54:27 +00:00
Michael Genson
85695c2529 add __filter_restricted__ to User table 2026-05-14 18:53:58 +00:00
Michael Genson
7eb8836c14 raise ValueError when querying on private columns 2026-05-14 18:45:47 +00:00
Michael Genson
b0acd415af Add private columns for sensitive data 2026-05-14 18:44:25 +00:00
Michael Genson
81ec849cc6 define PrivateColumn 2026-05-14 18:07:54 +00:00
8 changed files with 187 additions and 14 deletions

View File

@@ -1,5 +1,6 @@
import string import string
from datetime import datetime from datetime import datetime
from typing import Annotated, ClassVar, get_origin
from sqlalchemy import Integer from sqlalchemy import Integer
from sqlalchemy.orm import DeclarativeBase, Mapped, declared_attr, mapped_column, synonym from sqlalchemy.orm import DeclarativeBase, Mapped, declared_attr, mapped_column, synonym
@@ -14,7 +15,28 @@ NORMALIZE_PUNCTUATION = string.punctuation.replace("'", "").replace('"', "")
_NORMALIZE_PUNCTUATION_TABLE = str.maketrans(NORMALIZE_PUNCTUATION, " " * len(NORMALIZE_PUNCTUATION)) _NORMALIZE_PUNCTUATION_TABLE = str.maketrans(NORMALIZE_PUNCTUATION, " " * len(NORMALIZE_PUNCTUATION))
class PrivateColumn[T]:
"""
Drop-in replacement for `Mapped[]` that marks a column as private.
Private columns cannot be used in query filter expressions.
Only valid on scalar column fields. Using it on a relationship type (e.g. `list[Model]`)
will raise a `TypeError` at class definition time.
"""
def __class_getitem__(cls, item: type) -> type:
if get_origin(item) is list or item is list:
raise TypeError(
f"PrivateColumn cannot be used on relationship fields (got {item!r}). "
"Annotate the related model's scalar column directly instead."
)
return Mapped[Annotated[item, mapped_column(info={"private": True})]]
class SqlAlchemyBase(DeclarativeBase): class SqlAlchemyBase(DeclarativeBase):
__filter_restricted__: ClassVar[bool] = False
"""When True, the query filter API will block traversal into this model unless explicitly allowed."""
id: Mapped[int] = mapped_column(Integer, primary_key=True) id: Mapped[int] = mapped_column(Integer, primary_key=True)
created_at: Mapped[datetime | None] = mapped_column(NaiveDateTime, default=get_utc_now, index=True) created_at: Mapped[datetime | None] = mapped_column(NaiveDateTime, default=get_utc_now, index=True)
update_at: Mapped[datetime | None] = mapped_column(NaiveDateTime, default=get_utc_now, onupdate=get_utc_now) update_at: Mapped[datetime | None] = mapped_column(NaiveDateTime, default=get_utc_now, onupdate=get_utc_now)

View File

@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Optional
from sqlalchemy import ForeignKey, Integer, String, orm from sqlalchemy import ForeignKey, Integer, String, orm
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from .._model_base import BaseMixins, SqlAlchemyBase from .._model_base import BaseMixins, PrivateColumn, SqlAlchemyBase
from .._model_utils import guid from .._model_utils import guid
from .._model_utils.auto_init import auto_init from .._model_utils.auto_init import auto_init
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
class GroupInviteToken(SqlAlchemyBase, BaseMixins): class GroupInviteToken(SqlAlchemyBase, BaseMixins):
__tablename__ = "invite_tokens" __tablename__ = "invite_tokens"
token: Mapped[str] = mapped_column(String, index=True, nullable=False, unique=True) token: PrivateColumn[str] = mapped_column(String, index=True, nullable=False, unique=True)
uses_left: Mapped[int] = mapped_column(Integer, nullable=False, default=1) uses_left: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
group_id: Mapped[guid.GUID | None] = mapped_column(guid.GUID, ForeignKey("groups.id"), index=True) group_id: Mapped[guid.GUID | None] = mapped_column(guid.GUID, ForeignKey("groups.id"), index=True)

View File

@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
from sqlalchemy import ForeignKey, String, orm from sqlalchemy import ForeignKey, String, orm
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from .._model_base import BaseMixins, SqlAlchemyBase from .._model_base import BaseMixins, PrivateColumn, SqlAlchemyBase
from .._model_utils.guid import GUID from .._model_utils.guid import GUID
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -15,7 +15,7 @@ class PasswordResetModel(SqlAlchemyBase, BaseMixins):
user_id: Mapped[GUID] = mapped_column(GUID, ForeignKey("users.id"), nullable=False, index=True) user_id: Mapped[GUID] = mapped_column(GUID, ForeignKey("users.id"), nullable=False, index=True)
user: Mapped["User"] = orm.relationship("User", back_populates="password_reset_tokens", uselist=False) user: Mapped["User"] = orm.relationship("User", back_populates="password_reset_tokens", uselist=False)
token: Mapped[str] = mapped_column(String(64), unique=True, nullable=False) token: PrivateColumn[str] = mapped_column(String(64), unique=True, nullable=False)
def __init__(self, user_id, token, **_): def __init__(self, user_id, token, **_):
self.user_id = user_id self.user_id = user_id

View File

@@ -13,7 +13,7 @@ from mealie.db.models._model_utils.auto_init import auto_init
from mealie.db.models._model_utils.datetime import NaiveDateTime from mealie.db.models._model_utils.datetime import NaiveDateTime
from mealie.db.models._model_utils.guid import GUID from mealie.db.models._model_utils.guid import GUID
from .._model_base import BaseMixins, SqlAlchemyBase from .._model_base import BaseMixins, PrivateColumn, SqlAlchemyBase
from .user_to_recipe import UserToRecipe from .user_to_recipe import UserToRecipe
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
class LongLiveToken(SqlAlchemyBase, BaseMixins): class LongLiveToken(SqlAlchemyBase, BaseMixins):
__tablename__ = "long_live_tokens" __tablename__ = "long_live_tokens"
name: Mapped[str] = mapped_column(String, nullable=False) name: Mapped[str] = mapped_column(String, nullable=False)
token: Mapped[str] = mapped_column(String, nullable=False, index=True) token: PrivateColumn[str] = mapped_column(String, nullable=False, index=True)
user_id: Mapped[GUID | None] = mapped_column(GUID, ForeignKey("users.id"), index=True) user_id: Mapped[GUID | None] = mapped_column(GUID, ForeignKey("users.id"), index=True)
user: Mapped[Optional["User"]] = orm.relationship("User") user: Mapped[Optional["User"]] = orm.relationship("User")
@@ -50,11 +50,13 @@ class AuthMethod(enum.Enum):
class User(SqlAlchemyBase, BaseMixins): class User(SqlAlchemyBase, BaseMixins):
__tablename__ = "users" __tablename__ = "users"
__filter_restricted__ = True
id: Mapped[GUID] = mapped_column(GUID, primary_key=True, default=GUID.generate) id: Mapped[GUID] = mapped_column(GUID, primary_key=True, default=GUID.generate)
full_name: Mapped[str | None] = mapped_column(String, index=True) full_name: Mapped[str | None] = mapped_column(String, index=True)
username: Mapped[str | None] = mapped_column(String, index=True, unique=True) username: Mapped[str | None] = mapped_column(String, index=True, unique=True)
email: Mapped[str | None] = mapped_column(String, unique=True, index=True) email: Mapped[str | None] = mapped_column(String, unique=True, index=True)
password: Mapped[str | None] = mapped_column(String) password: PrivateColumn[str | None] = mapped_column(String)
auth_method: Mapped[Enum[AuthMethod]] = mapped_column(Enum(AuthMethod), default=AuthMethod.MEALIE) auth_method: Mapped[Enum[AuthMethod]] = mapped_column(Enum(AuthMethod), default=AuthMethod.MEALIE)
admin: Mapped[bool | None] = mapped_column(Boolean, default=False) admin: Mapped[bool | None] = mapped_column(Boolean, default=False)
advanced: Mapped[bool | None] = mapped_column(Boolean, default=False) advanced: Mapped[bool | None] = mapped_column(Boolean, default=False)

View File

@@ -27,6 +27,12 @@ from mealie.schema.household.household import HouseholdInDB
from mealie.schema.user.user import GroupInDB, PrivateUser from mealie.schema.user.user import GroupInDB, PrivateUser
from mealie.services.event_bus_service.event_bus_service import EventBusService from mealie.services.event_bus_service.event_bus_service import EventBusService
from mealie.services.event_bus_service.event_types import EventDocumentDataBase, EventTypes from mealie.services.event_bus_service.event_types import EventDocumentDataBase, EventTypes
from mealie.services.query_filter.context import allow_filter_restricted
def _set_no_restricted_filter() -> None:
"""FastAPI dependency that disables restricted model traversal for the current request."""
allow_filter_restricted.set(False)
class _BaseController(ABC): # noqa: B024 class _BaseController(ABC): # noqa: B024
@@ -94,6 +100,7 @@ class BasePublicGroupExploreController(BasePublicController):
""" """
group: GroupInDB = Depends(get_public_group) group: GroupInDB = Depends(get_public_group)
_no_restricted_filter: None = Depends(_set_no_restricted_filter)
@property @property
def group_id(self) -> UUID4 | None | NotSet: def group_id(self) -> UUID4 | None | NotSet:

View File

@@ -18,6 +18,7 @@ from mealie.db.models._model_utils.datetime import NaiveDateTime
from mealie.db.models._model_utils.guid import GUID from mealie.db.models._model_utils.guid import GUID
from mealie.schema._mealie.mealie_model import MealieModel from mealie.schema._mealie.mealie_model import MealieModel
from .context import allow_filter_restricted
from .keywords import PlaceholderKeyword, RelationalKeyword from .keywords import PlaceholderKeyword, RelationalKeyword
from .operators import LogicalOperator, RelationalOperator from .operators import LogicalOperator, RelationalOperator
@@ -199,10 +200,18 @@ class QueryFilterBuilder:
if i == len(group) - 1: if i == len(group) - 1:
return consolidated_group_builder.self_group() return consolidated_group_builder.self_group()
@classmethod
def _get_model_attr(cls, model: type[SqlAlchemyBase], attr_name: str) -> InstrumentedAttribute:
model_attr: InstrumentedAttribute = getattr(model, attr_name)
if getattr(model_attr, "info", {}).get("private"):
raise ValueError(f"cannot filter on private field '{model.__name__}.{attr_name}'")
return model_attr
@classmethod @classmethod
def get_model_and_model_attr_from_attr_string[Model: SqlAlchemyBase]( def get_model_and_model_attr_from_attr_string[Model: SqlAlchemyBase](
cls, attr_string: str, model: type[Model], *, query: sa.Select | None = None cls, attr_string: str, model: type[Model], *, query: sa.Select | None = None
) -> tuple[SqlAlchemyBase, InstrumentedAttribute, sa.Select | None]: ) -> tuple[type[SqlAlchemyBase], InstrumentedAttribute, sa.Select | None]:
""" """
Take an attribute string and traverse a database model and its relationships to get the desired Take an attribute string and traverse a database model and its relationships to get the desired
model and model attribute. Optionally provide a query to apply the necessary table joins. model and model attribute. Optionally provide a query to apply the necessary table joins.
@@ -222,17 +231,18 @@ class QueryFilterBuilder:
if not attribute_chain: if not attribute_chain:
raise ValueError("invalid query string: attribute name cannot be empty") raise ValueError("invalid query string: attribute name cannot be empty")
current_model: SqlAlchemyBase = model # type: ignore current_model: type[SqlAlchemyBase] = model
allow_restricted = allow_filter_restricted.get()
for i, attribute_link in enumerate(attribute_chain): for i, attribute_link in enumerate(attribute_chain):
try: try:
model_attr = getattr(current_model, attribute_link) model_attr = cls._get_model_attr(current_model, attribute_link)
# proxied attributes can't be joined to the query directly, so we need to inspect the proxy # proxied attributes can't be joined to the query directly, so we need to inspect the proxy
# and get the actual model and its attribute # and get the actual model and its attribute
if isinstance(model_attr, AssociationProxyInstance): if isinstance(model_attr, AssociationProxyInstance):
proxied_attribute_link = model_attr.target_collection proxied_attribute_link = model_attr.target_collection
next_attribute_link = model_attr.value_attr next_attribute_link = model_attr.value_attr
model_attr = getattr(current_model, proxied_attribute_link) model_attr = cls._get_model_attr(current_model, proxied_attribute_link)
if query is not None: if query is not None:
query = query.join(model_attr, isouter=True) query = query.join(model_attr, isouter=True)
@@ -240,7 +250,10 @@ class QueryFilterBuilder:
mapper = sa.inspect(current_model) mapper = sa.inspect(current_model)
relationship = mapper.relationships[proxied_attribute_link] relationship = mapper.relationships[proxied_attribute_link]
current_model = relationship.mapper.class_ current_model = relationship.mapper.class_
model_attr = getattr(current_model, next_attribute_link)
# Association proxies are intentional field exposures defined on the source model,
# so we do not apply the __filter_restricted__ check here.
model_attr = cls._get_model_attr(current_model, next_attribute_link)
# at the end of the chain there are no more relationships to inspect # at the end of the chain there are no more relationships to inspect
if i == len(attribute_chain) - 1: if i == len(attribute_chain) - 1:
@@ -252,6 +265,8 @@ class QueryFilterBuilder:
mapper = sa.inspect(current_model) mapper = sa.inspect(current_model)
relationship = mapper.relationships[attribute_link] relationship = mapper.relationships[attribute_link]
current_model = relationship.mapper.class_ current_model = relationship.mapper.class_
if not allow_restricted and current_model.__filter_restricted__:
raise ValueError(f"cannot traverse into restricted model '{current_model.__name__}'")
except (AttributeError, KeyError) as e: except (AttributeError, KeyError) as e:
raise ValueError(f"invalid attribute string: '{attr_string}' does not exist on this schema") from e raise ValueError(f"invalid attribute string: '{attr_string}' does not exist on this schema") from e
@@ -299,7 +314,9 @@ class QueryFilterBuilder:
if len(value) == 1: if len(value) == 1:
element = model_attr.in_(value) element = model_attr.in_(value)
else: else:
primary_model_attr: InstrumentedAttribute = getattr(model, component.attribute_name.split(".")[0]) primary_model_attr: InstrumentedAttribute = cls._get_model_attr(
model, component.attribute_name.split(".")[0]
)
element = sa.and_(*(primary_model_attr.any(model_attr == v) for v in value)) element = sa.and_(*(primary_model_attr.any(model_attr == v) for v in value))
elif component.relationship is RelationalKeyword.LIKE: elif component.relationship is RelationalKeyword.LIKE:
element = model_attr.ilike(value) element = model_attr.ilike(value)
@@ -368,7 +385,7 @@ class QueryFilterBuilder:
else: else:
component = cast(QueryFilterBuilderComponent, component) component = cast(QueryFilterBuilderComponent, component)
base_attribute_name = component.attribute_name.split(".")[-1] base_attribute_name = component.attribute_name.split(".")[-1]
model_attr = getattr(attr_model_map[i], base_attribute_name) model_attr = self._get_model_attr(attr_model_map[i], base_attribute_name)
if (column_alias := column_aliases.get(base_attribute_name)) is not None: if (column_alias := column_aliases.get(base_attribute_name)) is not None:
model_attr = column_alias model_attr = column_alias

View File

@@ -0,0 +1,3 @@
from contextvars import ContextVar
allow_filter_restricted: ContextVar[bool] = ContextVar("allow_filter_restricted", default=True)

View File

@@ -1,3 +1,9 @@
import pytest
import sqlalchemy as sa
from mealie.db.models._model_base import PrivateColumn
from mealie.db.models.recipe.recipe import RecipeModel
from mealie.db.models.users.users import LongLiveToken, User
from mealie.services.query_filter.builder import ( from mealie.services.query_filter.builder import (
LogicalOperator, LogicalOperator,
QueryFilterBuilder, QueryFilterBuilder,
@@ -6,6 +12,7 @@ from mealie.services.query_filter.builder import (
RelationalKeyword, RelationalKeyword,
RelationalOperator, RelationalOperator,
) )
from mealie.services.query_filter.context import allow_filter_restricted
def test_query_filter_builder_json(): def test_query_filter_builder_json():
@@ -74,3 +81,118 @@ def test_query_filter_builder_json_uses_raw_value():
), ),
] ]
) )
# ---------------------------------------------------------------------------
# PrivateColumn tests
# ---------------------------------------------------------------------------
def test_private_column_rejects_list_type():
"""PrivateColumn[list[X]] must raise TypeError at definition time to prevent misuse on relationships."""
with pytest.raises(TypeError, match="relationship"):
PrivateColumn[list[User]]
def test_private_field_user_password_raises():
"""Filtering on User.password (PrivateColumn) should raise ValueError."""
with pytest.raises(ValueError, match="private field"):
QueryFilterBuilder.get_model_and_model_attr_from_attr_string("password", User)
def test_private_field_long_live_token_raises():
"""Filtering on LongLiveToken.token (PrivateColumn) should raise ValueError."""
with pytest.raises(ValueError, match="private field"):
QueryFilterBuilder.get_model_and_model_attr_from_attr_string("token", LongLiveToken)
def test_non_private_field_does_not_raise():
"""Filtering on a normal field should not raise."""
model, attr, _ = QueryFilterBuilder.get_model_and_model_attr_from_attr_string("full_name", User)
assert model is User
assert attr is User.full_name
# ---------------------------------------------------------------------------
# __filter_restricted__ traversal tests
# ---------------------------------------------------------------------------
def test_restricted_traversal_blocked_when_disallowed():
"""Traversing into User (restricted) via RecipeModel.user should raise when the ContextVar is False."""
allow_filter_restricted.set(False)
try:
with pytest.raises(ValueError, match="restricted model"):
QueryFilterBuilder.get_model_and_model_attr_from_attr_string("user.email", RecipeModel)
finally:
allow_filter_restricted.set(True)
def test_association_proxy_through_restricted_model_allowed():
"""Association proxies (e.g. household_id) traverse through User but are intentional
exposures on the source model and must NOT be blocked even when the ContextVar is False."""
allow_filter_restricted.set(False)
try:
model, attr, _ = QueryFilterBuilder.get_model_and_model_attr_from_attr_string("household_id", RecipeModel)
assert model is User
finally:
allow_filter_restricted.set(True)
def test_restricted_traversal_allowed_by_default():
"""Traversing into User via RecipeModel.user should succeed when the ContextVar is True (default)."""
model, attr, _ = QueryFilterBuilder.get_model_and_model_attr_from_attr_string("user.email", RecipeModel)
assert model is User
assert attr is User.email
# ---------------------------------------------------------------------------
# ContextVar tests
# ---------------------------------------------------------------------------
def test_allow_filter_restricted_default_is_true():
"""The ContextVar default must be True so authenticated requests are unrestricted."""
assert allow_filter_restricted.get() is True
def test_filter_query_respects_context_var_false(monkeypatch):
"""filter_query should block restricted traversal when the ContextVar is False."""
allow_filter_restricted.set(False)
try:
query = sa.select(RecipeModel)
builder = QueryFilterBuilder("user.email = 'test@example.com'")
with pytest.raises(ValueError, match="restricted model"):
builder.filter_query(query, RecipeModel)
finally:
allow_filter_restricted.set(True)
def test_filter_query_respects_context_var_true():
"""filter_query should allow restricted traversal when the ContextVar is True (default)."""
allow_filter_restricted.set(True)
query = sa.select(RecipeModel)
builder = QueryFilterBuilder("user.email = 'test@example.com'")
# Should not raise
builder.filter_query(query, RecipeModel)
# ---------------------------------------------------------------------------
# orderBy restricted traversal tests
# ---------------------------------------------------------------------------
def test_order_by_restricted_traversal_blocked():
"""orderBy into a restricted model is blocked when the ContextVar is False."""
allow_filter_restricted.set(False)
try:
with pytest.raises(ValueError, match="restricted model"):
QueryFilterBuilder.get_model_and_model_attr_from_attr_string("user.email", RecipeModel)
finally:
allow_filter_restricted.set(True)
def test_order_by_private_field_blocked():
"""Ordering by a PrivateColumn field should always raise regardless of the ContextVar."""
with pytest.raises(ValueError, match="private field"):
QueryFilterBuilder.get_model_and_model_attr_from_attr_string("password", User)