Add support for filtering on relationships.

This commit is contained in:
Yaro Kasear 2025-08-27 10:26:18 -05:00
parent 35068618c4
commit 07db7466cc

View file

@ -1,6 +1,7 @@
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional
from sqlalchemy import asc, desc, select, false()
from sqlalchemy import asc, desc, select, false
from sqlalchemy.inspection import inspect
@dataclass
class QuerySpec:
@ -23,6 +24,56 @@ FILTER_OPS = {
"__isnull": lambda c, v: (c.is_(None) if v else c.is_not(None))
}
def _split_filter_key(raw_key: str):
for op in sorted(FILTER_OPS.keys(), key=len, reverse=True):
if raw_key.endswith(op):
return raw_key[: -len(op)], op
return raw_key, None
def _ensure_wildcards(op_key, value):
if op_key == "__ilike" and isinstance(value, str) and "%" not in value and "_" not in value:
return f"%{value}%"
return value
def _related_predicate(Model, path_parts, op_key, value):
"""
Build EXISTS subqueries for dotted filters:
- scalar rels -> attr.has(inner_predicate)
- collection -> attr.any(inner_predicate)
"""
head, *rest = path_parts
# class-bound relationship attribute (InstrumentedAttribute)
attr = getattr(Model, head, None)
if attr is None:
return None
# relationship metadata if you need uselist + target model
rel = inspect(Model).relationships.get(head)
if rel is None:
return None
Target = rel.mapper.class_
if not rest:
# filtering directly on a relationship without a leaf column isn't supported
return None
if len(rest) == 1:
# final hop is a column on the related model
leaf = rest[0]
col = getattr(Target, leaf, None)
if col is None:
return None
pred = FILTER_OPS[op_key](col, value) if op_key else (col == value)
else:
# recurse deeper: owner.room.area.name__ilike=...
pred = _related_predicate(Target, rest, op_key, value)
if pred is None:
return None
# wrap at this hop using the *attribute*, not the RelationshipProperty
return attr.any(pred) if rel.uselist else attr.has(pred)
def build_query(Model, spec: QuerySpec, eager_policy=None):
stmt = select(Model)
@ -37,22 +88,30 @@ def build_query(Model, spec: QuerySpec, eager_policy=None):
# filters
for raw_key, val in spec.filters.items():
for op in FILTER_OPS:
if raw_key.endswith(op):
colname = raw_key[: -len(op)]
col = getattr(Model, colname)
stmt = stmt.where(FILTER_OPS[op](col, val))
break
else:
stmt = stmt.where(getattr(Model, raw_key) == val)
path, op_key = _split_filter_key(raw_key)
val = _ensure_wildcards(op_key, val)
if "." in path:
pred = _related_predicate(Model, path.split("."), op_key, val)
if pred is not None:
stmt = stmt.where(pred)
continue
col = getattr(Model, path, None)
if col is None:
continue
stmt = stmt.where(FILTER_OPS[op_key](col, val) if op_key else (col == val))
# order_by
for key in spec.order_by:
desc_ = key.startswith("-")
col = getattr(Model, key[1:] if desc_ else key)
stmt = stmt.order_by(desc(col) if desc_ else asc(col))
# eager loading
if eager_policy:
opts = eager_policy(Model, spec.expand)
if opts:
stmt = stmt.options(*opts)
return stmt