Add support for filtering on relationships.

This commit is contained in:
Yaro Kasear 2025-08-27 10:26:18 -05:00
parent 35068618c4
commit 07db7466cc

View file

@ -1,6 +1,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from sqlalchemy import asc, desc, select, false() from sqlalchemy import asc, desc, select, false
from sqlalchemy.inspection import inspect
@dataclass @dataclass
class QuerySpec: class QuerySpec:
@ -23,6 +24,56 @@ FILTER_OPS = {
"__isnull": lambda c, v: (c.is_(None) if v else c.is_not(None)) "__isnull": lambda c, v: (c.is_(None) if v else c.is_not(None))
} }
def _split_filter_key(raw_key: str):
for op in sorted(FILTER_OPS.keys(), key=len, reverse=True):
if raw_key.endswith(op):
return raw_key[: -len(op)], op
return raw_key, None
def _ensure_wildcards(op_key, value):
if op_key == "__ilike" and isinstance(value, str) and "%" not in value and "_" not in value:
return f"%{value}%"
return value
def _related_predicate(Model, path_parts, op_key, value):
"""
Build EXISTS subqueries for dotted filters:
- scalar rels -> attr.has(inner_predicate)
- collection -> attr.any(inner_predicate)
"""
head, *rest = path_parts
# class-bound relationship attribute (InstrumentedAttribute)
attr = getattr(Model, head, None)
if attr is None:
return None
# relationship metadata if you need uselist + target model
rel = inspect(Model).relationships.get(head)
if rel is None:
return None
Target = rel.mapper.class_
if not rest:
# filtering directly on a relationship without a leaf column isn't supported
return None
if len(rest) == 1:
# final hop is a column on the related model
leaf = rest[0]
col = getattr(Target, leaf, None)
if col is None:
return None
pred = FILTER_OPS[op_key](col, value) if op_key else (col == value)
else:
# recurse deeper: owner.room.area.name__ilike=...
pred = _related_predicate(Target, rest, op_key, value)
if pred is None:
return None
# wrap at this hop using the *attribute*, not the RelationshipProperty
return attr.any(pred) if rel.uselist else attr.has(pred)
def build_query(Model, spec: QuerySpec, eager_policy=None): def build_query(Model, spec: QuerySpec, eager_policy=None):
stmt = select(Model) stmt = select(Model)
@ -37,22 +88,30 @@ def build_query(Model, spec: QuerySpec, eager_policy=None):
# filters # filters
for raw_key, val in spec.filters.items(): for raw_key, val in spec.filters.items():
for op in FILTER_OPS: path, op_key = _split_filter_key(raw_key)
if raw_key.endswith(op): val = _ensure_wildcards(op_key, val)
colname = raw_key[: -len(op)]
col = getattr(Model, colname) if "." in path:
stmt = stmt.where(FILTER_OPS[op](col, val)) pred = _related_predicate(Model, path.split("."), op_key, val)
break if pred is not None:
else: stmt = stmt.where(pred)
stmt = stmt.where(getattr(Model, raw_key) == val) continue
col = getattr(Model, path, None)
if col is None:
continue
stmt = stmt.where(FILTER_OPS[op_key](col, val) if op_key else (col == val))
# order_by # order_by
for key in spec.order_by: for key in spec.order_by:
desc_ = key.startswith("-") desc_ = key.startswith("-")
col = getattr(Model, key[1:] if desc_ else key) col = getattr(Model, key[1:] if desc_ else key)
stmt = stmt.order_by(desc(col) if desc_ else asc(col)) stmt = stmt.order_by(desc(col) if desc_ else asc(col))
# eager loading # eager loading
if eager_policy: if eager_policy:
opts = eager_policy(Model, spec.expand) opts = eager_policy(Model, spec.expand)
if opts: if opts:
stmt = stmt.options(*opts) stmt = stmt.options(*opts)
return stmt return stmt