Add support for filtering on relationships.
This commit is contained in:
parent
35068618c4
commit
07db7466cc
1 changed files with 68 additions and 9 deletions
|
|
@ -1,6 +1,7 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
from sqlalchemy import asc, desc, select, false()
|
from sqlalchemy import asc, desc, select, false
|
||||||
|
from sqlalchemy.inspection import inspect
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class QuerySpec:
|
class QuerySpec:
|
||||||
|
|
@ -23,6 +24,56 @@ FILTER_OPS = {
|
||||||
"__isnull": lambda c, v: (c.is_(None) if v else c.is_not(None))
|
"__isnull": lambda c, v: (c.is_(None) if v else c.is_not(None))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _split_filter_key(raw_key: str):
|
||||||
|
for op in sorted(FILTER_OPS.keys(), key=len, reverse=True):
|
||||||
|
if raw_key.endswith(op):
|
||||||
|
return raw_key[: -len(op)], op
|
||||||
|
return raw_key, None
|
||||||
|
|
||||||
|
def _ensure_wildcards(op_key, value):
|
||||||
|
if op_key == "__ilike" and isinstance(value, str) and "%" not in value and "_" not in value:
|
||||||
|
return f"%{value}%"
|
||||||
|
return value
|
||||||
|
|
||||||
|
def _related_predicate(Model, path_parts, op_key, value):
|
||||||
|
"""
|
||||||
|
Build EXISTS subqueries for dotted filters:
|
||||||
|
- scalar rels -> attr.has(inner_predicate)
|
||||||
|
- collection -> attr.any(inner_predicate)
|
||||||
|
"""
|
||||||
|
head, *rest = path_parts
|
||||||
|
|
||||||
|
# class-bound relationship attribute (InstrumentedAttribute)
|
||||||
|
attr = getattr(Model, head, None)
|
||||||
|
if attr is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# relationship metadata if you need uselist + target model
|
||||||
|
rel = inspect(Model).relationships.get(head)
|
||||||
|
if rel is None:
|
||||||
|
return None
|
||||||
|
Target = rel.mapper.class_
|
||||||
|
|
||||||
|
if not rest:
|
||||||
|
# filtering directly on a relationship without a leaf column isn't supported
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(rest) == 1:
|
||||||
|
# final hop is a column on the related model
|
||||||
|
leaf = rest[0]
|
||||||
|
col = getattr(Target, leaf, None)
|
||||||
|
if col is None:
|
||||||
|
return None
|
||||||
|
pred = FILTER_OPS[op_key](col, value) if op_key else (col == value)
|
||||||
|
else:
|
||||||
|
# recurse deeper: owner.room.area.name__ilike=...
|
||||||
|
pred = _related_predicate(Target, rest, op_key, value)
|
||||||
|
if pred is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# wrap at this hop using the *attribute*, not the RelationshipProperty
|
||||||
|
return attr.any(pred) if rel.uselist else attr.has(pred)
|
||||||
|
|
||||||
def build_query(Model, spec: QuerySpec, eager_policy=None):
|
def build_query(Model, spec: QuerySpec, eager_policy=None):
|
||||||
stmt = select(Model)
|
stmt = select(Model)
|
||||||
|
|
||||||
|
|
@ -37,22 +88,30 @@ def build_query(Model, spec: QuerySpec, eager_policy=None):
|
||||||
|
|
||||||
# filters
|
# filters
|
||||||
for raw_key, val in spec.filters.items():
|
for raw_key, val in spec.filters.items():
|
||||||
for op in FILTER_OPS:
|
path, op_key = _split_filter_key(raw_key)
|
||||||
if raw_key.endswith(op):
|
val = _ensure_wildcards(op_key, val)
|
||||||
colname = raw_key[: -len(op)]
|
|
||||||
col = getattr(Model, colname)
|
if "." in path:
|
||||||
stmt = stmt.where(FILTER_OPS[op](col, val))
|
pred = _related_predicate(Model, path.split("."), op_key, val)
|
||||||
break
|
if pred is not None:
|
||||||
else:
|
stmt = stmt.where(pred)
|
||||||
stmt = stmt.where(getattr(Model, raw_key) == val)
|
continue
|
||||||
|
|
||||||
|
col = getattr(Model, path, None)
|
||||||
|
if col is None:
|
||||||
|
continue
|
||||||
|
stmt = stmt.where(FILTER_OPS[op_key](col, val) if op_key else (col == val))
|
||||||
|
|
||||||
# order_by
|
# order_by
|
||||||
for key in spec.order_by:
|
for key in spec.order_by:
|
||||||
desc_ = key.startswith("-")
|
desc_ = key.startswith("-")
|
||||||
col = getattr(Model, key[1:] if desc_ else key)
|
col = getattr(Model, key[1:] if desc_ else key)
|
||||||
stmt = stmt.order_by(desc(col) if desc_ else asc(col))
|
stmt = stmt.order_by(desc(col) if desc_ else asc(col))
|
||||||
|
|
||||||
# eager loading
|
# eager loading
|
||||||
if eager_policy:
|
if eager_policy:
|
||||||
opts = eager_policy(Model, spec.expand)
|
opts = eager_policy(Model, spec.expand)
|
||||||
if opts:
|
if opts:
|
||||||
stmt = stmt.options(*opts)
|
stmt = stmt.options(*opts)
|
||||||
|
|
||||||
return stmt
|
return stmt
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue