diff --git a/crudkit/core/service.py b/crudkit/core/service.py index 753d90f..60e3aef 100644 --- a/crudkit/core/service.py +++ b/crudkit/core/service.py @@ -1,12 +1,11 @@ from __future__ import annotations from collections.abc import Iterable -from dataclasses import dataclass from flask import current_app from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast from sqlalchemy import and_, func, inspect, or_, text from sqlalchemy.engine import Engine, Connection -from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, selectinload +from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager, selectinload from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.sql import operators from sqlalchemy.sql.elements import UnaryExpression, ColumnElement @@ -38,68 +37,41 @@ class _SoftDeletable(Protocol): is_deleted: bool class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol): - """Surface expected by CRUDService.""" + """Minimal surface that our CRUD service relies on. Soft-delete is optional.""" pass T = TypeVar("T", bound=_CRUDModelProto) -# ---------------------------- utilities ---------------------------- - -def _collect_tables_from_filters(filters) -> set: - """Walk SQLA expressions to collect Table/Alias objects that appear in filters.""" - seen = set() - - def visit(node): - if node is None: - return - tbl = getattr(node, "table", None) - if tbl is not None: - cur = tbl - while cur is not None: - seen.add(cur) - cur = getattr(cur, "element", None) - for attr in ("get_children",): - fn = getattr(node, attr, None) - if fn: - for ch in fn(): - visit(ch) - for attr in ("left", "right", "element", "clause", "clauses"): - val = getattr(node, attr, None) - if val is None: - continue - if isinstance(val, (list, tuple)): - for v in val: visit(v) - else: - visit(val) - - for f in (filters or []): - visit(f) - return seen - def _unwrap_ob(ob): - elem = getattr(ob, "element", None) - col = elem if elem is not None else ob - - d = getattr(ob, "_direction", None) - if d is not None: - is_desc = (d is operators.desc_op) or (getattr(d, "name", "").upper() == "DESC") + """Return (col, is_desc) from an ORDER BY element (handles .asc()/.desc()).""" + col = getattr(ob, "element", None) + if col is None: + col = ob + is_desc = False + dir_attr = getattr(ob, "_direction", None) + if dir_attr is not None: + is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC") elif isinstance(ob, UnaryExpression): op = getattr(ob, "operator", None) is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") - else: - is_desc = False return col, bool(is_desc) def _order_identity(col: ColumnElement): + """ + Build a stable identity for a column suitable for deduping. + We ignore direction here. Duplicates are duplicates regardless of ASC/DESC. + """ table = getattr(col, "table", None) table_key = getattr(table, "key", None) or id(table) col_key = getattr(col, "key", None) or getattr(col, "name", None) return (table_key, col_key) def _dedupe_order_by(order_by): + """Remove duplicate ORDER BY entries (by column identity, ignoring direction).""" if not order_by: return [] - seen, out = set(), [] + seen = set() + out = [] for ob in order_by: col, _ = _unwrap_ob(ob) ident = _order_identity(col) @@ -126,8 +98,6 @@ def _normalize_fields_param(params: dict | None) -> list[str]: return [p for p in (s.strip() for s in raw.split(",")) if p] return [] -# ---------------------------- CRUD service ---------------------------- - class CRUDService(Generic[T]): def __init__( self, @@ -141,19 +111,21 @@ class CRUDService(Generic[T]): self._session_factory = session_factory self.polymorphic = polymorphic self.supports_soft_delete = hasattr(model, 'is_deleted') - self._backend: Optional[BackendInfo] = backend - # ---- infra + self._backend: Optional[BackendInfo] = backend @property def session(self) -> Session: + """Always return the Flask-scoped Session if available; otherwise the provided factory.""" try: - return current_app.extensions["crudkit"]["Session"] + sess = current_app.extensions["crudkit"]["Session"] + return sess except Exception: return self._session_factory() @property def backend(self) -> BackendInfo: + """Resolve backend info lazily against the active session's engine.""" if self._backend is None: bind = self.session.get_bind() eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind) @@ -166,13 +138,265 @@ class CRUDService(Generic[T]): return self.session.query(poly), poly return self.session.query(self.model), self.model - # ---- common building blocks + def _debug_bind(self, where: str): + try: + bind = self.session.get_bind() + eng = getattr(bind, "engine", bind) + print(f"SERVICE BIND [{where}]: engine_id={id(eng)} url={getattr(eng, 'url', '?')} session={type(self.session).__name__}") + except Exception as e: + print(f"SERVICE BIND [{where}]: failed to introspect bind: {e}") - def _apply_not_deleted(self, query, root_alias, params): + def _apply_not_deleted(self, query, root_alias, params) -> Any: if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")): return query.filter(getattr(root_alias, "is_deleted") == False) return query + def _extract_order_spec(self, root_alias, given_order_by): + """ + SQLAlchemy 2.x only: + Normalize order_by into (cols, desc_flags). Supports plain columns and + col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses. + """ + + given = self._stable_order_by(root_alias, given_order_by) + + cols, desc_flags = [], [] + for ob in given: + # Unwrap column if this is a UnaryExpression produced by .asc()/.desc() + elem = getattr(ob, "element", None) + col = elem if elem is not None else ob + + # Detect direction in SA 2.x + is_desc = False + dir_attr = getattr(ob, "_direction", None) + if dir_attr is not None: + is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC") + elif isinstance(ob, UnaryExpression): + op = getattr(ob, "operator", None) + is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") + + cols.append(col) + desc_flags.append(bool(is_desc)) + + return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags)) + + def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool): + if not key_vals: + return None + + conds = [] + for i, col in enumerate(spec.cols): + # If NULLs possible, normalize for comparison. Example using coalesce to a sentinel: + # sent_col = func.coalesce(col, literal("-∞")) + sent_col = col + ties = [spec.cols[j] == key_vals[j] for j in range(i)] + is_desc = spec.desc[i] + if not backward: + op = (sent_col < key_vals[i]) if is_desc else (sent_col > key_vals[i]) + else: + op = (sent_col > key_vals[i]) if is_desc else (sent_col < key_vals[i]) + conds.append(and_(*ties, op)) + return or_(*conds) + + def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]: + out = [] + for c in spec.cols: + # Only simple mapped columns supported for key pluck + key = getattr(c, "key", None) or getattr(c, "name", None) + if key is None or not hasattr(obj, key): + raise ValueError("Order includes non-mapped or related column; cannot pluck cursor key from row object.") + out.append(getattr(obj, key)) + return out + + def seek_window( + self, + params: dict | None = None, + *, + key: list[Any] | None = None, + backward: bool = False, + include_total: bool = True, + ) -> "SeekWindow[T]": + """ + Keyset pagination with relationship-safe filtering/sorting. + Always JOIN all CRUDSpec-discovered paths first; then apply filters, sort, seek. + """ + session = self.session + query, root_alias = self.get_query() + + # Requested fields → projection + optional loaders + fields = _normalize_fields_param(params) + expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], []) + + spec = CRUDSpec(self.model, params or {}, root_alias) + + # Parse all inputs so join_paths are populated + filters = spec.parse_filters() + order_by = spec.parse_sort() + root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() + spec.parse_includes() + join_paths = tuple(spec.get_join_paths()) + + # Soft delete + query = self._apply_not_deleted(query, root_alias, params) + + # Root column projection (load_only) + only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] + if only_cols: + query = query.options(Load(root_alias).load_only(*only_cols)) + + # Detect first hops that have deeper, nested tails requested (e.g. "contact.supervisor") + nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } + + # IMPORTANT: + # - Only attach loader options for first-hop relations from the root. + # - Always use selectinload here (avoid contains_eager joins). + # - Let compile_projections() supply deep chained options. + for base_alias, rel_attr, target_alias in join_paths: + is_firsthop_from_root = (base_alias is root_alias) + if not is_firsthop_from_root: + # Deeper hops are handled by proj_opts below + continue + prop = getattr(rel_attr, "property", None) + is_collection = bool(getattr(prop, "uselist", False)) + is_nested_firsthop = rel_attr.key in nested_first_hops + + opt = selectinload(rel_attr) + # Optional narrowng for collections + if is_collection: + child_names = (collection_field_names or {}).get(rel_attr.key, []) + if child_names: + target_cls = prop.mapper.class_ + cols = [getattr(target_cls, n, None) for n in child_names] + cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] + if cols: + opt = opt.load_only(*cols) + query = query.options(opt) + + # Filters AFTER joins → no cartesian products + if filters: + query = query.filter(*filters) + + # Order spec (with PK tie-breakers for stability) + order_spec = self._extract_order_spec(root_alias, order_by) + limit, _ = spec.parse_pagination() + if limit is None: + effective_limit = 50 + elif limit == 0: + effective_limit = None # unlimited + else: + effective_limit = limit + + # Seek predicate from cursor key (if any) + if key: + pred = self._key_predicate(order_spec, key, backward) + if pred is not None: + query = query.filter(pred) + + # Apply ORDER and LIMIT. Backward is SQL-inverted + reverse in-memory. + if not backward: + clauses = [(c.desc() if is_desc else c.asc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)] + query = query.order_by(*clauses) + if effective_limit is not None: + query = query.limit(effective_limit) + items = query.all() + else: + inv_clauses = [(c.asc() if is_desc else c.desc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)] + query = query.order_by(*inv_clauses) + if effective_limit is not None: + query = query.limit(effective_limit) + items = list(reversed(query.all())) + + # Projection meta tag for renderers + if fields: + proj = list(dict.fromkeys(fields)) + if "id" not in proj and hasattr(self.model, "id"): + proj.insert(0, "id") + else: + proj = [] + if root_field_names: + proj.extend(root_field_names) + if root_fields: + proj.extend(c.key for c in root_fields if hasattr(c, "key")) + for path, names in (rel_field_names or {}).items(): + prefix = ".".join(path) + for n in names: + proj.append(f"{prefix}.{n}") + if proj and "id" not in proj and hasattr(self.model, "id"): + proj.insert(0, "id") + + if proj: + for obj in items: + try: + setattr(obj, "__crudkit_projection__", tuple(proj)) + except Exception: + pass + + # Cursor key pluck: support related columns we hydrated via contains_eager + def _pluck_key_from_obj(obj: Any) -> list[Any]: + vals: list[Any] = [] + alias_to_rel: dict[Any, str] = {} + for _p, relationship_attr, target_alias in join_paths: + sel = getattr(target_alias, "selectable", None) + if sel is not None: + alias_to_rel[sel] = relationship_attr.key + + for col in order_spec.cols: + keyname = getattr(col, "key", None) or getattr(col, "name", None) + if keyname and hasattr(obj, keyname): + vals.append(getattr(obj, keyname)) + continue + table = getattr(col, "table", None) + relname = alias_to_rel.get(table) + if relname and keyname: + relobj = getattr(obj, relname, None) + if relobj is not None and hasattr(relobj, keyname): + vals.append(getattr(relobj, keyname)) + continue + raise ValueError("unpluckable") + return vals + + try: + first_key = _pluck_key_from_obj(items[0]) if items else None + last_key = _pluck_key_from_obj(items[-1]) if items else None + except Exception: + first_key = None + last_key = None + + # Count DISTINCT ids with mirrored joins + + # Apply deep projection loader options (safe: we avoided contains_eager) + if proj_opts: + query = query.options(*proj_opts) + total = None + if include_total: + base = session.query(getattr(root_alias, "id")) + base = self._apply_not_deleted(base, root_alias, params) + # same joins as above for correctness + for base_alias, rel_attr, target_alias in join_paths: + # do not join collections for COUNT mirror + if not bool(getattr(getattr(rel_attr, "property", None), "uselist", False)): + base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True) + if filters: + base = base.filter(*filters) + total = session.query(func.count()).select_from( + base.order_by(None).distinct().subquery() + ).scalar() or 0 + + window_limit_for_body = 0 if effective_limit is None and (limit == 0) else (effective_limit or 50) + + if log.isEnabledFor(logging.DEBUG): + log.debug("QUERY: %s", str(query)) + + return SeekWindow( + items=items, + limit=window_limit_for_body, + first_key=first_key, + last_key=last_key, + order=order_spec, + total=total, + ) + + # Helper: default ORDER BY for MSSQL when paginating without explicit order def _default_order_by(self, root_alias): mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) cols = [] @@ -184,114 +408,176 @@ class CRUDService(Generic[T]): return cols or [text("1")] def _stable_order_by(self, root_alias, given_order_by): + """ + Ensure deterministic ordering by appending PK columns as tiebreakers. + If no order is provided, fall back to default primary-key order. + Never duplicates columns in ORDER BY (SQL Server requires uniqueness). + """ order_by = list(given_order_by or []) if not order_by: return _dedupe_order_by(self._default_order_by(root_alias)) + order_by = _dedupe_order_by(order_by) + mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) present = {_order_identity(_unwrap_ob(ob)[0]) for ob in order_by} + for pk in mapper.primary_key: try: pk_col = getattr(root_alias, pk.key) except AttributeError: pk_col = pk - ident = _order_identity(pk_col) - if ident not in present: + if _order_identity(pk_col) not in present: order_by.append(pk_col.asc()) - present.add(ident) + present.add(_order_identity(pk_col)) + return order_by - def _extract_order_spec(self, root_alias, given_order_by): - given = self._stable_order_by(root_alias, given_order_by) - cols, desc_flags = [], [] - for ob in given: - elem = getattr(ob, "element", None) - col = elem if elem is not None else ob - is_desc = False - dir_attr = getattr(ob, "_direction", None) - if dir_attr is not None: - is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC") - elif isinstance(ob, UnaryExpression): - op = getattr(ob, "operator", None) - is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") - cols.append(col) - desc_flags.append(bool(is_desc)) - return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags)) + def get(self, id: int, params=None) -> T | None: + """ + Fetch a single row by id with conflict-free eager loading and clean projection. + Always JOIN any paths that CRUDSpec resolved for filters/fields/includes so + related-column filters never create cartesian products. + """ + query, root_alias = self.get_query() - def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool): - if not key_vals: - return None - conds = [] - for i, col in enumerate(spec.cols): - ties = [spec.cols[j] == key_vals[j] for j in range(i)] - is_desc = spec.desc[i] - op = (col < key_vals[i]) if is_desc ^ backward else (col > key_vals[i]) - conds.append(and_(*ties, op)) - return or_(*conds) + # Defaults so we can build a projection even if params is None + req_fields: list[str] = _normalize_fields_param(params) + root_fields: list[Any] = [] + root_field_names: dict[str, str] = {} + rel_field_names: dict[tuple[str, ...], list[str]] = {} - # ---- planning and application + # Soft-delete guard first + query = self._apply_not_deleted(query, root_alias, params) - @dataclass(slots=True) - class _Plan: - spec: Any - filters: Any - order_by: Any - limit: Any - offset: Any - root_fields: Any - rel_field_names: Any - root_field_names: Any - collection_field_names: Any - join_paths: Any - filter_tables: Any - req_fields: Any - proj_opts: Any - - def _plan(self, params, root_alias) -> _Plan: - req_fields = _normalize_fields_param(params) spec = CRUDSpec(self.model, params or {}, root_alias) + # Parse everything so CRUDSpec records any join paths it needed to resolve filters = spec.parse_filters() - order_by = spec.parse_sort() - limit, offset = spec.parse_pagination() - root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() if params else ([], {}, {}, {}) + # no ORDER BY for get() + if params: + root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() spec.parse_includes() + join_paths = tuple(spec.get_join_paths()) - filter_tables = _collect_tables_from_filters(filters) - _, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) - return self._Plan( - spec=spec, filters=filters, order_by=order_by, limit=limit, offset=offset, - root_fields=root_fields, rel_field_names=rel_field_names, - root_field_names=root_field_names, collection_field_names=collection_field_names, - join_paths=join_paths, filter_tables=filter_tables, - req_fields=req_fields, proj_opts=proj_opts - ) + # Root-column projection (load_only) + only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] + if only_cols: + query = query.options(Load(root_alias).load_only(*only_cols)) - def _apply_projection_load_only(self, query, root_alias, plan: _Plan): - only_cols = [c for c in plan.root_fields if isinstance(c, InstrumentedAttribute)] - return query.options(Load(root_alias).load_only(*only_cols)) if only_cols else query + nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } - def _apply_firsthop_strategies(self, query, root_alias, plan: _Plan): - nested_first_hops = { p[0] for p in (plan.rel_field_names or {}).keys() if len(p) > 1 } - for base_alias, rel_attr, target_alias in plan.join_paths: - if base_alias is not root_alias: + # First-hop only; use selectinload (no contains_eager) + for base_alias, rel_attr, target_alias in join_paths: + is_firsthop_from_root = (base_alias is root_alias) + if not is_firsthop_from_root: continue prop = getattr(rel_attr, "property", None) is_collection = bool(getattr(prop, "uselist", False)) + _is_nested_firsthop = rel_attr.key in nested_first_hops - sel = getattr(target_alias, "selectable", None) - sel_elem = getattr(sel, "element", None) - base_sel = sel_elem if sel_elem is not None else sel + opt = selectinload(rel_attr) + if is_collection: + child_names = (collection_field_names or {}).get(rel_attr.key, []) + if child_names: + target_cls = prop.mapper.class_ + cols = [getattr(target_cls, n, None) for n in child_names] + cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] + if cols: + opt = opt.load_only(*cols) + query = query.options(opt) - needed_for_filter = (sel in plan.filter_tables) or (base_sel in plan.filter_tables) + # Apply filters (joins are in place → no cartesian products) + if filters: + query = query.filter(*filters) + + # And the id filter + query = query.filter(getattr(root_alias, "id") == id) + + # Projection loader options compiled from requested fields. + # Skip if we used contains_eager to avoid loader-strategy conflicts. + expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) + if proj_opts: + query = query.options(*proj_opts) + + obj = query.first() + + # Emit exactly what the client requested (plus id), or a reasonable fallback + if req_fields: + proj = list(dict.fromkeys(req_fields)) + if "id" not in proj and hasattr(self.model, "id"): + proj.insert(0, "id") + else: + proj = [] + if root_field_names: + proj.extend(root_field_names) + if root_fields: + proj.extend(c.key for c in root_fields if hasattr(c, "key")) + for path, names in (rel_field_names or {}).items(): + prefix = ".".join(path) + for n in names: + proj.append(f"{prefix}.{n}") + if proj and "id" not in proj and hasattr(self.model, "id"): + proj.insert(0, "id") + + if proj and obj is not None: + try: + setattr(obj, "__crudkit_projection__", tuple(proj)) + except Exception: + pass + + if log.isEnabledFor(logging.DEBUG): + log.debug("QUERY: %s", str(query)) + + return obj or None + + def list(self, params=None) -> list[T]: + """ + Offset/limit listing with relationship-safe filtering. + We always JOIN every CRUDSpec-discovered path before applying filters/sorts. + """ + query, root_alias = self.get_query() + + # Defaults so we can reference them later even if params is None + req_fields: list[str] = _normalize_fields_param(params) + root_fields: list[Any] = [] + root_field_names: dict[str, str] = {} + rel_field_names: dict[tuple[str, ...], list[str]] = {} + + if params: + # Soft delete + query = self._apply_not_deleted(query, root_alias, params) + + spec = CRUDSpec(self.model, params or {}, root_alias) + filters = spec.parse_filters() + order_by = spec.parse_sort() + limit, offset = spec.parse_pagination() + + # Includes / fields (populates join_paths) + root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() + spec.parse_includes() + join_paths = tuple(spec.get_join_paths()) + + # Root column projection (load_only) + only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] + if only_cols: + query = query.options(Load(root_alias).load_only(*only_cols)) + + nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } + + # First-hop only; use selectinload + for base_alias, rel_attr, target_alias in join_paths: + is_firsthop_from_root = (base_alias is root_alias) + if not is_firsthop_from_root: + continue + prop = getattr(rel_attr, "property", None) + is_collection = bool(getattr(prop, "uselist", False)) + _is_nested_firsthop = rel_attr.key in nested_first_hops - if needed_for_filter and not is_collection: - query = query.join(rel_attr, isouter=True) - else: opt = selectinload(rel_attr) if is_collection: - child_names = (plan.collection_field_names or {}).get(rel_attr.key, []) + child_names = (collection_field_names or {}).get(rel_attr.key, []) if child_names: target_cls = prop.mapper.class_ cols = [getattr(target_cls, n, None) for n in child_names] @@ -299,194 +585,79 @@ class CRUDService(Generic[T]): if cols: opt = opt.load_only(*cols) query = query.options(opt) - return query - def _apply_proj_opts(self, query, plan: _Plan): - return query.options(*plan.proj_opts) if plan.proj_opts else query + # Filters AFTER joins → no cartesian products + if filters: + query = query.filter(*filters) - def _projection_meta(self, plan: _Plan): - if plan.req_fields: - proj = list(dict.fromkeys(plan.req_fields)) - return ["id"] + proj if "id" not in proj and hasattr(self.model, "id") else proj + # MSSQL requires ORDER BY when OFFSET is used; ensure stable PK tie-breakers + paginating = (limit is not None) or (offset is not None and offset != 0) + if paginating and not order_by and self.backend.requires_order_by_for_offset: + order_by = self._default_order_by(root_alias) - proj: list[str] = [] - if plan.root_field_names: - proj.extend(plan.root_field_names) - if plan.root_fields: - proj.extend(c.key for c in plan.root_fields if hasattr(c, "key")) - for path, names in (plan.rel_field_names or {}).items(): - prefix = ".".join(path) - for n in names: - proj.append(f"{prefix}.{n}") - if proj and "id" not in proj and hasattr(self.model, "id"): - proj.insert(0, "id") - return proj + order_by = _dedupe_order_by(order_by) - def _tag_projection(self, items, proj): - if not proj: - return - for obj in items if isinstance(items, list) else [items]: - try: - setattr(obj, "__crudkit_projection__", tuple(proj)) - except Exception: - pass + if order_by: + query = query.order_by(*order_by) - # ---- public read ops + # Offset/limit + if offset is not None and offset != 0: + query = query.offset(offset) + if limit is not None and limit > 0: + query = query.limit(limit) - def seek_window( - self, - params: dict | None = None, - *, - key: list[Any] | None = None, - backward: bool = False, - include_total: bool = True, - ) -> "SeekWindow[T]": - session = self.session - query, root_alias = self.get_query() - query = self._apply_not_deleted(query, root_alias, params) + # Projection loaders only if we didn’t use contains_eager + expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) + if proj_opts: + query = query.options(*proj_opts) - plan = self._plan(params, root_alias) - query = self._apply_projection_load_only(query, root_alias, plan) - query = self._apply_firsthop_strategies(query, root_alias, plan) - if plan.filters: - query = query.filter(*plan.filters) + else: + # No params; still honor projection loaders if any + expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) + if proj_opts: + query = query.options(*proj_opts) - order_spec = self._extract_order_spec(root_alias, plan.order_by) - limit = 50 if plan.limit is None else (None if plan.limit == 0 else plan.limit) - - if key: - pred = self._key_predicate(order_spec, key, backward) - if pred is not None: - query = query.filter(pred) - - clauses = [(c.desc() if d else c.asc()) for c, d in zip(order_spec.cols, order_spec.desc)] - if backward: - clauses = [(c.asc() if d else c.desc()) for c, d in zip(order_spec.cols, order_spec.desc)] - query = query.order_by(*clauses) - if limit is not None: - query = query.limit(limit) - - query = self._apply_proj_opts(query, plan) - rows = query.all() - items = list(reversed(rows)) if backward else rows - - proj = self._projection_meta(plan) - self._tag_projection(items, proj) - - # cursor keys - def pluck(obj): - vals = [] - alias_to_rel = {} - for _p, rel_attr, target_alias in plan.join_paths: - sel = getattr(target_alias, "selectable", None) - if sel is not None: - alias_to_rel[sel] = rel_attr.key - - for col in order_spec.cols: - keyname = getattr(col, "key", None) or getattr(col, "name", None) - if keyname and hasattr(obj, keyname): - vals.append(getattr(obj, keyname)); continue - table = getattr(col, "table", None) - relname = alias_to_rel.get(table) - if relname and keyname: - relobj = getattr(obj, relname, None) - if relobj is not None and hasattr(relobj, keyname): - vals.append(getattr(relobj, keyname)); continue - raise ValueError("unpluckable") - return vals - - try: - first_key = pluck(items[0]) if items else None - last_key = pluck(items[-1]) if items else None - except Exception: - first_key = last_key = None - - total = None - if include_total: - base = session.query(getattr(root_alias, "id")) - base = self._apply_not_deleted(base, root_alias, params) - for _b, rel_attr, target_alias in plan.join_paths: - if not bool(getattr(getattr(rel_attr, "property", None), "uselist", False)): - base = base.join(rel_attr, isouter=True) - if plan.filters: - base = base.filter(*plan.filters) - total = session.query(func.count()).select_from( - base.order_by(None).distinct().subquery() - ).scalar() or 0 - - if log.isEnabledFor(logging.DEBUG): - log.debug("QUERY: %s", str(query)) - - window_limit_for_body = 0 if limit is None and (plan.limit == 0) else (limit or 50) - return SeekWindow( - items=items, - limit=window_limit_for_body, - first_key=first_key, - last_key=last_key, - order=order_spec, - total=total, - ) - - def get(self, id: int, params=None) -> T | None: - query, root_alias = self.get_query() - query = self._apply_not_deleted(query, root_alias, params) - - plan = self._plan(params, root_alias) - query = self._apply_projection_load_only(query, root_alias, plan) - query = self._apply_firsthop_strategies(query, root_alias, plan) - if plan.filters: - query = query.filter(*plan.filters) - query = query.filter(getattr(root_alias, "id") == id) - query = self._apply_proj_opts(query, plan) - - obj = query.first() - proj = self._projection_meta(plan) - if obj: - self._tag_projection(obj, proj) - - if log.isEnabledFor(logging.DEBUG): - log.debug("QUERY: %s", str(query)) - return obj or None - - def list(self, params=None) -> list[T]: - query, root_alias = self.get_query() - - plan = self._plan(params, root_alias) - query = self._apply_not_deleted(query, root_alias, params) - query = self._apply_projection_load_only(query, root_alias, plan) - query = self._apply_firsthop_strategies(query, root_alias, plan) - if plan.filters: - query = query.filter(*plan.filters) - - order_by = plan.order_by - paginating = (plan.limit is not None) or (plan.offset not in (None, 0)) - if paginating and not order_by and self.backend.requires_order_by_for_offset: - order_by = self._default_order_by(root_alias) - order_by = _dedupe_order_by(order_by) - if order_by: - query = query.order_by(*order_by) - - if plan.offset: query = query.offset(plan.offset) - if plan.limit and plan.limit > 0: query = query.limit(plan.limit) - - query = self._apply_proj_opts(query, plan) rows = query.all() - proj = self._projection_meta(plan) - self._tag_projection(rows, proj) + # Build projection meta for renderers + if req_fields: + proj = list(dict.fromkeys(req_fields)) + if "id" not in proj and hasattr(self.model, "id"): + proj.insert(0, "id") + else: + proj = [] + if root_field_names: + proj.extend(root_field_names) + if root_fields: + proj.extend(c.key for c in root_fields if hasattr(c, "key")) + for path, names in (rel_field_names or {}).items(): + prefix = ".".join(path) + for n in names: + proj.append(f"{prefix}.{n}") + if proj and "id" not in proj and hasattr(self.model, "id"): + proj.insert(0, "id") + + if proj: + for obj in rows: + try: + setattr(obj, "__crudkit_projection__", tuple(proj)) + except Exception: + pass if log.isEnabledFor(logging.DEBUG): log.debug("QUERY: %s", str(query)) + return rows - # ---- write ops - def create(self, data: dict, actor=None, *, commit: bool = True) -> T: session = self.session obj = self.model(**data) session.add(obj) + session.flush() + self._log_version("create", obj, actor, commit=commit) + if commit: session.commit() return obj @@ -498,34 +669,59 @@ class CRUDService(Generic[T]): raise ValueError(f"{self.model.__name__} with ID {id} not found.") before = obj.as_dict() + + # Normalize and restrict payload to real columns norm = normalize_payload(data, self.model) incoming = filter_to_columns(norm, self.model) + + # Build a synthetic "desired" state for top-level columns desired = {**before, **incoming} - proposed = deep_diff(before, desired, ignore_keys={"id", "created_at", "updated_at"}, list_mode="index") + # Compute intended change set (before vs intended) + proposed = deep_diff( + before, desired, + ignore_keys={"id", "created_at", "updated_at"}, + list_mode="index", + ) patch = diff_to_patch(proposed) + + # Nothing to do if not patch: return obj + # Apply only what actually changes for k, v in patch.items(): setattr(obj, k, v) + # Optional: skip commit if ORM says no real change (paranoid check) + # Note: is_modified can lie if attrs are expired; use history for certainty. dirty = any(inspect(obj).attrs[k].history.has_changes() for k in patch.keys()) if not dirty: return obj + # Commit atomically if commit: session.commit() + # AFTER snapshot for audit after = obj.as_dict() - actual = deep_diff(before, after, ignore_keys={"id", "created_at", "updated_at"}, list_mode="index") + + # Actual diff (captures triggers/defaults, still ignoring noisy keys) + actual = deep_diff( + before, after, + ignore_keys={"id", "created_at", "updated_at"}, + list_mode="index", + ) + + # If truly nothing changed post-commit (rare), skip version spam if not (actual["added"] or actual["removed"] or actual["changed"]): return obj + # Log both what we *intended* and what *actually* happened self._log_version("update", obj, actor, metadata={"diff": actual, "patch": patch}, commit=commit) return obj - def delete(self, id: int, hard: bool = False, actor=None, *, commit: bool = True): + def delete(self, id: int, hard: bool = False, actor = None, *, commit: bool = True): session = self.session obj = session.get(self.model, id) if not obj: @@ -533,21 +729,22 @@ class CRUDService(Generic[T]): if hard or not self.supports_soft_delete: session.delete(obj) else: - cast(_SoftDeletable, obj).is_deleted = True + soft = cast(_SoftDeletable, obj) + soft.is_deleted = True if commit: session.commit() self._log_version("delete", obj, actor, commit=commit) return obj - # ---- audit - def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict | None = None, *, commit: bool = True): session = self.session try: + snapshot = {} try: snapshot = obj.as_dict() except Exception: snapshot = {"error": "serialize failed"} + version = Version( model_name=self.model.__name__, object_id=obj.id, diff --git a/inventory/templates/update_list.html b/inventory/templates/update_list.html index e0e2848..92f04b8 100644 --- a/inventory/templates/update_list.html +++ b/inventory/templates/update_list.html @@ -25,15 +25,6 @@ {% endfor %} -
- -
- - - -