From 87cb686c64637d5c83a87d21e700092098954489 Mon Sep 17 00:00:00 2001 From: Yaro Kasear Date: Fri, 10 Oct 2025 09:01:45 -0500 Subject: [PATCH] Finally got that bug! --- crudkit/core/service.py | 156 ++++++++++++++++++---------------------- crudkit/core/spec.py | 46 +++++++++++- 2 files changed, 112 insertions(+), 90 deletions(-) diff --git a/crudkit/core/service.py b/crudkit/core/service.py index c4a863f..fbad3a1 100644 --- a/crudkit/core/service.py +++ b/crudkit/core/service.py @@ -4,16 +4,16 @@ from collections.abc import Iterable from dataclasses import dataclass from flask import current_app from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast -from sqlalchemy import and_, func, inspect, or_, text +from sqlalchemy import and_, func, inspect, or_, text, select, literal from sqlalchemy.engine import Engine, Connection -from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, selectinload, with_loader_criteria +from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, selectinload, with_loader_criteria, aliased, with_parent from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.sql import operators, visitors from sqlalchemy.sql.elements import UnaryExpression, ColumnElement from crudkit.core import to_jsonable, deep_diff, diff_to_patch, filter_to_columns, normalize_payload from crudkit.core.base import Version -from crudkit.core.spec import CRUDSpec +from crudkit.core.spec import CRUDSpec, CollPred from crudkit.core.types import OrderSpec, SeekWindow from crudkit.backend import BackendInfo, make_backend_info from crudkit.projection import compile_projection @@ -516,107 +516,89 @@ class CRUDService(Generic[T]): return [visitors.replacement_traverse(f, {}, replace) for f in filters] def _final_filters(self, root_alias, plan): - """Return filters rebounded to our first-hop aliases, with first-hop collection - predicates rewritten to EXISTS via rel.any(...).""" + """ + Return filters where: + - root/to-one predicates are kept as SQLAlchemy expressions. + - first-hop collection predicates (CollPred) are rebuilt into a single + EXISTS via rel.any(...) with one alias per collection table. + """ filters = list(plan.filters or []) if not filters: return [] - # 1) Build alias map for first-hop targets we joined (to-one) - alias_map = {} - coll_map = {} # KEY CHANGE: table -> (rel_attr, target_cls) + # 1) Build a map of first-hop relationships: TABLE -> (rel_attr, target_cls) + coll_map = {} for base_alias, rel_attr, target_alias in plan.join_paths: if base_alias is not root_alias: continue prop = getattr(rel_attr, "property", None) - if prop is None: + if not prop or not getattr(prop, "uselist", False): continue + target_cls = prop.mapper.class_ + tbl = getattr(target_cls, "__table__", None) + if tbl is not None: + coll_map[tbl] = (rel_attr, target_cls) - # Try to capture a selectable for to-one rebinds (nice-to-have) - sel = getattr(target_alias, "selectable", None) - if sel is not None: - alias_map[frozenset(_selectable_keys(sel))] = sel - - # Always build a collection map keyed by the mapped table (no alias needed) - if bool(getattr(prop, "uselist", False)): - target_cls = prop.mapper.class_ - tbl = getattr(target_cls, "__table__", None) - if tbl is not None: - coll_map[tbl] = (rel_attr, target_cls) - print(f"STAGE 1 - alias_map = {alias_map}, coll_map={coll_map}") - - # 2) Rebind to-one columns to the exact alias objects we JOINed (if we have them) - if alias_map: - def _rebind(elem): - tbl = getattr(elem, "table", None) - if tbl is None: - return elem - keyset = frozenset(_selectable_keys(tbl)) - new_sel = alias_map.get(keyset) - if new_sel is None or new_sel is tbl: - return elem - colkey = getattr(elem, "key", None) or getattr(elem, "name", None) - if not colkey: - return elem - try: - return getattr(new_sel.c, colkey) - except Exception: - return elem - filters = [visitors.replacement_traverse(f, {}, _rebind) for f in filters] - print(f"STAGE 2 - filters = {filters}") - - # 3) If there are no collection filters, we’re done - if not coll_map: - print("STAGE 3 - No, I have determined there are no collections to handle like a bad girl.") - return filters - print("STAGE 3 - Yes, I have determined there are collections to handle like a good boy.") - - # 4) Group any filters that reference a first-hop collection TABLE - keep = [] - per_coll = {} # table -> [expr, ...] + # 2) Split raw filters into normal SQLA and CollPreds (by target table) + normal_filters = [] + by_table: dict[Any, list[CollPred]] = {} for f in filters: - touched_tbl = None - def _find(elem): - nonlocal touched_tbl - tbl = getattr(elem, "table", None) - if tbl is None: - return - # normalize alias -> base table - base_tbl = tbl - while getattr(base_tbl, "element", None) is not None: - base_tbl = getattr(base_tbl, "element") - if base_tbl in coll_map and touched_tbl is None: - touched_tbl = base_tbl - visitors.traverse(f, {}, {'column': _find}) - - if touched_tbl is None: - keep.append(f) + if isinstance(f, CollPred): + by_table.setdefault(f.table, []).append(f) else: - per_coll.setdefault(touched_tbl, []).append(f) - print(f"STAGE 4 - keep = {keep}, per_coll = {per_coll}") + normal_filters.append(f) - # 5) For each collection, remap columns to mapped class attrs and wrap with .any(and_(...)) - for tbl, exprs in per_coll.items(): + # 3) Rebuild each table group into ONE .any(...) using one alias + from sqlalchemy.orm import aliased + from sqlalchemy import and_ + + exists_filters = [] + for tbl, preds in by_table.items(): + if tbl not in coll_map: + # Safety: if it's not a first-hop collection, ignore or raise + continue rel_attr, target_cls = coll_map[tbl] + ta = aliased(target_cls) - def _to_model(elem): - etbl = getattr(elem, "table", None) - if etbl is not None: - # normalize alias -> base table - etbl_base = etbl - while getattr(etbl_base, "element", None) is not None: - etbl_base = getattr(etbl_base, "element") - if etbl_base is tbl: - key = getattr(elem, "key", None) or getattr(elem, "name", None) - if key and hasattr(target_cls, key): - return getattr(target_cls, key) - return elem + built = [] + for p in preds: + col = getattr(ta, p.col_key) + op = p.op + val = p.value + if op == 'icontains': + built.append(col.ilike(f"%{val}%")) + elif op == 'eq': + built.append(col == val) + elif op == 'ne': + built.append(col != val) + elif op == 'in': + vs = val if isinstance(val, (list, tuple, set)) else [val] + built.append(col.in_(vs)) + elif op == 'nin': + vs = val if isinstance(val, (list, tuple, set)) else [val] + built.append(~col.in_(vs)) + elif op == 'lt': + built.append(col < val) + elif op == 'lte': + built.append(col <= val) + elif op == 'gt': + built.append(col > val) + elif op == 'gte': + built.append(col >= val) + else: + # unknown op — skip or raise + continue - remapped = [visitors.replacement_traverse(e, {}, _to_model) for e in exprs] - keep.append(rel_attr.any(and_(*remapped))) - print(f"STAGE 5 - keep={keep}") + # enforce child soft delete inside the EXISTS + if hasattr(target_cls, "is_deleted"): + built.append(ta.is_deleted == False) - return keep + crit = and_(*built) if built else None + exists_filters.append(rel_attr.of_type(ta).any(crit) if crit is not None + else rel_attr.of_type(ta).any()) + + # 4) Final filter list = normal SQLA filters + all EXISTS filters + return normal_filters + exists_filters # ---- public read ops diff --git a/crudkit/core/spec.py b/crudkit/core/spec.py index 4ec972f..bfd5f11 100644 --- a/crudkit/core/spec.py +++ b/crudkit/core/spec.py @@ -1,9 +1,17 @@ +from dataclasses import dataclass from typing import Any, List, Tuple, Set, Dict, Optional, Iterable from sqlalchemy import and_, asc, desc, or_ from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import aliased, selectinload from sqlalchemy.orm.attributes import InstrumentedAttribute +@dataclass(frozen=True) +class CollPred: + table: Any + col_key: str + op: str + value: Any + OPERATORS = { 'eq': lambda col, val: col == val, 'lt': lambda col, val: col < val, @@ -68,16 +76,48 @@ class CRUDSpec: exprs = [] for col, join_path in pairs: - # Track eager path for each involved relationship chain if join_path: self.eager_paths.add(join_path) + + try: + cur_cls = self.model + names = list(join_path) + last_name = names[-1] + is_collection = False + for nm in names: + rel_attr = getattr(cur_cls, nm) + prop = rel_attr.property + cur_cls = prop.mapper.class_ + is_collection = bool(getattr(getattr(self.model, last_name), "property", None) + and getattr(getattr(self.model, last_name).property, "uselist", False)) + except Exception: + is_collection = False + + if is_collection: + target_cls = cur_cls + key = getattr(col, "key", None) or getattr(col, "name", None) + if key and hasattr(target_cls, key): + target_tbl = getattr(target_cls, "__table__", None) + if target_tbl is not None: + exprs.append(CollPred(table=target_tbl, col_key=key, op=op, value=value)) + continue + exprs.append(OPERATORS[op](col, value)) if not exprs: return None - if len(exprs) == 1: + + # If any CollPred is in exprs, do NOT or_ them. Keep it single for now. + if any(isinstance(x, CollPred) for x in exprs): + # If someone used a pipe 'relA.col|relB.col' that produced multiple CollPreds, + # keep the first or raise for now (your choice). + if len(exprs) > 1: + # raise NotImplementedError("OR across collection paths not supported yet") + exprs = [next(x for x in exprs if isinstance(x, CollPred))] return exprs[0] - return or_(*exprs) + + # Otherwise, standard SQLA clause(s) + return exprs[0] if len(exprs) == 1 else or_(*exprs) def _collect_filters(self, params: dict) -> list: """