from __future__ import annotations from collections.abc import Iterable from flask import current_app from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast from sqlalchemy import and_, func, inspect, or_, text from sqlalchemy.engine import Engine, Connection from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager, selectinload from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.sql import operators from sqlalchemy.sql.elements import UnaryExpression, ColumnElement from crudkit.core import to_jsonable, deep_diff, diff_to_patch, filter_to_columns, normalize_payload from crudkit.core.base import Version from crudkit.core.spec import CRUDSpec from crudkit.core.types import OrderSpec, SeekWindow from crudkit.backend import BackendInfo, make_backend_info from crudkit.projection import compile_projection import logging log = logging.getLogger("crudkit.service") @runtime_checkable class _HasID(Protocol): id: int @runtime_checkable class _HasTable(Protocol): __table__: Any @runtime_checkable class _HasADict(Protocol): def as_dict(self) -> dict: ... @runtime_checkable class _SoftDeletable(Protocol): is_deleted: bool class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol): """Minimal surface that our CRUD service relies on. Soft-delete is optional.""" pass T = TypeVar("T", bound=_CRUDModelProto) def _unwrap_ob(ob): """Return (col, is_desc) from an ORDER BY element (handles .asc()/.desc()).""" col = getattr(ob, "element", None) if col is None: col = ob is_desc = False dir_attr = getattr(ob, "_direction", None) if dir_attr is not None: is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC") elif isinstance(ob, UnaryExpression): op = getattr(ob, "operator", None) is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") return col, bool(is_desc) def _order_identity(col: ColumnElement): """ Build a stable identity for a column suitable for deduping. We ignore direction here. Duplicates are duplicates regardless of ASC/DESC. """ table = getattr(col, "table", None) table_key = getattr(table, "key", None) or id(table) col_key = getattr(col, "key", None) or getattr(col, "name", None) return (table_key, col_key) def _dedupe_order_by(order_by): """Remove duplicate ORDER BY entries (by column identity, ignoring direction).""" if not order_by: return [] seen = set() out = [] for ob in order_by: col, _ = _unwrap_ob(ob) ident = _order_identity(col) if ident in seen: continue seen.add(ident) out.append(ob) return out def _is_truthy(val): return str(val).lower() in ('1', 'true', 'yes', 'on') def _normalize_fields_param(params: dict | None) -> list[str]: if not params: return [] raw = params.get("fields") if isinstance(raw, (list, tuple)): out: list[str] = [] for item in raw: if isinstance(item, str): out.extend([p for p in (s.strip() for s in item.split(",")) if p]) return out if isinstance(raw, str): return [p for p in (s.strip() for s in raw.split(",")) if p] return [] class CRUDService(Generic[T]): def __init__( self, model: Type[T], session_factory: Callable[[], Session], polymorphic: bool = False, *, backend: Optional[BackendInfo] = None ): self.model = model self._session_factory = session_factory self.polymorphic = polymorphic self.supports_soft_delete = hasattr(model, 'is_deleted') self._backend: Optional[BackendInfo] = backend @property def session(self) -> Session: """Always return the Flask-scoped Session if available; otherwise the provided factory.""" try: sess = current_app.extensions["crudkit"]["Session"] return sess except Exception: return self._session_factory() @property def backend(self) -> BackendInfo: """Resolve backend info lazily against the active session's engine.""" if self._backend is None: bind = self.session.get_bind() eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind) self._backend = make_backend_info(eng) return self._backend def get_query(self): if self.polymorphic: poly = with_polymorphic(self.model, "*") return self.session.query(poly), poly return self.session.query(self.model), self.model def _debug_bind(self, where: str): try: bind = self.session.get_bind() eng = getattr(bind, "engine", bind) print(f"SERVICE BIND [{where}]: engine_id={id(eng)} url={getattr(eng, 'url', '?')} session={type(self.session).__name__}") except Exception as e: print(f"SERVICE BIND [{where}]: failed to introspect bind: {e}") def _apply_not_deleted(self, query, root_alias, params) -> Any: if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")): return query.filter(getattr(root_alias, "is_deleted") == False) return query def _extract_order_spec(self, root_alias, given_order_by): """ SQLAlchemy 2.x only: Normalize order_by into (cols, desc_flags). Supports plain columns and col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses. """ given = self._stable_order_by(root_alias, given_order_by) cols, desc_flags = [], [] for ob in given: # Unwrap column if this is a UnaryExpression produced by .asc()/.desc() elem = getattr(ob, "element", None) col = elem if elem is not None else ob # Detect direction in SA 2.x is_desc = False dir_attr = getattr(ob, "_direction", None) if dir_attr is not None: is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC") elif isinstance(ob, UnaryExpression): op = getattr(ob, "operator", None) is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") cols.append(col) desc_flags.append(bool(is_desc)) return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags)) def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool): if not key_vals: return None conds = [] for i, col in enumerate(spec.cols): # If NULLs possible, normalize for comparison. Example using coalesce to a sentinel: # sent_col = func.coalesce(col, literal("-∞")) sent_col = col ties = [spec.cols[j] == key_vals[j] for j in range(i)] is_desc = spec.desc[i] if not backward: op = (sent_col < key_vals[i]) if is_desc else (sent_col > key_vals[i]) else: op = (sent_col > key_vals[i]) if is_desc else (sent_col < key_vals[i]) conds.append(and_(*ties, op)) return or_(*conds) def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]: out = [] for c in spec.cols: # Only simple mapped columns supported for key pluck key = getattr(c, "key", None) or getattr(c, "name", None) if key is None or not hasattr(obj, key): raise ValueError("Order includes non-mapped or related column; cannot pluck cursor key from row object.") out.append(getattr(obj, key)) return out def seek_window( self, params: dict | None = None, *, key: list[Any] | None = None, backward: bool = False, include_total: bool = True, ) -> "SeekWindow[T]": """ Keyset pagination with relationship-safe filtering/sorting. Always JOIN all CRUDSpec-discovered paths first; then apply filters, sort, seek. """ session = self.session query, root_alias = self.get_query() # Requested fields → projection + optional loaders fields = _normalize_fields_param(params) expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], []) spec = CRUDSpec(self.model, params or {}, root_alias) # Parse all inputs so join_paths are populated filters = spec.parse_filters() order_by = spec.parse_sort() root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() spec.parse_includes() join_paths = tuple(spec.get_join_paths()) # Soft delete query = self._apply_not_deleted(query, root_alias, params) # Root column projection (load_only) only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) # Detect first hops that have deeper, nested tails requested (e.g. "contact.supervisor") nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } used_contains_eager = False for base_alias, rel_attr, target_alias in join_paths: is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False)) is_nested_firsthop = rel_attr.key in nested_first_hops if is_collection or is_nested_firsthop: # Use selectinload so deeper hops can chain cleanly (and to avoid # contains_eager/loader conflicts on nested paths). opt = selectinload(rel_attr) # Narrow columns for collections if we know child scalar names if is_collection: child_names = (collection_field_names or {}).get(rel_attr.key, []) if child_names: target_cls = rel_attr.property.mapper.class_ cols = [getattr(target_cls, n, None) for n in child_names] cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] if cols: opt = opt.load_only(*cols) query = query.options(opt) else: # Simple first-hop scalar rel with no deeper tails: safe to join + contains_eager query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.options(contains_eager(rel_attr, alias=target_alias)) used_contains_eager = True # Filters AFTER joins → no cartesian products if filters: query = query.filter(*filters) # Order spec (with PK tie-breakers for stability) order_spec = self._extract_order_spec(root_alias, order_by) limit, _ = spec.parse_pagination() if limit is None: effective_limit = 50 elif limit == 0: effective_limit = None # unlimited else: effective_limit = limit # Seek predicate from cursor key (if any) if key: pred = self._key_predicate(order_spec, key, backward) if pred is not None: query = query.filter(pred) # Apply ORDER and LIMIT. Backward is SQL-inverted + reverse in-memory. if not backward: clauses = [(c.desc() if is_desc else c.asc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)] query = query.order_by(*clauses) if effective_limit is not None: query = query.limit(effective_limit) items = query.all() else: inv_clauses = [(c.asc() if is_desc else c.desc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)] query = query.order_by(*inv_clauses) if effective_limit is not None: query = query.limit(effective_limit) items = list(reversed(query.all())) # Projection meta tag for renderers if fields: proj = list(dict.fromkeys(fields)) if "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") else: proj = [] if root_field_names: proj.extend(root_field_names) if root_fields: proj.extend(c.key for c in root_fields if hasattr(c, "key")) for path, names in (rel_field_names or {}).items(): prefix = ".".join(path) for n in names: proj.append(f"{prefix}.{n}") if proj and "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") if proj: for obj in items: try: setattr(obj, "__crudkit_projection__", tuple(proj)) except Exception: pass # Cursor key pluck: support related columns we hydrated via contains_eager def _pluck_key_from_obj(obj: Any) -> list[Any]: vals: list[Any] = [] alias_to_rel: dict[Any, str] = {} for _p, relationship_attr, target_alias in join_paths: sel = getattr(target_alias, "selectable", None) if sel is not None: alias_to_rel[sel] = relationship_attr.key for col in order_spec.cols: keyname = getattr(col, "key", None) or getattr(col, "name", None) if keyname and hasattr(obj, keyname): vals.append(getattr(obj, keyname)) continue table = getattr(col, "table", None) relname = alias_to_rel.get(table) if relname and keyname: relobj = getattr(obj, relname, None) if relobj is not None and hasattr(relobj, keyname): vals.append(getattr(relobj, keyname)) continue raise ValueError("unpluckable") return vals try: first_key = _pluck_key_from_obj(items[0]) if items else None last_key = _pluck_key_from_obj(items[-1]) if items else None except Exception: first_key = None last_key = None # Count DISTINCT ids with mirrored joins total = None if include_total: base = session.query(getattr(root_alias, "id")) base = self._apply_not_deleted(base, root_alias, params) # same joins as above for correctness for base_alias, rel_attr, target_alias in join_paths: # do not join collections for COUNT mirror if not bool(getattr(getattr(rel_attr, "property", None), "uselist", False)): base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True) if filters: base = base.filter(*filters) total = session.query(func.count()).select_from( base.order_by(None).distinct().subquery() ).scalar() or 0 window_limit_for_body = 0 if effective_limit is None and (limit == 0) else (effective_limit or 50) if log.isEnabledFor(logging.DEBUG): log.debug("QUERY: %s", str(query)) return SeekWindow( items=items, limit=window_limit_for_body, first_key=first_key, last_key=last_key, order=order_spec, total=total, ) # Helper: default ORDER BY for MSSQL when paginating without explicit order def _default_order_by(self, root_alias): mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) cols = [] for col in mapper.primary_key: try: cols.append(getattr(root_alias, col.key)) except AttributeError: cols.append(col) return cols or [text("1")] def _stable_order_by(self, root_alias, given_order_by): """ Ensure deterministic ordering by appending PK columns as tiebreakers. If no order is provided, fall back to default primary-key order. Never duplicates columns in ORDER BY (SQL Server requires uniqueness). """ order_by = list(given_order_by or []) if not order_by: return _dedupe_order_by(self._default_order_by(root_alias)) order_by = _dedupe_order_by(order_by) mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) present = {_order_identity(_unwrap_ob(ob)[0]) for ob in order_by} for pk in mapper.primary_key: try: pk_col = getattr(root_alias, pk.key) except AttributeError: pk_col = pk if _order_identity(pk_col) not in present: order_by.append(pk_col.asc()) present.add(_order_identity(pk_col)) return order_by def get(self, id: int, params=None) -> T | None: """ Fetch a single row by id with conflict-free eager loading and clean projection. Always JOIN any paths that CRUDSpec resolved for filters/fields/includes so related-column filters never create cartesian products. """ query, root_alias = self.get_query() # Defaults so we can build a projection even if params is None req_fields: list[str] = _normalize_fields_param(params) root_fields: list[Any] = [] root_field_names: dict[str, str] = {} rel_field_names: dict[tuple[str, ...], list[str]] = {} # Soft-delete guard first query = self._apply_not_deleted(query, root_alias, params) spec = CRUDSpec(self.model, params or {}, root_alias) # Parse everything so CRUDSpec records any join paths it needed to resolve filters = spec.parse_filters() # no ORDER BY for get() if params: root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() spec.parse_includes() join_paths = tuple(spec.get_join_paths()) # Root-column projection (load_only) only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } used_contains_eager = False for base_alias, rel_attr, target_alias in join_paths: is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False)) is_nested_firsthop = rel_attr.key in nested_first_hops if is_collection or is_nested_firsthop: opt = selectinload(rel_attr) if is_collection: child_names = (collection_field_names or {}).get(rel_attr.key, []) if child_names: target_cls = rel_attr.property.mapper.class_ cols = [getattr(target_cls, n, None) for n in child_names] cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] if cols: opt = opt.load_only(*cols) query = query.options(opt) else: query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.options(contains_eager(rel_attr, alias=target_alias)) used_contains_eager = True # Apply filters (joins are in place → no cartesian products) if filters: query = query.filter(*filters) # And the id filter query = query.filter(getattr(root_alias, "id") == id) # Projection loader options compiled from requested fields. # Skip if we used contains_eager to avoid loader-strategy conflicts. expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) if proj_opts and not used_contains_eager: query = query.options(*proj_opts) obj = query.first() # Emit exactly what the client requested (plus id), or a reasonable fallback if req_fields: proj = list(dict.fromkeys(req_fields)) if "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") else: proj = [] if root_field_names: proj.extend(root_field_names) if root_fields: proj.extend(c.key for c in root_fields if hasattr(c, "key")) for path, names in (rel_field_names or {}).items(): prefix = ".".join(path) for n in names: proj.append(f"{prefix}.{n}") if proj and "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") if proj and obj is not None: try: setattr(obj, "__crudkit_projection__", tuple(proj)) except Exception: pass if log.isEnabledFor(logging.DEBUG): log.debug("QUERY: %s", str(query)) return obj or None def list(self, params=None) -> list[T]: """ Offset/limit listing with relationship-safe filtering. We always JOIN every CRUDSpec-discovered path before applying filters/sorts. """ query, root_alias = self.get_query() # Defaults so we can reference them later even if params is None req_fields: list[str] = _normalize_fields_param(params) root_fields: list[Any] = [] root_field_names: dict[str, str] = {} rel_field_names: dict[tuple[str, ...], list[str]] = {} if params: # Soft delete query = self._apply_not_deleted(query, root_alias, params) spec = CRUDSpec(self.model, params or {}, root_alias) filters = spec.parse_filters() order_by = spec.parse_sort() limit, offset = spec.parse_pagination() # Includes / fields (populates join_paths) root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() spec.parse_includes() join_paths = tuple(spec.get_join_paths()) # Root column projection (load_only) only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } used_contains_eager = False for _base_alias, rel_attr, target_alias in join_paths: is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False)) is_nested_firsthop = rel_attr.key in nested_first_hops if is_collection or is_nested_firsthop: opt = selectinload(rel_attr) if is_collection: child_names = (collection_field_names or {}).get(rel_attr.key, []) if child_names: target_cls = rel_attr.property.mapper.class_ cols = [getattr(target_cls, n, None) for n in child_names] cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] if cols: opt = opt.load_only(*cols) query = query.options(opt) else: query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.options(contains_eager(rel_attr, alias=target_alias)) used_contains_eager = True # Filters AFTER joins → no cartesian products if filters: query = query.filter(*filters) # MSSQL requires ORDER BY when OFFSET is used; ensure stable PK tie-breakers paginating = (limit is not None) or (offset is not None and offset != 0) if paginating and not order_by and self.backend.requires_order_by_for_offset: order_by = self._default_order_by(root_alias) order_by = _dedupe_order_by(order_by) if order_by: query = query.order_by(*order_by) # Offset/limit if offset is not None and offset != 0: query = query.offset(offset) if limit is not None and limit > 0: query = query.limit(limit) # Projection loaders only if we didn’t use contains_eager expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) if proj_opts and not used_contains_eager: query = query.options(*proj_opts) else: # No params; still honor projection loaders if any expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) if proj_opts: query = query.options(*proj_opts) rows = query.all() # Build projection meta for renderers if req_fields: proj = list(dict.fromkeys(req_fields)) if "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") else: proj = [] if root_field_names: proj.extend(root_field_names) if root_fields: proj.extend(c.key for c in root_fields if hasattr(c, "key")) for path, names in (rel_field_names or {}).items(): prefix = ".".join(path) for n in names: proj.append(f"{prefix}.{n}") if proj and "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") if proj: for obj in rows: try: setattr(obj, "__crudkit_projection__", tuple(proj)) except Exception: pass if log.isEnabledFor(logging.DEBUG): log.debug("QUERY: %s", str(query)) return rows def create(self, data: dict, actor=None) -> T: session = self.session obj = self.model(**data) session.add(obj) session.commit() self._log_version("create", obj, actor) return obj def update(self, id: int, data: dict, actor=None) -> T: session = self.session obj = session.get(self.model, id) if not obj: raise ValueError(f"{self.model.__name__} with ID {id} not found.") before = obj.as_dict() # Normalize and restrict payload to real columns norm = normalize_payload(data, self.model) incoming = filter_to_columns(norm, self.model) # Build a synthetic "desired" state for top-level columns desired = {**before, **incoming} # Compute intended change set (before vs intended) proposed = deep_diff( before, desired, ignore_keys={"id", "created_at", "updated_at"}, list_mode="index", ) patch = diff_to_patch(proposed) # Nothing to do if not patch: return obj # Apply only what actually changes for k, v in patch.items(): setattr(obj, k, v) # Optional: skip commit if ORM says no real change (paranoid check) # Note: is_modified can lie if attrs are expired; use history for certainty. dirty = any(inspect(obj).attrs[k].history.has_changes() for k in patch.keys()) if not dirty: return obj # Commit atomically session.commit() # AFTER snapshot for audit after = obj.as_dict() # Actual diff (captures triggers/defaults, still ignoring noisy keys) actual = deep_diff( before, after, ignore_keys={"id", "created_at", "updated_at"}, list_mode="index", ) # If truly nothing changed post-commit (rare), skip version spam if not (actual["added"] or actual["removed"] or actual["changed"]): return obj # Log both what we *intended* and what *actually* happened self._log_version("update", obj, actor, metadata={"diff": actual, "patch": patch}) return obj def delete(self, id: int, hard: bool = False, actor = None): session = self.session obj = session.get(self.model, id) if not obj: return None if hard or not self.supports_soft_delete: session.delete(obj) else: soft = cast(_SoftDeletable, obj) soft.is_deleted = True session.commit() self._log_version("delete", obj, actor) return obj def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict | None = None): session = self.session try: snapshot = {} try: snapshot = obj.as_dict() except Exception: snapshot = {"error": "serialize failed"} version = Version( model_name=self.model.__name__, object_id=obj.id, change_type=change_type, data=to_jsonable(snapshot), actor=str(actor) if actor else None, meta=to_jsonable(metadata) if metadata else None, ) session.add(version) session.commit() except Exception as e: log.warning(f"Version logging failed for {self.model.__name__} id={getattr(obj, "id", "?")}: {str(e)}") session.rollback()