147 lines
4.6 KiB
Python
147 lines
4.6 KiB
Python
from dataclasses import dataclass, field
|
|
from typing import List, Dict, Any, Optional
|
|
from sqlalchemy import asc, desc, select, false
|
|
from sqlalchemy.inspection import inspect
|
|
|
|
@dataclass
|
|
class QuerySpec:
|
|
filters: Dict[str, Any] = field(default_factory=dict)
|
|
order_by: List[str] = field(default_factory=list)
|
|
page: Optional[int] = None
|
|
per_page: Optional[int] = None
|
|
expand: List[str] = field(default_factory=list)
|
|
fields: Optional[List[str]] = None
|
|
|
|
FILTER_OPS = {
|
|
"__eq": lambda c, v: c == v,
|
|
"__ne": lambda c, v: c != v,
|
|
"__lt": lambda c, v: c < v,
|
|
"__lte": lambda c, v: c <= v,
|
|
"__gt": lambda c, v: c > v,
|
|
"__gte": lambda c, v: c >= v,
|
|
"__ilike": lambda c, v: c.ilike(v),
|
|
"__in": lambda c, v: c.in_(v),
|
|
"__isnull": lambda c, v: (c.is_(None) if v else c.is_not(None))
|
|
}
|
|
|
|
def _split_filter_key(raw_key: str):
|
|
for op in sorted(FILTER_OPS.keys(), key=len, reverse=True):
|
|
if raw_key.endswith(op):
|
|
return raw_key[: -len(op)], op
|
|
return raw_key, None
|
|
|
|
def _ensure_wildcards(op_key, value):
|
|
if op_key == "__ilike" and isinstance(value, str) and "%" not in value and "_" not in value:
|
|
return f"%{value}%"
|
|
return value
|
|
|
|
def _related_predicate(Model, path_parts, op_key, value):
|
|
"""
|
|
Build EXISTS subqueries for dotted filters:
|
|
- scalar rels -> attr.has(inner_predicate)
|
|
- collection -> attr.any(inner_predicate)
|
|
"""
|
|
head, *rest = path_parts
|
|
|
|
# class-bound relationship attribute (InstrumentedAttribute)
|
|
attr = getattr(Model, head, None)
|
|
if attr is None:
|
|
return None
|
|
|
|
# relationship metadata if you need uselist + target model
|
|
rel = inspect(Model).relationships.get(head)
|
|
if rel is None:
|
|
return None
|
|
Target = rel.mapper.class_
|
|
|
|
if not rest:
|
|
# filtering directly on a relationship without a leaf column isn't supported
|
|
return None
|
|
|
|
if len(rest) == 1:
|
|
# final hop is a column on the related model
|
|
leaf = rest[0]
|
|
col = getattr(Target, leaf, None)
|
|
if col is None:
|
|
return None
|
|
pred = FILTER_OPS[op_key](col, value) if op_key else (col == value)
|
|
else:
|
|
# recurse deeper: owner.room.area.name__ilike=...
|
|
pred = _related_predicate(Target, rest, op_key, value)
|
|
if pred is None:
|
|
return None
|
|
|
|
# wrap at this hop using the *attribute*, not the RelationshipProperty
|
|
return attr.any(pred) if rel.uselist else attr.has(pred)
|
|
|
|
def split_sort_tokens(tokens):
|
|
simple, dotted = [], []
|
|
for tok in (tokens or []):
|
|
if not tok:
|
|
continue
|
|
key = tok.lstrip("-")
|
|
if ":" in key:
|
|
key = key.split(":", 1)[0]
|
|
(dotted if "." in key else simple).append(tok)
|
|
return simple, dotted
|
|
|
|
def build_query(Model, spec: QuerySpec, eager_policy=None):
|
|
stmt = select(Model)
|
|
|
|
# filter out soft-deleted rows
|
|
deleted_attr = getattr(Model, "deleted", None)
|
|
if deleted_attr is not None:
|
|
stmt = stmt.where(deleted_attr == false())
|
|
else:
|
|
is_deleted_attr = getattr(Model, "is_deleted", None)
|
|
if is_deleted_attr is not None:
|
|
stmt = stmt.where(is_deleted_attr == false())
|
|
|
|
# filters
|
|
for raw_key, val in spec.filters.items():
|
|
path, op_key = _split_filter_key(raw_key)
|
|
val = _ensure_wildcards(op_key, val)
|
|
|
|
if "." in path:
|
|
pred = _related_predicate(Model, path.split("."), op_key, val)
|
|
if pred is not None:
|
|
stmt = stmt.where(pred)
|
|
continue
|
|
|
|
col = getattr(Model, path, None)
|
|
if col is None:
|
|
continue
|
|
stmt = stmt.where(FILTER_OPS[op_key](col, val) if op_key else (col == val))
|
|
|
|
simple_sorts, _ = split_sort_tokens(spec.order_by)
|
|
|
|
for token in simple_sorts:
|
|
direction = "asc"
|
|
key = token
|
|
if token.startswith("-"):
|
|
direction = "desc"
|
|
key = token[1:]
|
|
if ":" in key:
|
|
key, d = key.rsplit(":", 1)
|
|
direction = "desc" if d.lower().startswith("d") else "asc"
|
|
|
|
if "." in key:
|
|
continue
|
|
|
|
col = getattr(Model, key, None)
|
|
if col is None:
|
|
continue
|
|
stmt = stmt.order_by(desc(col) if direction == "desc" else asc(col))
|
|
|
|
if not spec.order_by and spec.page and spec.per_page:
|
|
pk_cols = inspect(Model).primary_key
|
|
if pk_cols:
|
|
stmt = stmt.order_by(*(asc(c) for c in pk_cols))
|
|
|
|
# eager loading
|
|
if eager_policy:
|
|
opts = eager_policy(Model, spec.expand)
|
|
if opts:
|
|
stmt = stmt.options(*opts)
|
|
|
|
return stmt
|