Finally got that bug!
This commit is contained in:
parent
dc1d4111e2
commit
87cb686c64
2 changed files with 112 additions and 90 deletions
|
|
@ -4,16 +4,16 @@ from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
|
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
|
||||||
from sqlalchemy import and_, func, inspect, or_, text
|
from sqlalchemy import and_, func, inspect, or_, text, select, literal
|
||||||
from sqlalchemy.engine import Engine, Connection
|
from sqlalchemy.engine import Engine, Connection
|
||||||
from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, selectinload, with_loader_criteria
|
from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, selectinload, with_loader_criteria, aliased, with_parent
|
||||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||||
from sqlalchemy.sql import operators, visitors
|
from sqlalchemy.sql import operators, visitors
|
||||||
from sqlalchemy.sql.elements import UnaryExpression, ColumnElement
|
from sqlalchemy.sql.elements import UnaryExpression, ColumnElement
|
||||||
|
|
||||||
from crudkit.core import to_jsonable, deep_diff, diff_to_patch, filter_to_columns, normalize_payload
|
from crudkit.core import to_jsonable, deep_diff, diff_to_patch, filter_to_columns, normalize_payload
|
||||||
from crudkit.core.base import Version
|
from crudkit.core.base import Version
|
||||||
from crudkit.core.spec import CRUDSpec
|
from crudkit.core.spec import CRUDSpec, CollPred
|
||||||
from crudkit.core.types import OrderSpec, SeekWindow
|
from crudkit.core.types import OrderSpec, SeekWindow
|
||||||
from crudkit.backend import BackendInfo, make_backend_info
|
from crudkit.backend import BackendInfo, make_backend_info
|
||||||
from crudkit.projection import compile_projection
|
from crudkit.projection import compile_projection
|
||||||
|
|
@ -516,107 +516,89 @@ class CRUDService(Generic[T]):
|
||||||
return [visitors.replacement_traverse(f, {}, replace) for f in filters]
|
return [visitors.replacement_traverse(f, {}, replace) for f in filters]
|
||||||
|
|
||||||
def _final_filters(self, root_alias, plan):
|
def _final_filters(self, root_alias, plan):
|
||||||
"""Return filters rebounded to our first-hop aliases, with first-hop collection
|
"""
|
||||||
predicates rewritten to EXISTS via rel.any(...)."""
|
Return filters where:
|
||||||
|
- root/to-one predicates are kept as SQLAlchemy expressions.
|
||||||
|
- first-hop collection predicates (CollPred) are rebuilt into a single
|
||||||
|
EXISTS via rel.any(...) with one alias per collection table.
|
||||||
|
"""
|
||||||
filters = list(plan.filters or [])
|
filters = list(plan.filters or [])
|
||||||
if not filters:
|
if not filters:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 1) Build alias map for first-hop targets we joined (to-one)
|
# 1) Build a map of first-hop relationships: TABLE -> (rel_attr, target_cls)
|
||||||
alias_map = {}
|
coll_map = {}
|
||||||
coll_map = {} # KEY CHANGE: table -> (rel_attr, target_cls)
|
|
||||||
for base_alias, rel_attr, target_alias in plan.join_paths:
|
for base_alias, rel_attr, target_alias in plan.join_paths:
|
||||||
if base_alias is not root_alias:
|
if base_alias is not root_alias:
|
||||||
continue
|
continue
|
||||||
prop = getattr(rel_attr, "property", None)
|
prop = getattr(rel_attr, "property", None)
|
||||||
if prop is None:
|
if not prop or not getattr(prop, "uselist", False):
|
||||||
continue
|
continue
|
||||||
|
target_cls = prop.mapper.class_
|
||||||
|
tbl = getattr(target_cls, "__table__", None)
|
||||||
|
if tbl is not None:
|
||||||
|
coll_map[tbl] = (rel_attr, target_cls)
|
||||||
|
|
||||||
# Try to capture a selectable for to-one rebinds (nice-to-have)
|
# 2) Split raw filters into normal SQLA and CollPreds (by target table)
|
||||||
sel = getattr(target_alias, "selectable", None)
|
normal_filters = []
|
||||||
if sel is not None:
|
by_table: dict[Any, list[CollPred]] = {}
|
||||||
alias_map[frozenset(_selectable_keys(sel))] = sel
|
|
||||||
|
|
||||||
# Always build a collection map keyed by the mapped table (no alias needed)
|
|
||||||
if bool(getattr(prop, "uselist", False)):
|
|
||||||
target_cls = prop.mapper.class_
|
|
||||||
tbl = getattr(target_cls, "__table__", None)
|
|
||||||
if tbl is not None:
|
|
||||||
coll_map[tbl] = (rel_attr, target_cls)
|
|
||||||
print(f"STAGE 1 - alias_map = {alias_map}, coll_map={coll_map}")
|
|
||||||
|
|
||||||
# 2) Rebind to-one columns to the exact alias objects we JOINed (if we have them)
|
|
||||||
if alias_map:
|
|
||||||
def _rebind(elem):
|
|
||||||
tbl = getattr(elem, "table", None)
|
|
||||||
if tbl is None:
|
|
||||||
return elem
|
|
||||||
keyset = frozenset(_selectable_keys(tbl))
|
|
||||||
new_sel = alias_map.get(keyset)
|
|
||||||
if new_sel is None or new_sel is tbl:
|
|
||||||
return elem
|
|
||||||
colkey = getattr(elem, "key", None) or getattr(elem, "name", None)
|
|
||||||
if not colkey:
|
|
||||||
return elem
|
|
||||||
try:
|
|
||||||
return getattr(new_sel.c, colkey)
|
|
||||||
except Exception:
|
|
||||||
return elem
|
|
||||||
filters = [visitors.replacement_traverse(f, {}, _rebind) for f in filters]
|
|
||||||
print(f"STAGE 2 - filters = {filters}")
|
|
||||||
|
|
||||||
# 3) If there are no collection filters, we’re done
|
|
||||||
if not coll_map:
|
|
||||||
print("STAGE 3 - No, I have determined there are no collections to handle like a bad girl.")
|
|
||||||
return filters
|
|
||||||
print("STAGE 3 - Yes, I have determined there are collections to handle like a good boy.")
|
|
||||||
|
|
||||||
# 4) Group any filters that reference a first-hop collection TABLE
|
|
||||||
keep = []
|
|
||||||
per_coll = {} # table -> [expr, ...]
|
|
||||||
for f in filters:
|
for f in filters:
|
||||||
touched_tbl = None
|
if isinstance(f, CollPred):
|
||||||
def _find(elem):
|
by_table.setdefault(f.table, []).append(f)
|
||||||
nonlocal touched_tbl
|
|
||||||
tbl = getattr(elem, "table", None)
|
|
||||||
if tbl is None:
|
|
||||||
return
|
|
||||||
# normalize alias -> base table
|
|
||||||
base_tbl = tbl
|
|
||||||
while getattr(base_tbl, "element", None) is not None:
|
|
||||||
base_tbl = getattr(base_tbl, "element")
|
|
||||||
if base_tbl in coll_map and touched_tbl is None:
|
|
||||||
touched_tbl = base_tbl
|
|
||||||
visitors.traverse(f, {}, {'column': _find})
|
|
||||||
|
|
||||||
if touched_tbl is None:
|
|
||||||
keep.append(f)
|
|
||||||
else:
|
else:
|
||||||
per_coll.setdefault(touched_tbl, []).append(f)
|
normal_filters.append(f)
|
||||||
print(f"STAGE 4 - keep = {keep}, per_coll = {per_coll}")
|
|
||||||
|
|
||||||
# 5) For each collection, remap columns to mapped class attrs and wrap with .any(and_(...))
|
# 3) Rebuild each table group into ONE .any(...) using one alias
|
||||||
for tbl, exprs in per_coll.items():
|
from sqlalchemy.orm import aliased
|
||||||
|
from sqlalchemy import and_
|
||||||
|
|
||||||
|
exists_filters = []
|
||||||
|
for tbl, preds in by_table.items():
|
||||||
|
if tbl not in coll_map:
|
||||||
|
# Safety: if it's not a first-hop collection, ignore or raise
|
||||||
|
continue
|
||||||
rel_attr, target_cls = coll_map[tbl]
|
rel_attr, target_cls = coll_map[tbl]
|
||||||
|
ta = aliased(target_cls)
|
||||||
|
|
||||||
def _to_model(elem):
|
built = []
|
||||||
etbl = getattr(elem, "table", None)
|
for p in preds:
|
||||||
if etbl is not None:
|
col = getattr(ta, p.col_key)
|
||||||
# normalize alias -> base table
|
op = p.op
|
||||||
etbl_base = etbl
|
val = p.value
|
||||||
while getattr(etbl_base, "element", None) is not None:
|
if op == 'icontains':
|
||||||
etbl_base = getattr(etbl_base, "element")
|
built.append(col.ilike(f"%{val}%"))
|
||||||
if etbl_base is tbl:
|
elif op == 'eq':
|
||||||
key = getattr(elem, "key", None) or getattr(elem, "name", None)
|
built.append(col == val)
|
||||||
if key and hasattr(target_cls, key):
|
elif op == 'ne':
|
||||||
return getattr(target_cls, key)
|
built.append(col != val)
|
||||||
return elem
|
elif op == 'in':
|
||||||
|
vs = val if isinstance(val, (list, tuple, set)) else [val]
|
||||||
|
built.append(col.in_(vs))
|
||||||
|
elif op == 'nin':
|
||||||
|
vs = val if isinstance(val, (list, tuple, set)) else [val]
|
||||||
|
built.append(~col.in_(vs))
|
||||||
|
elif op == 'lt':
|
||||||
|
built.append(col < val)
|
||||||
|
elif op == 'lte':
|
||||||
|
built.append(col <= val)
|
||||||
|
elif op == 'gt':
|
||||||
|
built.append(col > val)
|
||||||
|
elif op == 'gte':
|
||||||
|
built.append(col >= val)
|
||||||
|
else:
|
||||||
|
# unknown op — skip or raise
|
||||||
|
continue
|
||||||
|
|
||||||
remapped = [visitors.replacement_traverse(e, {}, _to_model) for e in exprs]
|
# enforce child soft delete inside the EXISTS
|
||||||
keep.append(rel_attr.any(and_(*remapped)))
|
if hasattr(target_cls, "is_deleted"):
|
||||||
print(f"STAGE 5 - keep={keep}")
|
built.append(ta.is_deleted == False)
|
||||||
|
|
||||||
return keep
|
crit = and_(*built) if built else None
|
||||||
|
exists_filters.append(rel_attr.of_type(ta).any(crit) if crit is not None
|
||||||
|
else rel_attr.of_type(ta).any())
|
||||||
|
|
||||||
|
# 4) Final filter list = normal SQLA filters + all EXISTS filters
|
||||||
|
return normal_filters + exists_filters
|
||||||
|
|
||||||
# ---- public read ops
|
# ---- public read ops
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,17 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Any, List, Tuple, Set, Dict, Optional, Iterable
|
from typing import Any, List, Tuple, Set, Dict, Optional, Iterable
|
||||||
from sqlalchemy import and_, asc, desc, or_
|
from sqlalchemy import and_, asc, desc, or_
|
||||||
from sqlalchemy.ext.hybrid import hybrid_property
|
from sqlalchemy.ext.hybrid import hybrid_property
|
||||||
from sqlalchemy.orm import aliased, selectinload
|
from sqlalchemy.orm import aliased, selectinload
|
||||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CollPred:
|
||||||
|
table: Any
|
||||||
|
col_key: str
|
||||||
|
op: str
|
||||||
|
value: Any
|
||||||
|
|
||||||
OPERATORS = {
|
OPERATORS = {
|
||||||
'eq': lambda col, val: col == val,
|
'eq': lambda col, val: col == val,
|
||||||
'lt': lambda col, val: col < val,
|
'lt': lambda col, val: col < val,
|
||||||
|
|
@ -68,16 +76,48 @@ class CRUDSpec:
|
||||||
|
|
||||||
exprs = []
|
exprs = []
|
||||||
for col, join_path in pairs:
|
for col, join_path in pairs:
|
||||||
# Track eager path for each involved relationship chain
|
|
||||||
if join_path:
|
if join_path:
|
||||||
self.eager_paths.add(join_path)
|
self.eager_paths.add(join_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
cur_cls = self.model
|
||||||
|
names = list(join_path)
|
||||||
|
last_name = names[-1]
|
||||||
|
is_collection = False
|
||||||
|
for nm in names:
|
||||||
|
rel_attr = getattr(cur_cls, nm)
|
||||||
|
prop = rel_attr.property
|
||||||
|
cur_cls = prop.mapper.class_
|
||||||
|
is_collection = bool(getattr(getattr(self.model, last_name), "property", None)
|
||||||
|
and getattr(getattr(self.model, last_name).property, "uselist", False))
|
||||||
|
except Exception:
|
||||||
|
is_collection = False
|
||||||
|
|
||||||
|
if is_collection:
|
||||||
|
target_cls = cur_cls
|
||||||
|
key = getattr(col, "key", None) or getattr(col, "name", None)
|
||||||
|
if key and hasattr(target_cls, key):
|
||||||
|
target_tbl = getattr(target_cls, "__table__", None)
|
||||||
|
if target_tbl is not None:
|
||||||
|
exprs.append(CollPred(table=target_tbl, col_key=key, op=op, value=value))
|
||||||
|
continue
|
||||||
|
|
||||||
exprs.append(OPERATORS[op](col, value))
|
exprs.append(OPERATORS[op](col, value))
|
||||||
|
|
||||||
if not exprs:
|
if not exprs:
|
||||||
return None
|
return None
|
||||||
if len(exprs) == 1:
|
|
||||||
|
# If any CollPred is in exprs, do NOT or_ them. Keep it single for now.
|
||||||
|
if any(isinstance(x, CollPred) for x in exprs):
|
||||||
|
# If someone used a pipe 'relA.col|relB.col' that produced multiple CollPreds,
|
||||||
|
# keep the first or raise for now (your choice).
|
||||||
|
if len(exprs) > 1:
|
||||||
|
# raise NotImplementedError("OR across collection paths not supported yet")
|
||||||
|
exprs = [next(x for x in exprs if isinstance(x, CollPred))]
|
||||||
return exprs[0]
|
return exprs[0]
|
||||||
return or_(*exprs)
|
|
||||||
|
# Otherwise, standard SQLA clause(s)
|
||||||
|
return exprs[0] if len(exprs) == 1 else or_(*exprs)
|
||||||
|
|
||||||
def _collect_filters(self, params: dict) -> list:
|
def _collect_filters(self, params: dict) -> list:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue