From 2ad327fcd94006305afd7bfd83259a8fa76e6382 Mon Sep 17 00:00:00 2001 From: Yaro Kasear Date: Mon, 6 Oct 2025 14:36:08 -0500 Subject: [PATCH] Refactored the service to be less painful and redundant. --- crudkit/core/service.py | 793 +++++++++++++--------------------------- 1 file changed, 264 insertions(+), 529 deletions(-) diff --git a/crudkit/core/service.py b/crudkit/core/service.py index 48e0e56..753d90f 100644 --- a/crudkit/core/service.py +++ b/crudkit/core/service.py @@ -1,15 +1,15 @@ from __future__ import annotations from collections.abc import Iterable +from dataclasses import dataclass from flask import current_app from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast from sqlalchemy import and_, func, inspect, or_, text from sqlalchemy.engine import Engine, Connection -from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager, selectinload +from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, selectinload from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.sql import operators from sqlalchemy.sql.elements import UnaryExpression, ColumnElement -from sqlalchemy.sql.visitors import Visitable from crudkit.core import to_jsonable, deep_diff, diff_to_patch, filter_to_columns, normalize_payload from crudkit.core.base import Version @@ -38,80 +38,68 @@ class _SoftDeletable(Protocol): is_deleted: bool class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol): - """Minimal surface that our CRUD service relies on. Soft-delete is optional.""" + """Surface expected by CRUDService.""" pass T = TypeVar("T", bound=_CRUDModelProto) +# ---------------------------- utilities ---------------------------- + def _collect_tables_from_filters(filters) -> set: - """ - Walk SQLAlchemy filter expressions and collect the Table/Alias objects - that appear, so we can detect which relationships are *actually used* - by filters and must be JOINed (not just selectinloaded). - """ + """Walk SQLA expressions to collect Table/Alias objects that appear in filters.""" seen = set() + def visit(node): if node is None: return - # record table / selectable if present tbl = getattr(node, "table", None) if tbl is not None: - # include the selectable and its base element (alias -> base table) cur = tbl while cur is not None: seen.add(cur) cur = getattr(cur, "element", None) - # generic children walker - try: - children = list(node.get_children()) - except Exception: - children = [] - for ch in children: - visit(ch) - # also inspect common attributes if present + for attr in ("get_children",): + fn = getattr(node, attr, None) + if fn: + for ch in fn(): + visit(ch) for attr in ("left", "right", "element", "clause", "clauses"): val = getattr(node, attr, None) if val is None: continue if isinstance(val, (list, tuple)): - for v in val: - visit(v) + for v in val: visit(v) else: visit(val) + for f in (filters or []): visit(f) return seen def _unwrap_ob(ob): - """Return (col, is_desc) from an ORDER BY element (handles .asc()/.desc()).""" - col = getattr(ob, "element", None) - if col is None: - col = ob - is_desc = False - dir_attr = getattr(ob, "_direction", None) - if dir_attr is not None: - is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC") + elem = getattr(ob, "element", None) + col = elem if elem is not None else ob + + d = getattr(ob, "_direction", None) + if d is not None: + is_desc = (d is operators.desc_op) or (getattr(d, "name", "").upper() == "DESC") elif isinstance(ob, UnaryExpression): op = getattr(ob, "operator", None) is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") + else: + is_desc = False return col, bool(is_desc) def _order_identity(col: ColumnElement): - """ - Build a stable identity for a column suitable for deduping. - We ignore direction here. Duplicates are duplicates regardless of ASC/DESC. - """ table = getattr(col, "table", None) table_key = getattr(table, "key", None) or id(table) col_key = getattr(col, "key", None) or getattr(col, "name", None) return (table_key, col_key) def _dedupe_order_by(order_by): - """Remove duplicate ORDER BY entries (by column identity, ignoring direction).""" if not order_by: return [] - seen = set() - out = [] + seen, out = set(), [] for ob in order_by: col, _ = _unwrap_ob(ob) ident = _order_identity(col) @@ -138,6 +126,8 @@ def _normalize_fields_param(params: dict | None) -> list[str]: return [p for p in (s.strip() for s in raw.split(",")) if p] return [] +# ---------------------------- CRUD service ---------------------------- + class CRUDService(Generic[T]): def __init__( self, @@ -151,21 +141,19 @@ class CRUDService(Generic[T]): self._session_factory = session_factory self.polymorphic = polymorphic self.supports_soft_delete = hasattr(model, 'is_deleted') - self._backend: Optional[BackendInfo] = backend + # ---- infra + @property def session(self) -> Session: - """Always return the Flask-scoped Session if available; otherwise the provided factory.""" try: - sess = current_app.extensions["crudkit"]["Session"] - return sess + return current_app.extensions["crudkit"]["Session"] except Exception: return self._session_factory() @property def backend(self) -> BackendInfo: - """Resolve backend info lazily against the active session's engine.""" if self._backend is None: bind = self.session.get_bind() eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind) @@ -178,276 +166,13 @@ class CRUDService(Generic[T]): return self.session.query(poly), poly return self.session.query(self.model), self.model - def _debug_bind(self, where: str): - try: - bind = self.session.get_bind() - eng = getattr(bind, "engine", bind) - print(f"SERVICE BIND [{where}]: engine_id={id(eng)} url={getattr(eng, 'url', '?')} session={type(self.session).__name__}") - except Exception as e: - print(f"SERVICE BIND [{where}]: failed to introspect bind: {e}") + # ---- common building blocks - def _apply_not_deleted(self, query, root_alias, params) -> Any: + def _apply_not_deleted(self, query, root_alias, params): if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")): return query.filter(getattr(root_alias, "is_deleted") == False) return query - def _extract_order_spec(self, root_alias, given_order_by): - """ - SQLAlchemy 2.x only: - Normalize order_by into (cols, desc_flags). Supports plain columns and - col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses. - """ - - given = self._stable_order_by(root_alias, given_order_by) - - cols, desc_flags = [], [] - for ob in given: - # Unwrap column if this is a UnaryExpression produced by .asc()/.desc() - elem = getattr(ob, "element", None) - col = elem if elem is not None else ob - - # Detect direction in SA 2.x - is_desc = False - dir_attr = getattr(ob, "_direction", None) - if dir_attr is not None: - is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC") - elif isinstance(ob, UnaryExpression): - op = getattr(ob, "operator", None) - is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") - - cols.append(col) - desc_flags.append(bool(is_desc)) - - return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags)) - - def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool): - if not key_vals: - return None - - conds = [] - for i, col in enumerate(spec.cols): - # If NULLs possible, normalize for comparison. Example using coalesce to a sentinel: - # sent_col = func.coalesce(col, literal("-∞")) - sent_col = col - ties = [spec.cols[j] == key_vals[j] for j in range(i)] - is_desc = spec.desc[i] - if not backward: - op = (sent_col < key_vals[i]) if is_desc else (sent_col > key_vals[i]) - else: - op = (sent_col > key_vals[i]) if is_desc else (sent_col < key_vals[i]) - conds.append(and_(*ties, op)) - return or_(*conds) - - def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]: - out = [] - for c in spec.cols: - # Only simple mapped columns supported for key pluck - key = getattr(c, "key", None) or getattr(c, "name", None) - if key is None or not hasattr(obj, key): - raise ValueError("Order includes non-mapped or related column; cannot pluck cursor key from row object.") - out.append(getattr(obj, key)) - return out - - def seek_window( - self, - params: dict | None = None, - *, - key: list[Any] | None = None, - backward: bool = False, - include_total: bool = True, - ) -> "SeekWindow[T]": - """ - Keyset pagination with relationship-safe filtering/sorting. - Always JOIN all CRUDSpec-discovered paths first; then apply filters, sort, seek. - """ - session = self.session - query, root_alias = self.get_query() - - # Requested fields → projection + optional loaders - fields = _normalize_fields_param(params) - expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], []) - - spec = CRUDSpec(self.model, params or {}, root_alias) - - # Parse all inputs so join_paths are populated - filters = spec.parse_filters() - order_by = spec.parse_sort() - root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() - spec.parse_includes() - join_paths = tuple(spec.get_join_paths()) - - # Soft delete - query = self._apply_not_deleted(query, root_alias, params) - - # Which related tables are referenced by filters? - filter_tables = _collect_tables_from_filters(filters) - - # Root column projection (load_only) - only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] - if only_cols: - query = query.options(Load(root_alias).load_only(*only_cols)) - - # Detect first hops that have deeper, nested tails requested (e.g. "contact.supervisor") - nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } - - # First-hop strategy: - # - If a non-collection relation's target table is used in filters -> plain JOIN (no explicit alias) - # - Otherwise -> selectinload (keeps row counts sane) - used_contains_eager = False # we purposely avoid contains_eager here - for base_alias, rel_attr, target_alias in join_paths: - is_firsthop_from_root = (base_alias is root_alias) - if not is_firsthop_from_root: - # deeper hops handled separately (via proj_opts) - continue - prop = getattr(rel_attr, "property", None) - is_collection = bool(getattr(prop, "uselist", False)) - _is_nested_firsthop = rel_attr.key in nested_first_hops - - target_selectable = getattr(target_alias, "selectable", None) - target_base = getattr(target_selectable, "element", None) or target_selectable - needed_for_filter = (target_selectable in filter_tables) or (target_base in filter_tables) - - if needed_for_filter and not is_collection: - # Join via the relationship attribute so SA reuses the same FROM - # that filter expressions reference (avoids cartesian products). - query = query.join(rel_attr, isouter=True) - else: - opt = selectinload(rel_attr) - if is_collection: - child_names = (collection_field_names or {}).get(rel_attr.key, []) - if child_names: - target_cls = prop.mapper.class_ - cols = [getattr(target_cls, n, None) for n in child_names] - cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] - if cols: - opt = opt.load_only(*cols) - query = query.options(opt) - - # Filters AFTER relations wiring → no cartesian products - if filters: - query = query.filter(*filters) - - # Apply deep projection loader options (only when we didn't use contains_eager) - if proj_opts and not used_contains_eager: - query = query.options(*proj_opts) - - # Order spec (with PK tie-breakers for stability) - order_spec = self._extract_order_spec(root_alias, order_by) - limit, _ = spec.parse_pagination() - if limit is None: - effective_limit = 50 - elif limit == 0: - effective_limit = None # unlimited - else: - effective_limit = limit - - # Seek predicate from cursor key (if any) - if key: - pred = self._key_predicate(order_spec, key, backward) - if pred is not None: - query = query.filter(pred) - - # Apply ORDER and LIMIT. Backward is SQL-inverted + reverse in-memory. - if not backward: - clauses = [(c.desc() if is_desc else c.asc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)] - query = query.order_by(*clauses) - if effective_limit is not None: - query = query.limit(effective_limit) - items = query.all() - else: - inv_clauses = [(c.asc() if is_desc else c.desc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)] - query = query.order_by(*inv_clauses) - if effective_limit is not None: - query = query.limit(effective_limit) - items = list(reversed(query.all())) - - # Projection meta tag for renderers - if fields: - proj = list(dict.fromkeys(fields)) - if "id" not in proj and hasattr(self.model, "id"): - proj.insert(0, "id") - else: - proj = [] - if root_field_names: - proj.extend(root_field_names) - if root_fields: - proj.extend(c.key for c in root_fields if hasattr(c, "key")) - for path, names in (rel_field_names or {}).items(): - prefix = ".".join(path) - for n in names: - proj.append(f"{prefix}.{n}") - if proj and "id" not in proj and hasattr(self.model, "id"): - proj.insert(0, "id") - - if proj: - for obj in items: - try: - setattr(obj, "__crudkit_projection__", tuple(proj)) - except Exception: - pass - - # Cursor key pluck: support related columns we hydrated via contains_eager - def _pluck_key_from_obj(obj: Any) -> list[Any]: - vals: list[Any] = [] - alias_to_rel: dict[Any, str] = {} - for _p, relationship_attr, target_alias in join_paths: - sel = getattr(target_alias, "selectable", None) - if sel is not None: - alias_to_rel[sel] = relationship_attr.key - - for col in order_spec.cols: - keyname = getattr(col, "key", None) or getattr(col, "name", None) - if keyname and hasattr(obj, keyname): - vals.append(getattr(obj, keyname)) - continue - table = getattr(col, "table", None) - relname = alias_to_rel.get(table) - if relname and keyname: - relobj = getattr(obj, relname, None) - if relobj is not None and hasattr(relobj, keyname): - vals.append(getattr(relobj, keyname)) - continue - raise ValueError("unpluckable") - return vals - - try: - first_key = _pluck_key_from_obj(items[0]) if items else None - last_key = _pluck_key_from_obj(items[-1]) if items else None - except Exception: - first_key = None - last_key = None - - # Count DISTINCT ids with mirrored joins - total = None - if include_total: - base = session.query(getattr(root_alias, "id")) - base = self._apply_not_deleted(base, root_alias, params) - # same joins as above for correctness - for base_alias, rel_attr, target_alias in join_paths: - # do not join collections for COUNT mirror - if not bool(getattr(getattr(rel_attr, "property", None), "uselist", False)): - base = base.join(rel_attr, isouter=True) - if filters: - base = base.filter(*filters) - total = session.query(func.count()).select_from( - base.order_by(None).distinct().subquery() - ).scalar() or 0 - - window_limit_for_body = 0 if effective_limit is None and (limit == 0) else (effective_limit or 50) - - if log.isEnabledFor(logging.DEBUG): - log.debug("QUERY: %s", str(query)) - - return SeekWindow( - items=items, - limit=window_limit_for_body, - first_key=first_key, - last_key=last_key, - order=order_spec, - total=total, - ) - - # Helper: default ORDER BY for MSSQL when paginating without explicit order def _default_order_by(self, root_alias): mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) cols = [] @@ -459,86 +184,114 @@ class CRUDService(Generic[T]): return cols or [text("1")] def _stable_order_by(self, root_alias, given_order_by): - """ - Ensure deterministic ordering by appending PK columns as tiebreakers. - If no order is provided, fall back to default primary-key order. - Never duplicates columns in ORDER BY (SQL Server requires uniqueness). - """ order_by = list(given_order_by or []) if not order_by: return _dedupe_order_by(self._default_order_by(root_alias)) - order_by = _dedupe_order_by(order_by) - mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) present = {_order_identity(_unwrap_ob(ob)[0]) for ob in order_by} - for pk in mapper.primary_key: try: pk_col = getattr(root_alias, pk.key) except AttributeError: pk_col = pk - if _order_identity(pk_col) not in present: + ident = _order_identity(pk_col) + if ident not in present: order_by.append(pk_col.asc()) - present.add(_order_identity(pk_col)) - + present.add(ident) return order_by - def get(self, id: int, params=None) -> T | None: - """ - Fetch a single row by id with conflict-free eager loading and clean projection. - Always JOIN any paths that CRUDSpec resolved for filters/fields/includes so - related-column filters never create cartesian products. - """ - query, root_alias = self.get_query() + def _extract_order_spec(self, root_alias, given_order_by): + given = self._stable_order_by(root_alias, given_order_by) + cols, desc_flags = [], [] + for ob in given: + elem = getattr(ob, "element", None) + col = elem if elem is not None else ob + is_desc = False + dir_attr = getattr(ob, "_direction", None) + if dir_attr is not None: + is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC") + elif isinstance(ob, UnaryExpression): + op = getattr(ob, "operator", None) + is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") + cols.append(col) + desc_flags.append(bool(is_desc)) + return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags)) - # Defaults so we can build a projection even if params is None - req_fields: list[str] = _normalize_fields_param(params) - root_fields: list[Any] = [] - root_field_names: dict[str, str] = {} - rel_field_names: dict[tuple[str, ...], list[str]] = {} + def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool): + if not key_vals: + return None + conds = [] + for i, col in enumerate(spec.cols): + ties = [spec.cols[j] == key_vals[j] for j in range(i)] + is_desc = spec.desc[i] + op = (col < key_vals[i]) if is_desc ^ backward else (col > key_vals[i]) + conds.append(and_(*ties, op)) + return or_(*conds) - # Soft-delete guard first - query = self._apply_not_deleted(query, root_alias, params) + # ---- planning and application + @dataclass(slots=True) + class _Plan: + spec: Any + filters: Any + order_by: Any + limit: Any + offset: Any + root_fields: Any + rel_field_names: Any + root_field_names: Any + collection_field_names: Any + join_paths: Any + filter_tables: Any + req_fields: Any + proj_opts: Any + + def _plan(self, params, root_alias) -> _Plan: + req_fields = _normalize_fields_param(params) spec = CRUDSpec(self.model, params or {}, root_alias) - # Parse everything so CRUDSpec records any join paths it needed to resolve filters = spec.parse_filters() - # no ORDER BY for get() - if params: - root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() + order_by = spec.parse_sort() + limit, offset = spec.parse_pagination() + root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() if params else ([], {}, {}, {}) spec.parse_includes() - join_paths = tuple(spec.get_join_paths()) + filter_tables = _collect_tables_from_filters(filters) + _, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) - # Root-column projection (load_only) - only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] - if only_cols: - query = query.options(Load(root_alias).load_only(*only_cols)) + return self._Plan( + spec=spec, filters=filters, order_by=order_by, limit=limit, offset=offset, + root_fields=root_fields, rel_field_names=rel_field_names, + root_field_names=root_field_names, collection_field_names=collection_field_names, + join_paths=join_paths, filter_tables=filter_tables, + req_fields=req_fields, proj_opts=proj_opts + ) - nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } + def _apply_projection_load_only(self, query, root_alias, plan: _Plan): + only_cols = [c for c in plan.root_fields if isinstance(c, InstrumentedAttribute)] + return query.options(Load(root_alias).load_only(*only_cols)) if only_cols else query - # First-hop strategy (same as peek_window): join if filter needs it, else selectinload - use_contains_eager = False # we avoid contains_eager here, too - for base_alias, rel_attr, target_alias in join_paths: - is_firsthop_from_root = (base_alias is root_alias) - if not is_firsthop_from_root: + def _apply_firsthop_strategies(self, query, root_alias, plan: _Plan): + nested_first_hops = { p[0] for p in (plan.rel_field_names or {}).keys() if len(p) > 1 } + for base_alias, rel_attr, target_alias in plan.join_paths: + if base_alias is not root_alias: continue prop = getattr(rel_attr, "property", None) is_collection = bool(getattr(prop, "uselist", False)) - _is_nested_firsthop = rel_attr.key in nested_first_hops - target_selectable = getattr(target_alias, "selectable", None) - target_base = getattr(target_selectable, "element", None) or target_selectable - needed_for_filter = target_selectable in _collect_tables_from_filters(filters) or (target_base in _collect_tables_from_filters(filters)) + sel = getattr(target_alias, "selectable", None) + sel_elem = getattr(sel, "element", None) + base_sel = sel_elem if sel_elem is not None else sel + + needed_for_filter = (sel in plan.filter_tables) or (base_sel in plan.filter_tables) if needed_for_filter and not is_collection: query = query.join(rel_attr, isouter=True) else: opt = selectinload(rel_attr) if is_collection: - child_names = (collection_field_names or {}).get(rel_attr.key, []) + child_names = (plan.collection_field_names or {}).get(rel_attr.key, []) if child_names: target_cls = prop.mapper.class_ cols = [getattr(target_cls, n, None) for n in child_names] @@ -546,186 +299,194 @@ class CRUDService(Generic[T]): if cols: opt = opt.load_only(*cols) query = query.options(opt) + return query - # Apply filters (joins are in place → no cartesian products) - if filters: - query = query.filter(*filters) + def _apply_proj_opts(self, query, plan: _Plan): + return query.options(*plan.proj_opts) if plan.proj_opts else query - # And the id filter - query = query.filter(getattr(root_alias, "id") == id) + def _projection_meta(self, plan: _Plan): + if plan.req_fields: + proj = list(dict.fromkeys(plan.req_fields)) + return ["id"] + proj if "id" not in proj and hasattr(self.model, "id") else proj - # Projection loader options compiled from requested fields. - # Skip if we used contains_eager to avoid loader-strategy conflicts. - expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) - if proj_opts and not use_contains_eager: - query = query.options(*proj_opts) + proj: list[str] = [] + if plan.root_field_names: + proj.extend(plan.root_field_names) + if plan.root_fields: + proj.extend(c.key for c in plan.root_fields if hasattr(c, "key")) + for path, names in (plan.rel_field_names or {}).items(): + prefix = ".".join(path) + for n in names: + proj.append(f"{prefix}.{n}") + if proj and "id" not in proj and hasattr(self.model, "id"): + proj.insert(0, "id") + return proj - obj = query.first() - - # Emit exactly what the client requested (plus id), or a reasonable fallback - if req_fields: - proj = list(dict.fromkeys(req_fields)) - if "id" not in proj and hasattr(self.model, "id"): - proj.insert(0, "id") - else: - proj = [] - if root_field_names: - proj.extend(root_field_names) - if root_fields: - proj.extend(c.key for c in root_fields if hasattr(c, "key")) - for path, names in (rel_field_names or {}).items(): - prefix = ".".join(path) - for n in names: - proj.append(f"{prefix}.{n}") - if proj and "id" not in proj and hasattr(self.model, "id"): - proj.insert(0, "id") - - if proj and obj is not None: + def _tag_projection(self, items, proj): + if not proj: + return + for obj in items if isinstance(items, list) else [items]: try: setattr(obj, "__crudkit_projection__", tuple(proj)) except Exception: pass + # ---- public read ops + + def seek_window( + self, + params: dict | None = None, + *, + key: list[Any] | None = None, + backward: bool = False, + include_total: bool = True, + ) -> "SeekWindow[T]": + session = self.session + query, root_alias = self.get_query() + query = self._apply_not_deleted(query, root_alias, params) + + plan = self._plan(params, root_alias) + query = self._apply_projection_load_only(query, root_alias, plan) + query = self._apply_firsthop_strategies(query, root_alias, plan) + if plan.filters: + query = query.filter(*plan.filters) + + order_spec = self._extract_order_spec(root_alias, plan.order_by) + limit = 50 if plan.limit is None else (None if plan.limit == 0 else plan.limit) + + if key: + pred = self._key_predicate(order_spec, key, backward) + if pred is not None: + query = query.filter(pred) + + clauses = [(c.desc() if d else c.asc()) for c, d in zip(order_spec.cols, order_spec.desc)] + if backward: + clauses = [(c.asc() if d else c.desc()) for c, d in zip(order_spec.cols, order_spec.desc)] + query = query.order_by(*clauses) + if limit is not None: + query = query.limit(limit) + + query = self._apply_proj_opts(query, plan) + rows = query.all() + items = list(reversed(rows)) if backward else rows + + proj = self._projection_meta(plan) + self._tag_projection(items, proj) + + # cursor keys + def pluck(obj): + vals = [] + alias_to_rel = {} + for _p, rel_attr, target_alias in plan.join_paths: + sel = getattr(target_alias, "selectable", None) + if sel is not None: + alias_to_rel[sel] = rel_attr.key + + for col in order_spec.cols: + keyname = getattr(col, "key", None) or getattr(col, "name", None) + if keyname and hasattr(obj, keyname): + vals.append(getattr(obj, keyname)); continue + table = getattr(col, "table", None) + relname = alias_to_rel.get(table) + if relname and keyname: + relobj = getattr(obj, relname, None) + if relobj is not None and hasattr(relobj, keyname): + vals.append(getattr(relobj, keyname)); continue + raise ValueError("unpluckable") + return vals + + try: + first_key = pluck(items[0]) if items else None + last_key = pluck(items[-1]) if items else None + except Exception: + first_key = last_key = None + + total = None + if include_total: + base = session.query(getattr(root_alias, "id")) + base = self._apply_not_deleted(base, root_alias, params) + for _b, rel_attr, target_alias in plan.join_paths: + if not bool(getattr(getattr(rel_attr, "property", None), "uselist", False)): + base = base.join(rel_attr, isouter=True) + if plan.filters: + base = base.filter(*plan.filters) + total = session.query(func.count()).select_from( + base.order_by(None).distinct().subquery() + ).scalar() or 0 + if log.isEnabledFor(logging.DEBUG): log.debug("QUERY: %s", str(query)) + window_limit_for_body = 0 if limit is None and (plan.limit == 0) else (limit or 50) + return SeekWindow( + items=items, + limit=window_limit_for_body, + first_key=first_key, + last_key=last_key, + order=order_spec, + total=total, + ) + + def get(self, id: int, params=None) -> T | None: + query, root_alias = self.get_query() + query = self._apply_not_deleted(query, root_alias, params) + + plan = self._plan(params, root_alias) + query = self._apply_projection_load_only(query, root_alias, plan) + query = self._apply_firsthop_strategies(query, root_alias, plan) + if plan.filters: + query = query.filter(*plan.filters) + query = query.filter(getattr(root_alias, "id") == id) + query = self._apply_proj_opts(query, plan) + + obj = query.first() + proj = self._projection_meta(plan) + if obj: + self._tag_projection(obj, proj) + + if log.isEnabledFor(logging.DEBUG): + log.debug("QUERY: %s", str(query)) return obj or None def list(self, params=None) -> list[T]: - """ - Offset/limit listing with relationship-safe filtering. - We always JOIN every CRUDSpec-discovered path before applying filters/sorts. - """ query, root_alias = self.get_query() - # Defaults so we can reference them later even if params is None - req_fields: list[str] = _normalize_fields_param(params) - root_fields: list[Any] = [] - root_field_names: dict[str, str] = {} - rel_field_names: dict[tuple[str, ...], list[str]] = {} + plan = self._plan(params, root_alias) + query = self._apply_not_deleted(query, root_alias, params) + query = self._apply_projection_load_only(query, root_alias, plan) + query = self._apply_firsthop_strategies(query, root_alias, plan) + if plan.filters: + query = query.filter(*plan.filters) - if params: - # Soft delete - query = self._apply_not_deleted(query, root_alias, params) + order_by = plan.order_by + paginating = (plan.limit is not None) or (plan.offset not in (None, 0)) + if paginating and not order_by and self.backend.requires_order_by_for_offset: + order_by = self._default_order_by(root_alias) + order_by = _dedupe_order_by(order_by) + if order_by: + query = query.order_by(*order_by) - spec = CRUDSpec(self.model, params or {}, root_alias) - filters = spec.parse_filters() - order_by = spec.parse_sort() - limit, offset = spec.parse_pagination() - - # Includes / fields (populates join_paths) - root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() - spec.parse_includes() - join_paths = tuple(spec.get_join_paths()) - - # Root column projection (load_only) - only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] - if only_cols: - query = query.options(Load(root_alias).load_only(*only_cols)) - - nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } - - # First-hop strategy (same as seek_window): join if filter needs it, else selectinload - used_contains_eager = False # We avoid contains_eager here as well - for base_alias, rel_attr, target_alias in join_paths: - is_firsthop_from_root = (base_alias is root_alias) - if not is_firsthop_from_root: - continue - prop = getattr(rel_attr, "property", None) - is_collection = bool(getattr(prop, "uselist", False)) - _is_nested_firsthop = rel_attr.key in nested_first_hops - - target_selectable = getattr(target_alias, "selectable", None) - target_base = getattr(target_selectable, "element", None) or target_selectable - ftables = _collect_tables_from_filters(filters) - needed_for_filter = (target_selectable in ftables) or (target_base in ftables) - - if needed_for_filter and not is_collection: - query = query.join(rel_attr, isouter=True) - else: - opt = selectinload(rel_attr) - if is_collection: - child_names = (collection_field_names or {}).get(rel_attr.key, []) - if child_names: - target_cls = prop.mapper.class_ - cols = [getattr(target_cls, n, None) for n in child_names] - cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] - if cols: - opt = opt.load_only(*cols) - query = query.options(opt) - - # Filters AFTER joins → no cartesian products - if filters: - query = query.filter(*filters) - - # MSSQL requires ORDER BY when OFFSET is used; ensure stable PK tie-breakers - paginating = (limit is not None) or (offset is not None and offset != 0) - if paginating and not order_by and self.backend.requires_order_by_for_offset: - order_by = self._default_order_by(root_alias) - - order_by = _dedupe_order_by(order_by) - - if order_by: - query = query.order_by(*order_by) - - # Offset/limit - if offset is not None and offset != 0: - query = query.offset(offset) - if limit is not None and limit > 0: - query = query.limit(limit) - - # Projection loaders only if we didn’t use contains_eager - expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) - if proj_opts and not used_contains_eager: - query = query.options(*proj_opts) - - else: - # No params; still honor projection loaders if any - expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) - if proj_opts: - query = query.options(*proj_opts) + if plan.offset: query = query.offset(plan.offset) + if plan.limit and plan.limit > 0: query = query.limit(plan.limit) + query = self._apply_proj_opts(query, plan) rows = query.all() - # Build projection meta for renderers - if req_fields: - proj = list(dict.fromkeys(req_fields)) - if "id" not in proj and hasattr(self.model, "id"): - proj.insert(0, "id") - else: - proj = [] - if root_field_names: - proj.extend(root_field_names) - if root_fields: - proj.extend(c.key for c in root_fields if hasattr(c, "key")) - for path, names in (rel_field_names or {}).items(): - prefix = ".".join(path) - for n in names: - proj.append(f"{prefix}.{n}") - if proj and "id" not in proj and hasattr(self.model, "id"): - proj.insert(0, "id") - - if proj: - for obj in rows: - try: - setattr(obj, "__crudkit_projection__", tuple(proj)) - except Exception: - pass + proj = self._projection_meta(plan) + self._tag_projection(rows, proj) if log.isEnabledFor(logging.DEBUG): log.debug("QUERY: %s", str(query)) - return rows + # ---- write ops + def create(self, data: dict, actor=None, *, commit: bool = True) -> T: session = self.session obj = self.model(**data) session.add(obj) - session.flush() - self._log_version("create", obj, actor, commit=commit) - if commit: session.commit() return obj @@ -737,59 +498,34 @@ class CRUDService(Generic[T]): raise ValueError(f"{self.model.__name__} with ID {id} not found.") before = obj.as_dict() - - # Normalize and restrict payload to real columns norm = normalize_payload(data, self.model) incoming = filter_to_columns(norm, self.model) - - # Build a synthetic "desired" state for top-level columns desired = {**before, **incoming} - # Compute intended change set (before vs intended) - proposed = deep_diff( - before, desired, - ignore_keys={"id", "created_at", "updated_at"}, - list_mode="index", - ) + proposed = deep_diff(before, desired, ignore_keys={"id", "created_at", "updated_at"}, list_mode="index") patch = diff_to_patch(proposed) - - # Nothing to do if not patch: return obj - # Apply only what actually changes for k, v in patch.items(): setattr(obj, k, v) - # Optional: skip commit if ORM says no real change (paranoid check) - # Note: is_modified can lie if attrs are expired; use history for certainty. dirty = any(inspect(obj).attrs[k].history.has_changes() for k in patch.keys()) if not dirty: return obj - # Commit atomically if commit: session.commit() - # AFTER snapshot for audit after = obj.as_dict() - - # Actual diff (captures triggers/defaults, still ignoring noisy keys) - actual = deep_diff( - before, after, - ignore_keys={"id", "created_at", "updated_at"}, - list_mode="index", - ) - - # If truly nothing changed post-commit (rare), skip version spam + actual = deep_diff(before, after, ignore_keys={"id", "created_at", "updated_at"}, list_mode="index") if not (actual["added"] or actual["removed"] or actual["changed"]): return obj - # Log both what we *intended* and what *actually* happened self._log_version("update", obj, actor, metadata={"diff": actual, "patch": patch}, commit=commit) return obj - def delete(self, id: int, hard: bool = False, actor = None, *, commit: bool = True): + def delete(self, id: int, hard: bool = False, actor=None, *, commit: bool = True): session = self.session obj = session.get(self.model, id) if not obj: @@ -797,22 +533,21 @@ class CRUDService(Generic[T]): if hard or not self.supports_soft_delete: session.delete(obj) else: - soft = cast(_SoftDeletable, obj) - soft.is_deleted = True + cast(_SoftDeletable, obj).is_deleted = True if commit: session.commit() self._log_version("delete", obj, actor, commit=commit) return obj + # ---- audit + def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict | None = None, *, commit: bool = True): session = self.session try: - snapshot = {} try: snapshot = obj.as_dict() except Exception: snapshot = {"error": "serialize failed"} - version = Version( model_name=self.model.__name__, object_id=obj.id,