mirror of
				https://github.com/mealie-recipes/mealie.git
				synced 2025-10-31 02:03:35 -04:00 
			
		
		
		
	feature: query filter support for common SQL keywords (#2366)
* added support for SQL keywords IS, IN, LIKE, NOT deprecated datetime workaround for "<> null" updated frontend reference for "<> null" to "IS NOT NULL" * tests * refactored query filtering to leverage orm * added CONTAINS ALL keyword * tests * fixed bug where "and" or "or" was in an attr name * more tests * linter fixes * TIL this works
This commit is contained in:
		| @@ -217,7 +217,7 @@ export default defineComponent({ | |||||||
|  |  | ||||||
|     const queryFilter = computed(() => { |     const queryFilter = computed(() => { | ||||||
|       const orderBy = props.query?.orderBy || preferences.value.orderBy; |       const orderBy = props.query?.orderBy || preferences.value.orderBy; | ||||||
|       return preferences.value.filterNull && orderBy ? `${orderBy} <> null` : null; |       return preferences.value.filterNull && orderBy ? `${orderBy} IS NOT NULL` : null; | ||||||
|     }); |     }); | ||||||
|  |  | ||||||
|     async function fetchRecipes(pageCount = 1) { |     async function fetchRecipes(pageCount = 1) { | ||||||
|   | |||||||
| @@ -1,7 +1,7 @@ | |||||||
| from __future__ import annotations | from __future__ import annotations | ||||||
|  |  | ||||||
| import datetime |  | ||||||
| import re | import re | ||||||
|  | from collections import deque | ||||||
| from enum import Enum | from enum import Enum | ||||||
| from typing import Any, TypeVar, cast | from typing import Any, TypeVar, cast | ||||||
| from uuid import UUID | from uuid import UUID | ||||||
| @@ -9,16 +9,66 @@ from uuid import UUID | |||||||
| from dateutil import parser as date_parser | from dateutil import parser as date_parser | ||||||
| from dateutil.parser import ParserError | from dateutil.parser import ParserError | ||||||
| from humps import decamelize | from humps import decamelize | ||||||
| from sqlalchemy import Select, bindparam, inspect, text | from sqlalchemy import ColumnElement, Select, and_, inspect, or_ | ||||||
| from sqlalchemy.orm import Mapper | from sqlalchemy.orm import InstrumentedAttribute, Mapper | ||||||
| from sqlalchemy.sql import sqltypes | from sqlalchemy.sql import sqltypes | ||||||
| from sqlalchemy.sql.expression import BindParameter |  | ||||||
|  |  | ||||||
| from mealie.db.models._model_utils.guid import GUID | from mealie.db.models._model_utils.guid import GUID | ||||||
|  |  | ||||||
| Model = TypeVar("Model") | Model = TypeVar("Model") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class RelationalKeyword(Enum): | ||||||
|  |     IS = "IS" | ||||||
|  |     IS_NOT = "IS NOT" | ||||||
|  |     IN = "IN" | ||||||
|  |     NOT_IN = "NOT IN" | ||||||
|  |     CONTAINS_ALL = "CONTAINS ALL" | ||||||
|  |     LIKE = "LIKE" | ||||||
|  |     NOT_LIKE = "NOT LIKE" | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def parse_component(cls, component: str) -> list[str] | None: | ||||||
|  |         """ | ||||||
|  |         Try to parse a component using a relational keyword | ||||||
|  |  | ||||||
|  |         If no matching keyword is found, returns None | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         # extract the attribute name from the component | ||||||
|  |         parsed_component = component.split(maxsplit=1) | ||||||
|  |         if len(parsed_component) < 2: | ||||||
|  |             return None | ||||||
|  |  | ||||||
|  |         # assume the component has already filtered out the value and try to match a keyword | ||||||
|  |         # if we try to filter out the value without checking first, keywords with spaces won't parse correctly | ||||||
|  |         possible_keyword = parsed_component[1].strip().lower() | ||||||
|  |         for rel_kw in sorted([keyword.value for keyword in cls], key=len, reverse=True): | ||||||
|  |             if rel_kw.lower() != possible_keyword: | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  |             parsed_component[1] = rel_kw | ||||||
|  |             return parsed_component | ||||||
|  |  | ||||||
|  |         # there was no match, so the component may still have the value in it | ||||||
|  |         try: | ||||||
|  |             _possible_keyword, _value = parsed_component[-1].rsplit(maxsplit=1) | ||||||
|  |             parsed_component = [parsed_component[0], _possible_keyword, _value] | ||||||
|  |         except ValueError: | ||||||
|  |             # the component has no value to filter out | ||||||
|  |             return None | ||||||
|  |  | ||||||
|  |         possible_keyword = parsed_component[1].strip().lower() | ||||||
|  |         for rel_kw in sorted([keyword.value for keyword in cls], key=len, reverse=True): | ||||||
|  |             if rel_kw.lower() != possible_keyword: | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  |             parsed_component[1] = rel_kw | ||||||
|  |             return parsed_component | ||||||
|  |  | ||||||
|  |         return None | ||||||
|  |  | ||||||
|  |  | ||||||
| class RelationalOperator(Enum): | class RelationalOperator(Enum): | ||||||
|     EQ = "=" |     EQ = "=" | ||||||
|     NOTEQ = "<>" |     NOTEQ = "<>" | ||||||
| @@ -27,6 +77,24 @@ class RelationalOperator(Enum): | |||||||
|     GTE = ">=" |     GTE = ">=" | ||||||
|     LTE = "<=" |     LTE = "<=" | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def parse_component(cls, component: str) -> list[str] | None: | ||||||
|  |         """ | ||||||
|  |         Try to parse a component using a relational operator | ||||||
|  |  | ||||||
|  |         If no matching operator is found, returns None | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         for rel_op in sorted([operator.value for operator in cls], key=len, reverse=True): | ||||||
|  |             if rel_op not in component: | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  |             parsed_component = [base_component.strip() for base_component in component.split(rel_op) if base_component] | ||||||
|  |             parsed_component.insert(1, rel_op) | ||||||
|  |             return parsed_component | ||||||
|  |  | ||||||
|  |         return None | ||||||
|  |  | ||||||
|  |  | ||||||
| class LogicalOperator(Enum): | class LogicalOperator(Enum): | ||||||
|     AND = "AND" |     AND = "AND" | ||||||
| @@ -36,31 +104,107 @@ class LogicalOperator(Enum): | |||||||
| class QueryFilterComponent: | class QueryFilterComponent: | ||||||
|     """A single relational statement""" |     """A single relational statement""" | ||||||
|  |  | ||||||
|     def __init__(self, attribute_name: str, relational_operator: RelationalOperator, value: str) -> None: |     @staticmethod | ||||||
|  |     def strip_quotes_from_string(val: str) -> str: | ||||||
|  |         if len(val) > 2 and val[0] == '"' and val[-1] == '"': | ||||||
|  |             return val[1:-1] | ||||||
|  |         else: | ||||||
|  |             return val | ||||||
|  |  | ||||||
|  |     def __init__( | ||||||
|  |         self, attribute_name: str, relationship: RelationalKeyword | RelationalOperator, value: str | list[str] | ||||||
|  |     ) -> None: | ||||||
|         self.attribute_name = decamelize(attribute_name) |         self.attribute_name = decamelize(attribute_name) | ||||||
|         self.relational_operator = relational_operator |         self.relationship = relationship | ||||||
|         self.value = value |  | ||||||
|  |  | ||||||
|         # remove encasing quotes |         # remove encasing quotes | ||||||
|         if len(value) > 2 and value[0] == '"' and value[-1] == '"': |         if isinstance(value, str): | ||||||
|             self.value = value[1:-1] |             value = self.strip_quotes_from_string(value) | ||||||
|  |  | ||||||
|  |         elif isinstance(value, list): | ||||||
|  |             value = [self.strip_quotes_from_string(v) for v in value] | ||||||
|  |  | ||||||
|  |         # validate relationship/value pairs | ||||||
|  |         if relationship in [ | ||||||
|  |             RelationalKeyword.IN, | ||||||
|  |             RelationalKeyword.NOT_IN, | ||||||
|  |             RelationalKeyword.CONTAINS_ALL, | ||||||
|  |         ] and not isinstance(value, list): | ||||||
|  |             raise ValueError( | ||||||
|  |                 f"invalid query string: {relationship.value} must be given a list of values" | ||||||
|  |                 f"enclosed by {QueryFilter.l_list_sep} and {QueryFilter.r_list_sep}" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         if relationship is RelationalKeyword.IS or relationship is RelationalKeyword.IS_NOT: | ||||||
|  |             if not isinstance(value, str) or value.lower() not in ["null", "none"]: | ||||||
|  |                 raise ValueError( | ||||||
|  |                     f'invalid query string: "{relationship.value}" can only be used with "NULL", not "{value}"' | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |             self.value = None | ||||||
|  |         else: | ||||||
|  |             self.value = value | ||||||
|  |  | ||||||
|     def __repr__(self) -> str: |     def __repr__(self) -> str: | ||||||
|         return f"[{self.attribute_name} {self.relational_operator.value} {self.value}]" |         return f"[{self.attribute_name} {self.relationship.value} {self.value}]" | ||||||
|  |  | ||||||
|  |     def validate(self, model_attr_type: Any) -> Any: | ||||||
|  |         """Validate value against an model attribute's type and return a validated value, or raise a ValueError""" | ||||||
|  |  | ||||||
|  |         sanitized_values: list[Any] | ||||||
|  |         if not isinstance(self.value, list): | ||||||
|  |             sanitized_values = [self.value] | ||||||
|  |         else: | ||||||
|  |             sanitized_values = self.value | ||||||
|  |  | ||||||
|  |         for i, v in enumerate(sanitized_values): | ||||||
|  |             # always allow querying for null values | ||||||
|  |             if v is None: | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  |             if self.relationship is RelationalKeyword.LIKE or self.relationship is RelationalKeyword.NOT_LIKE: | ||||||
|  |                 if not isinstance(model_attr_type, sqltypes.String): | ||||||
|  |                     raise ValueError( | ||||||
|  |                         f'invalid query string: "{self.relationship.value}" can only be used with string columns' | ||||||
|  |                     ) | ||||||
|  |  | ||||||
|  |             if isinstance(model_attr_type, (GUID)): | ||||||
|  |                 try: | ||||||
|  |                     # we don't set value since a UUID is functionally identical to a string here | ||||||
|  |                     UUID(v) | ||||||
|  |                 except ValueError as e: | ||||||
|  |                     raise ValueError(f"invalid query string: invalid UUID '{v}'") from e | ||||||
|  |  | ||||||
|  |             if isinstance(model_attr_type, sqltypes.Date | sqltypes.DateTime): | ||||||
|  |                 try: | ||||||
|  |                     sanitized_values[i] = date_parser.parse(v) | ||||||
|  |                 except ParserError as e: | ||||||
|  |                     raise ValueError(f"invalid query string: unknown date or datetime format '{v}'") from e | ||||||
|  |  | ||||||
|  |             if isinstance(model_attr_type, sqltypes.Boolean): | ||||||
|  |                 try: | ||||||
|  |                     sanitized_values[i] = v.lower()[0] in ["t", "y"] or v == "1" | ||||||
|  |                 except IndexError as e: | ||||||
|  |                     raise ValueError("invalid query string") from e | ||||||
|  |  | ||||||
|  |         return sanitized_values if isinstance(self.value, list) else sanitized_values[0] | ||||||
|  |  | ||||||
|  |  | ||||||
| class QueryFilter: | class QueryFilter: | ||||||
|     lsep: str = "(" |     l_group_sep: str = "(" | ||||||
|     rsep: str = ")" |     r_group_sep: str = ")" | ||||||
|  |     group_seps: set[str] = {l_group_sep, r_group_sep} | ||||||
|  |  | ||||||
|     seps: set[str] = {lsep, rsep} |     l_list_sep: str = "[" | ||||||
|  |     r_list_sep: str = "]" | ||||||
|  |     list_item_sep: str = "," | ||||||
|  |  | ||||||
|     def __init__(self, filter_string: str) -> None: |     def __init__(self, filter_string: str) -> None: | ||||||
|         # parse filter string |         # parse filter string | ||||||
|         components = QueryFilter._break_filter_string_into_components(filter_string) |         components = QueryFilter._break_filter_string_into_components(filter_string) | ||||||
|         base_components = QueryFilter._break_components_into_base_components(components) |         base_components = QueryFilter._break_components_into_base_components(components) | ||||||
|         if base_components.count(QueryFilter.lsep) != base_components.count(QueryFilter.rsep): |         if base_components.count(QueryFilter.l_group_sep) != base_components.count(QueryFilter.r_group_sep): | ||||||
|             raise ValueError("invalid filter string: parenthesis are unbalanced") |             raise ValueError("invalid query string: parenthesis are unbalanced") | ||||||
|  |  | ||||||
|         # parse base components into a filter group |         # parse base components into a filter group | ||||||
|         self.filter_components = QueryFilter._parse_base_components_into_filter_components(base_components) |         self.filter_components = QueryFilter._parse_base_components_into_filter_components(base_components) | ||||||
| @@ -75,97 +219,125 @@ class QueryFilter: | |||||||
|  |  | ||||||
|         return f"<<{joined}>>" |         return f"<<{joined}>>" | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def _consolidate_group(cls, group: list[ColumnElement], logical_operators: deque[LogicalOperator]) -> ColumnElement: | ||||||
|  |         consolidated_group_builder: ColumnElement | None = None | ||||||
|  |         for i, element in enumerate(reversed(group)): | ||||||
|  |             if not i: | ||||||
|  |                 consolidated_group_builder = element | ||||||
|  |             else: | ||||||
|  |                 operator = logical_operators.pop() | ||||||
|  |                 if operator is LogicalOperator.AND: | ||||||
|  |                     consolidated_group_builder = and_(consolidated_group_builder, element) | ||||||
|  |                 elif operator is LogicalOperator.OR: | ||||||
|  |                     consolidated_group_builder = or_(consolidated_group_builder, element) | ||||||
|  |                 else: | ||||||
|  |                     raise ValueError(f"invalid logical operator {operator}") | ||||||
|  |  | ||||||
|  |             if i == len(group) - 1: | ||||||
|  |                 return consolidated_group_builder.self_group() | ||||||
|  |  | ||||||
|     def filter_query(self, query: Select, model: type[Model]) -> Select: |     def filter_query(self, query: Select, model: type[Model]) -> Select: | ||||||
|         segments: list[str] = [] |         # join tables and build model chain | ||||||
|         params: list[BindParameter] = [] |         attr_model_map: dict[int, Any] = {} | ||||||
|  |         model_attr: InstrumentedAttribute | ||||||
|         for i, component in enumerate(self.filter_components): |         for i, component in enumerate(self.filter_components): | ||||||
|             if component in QueryFilter.seps: |             if not isinstance(component, QueryFilterComponent): | ||||||
|                 segments.append(component)  # type: ignore |  | ||||||
|                 continue |                 continue | ||||||
|  |  | ||||||
|             if isinstance(component, LogicalOperator): |  | ||||||
|                 segments.append(component.value) |  | ||||||
|                 continue |  | ||||||
|  |  | ||||||
|             # for some reason typing doesn't like the lsep and rsep literals, so |  | ||||||
|             # we explicitly mark this as a filter component instead cast doesn't |  | ||||||
|             # actually do anything at runtime |  | ||||||
|             component = cast(QueryFilterComponent, component) |  | ||||||
|             attribute_chain = component.attribute_name.split(".") |             attribute_chain = component.attribute_name.split(".") | ||||||
|             if not attribute_chain: |             if not attribute_chain: | ||||||
|                 raise ValueError("invalid query string: attribute name cannot be empty") |                 raise ValueError("invalid query string: attribute name cannot be empty") | ||||||
|  |  | ||||||
|             attr_model: Any = model |             current_model = model | ||||||
|             for j, attribute_link in enumerate(attribute_chain): |             for j, attribute_link in enumerate(attribute_chain): | ||||||
|                 # last element |  | ||||||
|                 if j == len(attribute_chain) - 1: |  | ||||||
|                     if not hasattr(attr_model, attribute_link): |  | ||||||
|                         raise ValueError( |  | ||||||
|                             f"invalid query string: '{component.attribute_name}' does not exist on this schema" |  | ||||||
|                         ) |  | ||||||
|  |  | ||||||
|                     attr_value = attribute_link |  | ||||||
|                     if j: |  | ||||||
|                         # use the nested table name, rather than the dot notation |  | ||||||
|                         component.attribute_name = f"{attr_model.__table__.name}.{attr_value}" |  | ||||||
|  |  | ||||||
|                     continue |  | ||||||
|  |  | ||||||
|                 # join on nested model |  | ||||||
|                 try: |                 try: | ||||||
|                     query = query.join(getattr(attr_model, attribute_link)) |                     model_attr = getattr(current_model, attribute_link) | ||||||
|  |  | ||||||
|                     mapper: Mapper = inspect(attr_model) |                     # at the end of the chain there are no more relationships to inspect | ||||||
|  |                     if j == len(attribute_chain) - 1: | ||||||
|  |                         break | ||||||
|  |  | ||||||
|  |                     query = query.join(model_attr) | ||||||
|  |                     mapper: Mapper = inspect(current_model) | ||||||
|                     relationship = mapper.relationships[attribute_link] |                     relationship = mapper.relationships[attribute_link] | ||||||
|                     attr_model = relationship.mapper.class_ |                     current_model = relationship.mapper.class_ | ||||||
|  |  | ||||||
|                 except (AttributeError, KeyError) as e: |                 except (AttributeError, KeyError) as e: | ||||||
|                     raise ValueError( |                     raise ValueError( | ||||||
|                         f"invalid query string: '{component.attribute_name}' does not exist on this schema" |                         f"invalid query string: '{component.attribute_name}' does not exist on this schema" | ||||||
|                     ) from e |                     ) from e | ||||||
|  |             attr_model_map[i] = current_model | ||||||
|  |  | ||||||
|             # convert values to their proper types |         # build query filter | ||||||
|             attr = getattr(attr_model, attr_value) |         partial_group: list[ColumnElement] = [] | ||||||
|             value: Any = component.value |         partial_group_stack: deque[list[ColumnElement]] = deque() | ||||||
|  |         logical_operator_stack: deque[LogicalOperator] = deque() | ||||||
|  |         for i, component in enumerate(self.filter_components): | ||||||
|  |             if component == self.l_group_sep: | ||||||
|  |                 partial_group_stack.append(partial_group) | ||||||
|  |                 partial_group = [] | ||||||
|  |  | ||||||
|             if isinstance(attr.type, (GUID)): |             elif component == self.r_group_sep: | ||||||
|                 try: |                 if partial_group: | ||||||
|                     # we don't set value since a UUID is functionally identical to a string here |                     complete_group = self._consolidate_group(partial_group, logical_operator_stack) | ||||||
|                     UUID(value) |                     partial_group = partial_group_stack.pop() | ||||||
|  |                     partial_group.append(complete_group) | ||||||
|  |                 else: | ||||||
|  |                     partial_group = partial_group_stack.pop() | ||||||
|  |  | ||||||
|                 except ValueError as e: |             elif isinstance(component, LogicalOperator): | ||||||
|                     raise ValueError(f"invalid query string: invalid UUID '{component.value}'") from e |                 logical_operator_stack.append(component) | ||||||
|  |  | ||||||
|             if isinstance(attr.type, sqltypes.Date | sqltypes.DateTime): |  | ||||||
|                 # TODO: add support for IS NULL and IS NOT NULL |  | ||||||
|                 # in the meantime, this will work for the specific usecase of non-null dates/datetimes |  | ||||||
|                 if value in ["none", "null"] and component.relational_operator == RelationalOperator.NOTEQ: |  | ||||||
|                     component.relational_operator = RelationalOperator.GTE |  | ||||||
|                     value = datetime.datetime(datetime.MINYEAR, 1, 1) |  | ||||||
|  |  | ||||||
|             else: |             else: | ||||||
|                     try: |                 component = cast(QueryFilterComponent, component) | ||||||
|                         value = date_parser.parse(component.value) |                 model_attr = getattr(attr_model_map[i], component.attribute_name.split(".")[-1]) | ||||||
|  |  | ||||||
|                     except ParserError as e: |                 # Keywords | ||||||
|                         raise ValueError( |                 if component.relationship is RelationalKeyword.IS: | ||||||
|                             f"invalid query string: unknown date or datetime format '{component.value}'" |                     element = model_attr.is_(component.validate(model_attr.type)) | ||||||
|                         ) from e |                 elif component.relationship is RelationalKeyword.IS_NOT: | ||||||
|  |                     element = model_attr.is_not(component.validate(model_attr.type)) | ||||||
|  |                 elif component.relationship is RelationalKeyword.IN: | ||||||
|  |                     element = model_attr.in_(component.validate(model_attr.type)) | ||||||
|  |                 elif component.relationship is RelationalKeyword.NOT_IN: | ||||||
|  |                     element = model_attr.not_in(component.validate(model_attr.type)) | ||||||
|  |                 elif component.relationship is RelationalKeyword.CONTAINS_ALL: | ||||||
|  |                     primary_model_attr: InstrumentedAttribute = getattr(model, component.attribute_name.split(".")[0]) | ||||||
|  |                     element = and_() | ||||||
|  |                     for v in component.validate(model_attr.type): | ||||||
|  |                         element = and_(element, primary_model_attr.any(model_attr == v)) | ||||||
|  |                 elif component.relationship is RelationalKeyword.LIKE: | ||||||
|  |                     element = model_attr.like(component.validate(model_attr.type)) | ||||||
|  |                 elif component.relationship is RelationalKeyword.NOT_LIKE: | ||||||
|  |                     element = model_attr.not_like(component.validate(model_attr.type)) | ||||||
|  |  | ||||||
|             if isinstance(attr.type, sqltypes.Boolean): |                 # Operators | ||||||
|                 try: |                 elif component.relationship is RelationalOperator.EQ: | ||||||
|                     value = component.value.lower()[0] in ["t", "y"] or component.value == "1" |                     element = model_attr == component.validate(model_attr.type) | ||||||
|  |                 elif component.relationship is RelationalOperator.NOTEQ: | ||||||
|  |                     element = model_attr != component.validate(model_attr.type) | ||||||
|  |                 elif component.relationship is RelationalOperator.GT: | ||||||
|  |                     element = model_attr > component.validate(model_attr.type) | ||||||
|  |                 elif component.relationship is RelationalOperator.LT: | ||||||
|  |                     element = model_attr < component.validate(model_attr.type) | ||||||
|  |                 elif component.relationship is RelationalOperator.GTE: | ||||||
|  |                     element = model_attr >= component.validate(model_attr.type) | ||||||
|  |                 elif component.relationship is RelationalOperator.LTE: | ||||||
|  |                     element = model_attr <= component.validate(model_attr.type) | ||||||
|  |                 else: | ||||||
|  |                     raise ValueError(f"invalid relationship {component.relationship}") | ||||||
|  |  | ||||||
|                 except IndexError as e: |                 partial_group.append(element) | ||||||
|                     raise ValueError("invalid query string") from e |  | ||||||
|  |  | ||||||
|             paramkey = f"P{i+1}" |         # combine the completed groups into one filter | ||||||
|             segments.append(" ".join([component.attribute_name, component.relational_operator.value, f":{paramkey}"])) |         while True: | ||||||
|             params.append(bindparam(paramkey, value, attr.type)) |             consolidated_group = self._consolidate_group(partial_group, logical_operator_stack) | ||||||
|  |             if not partial_group_stack: | ||||||
|         qs = text(" ".join(segments)).bindparams(*params) |                 return query.filter(consolidated_group) | ||||||
|         query = query.filter(qs) |             else: | ||||||
|         return query |                 partial_group = partial_group_stack.pop() | ||||||
|  |                 partial_group.append(consolidated_group) | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def _break_filter_string_into_components(filter_string: str) -> list[str]: |     def _break_filter_string_into_components(filter_string: str) -> list[str]: | ||||||
| @@ -176,7 +348,7 @@ class QueryFilter: | |||||||
|             subcomponents = [] |             subcomponents = [] | ||||||
|             for component in components: |             for component in components: | ||||||
|                 # don't parse components comprised of only a separator |                 # don't parse components comprised of only a separator | ||||||
|                 if component in QueryFilter.seps: |                 if component in QueryFilter.group_seps: | ||||||
|                     subcomponents.append(component) |                     subcomponents.append(component) | ||||||
|                     continue |                     continue | ||||||
|  |  | ||||||
| @@ -187,7 +359,7 @@ class QueryFilter: | |||||||
|                     if c == '"': |                     if c == '"': | ||||||
|                         in_quotes = not in_quotes |                         in_quotes = not in_quotes | ||||||
|  |  | ||||||
|                     if c in QueryFilter.seps and not in_quotes: |                     if c in QueryFilter.group_seps and not in_quotes: | ||||||
|                         if new_component: |                         if new_component: | ||||||
|                             subcomponents.append(new_component) |                             subcomponents.append(new_component) | ||||||
|  |  | ||||||
| @@ -208,25 +380,50 @@ class QueryFilter: | |||||||
|         return components |         return components | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def _break_components_into_base_components(components: list[str]) -> list[str]: |     def _break_components_into_base_components(components: list[str]) -> list[str | list[str]]: | ||||||
|         """Further break down components by splitting at relational and logical operators""" |         """Further break down components by splitting at relational and logical operators""" | ||||||
|         logical_operators = re.compile( |         pattern = "|".join([f"\\b{operator.value}\\b" for operator in LogicalOperator]) | ||||||
|             f'({"|".join(operator.value for operator in LogicalOperator)})', flags=re.IGNORECASE |         logical_operators = re.compile(f"({pattern})", flags=re.IGNORECASE) | ||||||
|  |  | ||||||
|  |         in_list = False | ||||||
|  |         base_components: list[str | list] = [] | ||||||
|  |         list_value_components = [] | ||||||
|  |         for component in components: | ||||||
|  |             # parse out lists as their own singular sub component | ||||||
|  |             subcomponents = component.split(QueryFilter.l_list_sep) | ||||||
|  |             for i, subcomponent in enumerate(subcomponents): | ||||||
|  |                 if not i: | ||||||
|  |                     continue | ||||||
|  |  | ||||||
|  |                 for j, list_value_string in enumerate(subcomponent.split(QueryFilter.r_list_sep)): | ||||||
|  |                     if j % 2: | ||||||
|  |                         continue | ||||||
|  |  | ||||||
|  |                     list_value_components.append( | ||||||
|  |                         [val.strip() for val in list_value_string.split(QueryFilter.list_item_sep)] | ||||||
|                     ) |                     ) | ||||||
|  |  | ||||||
|         base_components = [] |             quote_offset = 0 | ||||||
|         for component in components: |  | ||||||
|             offset = 0 |  | ||||||
|             subcomponents = component.split('"') |             subcomponents = component.split('"') | ||||||
|             for i, subcomponent in enumerate(subcomponents): |             for i, subcomponent in enumerate(subcomponents): | ||||||
|  |                 # we are in a list subcomponent, which is already handled | ||||||
|  |                 if in_list: | ||||||
|  |                     if QueryFilter.r_list_sep in subcomponent: | ||||||
|  |                         # filter out the remainder of the list subcomponent and continue parsing | ||||||
|  |                         base_components.append(list_value_components.pop(0)) | ||||||
|  |                         subcomponent = subcomponent.split(QueryFilter.r_list_sep, maxsplit=1)[-1].strip() | ||||||
|  |                         in_list = False | ||||||
|  |                     else: | ||||||
|  |                         continue | ||||||
|  |  | ||||||
|                 # don't parse components comprised of only a separator |                 # don't parse components comprised of only a separator | ||||||
|                 if subcomponent in QueryFilter.seps: |                 if subcomponent in QueryFilter.group_seps: | ||||||
|                     offset += 1 |                     quote_offset += 1 | ||||||
|                     base_components.append(subcomponent) |                     base_components.append(subcomponent) | ||||||
|                     continue |                     continue | ||||||
|  |  | ||||||
|                 # this subscomponent was surrounded in quotes, so we keep it as-is |                 # this subcomponent was surrounded in quotes, so we keep it as-is | ||||||
|                 if (i + offset) % 2: |                 if (i + quote_offset) % 2: | ||||||
|                     base_components.append(f'"{subcomponent.strip()}"') |                     base_components.append(f'"{subcomponent.strip()}"') | ||||||
|                     continue |                     continue | ||||||
|  |  | ||||||
| @@ -234,53 +431,70 @@ class QueryFilter: | |||||||
|                 if not subcomponent: |                 if not subcomponent: | ||||||
|                     continue |                     continue | ||||||
|  |  | ||||||
|  |                 # continue parsing this subcomponent up to the list, then skip over subsequent subcomponents | ||||||
|  |                 if not in_list and QueryFilter.l_list_sep in subcomponent: | ||||||
|  |                     subcomponent, _new_sub_component = subcomponent.split(QueryFilter.l_list_sep, maxsplit=1) | ||||||
|  |                     subcomponent = subcomponent.strip() | ||||||
|  |                     subcomponents.insert(i + 1, _new_sub_component) | ||||||
|  |                     quote_offset += 1 | ||||||
|  |                     in_list = True | ||||||
|  |  | ||||||
|                 # parse out logical operators |                 # parse out logical operators | ||||||
|                 new_components = [ |                 new_components = [ | ||||||
|                     base_component.strip() for base_component in logical_operators.split(subcomponent) if base_component |                     base_component.strip() for base_component in logical_operators.split(subcomponent) if base_component | ||||||
|                 ] |                 ] | ||||||
|  |  | ||||||
|                 # parse out relational operators; each base_subcomponent has exactly zero or one relational operator |                 # parse out relational keywords and operators | ||||||
|                 # we do them one at a time in descending length since some operators overlap (e.g. :> and >) |                 # each base_subcomponent has exactly zero or one keyword or operator | ||||||
|                 for component in new_components: |                 for component in new_components: | ||||||
|                     if not component: |                     if not component: | ||||||
|                         continue |                         continue | ||||||
|  |  | ||||||
|                     added_to_base_components = False |                     # we try relational operators first since they aren't required to be surrounded by spaces | ||||||
|                     for rel_op in sorted([operator.value for operator in RelationalOperator], key=len, reverse=True): |                     parsed_component = RelationalOperator.parse_component(component) | ||||||
|                         if rel_op in component: |                     if parsed_component is not None: | ||||||
|                             new_base_components = [ |                         base_components.extend(parsed_component) | ||||||
|                                 base_component.strip() for base_component in component.split(rel_op) if base_component |                         continue | ||||||
|                             ] |  | ||||||
|                             new_base_components.insert(1, rel_op) |  | ||||||
|                             base_components.extend(new_base_components) |  | ||||||
|  |  | ||||||
|                             added_to_base_components = True |                     parsed_component = RelationalKeyword.parse_component(component) | ||||||
|                             break |                     if parsed_component is not None: | ||||||
|  |                         base_components.extend(parsed_component) | ||||||
|  |                         continue | ||||||
|  |  | ||||||
|                     if not added_to_base_components: |                     # this component does not have any keywords or operators, so we just add it as-is | ||||||
|                     base_components.append(component) |                     base_components.append(component) | ||||||
|  |  | ||||||
|         return base_components |         return base_components | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def _parse_base_components_into_filter_components( |     def _parse_base_components_into_filter_components( | ||||||
|         base_components: list[str], |         base_components: list[str | list[str]], | ||||||
|     ) -> list[str | QueryFilterComponent | LogicalOperator]: |     ) -> list[str | QueryFilterComponent | LogicalOperator]: | ||||||
|         """Walk through base components and construct filter collections""" |         """Walk through base components and construct filter collections""" | ||||||
|  |         relational_keywords = [kw.value for kw in RelationalKeyword] | ||||||
|         relational_operators = [op.value for op in RelationalOperator] |         relational_operators = [op.value for op in RelationalOperator] | ||||||
|         logical_operators = [op.value for op in LogicalOperator] |         logical_operators = [op.value for op in LogicalOperator] | ||||||
|  |  | ||||||
|         # parse QueryFilterComponents and logical operators |         # parse QueryFilterComponents and logical operators | ||||||
|         components: list[str | QueryFilterComponent | LogicalOperator] = [] |         components: list[str | QueryFilterComponent | LogicalOperator] = [] | ||||||
|         for i, base_component in enumerate(base_components): |         for i, base_component in enumerate(base_components): | ||||||
|             if base_component in QueryFilter.seps: |             if isinstance(base_component, list): | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  |             if base_component in QueryFilter.group_seps: | ||||||
|                 components.append(base_component) |                 components.append(base_component) | ||||||
|  |  | ||||||
|             elif base_component in relational_operators: |             elif base_component in relational_keywords or base_component in relational_operators: | ||||||
|  |                 relationship: RelationalKeyword | RelationalOperator | ||||||
|  |                 if base_component in relational_keywords: | ||||||
|  |                     relationship = RelationalKeyword(base_components[i]) | ||||||
|  |                 else: | ||||||
|  |                     relationship = RelationalOperator(base_components[i]) | ||||||
|  |  | ||||||
|                 components.append( |                 components.append( | ||||||
|                     QueryFilterComponent( |                     QueryFilterComponent( | ||||||
|                         attribute_name=base_components[i - 1], |                         attribute_name=base_components[i - 1],  # type: ignore | ||||||
|                         relational_operator=RelationalOperator(base_components[i]), |                         relationship=relationship, | ||||||
|                         value=base_components[i + 1], |                         value=base_components[i + 1], | ||||||
|                     ) |                     ) | ||||||
|                 ) |                 ) | ||||||
|   | |||||||
| @@ -1,5 +1,6 @@ | |||||||
| import time | import time | ||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
|  | from datetime import datetime | ||||||
| from random import randint | from random import randint | ||||||
| from urllib.parse import parse_qsl, urlsplit | from urllib.parse import parse_qsl, urlsplit | ||||||
|  |  | ||||||
| @@ -9,7 +10,10 @@ from humps import camelize | |||||||
|  |  | ||||||
| from mealie.repos.repository_factory import AllRepositories | from mealie.repos.repository_factory import AllRepositories | ||||||
| from mealie.repos.repository_units import RepositoryUnit | from mealie.repos.repository_units import RepositoryUnit | ||||||
|  | from mealie.schema.recipe import Recipe | ||||||
|  | from mealie.schema.recipe.recipe_category import CategorySave, TagSave | ||||||
| from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientUnit | from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientUnit | ||||||
|  | from mealie.schema.recipe.recipe_tool import RecipeToolSave | ||||||
| from mealie.schema.response.pagination import PaginationQuery | from mealie.schema.response.pagination import PaginationQuery | ||||||
| from mealie.services.seeder.seeder_service import SeederService | from mealie.services.seeder.seeder_service import SeederService | ||||||
| from tests.utils import api_routes | from tests.utils import api_routes | ||||||
| @@ -172,6 +176,256 @@ def test_pagination_filter_basic(query_units: tuple[RepositoryUnit, IngredientUn | |||||||
|     assert unit_results[0].id == unit_2.id |     assert unit_results[0].id == unit_2.id | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_pagination_filter_null(database: AllRepositories, unique_user: TestUser): | ||||||
|  |     recipe_not_made_1 = database.recipes.create( | ||||||
|  |         Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string()) | ||||||
|  |     ) | ||||||
|  |     recipe_not_made_2 = database.recipes.create( | ||||||
|  |         Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string()) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # give one recipe a last made date | ||||||
|  |     recipe_made = database.recipes.create( | ||||||
|  |         Recipe( | ||||||
|  |             user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string(), last_made=datetime.now() | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     recipe_repo = database.recipes.by_group(unique_user.group_id)  # type: ignore | ||||||
|  |  | ||||||
|  |     query = PaginationQuery(page=1, per_page=-1, query_filter="lastMade IS NONE") | ||||||
|  |     recipe_results = recipe_repo.page_all(query).items | ||||||
|  |     assert len(recipe_results) == 2 | ||||||
|  |     result_ids = {result.id for result in recipe_results} | ||||||
|  |     assert recipe_not_made_1.id in result_ids | ||||||
|  |     assert recipe_not_made_2.id in result_ids | ||||||
|  |     assert recipe_made.id not in result_ids | ||||||
|  |  | ||||||
|  |     query = PaginationQuery(page=1, per_page=-1, query_filter="lastMade IS NULL") | ||||||
|  |     recipe_results = recipe_repo.page_all(query).items | ||||||
|  |     assert len(recipe_results) == 2 | ||||||
|  |     result_ids = {result.id for result in recipe_results} | ||||||
|  |     assert recipe_not_made_1.id in result_ids | ||||||
|  |     assert recipe_not_made_2.id in result_ids | ||||||
|  |     assert recipe_made.id not in result_ids | ||||||
|  |  | ||||||
|  |     query = PaginationQuery(page=1, per_page=-1, query_filter="lastMade IS NOT NONE") | ||||||
|  |     recipe_results = recipe_repo.page_all(query).items | ||||||
|  |     assert len(recipe_results) == 1 | ||||||
|  |     result_ids = {result.id for result in recipe_results} | ||||||
|  |     assert recipe_not_made_1.id not in result_ids | ||||||
|  |     assert recipe_not_made_2.id not in result_ids | ||||||
|  |     assert recipe_made.id in result_ids | ||||||
|  |  | ||||||
|  |     query = PaginationQuery(page=1, per_page=-1, query_filter="lastMade IS NOT NULL") | ||||||
|  |     recipe_results = recipe_repo.page_all(query).items | ||||||
|  |     assert len(recipe_results) == 1 | ||||||
|  |     result_ids = {result.id for result in recipe_results} | ||||||
|  |     assert recipe_not_made_1.id not in result_ids | ||||||
|  |     assert recipe_not_made_2.id not in result_ids | ||||||
|  |     assert recipe_made.id in result_ids | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_pagination_filter_in(query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit]): | ||||||
|  |     units_repo, unit_1, unit_2, unit_3 = query_units | ||||||
|  |  | ||||||
|  |     query = PaginationQuery(page=1, per_page=-1, query_filter=f"name IN [{unit_1.name}, {unit_2.name}]") | ||||||
|  |     unit_results = units_repo.page_all(query).items | ||||||
|  |  | ||||||
|  |     assert len(unit_results) == 2 | ||||||
|  |     result_ids = {unit.id for unit in unit_results} | ||||||
|  |     assert unit_1.id in result_ids | ||||||
|  |     assert unit_2.id in result_ids | ||||||
|  |     assert unit_3.id not in result_ids | ||||||
|  |  | ||||||
|  |     query = PaginationQuery(page=1, per_page=-1, query_filter=f"name NOT IN [{unit_1.name}, {unit_2.name}]") | ||||||
|  |     unit_results = units_repo.page_all(query).items | ||||||
|  |  | ||||||
|  |     assert len(unit_results) == 1 | ||||||
|  |     result_ids = {unit.id for unit in unit_results} | ||||||
|  |     assert unit_1.id not in result_ids | ||||||
|  |     assert unit_2.id not in result_ids | ||||||
|  |     assert unit_3.id in result_ids | ||||||
|  |  | ||||||
|  |     query = PaginationQuery(page=1, per_page=-1, query_filter=f'name IN ["{unit_3.name}"]') | ||||||
|  |     unit_results = units_repo.page_all(query).items | ||||||
|  |  | ||||||
|  |     assert len(unit_results) == 1 | ||||||
|  |     result_ids = {unit.id for unit in unit_results} | ||||||
|  |     assert unit_1.id not in result_ids | ||||||
|  |     assert unit_2.id not in result_ids | ||||||
|  |     assert unit_3.id in result_ids | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_pagination_filter_in_advanced(database: AllRepositories, unique_user: TestUser): | ||||||
|  |     slug1, slug2 = (random_string(10) for _ in range(2)) | ||||||
|  |  | ||||||
|  |     tags = [ | ||||||
|  |         TagSave(group_id=unique_user.group_id, name=slug1, slug=slug1), | ||||||
|  |         TagSave(group_id=unique_user.group_id, name=slug2, slug=slug2), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     tag_1, tag_2 = [database.tags.create(tag) for tag in tags] | ||||||
|  |  | ||||||
|  |     # Bootstrap the database with recipes | ||||||
|  |     slug = random_string() | ||||||
|  |     recipe_0 = database.recipes.create( | ||||||
|  |         Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=slug, slug=slug, tags=[]) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     slug = random_string() | ||||||
|  |     recipe_1 = database.recipes.create( | ||||||
|  |         Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=slug, slug=slug, tags=[tag_1]) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     slug = random_string() | ||||||
|  |     recipe_2 = database.recipes.create( | ||||||
|  |         Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=slug, slug=slug, tags=[tag_2]) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     slug = random_string() | ||||||
|  |     recipe_1_2 = database.recipes.create( | ||||||
|  |         Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=slug, slug=slug, tags=[tag_1, tag_2]) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     query = PaginationQuery(page=1, per_page=-1, query_filter=f"tags.name IN [{tag_1.name}]") | ||||||
|  |     recipe_results = database.recipes.page_all(query).items | ||||||
|  |     assert len(recipe_results) == 2 | ||||||
|  |     recipe_ids = {recipe.id for recipe in recipe_results} | ||||||
|  |     assert recipe_0.id not in recipe_ids | ||||||
|  |     assert recipe_1.id in recipe_ids | ||||||
|  |     assert recipe_2.id not in recipe_ids | ||||||
|  |     assert recipe_1_2.id in recipe_ids | ||||||
|  |  | ||||||
|  |     query = PaginationQuery(page=1, per_page=-1, query_filter=f"tags.name IN [{tag_1.name}, {tag_2.name}]") | ||||||
|  |     recipe_results = database.recipes.page_all(query).items | ||||||
|  |     assert len(recipe_results) == 3 | ||||||
|  |     recipe_ids = {recipe.id for recipe in recipe_results} | ||||||
|  |     assert recipe_0.id not in recipe_ids | ||||||
|  |     assert recipe_1.id in recipe_ids | ||||||
|  |     assert recipe_2.id in recipe_ids | ||||||
|  |     assert recipe_1_2.id in recipe_ids | ||||||
|  |  | ||||||
|  |     query = PaginationQuery(page=1, per_page=-1, query_filter=f"tags.name CONTAINS ALL [{tag_1.name}, {tag_2.name}]") | ||||||
|  |     recipe_results = database.recipes.page_all(query).items | ||||||
|  |     assert len(recipe_results) == 1 | ||||||
|  |     recipe_ids = {recipe.id for recipe in recipe_results} | ||||||
|  |     assert recipe_0.id not in recipe_ids | ||||||
|  |     assert recipe_1.id not in recipe_ids | ||||||
|  |     assert recipe_2.id not in recipe_ids | ||||||
|  |     assert recipe_1_2.id in recipe_ids | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_pagination_filter_like(query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit]): | ||||||
|  |     units_repo, unit_1, unit_2, unit_3 = query_units | ||||||
|  |  | ||||||
|  |     query = PaginationQuery(page=1, per_page=-1, query_filter=r'name LIKE "test u_it%"') | ||||||
|  |     unit_results = units_repo.page_all(query).items | ||||||
|  |  | ||||||
|  |     assert len(unit_results) == 3 | ||||||
|  |     result_ids = {unit.id for unit in unit_results} | ||||||
|  |     assert unit_1.id in result_ids | ||||||
|  |     assert unit_2.id in result_ids | ||||||
|  |     assert unit_3.id in result_ids | ||||||
|  |  | ||||||
|  |     query = PaginationQuery(page=1, per_page=-1, query_filter=r'name LIKE "%unit 1"') | ||||||
|  |     unit_results = units_repo.page_all(query).items | ||||||
|  |  | ||||||
|  |     assert len(unit_results) == 1 | ||||||
|  |     result_ids = {unit.id for unit in unit_results} | ||||||
|  |     assert unit_1.id in result_ids | ||||||
|  |     assert unit_2.id not in result_ids | ||||||
|  |     assert unit_3.id not in result_ids | ||||||
|  |  | ||||||
|  |     query = PaginationQuery(page=1, per_page=-1, query_filter=r'name NOT LIKE %t_1"') | ||||||
|  |     unit_results = units_repo.page_all(query).items | ||||||
|  |  | ||||||
|  |     assert len(unit_results) == 2 | ||||||
|  |     result_ids = {unit.id for unit in unit_results} | ||||||
|  |     assert unit_1.id not in result_ids | ||||||
|  |     assert unit_2.id in result_ids | ||||||
|  |     assert unit_3.id in result_ids | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_pagination_filter_keyword_namespace_conflict(database: AllRepositories, unique_user: TestUser): | ||||||
|  |     recipe_rating_1 = database.recipes.create( | ||||||
|  |         Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string(), rating=1) | ||||||
|  |     ) | ||||||
|  |     recipe_rating_2 = database.recipes.create( | ||||||
|  |         Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string(), rating=2) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     recipe_rating_3 = database.recipes.create( | ||||||
|  |         Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string(), rating=3) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     recipe_repo = database.recipes.by_group(unique_user.group_id)  # type: ignore | ||||||
|  |  | ||||||
|  |     # "rating" contains the word "in", but we should not parse this as the keyword "IN" | ||||||
|  |     query = PaginationQuery(page=1, per_page=-1, query_filter="rating > 2") | ||||||
|  |     recipe_results = recipe_repo.page_all(query).items | ||||||
|  |  | ||||||
|  |     assert len(recipe_results) == 1 | ||||||
|  |     result_ids = {recipe.id for recipe in recipe_results} | ||||||
|  |     assert recipe_rating_1.id not in result_ids | ||||||
|  |     assert recipe_rating_2.id not in result_ids | ||||||
|  |     assert recipe_rating_3.id in result_ids | ||||||
|  |  | ||||||
|  |     query = PaginationQuery(page=1, per_page=-1, query_filter="rating in [1, 3]") | ||||||
|  |     recipe_results = recipe_repo.page_all(query).items | ||||||
|  |  | ||||||
|  |     assert len(recipe_results) == 2 | ||||||
|  |     result_ids = {recipe.id for recipe in recipe_results} | ||||||
|  |     assert recipe_rating_1.id in result_ids | ||||||
|  |     assert recipe_rating_2.id not in result_ids | ||||||
|  |     assert recipe_rating_3.id in result_ids | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_pagination_filter_logical_namespace_conflict(database: AllRepositories, unique_user: TestUser): | ||||||
|  |     categories = [ | ||||||
|  |         CategorySave(group_id=unique_user.group_id, name=random_string(10)), | ||||||
|  |         CategorySave(group_id=unique_user.group_id, name=random_string(10)), | ||||||
|  |     ] | ||||||
|  |     category_1, category_2 = [database.categories.create(category) for category in categories] | ||||||
|  |  | ||||||
|  |     # Bootstrap the database with recipes | ||||||
|  |     slug = random_string() | ||||||
|  |     recipe_category_0 = database.recipes.create( | ||||||
|  |         Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=slug, slug=slug) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     slug = random_string() | ||||||
|  |     recipe_category_1 = database.recipes.create( | ||||||
|  |         Recipe( | ||||||
|  |             user_id=unique_user.user_id, | ||||||
|  |             group_id=unique_user.group_id, | ||||||
|  |             name=slug, | ||||||
|  |             slug=slug, | ||||||
|  |             recipe_category=[category_1], | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     slug = random_string() | ||||||
|  |     recipe_category_2 = database.recipes.create( | ||||||
|  |         Recipe( | ||||||
|  |             user_id=unique_user.user_id, | ||||||
|  |             group_id=unique_user.group_id, | ||||||
|  |             name=slug, | ||||||
|  |             slug=slug, | ||||||
|  |             recipe_category=[category_2], | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # "recipeCategory" has the substring "or" in it, which shouldn't break queries | ||||||
|  |     query = PaginationQuery(page=1, per_page=-1, query_filter=f'recipeCategory.id = "{category_1.id}"') | ||||||
|  |     recipe_results = database.recipes.by_group(unique_user.group_id).page_all(query).items  # type: ignore | ||||||
|  |     assert len(recipe_results) == 1 | ||||||
|  |     recipe_ids = {recipe.id for recipe in recipe_results} | ||||||
|  |     assert recipe_category_0.id not in recipe_ids | ||||||
|  |     assert recipe_category_1.id in recipe_ids | ||||||
|  |     assert recipe_category_2.id not in recipe_ids | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_pagination_filter_datetimes( | def test_pagination_filter_datetimes( | ||||||
|     query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit] |     query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit] | ||||||
| ): | ): | ||||||
| @@ -197,15 +451,183 @@ def test_pagination_filter_booleans(query_units: tuple[RepositoryUnit, Ingredien | |||||||
|  |  | ||||||
|  |  | ||||||
| def test_pagination_filter_advanced(query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit]): | def test_pagination_filter_advanced(query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit]): | ||||||
|     units_repo = query_units[0] |     units_repo, unit_1, unit_2, unit_3 = query_units | ||||||
|     unit_3 = query_units[3] |  | ||||||
|  |  | ||||||
|     dt = str(unit_3.created_at.isoformat())  # type: ignore |     dt = str(unit_3.created_at.isoformat())  # type: ignore | ||||||
|     qf = f'name="test unit 1" OR (useAbbreviation=f AND (name="test unit 2" OR createdAt > "{dt}"))' |     qf = f'name="test unit 1" OR (useAbbreviation=f AND (name="{unit_2.name}" OR createdAt > "{dt}"))' | ||||||
|     query = PaginationQuery(page=1, per_page=-1, query_filter=qf) |     query = PaginationQuery(page=1, per_page=-1, query_filter=qf) | ||||||
|     unit_results = units_repo.page_all(query).items |     unit_results = units_repo.page_all(query).items | ||||||
|  |  | ||||||
|     assert len(unit_results) == 2 |     assert len(unit_results) == 2 | ||||||
|     assert unit_3.id not in [unit.id for unit in unit_results] |     result_ids = {unit.id for unit in unit_results} | ||||||
|  |     assert unit_1.id in result_ids | ||||||
|  |     assert unit_2.id in result_ids | ||||||
|  |     assert unit_3.id not in result_ids | ||||||
|  |  | ||||||
|  |     qf = f'(name LIKE %_1 OR name IN ["{unit_2.name}"]) AND createdAt IS NOT NONE' | ||||||
|  |     query = PaginationQuery(page=1, per_page=-1, query_filter=qf) | ||||||
|  |     unit_results = units_repo.page_all(query).items | ||||||
|  |  | ||||||
|  |     assert len(unit_results) == 2 | ||||||
|  |     result_ids = {unit.id for unit in unit_results} | ||||||
|  |     assert unit_1.id in result_ids | ||||||
|  |     assert unit_2.id in result_ids | ||||||
|  |     assert unit_3.id not in result_ids | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_pagination_filter_advanced_frontend_sort(database: AllRepositories, unique_user: TestUser): | ||||||
|  |     categories = [ | ||||||
|  |         CategorySave(group_id=unique_user.group_id, name=random_string(10)), | ||||||
|  |         CategorySave(group_id=unique_user.group_id, name=random_string(10)), | ||||||
|  |     ] | ||||||
|  |     category_1, category_2 = [database.categories.create(category) for category in categories] | ||||||
|  |  | ||||||
|  |     slug1, slug2 = (random_string(10) for _ in range(2)) | ||||||
|  |     tags = [ | ||||||
|  |         TagSave(group_id=unique_user.group_id, name=slug1, slug=slug1), | ||||||
|  |         TagSave(group_id=unique_user.group_id, name=slug2, slug=slug2), | ||||||
|  |     ] | ||||||
|  |     tag_1, tag_2 = [database.tags.create(tag) for tag in tags] | ||||||
|  |  | ||||||
|  |     tools = [ | ||||||
|  |         RecipeToolSave(group_id=unique_user.group_id, name=random_string(10)), | ||||||
|  |         RecipeToolSave(group_id=unique_user.group_id, name=random_string(10)), | ||||||
|  |     ] | ||||||
|  |     tool_1, tool_2 = [database.tools.create(tool) for tool in tools] | ||||||
|  |  | ||||||
|  |     # Bootstrap the database with recipes | ||||||
|  |     slug = random_string() | ||||||
|  |     recipe_ct0_tg0_tl0 = database.recipes.create( | ||||||
|  |         Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=slug, slug=slug) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     slug = random_string() | ||||||
|  |     recipe_ct1_tg0_tl0 = database.recipes.create( | ||||||
|  |         Recipe( | ||||||
|  |             user_id=unique_user.user_id, | ||||||
|  |             group_id=unique_user.group_id, | ||||||
|  |             name=slug, | ||||||
|  |             slug=slug, | ||||||
|  |             recipe_category=[category_1], | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     slug = random_string() | ||||||
|  |     recipe_ct12_tg0_tl0 = database.recipes.create( | ||||||
|  |         Recipe( | ||||||
|  |             user_id=unique_user.user_id, | ||||||
|  |             group_id=unique_user.group_id, | ||||||
|  |             name=slug, | ||||||
|  |             slug=slug, | ||||||
|  |             recipe_category=[category_1, category_2], | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     slug = random_string() | ||||||
|  |     recipe_ct1_tg1_tl0 = database.recipes.create( | ||||||
|  |         Recipe( | ||||||
|  |             user_id=unique_user.user_id, | ||||||
|  |             group_id=unique_user.group_id, | ||||||
|  |             name=slug, | ||||||
|  |             slug=slug, | ||||||
|  |             recipe_category=[category_1], | ||||||
|  |             tags=[tag_1], | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     slug = random_string() | ||||||
|  |     recipe_ct1_tg0_tl1 = database.recipes.create( | ||||||
|  |         Recipe( | ||||||
|  |             user_id=unique_user.user_id, | ||||||
|  |             group_id=unique_user.group_id, | ||||||
|  |             name=slug, | ||||||
|  |             slug=slug, | ||||||
|  |             recipe_category=[category_1], | ||||||
|  |             tools=[tool_1], | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     slug = random_string() | ||||||
|  |     recipe_ct0_tg2_tl2 = database.recipes.create( | ||||||
|  |         Recipe( | ||||||
|  |             user_id=unique_user.user_id, | ||||||
|  |             group_id=unique_user.group_id, | ||||||
|  |             name=slug, | ||||||
|  |             slug=slug, | ||||||
|  |             tags=[tag_2], | ||||||
|  |             tools=[tool_2], | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     slug = random_string() | ||||||
|  |     recipe_ct12_tg12_tl2 = database.recipes.create( | ||||||
|  |         Recipe( | ||||||
|  |             user_id=unique_user.user_id, | ||||||
|  |             group_id=unique_user.group_id, | ||||||
|  |             name=slug, | ||||||
|  |             slug=slug, | ||||||
|  |             recipe_category=[category_1, category_2], | ||||||
|  |             tags=[tag_1, tag_2], | ||||||
|  |             tools=[tool_2], | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     repo = database.recipes.by_group(unique_user.group_id)  # type: ignore | ||||||
|  |  | ||||||
|  |     qf = f'recipeCategory.id IN ["{category_1.id}"] AND tools.id IN ["{tool_1.id}"]' | ||||||
|  |     query = PaginationQuery(page=1, per_page=-1, query_filter=qf) | ||||||
|  |     recipe_results = repo.page_all(query).items | ||||||
|  |     assert len(recipe_results) == 1 | ||||||
|  |     recipe_ids = {recipe.id for recipe in recipe_results} | ||||||
|  |     assert recipe_ct0_tg0_tl0.id not in recipe_ids | ||||||
|  |     assert recipe_ct1_tg0_tl0.id not in recipe_ids | ||||||
|  |     assert recipe_ct12_tg0_tl0.id not in recipe_ids | ||||||
|  |     assert recipe_ct1_tg1_tl0.id not in recipe_ids | ||||||
|  |     assert recipe_ct1_tg0_tl1.id in recipe_ids | ||||||
|  |     assert recipe_ct0_tg2_tl2.id not in recipe_ids | ||||||
|  |     assert recipe_ct12_tg12_tl2.id not in recipe_ids | ||||||
|  |  | ||||||
|  |     qf = f'recipeCategory.id CONTAINS ALL ["{category_1.id}", "{category_2.id}"] AND tags.id IN ["{tag_1.id}"]' | ||||||
|  |     query = PaginationQuery(page=1, per_page=-1, query_filter=qf) | ||||||
|  |     recipe_results = repo.page_all(query).items | ||||||
|  |     assert len(recipe_results) == 1 | ||||||
|  |     recipe_ids = {recipe.id for recipe in recipe_results} | ||||||
|  |     assert recipe_ct0_tg0_tl0.id not in recipe_ids | ||||||
|  |     assert recipe_ct1_tg0_tl0.id not in recipe_ids | ||||||
|  |     assert recipe_ct12_tg0_tl0.id not in recipe_ids | ||||||
|  |     assert recipe_ct1_tg1_tl0.id not in recipe_ids | ||||||
|  |     assert recipe_ct1_tg0_tl1.id not in recipe_ids | ||||||
|  |     assert recipe_ct0_tg2_tl2.id not in recipe_ids | ||||||
|  |     assert recipe_ct12_tg12_tl2.id in recipe_ids | ||||||
|  |  | ||||||
|  |     qf = f'tags.id IN ["{tag_1.id}", "{tag_2.id}"] AND tools.id IN ["{tool_2.id}"]' | ||||||
|  |     query = PaginationQuery(page=1, per_page=-1, query_filter=qf) | ||||||
|  |     recipe_results = repo.page_all(query).items | ||||||
|  |     assert len(recipe_results) == 2 | ||||||
|  |     recipe_ids = {recipe.id for recipe in recipe_results} | ||||||
|  |     assert recipe_ct0_tg0_tl0.id not in recipe_ids | ||||||
|  |     assert recipe_ct1_tg0_tl0.id not in recipe_ids | ||||||
|  |     assert recipe_ct12_tg0_tl0.id not in recipe_ids | ||||||
|  |     assert recipe_ct1_tg1_tl0.id not in recipe_ids | ||||||
|  |     assert recipe_ct1_tg0_tl1.id not in recipe_ids | ||||||
|  |     assert recipe_ct0_tg2_tl2.id in recipe_ids | ||||||
|  |     assert recipe_ct12_tg12_tl2.id in recipe_ids | ||||||
|  |  | ||||||
|  |     qf = ( | ||||||
|  |         f'recipeCategory.id CONTAINS ALL ["{category_1.id}", "{category_2.id}"]' | ||||||
|  |         f'AND tags.id IN ["{tag_1.id}", "{tag_2.id}"] AND tools.id IN ["{tool_1.id}", "{tool_2.id}"]' | ||||||
|  |     ) | ||||||
|  |     query = PaginationQuery(page=1, per_page=-1, query_filter=qf) | ||||||
|  |     recipe_results = repo.page_all(query).items | ||||||
|  |     assert len(recipe_results) == 1 | ||||||
|  |     recipe_ids = {recipe.id for recipe in recipe_results} | ||||||
|  |     assert recipe_ct0_tg0_tl0.id not in recipe_ids | ||||||
|  |     assert recipe_ct1_tg0_tl0.id not in recipe_ids | ||||||
|  |     assert recipe_ct12_tg0_tl0.id not in recipe_ids | ||||||
|  |     assert recipe_ct1_tg1_tl0.id not in recipe_ids | ||||||
|  |     assert recipe_ct1_tg0_tl1.id not in recipe_ids | ||||||
|  |     assert recipe_ct0_tg2_tl2.id not in recipe_ids | ||||||
|  |     assert recipe_ct12_tg12_tl2.id in recipe_ids | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||||
| @@ -214,6 +636,13 @@ def test_pagination_filter_advanced(query_units: tuple[RepositoryUnit, Ingredien | |||||||
|         pytest.param('(name="test name" AND useAbbreviation=f))', id="unbalanced parenthesis"), |         pytest.param('(name="test name" AND useAbbreviation=f))', id="unbalanced parenthesis"), | ||||||
|         pytest.param('id="this is not a valid UUID"', id="invalid UUID"), |         pytest.param('id="this is not a valid UUID"', id="invalid UUID"), | ||||||
|         pytest.param('createdAt="this is not a valid datetime format"', id="invalid datetime format"), |         pytest.param('createdAt="this is not a valid datetime format"', id="invalid datetime format"), | ||||||
|  |         pytest.param('name IS "test name"', id="IS can only be used with NULL or NONE"), | ||||||
|  |         pytest.param('name IS NOT "test name"', id="IS NOT can only be used with NULL or NONE"), | ||||||
|  |         pytest.param('name IN "test name"', id="IN must use a list of values"), | ||||||
|  |         pytest.param('name NOT IN "test name"', id="NOT IN must use a list of values"), | ||||||
|  |         pytest.param('name CONTAINS ALL "test name"', id="CONTAINS ALL must use a list of values"), | ||||||
|  |         pytest.param('createdAt LIKE "2023-02-25"', id="LIKE is only valid for string columns"), | ||||||
|  |         pytest.param('createdAt NOT LIKE "2023-02-25"', id="NOT LIKE is only valid for string columns"), | ||||||
|         pytest.param('badAttribute="test value"', id="invalid attribute"), |         pytest.param('badAttribute="test value"', id="invalid attribute"), | ||||||
|         pytest.param('group.badAttribute="test value"', id="bad nested attribute"), |         pytest.param('group.badAttribute="test value"', id="bad nested attribute"), | ||||||
|         pytest.param('group.preferences.badAttribute="test value"', id="bad double nested attribute"), |         pytest.param('group.preferences.badAttribute="test value"', id="bad double nested attribute"), | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user