from __future__ import annotations from collections.abc import Iterable from dataclasses import dataclass from flask import current_app from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast from sqlalchemy import and_, func, inspect, or_, text from sqlalchemy.engine import Engine, Connection from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, selectinload from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.sql import operators from sqlalchemy.sql.elements import UnaryExpression, ColumnElement from crudkit.core import to_jsonable, deep_diff, diff_to_patch, filter_to_columns, normalize_payload from crudkit.core.base import Version from crudkit.core.spec import CRUDSpec from crudkit.core.types import OrderSpec, SeekWindow from crudkit.backend import BackendInfo, make_backend_info from crudkit.projection import compile_projection import logging log = logging.getLogger("crudkit.service") @runtime_checkable class _HasID(Protocol): id: int @runtime_checkable class _HasTable(Protocol): __table__: Any @runtime_checkable class _HasADict(Protocol): def as_dict(self) -> dict: ... @runtime_checkable class _SoftDeletable(Protocol): is_deleted: bool class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol): """Surface expected by CRUDService.""" pass T = TypeVar("T", bound=_CRUDModelProto) # ---------------------------- utilities ---------------------------- def _collect_tables_from_filters(filters) -> set: """Walk SQLA expressions to collect Table/Alias objects that appear in filters.""" seen = set() def visit(node): if node is None: return tbl = getattr(node, "table", None) if tbl is not None: cur = tbl while cur is not None: seen.add(cur) cur = getattr(cur, "element", None) for attr in ("get_children",): fn = getattr(node, attr, None) if fn: for ch in fn(): visit(ch) for attr in ("left", "right", "element", "clause", "clauses"): val = getattr(node, attr, None) if val is None: continue if isinstance(val, (list, tuple)): for v in val: visit(v) else: visit(val) for f in (filters or []): visit(f) return seen def _selectable_keys(sel) -> set[str]: """ Return a set of stable string keys for a selectable/alias and its base, so we can match when when different alias objects are used. """ keys: set[str] = set() cur = sel while cur is not None: k = getattr(cur, "key", None) or getattr(cur, "name", None) if isinstance(k, str) and k: keys.add(k) cur = getattr(cur, "element", None) return keys def _unwrap_ob(ob): elem = getattr(ob, "element", None) col = elem if elem is not None else ob d = getattr(ob, "_direction", None) if d is not None: is_desc = (d is operators.desc_op) or (getattr(d, "name", "").upper() == "DESC") elif isinstance(ob, UnaryExpression): op = getattr(ob, "operator", None) is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") else: is_desc = False return col, bool(is_desc) def _order_identity(col: ColumnElement): table = getattr(col, "table", None) table_key = getattr(table, "key", None) or id(table) col_key = getattr(col, "key", None) or getattr(col, "name", None) return (table_key, col_key) def _dedupe_order_by(order_by): if not order_by: return [] seen, out = set(), [] for ob in order_by: col, _ = _unwrap_ob(ob) ident = _order_identity(col) if ident in seen: continue seen.add(ident) out.append(ob) return out def _is_truthy(val): return str(val).lower() in ('1', 'true', 'yes', 'on') def _normalize_fields_param(params: dict | None) -> list[str]: if not params: return [] raw = params.get("fields") if isinstance(raw, (list, tuple)): out: list[str] = [] for item in raw: if isinstance(item, str): out.extend([p for p in (s.strip() for s in item.split(",")) if p]) return out if isinstance(raw, str): return [p for p in (s.strip() for s in raw.split(",")) if p] return [] # ---------------------------- CRUD service ---------------------------- class CRUDService(Generic[T]): def __init__( self, model: Type[T], session_factory: Callable[[], Session], polymorphic: bool = False, *, backend: Optional[BackendInfo] = None ): self.model = model self._session_factory = session_factory self.polymorphic = polymorphic self.supports_soft_delete = hasattr(model, 'is_deleted') self._backend: Optional[BackendInfo] = backend # ---- infra @property def session(self) -> Session: try: return current_app.extensions["crudkit"]["Session"] except Exception: return self._session_factory() @property def backend(self) -> BackendInfo: if self._backend is None: bind = self.session.get_bind() eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind) self._backend = make_backend_info(eng) return self._backend def get_query(self): if self.polymorphic: poly = with_polymorphic(self.model, "*") return self.session.query(poly), poly return self.session.query(self.model), self.model # ---- common building blocks def _apply_not_deleted(self, query, root_alias, params): if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")): return query.filter(getattr(root_alias, "is_deleted") == False) return query def _default_order_by(self, root_alias): mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) cols = [] for col in mapper.primary_key: try: cols.append(getattr(root_alias, col.key)) except AttributeError: cols.append(col) return cols or [text("1")] def _stable_order_by(self, root_alias, given_order_by): order_by = list(given_order_by or []) if not order_by: return _dedupe_order_by(self._default_order_by(root_alias)) order_by = _dedupe_order_by(order_by) mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) present = {_order_identity(_unwrap_ob(ob)[0]) for ob in order_by} for pk in mapper.primary_key: try: pk_col = getattr(root_alias, pk.key) except AttributeError: pk_col = pk ident = _order_identity(pk_col) if ident not in present: order_by.append(pk_col.asc()) present.add(ident) return order_by def _extract_order_spec(self, root_alias, given_order_by): given = self._stable_order_by(root_alias, given_order_by) cols, desc_flags = [], [] for ob in given: elem = getattr(ob, "element", None) col = elem if elem is not None else ob is_desc = False dir_attr = getattr(ob, "_direction", None) if dir_attr is not None: is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC") elif isinstance(ob, UnaryExpression): op = getattr(ob, "operator", None) is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") cols.append(col) desc_flags.append(bool(is_desc)) return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags)) def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool): if not key_vals: return None conds = [] for i, col in enumerate(spec.cols): ties = [spec.cols[j] == key_vals[j] for j in range(i)] is_desc = spec.desc[i] op = (col < key_vals[i]) if is_desc ^ backward else (col > key_vals[i]) conds.append(and_(*ties, op)) return or_(*conds) # ---- planning and application @dataclass(slots=True) class _Plan: spec: Any filters: Any order_by: Any limit: Any offset: Any root_fields: Any rel_field_names: Any root_field_names: Any collection_field_names: Any join_paths: Any filter_tables: Any filter_table_keys: Any req_fields: Any proj_opts: Any def _plan(self, params, root_alias) -> _Plan: req_fields = _normalize_fields_param(params) spec = CRUDSpec(self.model, params or {}, root_alias) filters = spec.parse_filters() order_by = spec.parse_sort() limit, offset = spec.parse_pagination() root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() if params else ([], {}, {}, {}) spec.parse_includes() join_paths = tuple(spec.get_join_paths()) filter_tables = _collect_tables_from_filters(filters) _, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) # Precompute a string-key set for quick/stable membership tests fkeys: set[str] = set() for t in filter_tables: try: fkeys |= _selectable_keys(t) except Exception: pass return self._Plan( spec=spec, filters=filters, order_by=order_by, limit=limit, offset=offset, root_fields=root_fields, rel_field_names=rel_field_names, root_field_names=root_field_names, collection_field_names=collection_field_names, join_paths=join_paths, filter_tables=filter_tables, filter_table_keys=fkeys, req_fields=req_fields, proj_opts=proj_opts ) def _apply_projection_load_only(self, query, root_alias, plan: _Plan): only_cols = [c for c in plan.root_fields if isinstance(c, InstrumentedAttribute)] return query.options(Load(root_alias).load_only(*only_cols)) if only_cols else query def _apply_firsthop_strategies(self, query, root_alias, plan: _Plan): nested_first_hops = { p[0] for p in (plan.rel_field_names or {}).keys() if len(p) > 1 } for base_alias, rel_attr, target_alias in plan.join_paths: if base_alias is not root_alias: continue prop = getattr(rel_attr, "property", None) is_collection = bool(getattr(prop, "uselist", False)) if not is_collection: query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) else: opt = selectinload(rel_attr) if is_collection: child_names = (plan.collection_field_names or {}).get(rel_attr.key, []) if child_names: target_cls = prop.mapper.class_ cols = [getattr(target_cls, n, None) for n in child_names] cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] if cols: opt = opt.load_only(*cols) query = query.options(opt) return query def _apply_proj_opts(self, query, plan: _Plan): return query.options(*plan.proj_opts) if plan.proj_opts else query def _projection_meta(self, plan: _Plan): if plan.req_fields: proj = list(dict.fromkeys(plan.req_fields)) return ["id"] + proj if "id" not in proj and hasattr(self.model, "id") else proj proj: list[str] = [] if plan.root_field_names: proj.extend(plan.root_field_names) if plan.root_fields: proj.extend(c.key for c in plan.root_fields if hasattr(c, "key")) for path, names in (plan.rel_field_names or {}).items(): prefix = ".".join(path) for n in names: proj.append(f"{prefix}.{n}") if proj and "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") return proj def _tag_projection(self, items, proj): if not proj: return for obj in items if isinstance(items, list) else [items]: try: setattr(obj, "__crudkit_projection__", tuple(proj)) except Exception: pass # ---- public read ops def seek_window( self, params: dict | None = None, *, key: list[Any] | None = None, backward: bool = False, include_total: bool = True, ) -> "SeekWindow[T]": session = self.session query, root_alias = self.get_query() query = self._apply_not_deleted(query, root_alias, params) plan = self._plan(params, root_alias) query = self._apply_projection_load_only(query, root_alias, plan) query = self._apply_firsthop_strategies(query, root_alias, plan) if plan.filters: query = query.filter(*plan.filters) order_spec = self._extract_order_spec(root_alias, plan.order_by) limit = 50 if plan.limit is None else (None if plan.limit == 0 else plan.limit) if key: pred = self._key_predicate(order_spec, key, backward) if pred is not None: query = query.filter(pred) clauses = [(c.desc() if d else c.asc()) for c, d in zip(order_spec.cols, order_spec.desc)] if backward: clauses = [(c.asc() if d else c.desc()) for c, d in zip(order_spec.cols, order_spec.desc)] query = query.order_by(*clauses) if limit is not None: query = query.limit(limit) query = self._apply_proj_opts(query, plan) rows = query.all() items = list(reversed(rows)) if backward else rows proj = self._projection_meta(plan) self._tag_projection(items, proj) # cursor keys def pluck(obj): vals = [] alias_to_rel = {} for _p, rel_attr, target_alias in plan.join_paths: sel = getattr(target_alias, "selectable", None) if sel is not None: alias_to_rel[sel] = rel_attr.key for col in order_spec.cols: keyname = getattr(col, "key", None) or getattr(col, "name", None) if keyname and hasattr(obj, keyname): vals.append(getattr(obj, keyname)); continue table = getattr(col, "table", None) relname = alias_to_rel.get(table) if relname and keyname: relobj = getattr(obj, relname, None) if relobj is not None and hasattr(relobj, keyname): vals.append(getattr(relobj, keyname)); continue raise ValueError("unpluckable") return vals try: first_key = pluck(items[0]) if items else None last_key = pluck(items[-1]) if items else None except Exception: first_key = last_key = None total = None if include_total: base = session.query(getattr(root_alias, "id")) base = self._apply_not_deleted(base, root_alias, params) for _b, rel_attr, target_alias in plan.join_paths: if not bool(getattr(getattr(rel_attr, "property", None), "uselist", False)): base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True) if plan.filters: base = base.filter(*plan.filters) total = session.query(func.count()).select_from( base.order_by(None).distinct().subquery() ).scalar() or 0 if log.isEnabledFor(logging.DEBUG): log.debug("QUERY: %s", str(query)) window_limit_for_body = 0 if limit is None and (plan.limit == 0) else (limit or 50) return SeekWindow( items=items, limit=window_limit_for_body, first_key=first_key, last_key=last_key, order=order_spec, total=total, ) def get(self, id: int, params=None) -> T | None: query, root_alias = self.get_query() query = self._apply_not_deleted(query, root_alias, params) plan = self._plan(params, root_alias) query = self._apply_projection_load_only(query, root_alias, plan) query = self._apply_firsthop_strategies(query, root_alias, plan) if plan.filters: query = query.filter(*plan.filters) query = query.filter(getattr(root_alias, "id") == id) query = self._apply_proj_opts(query, plan) obj = query.first() proj = self._projection_meta(plan) if obj: self._tag_projection(obj, proj) if log.isEnabledFor(logging.DEBUG): log.debug("QUERY: %s", str(query)) return obj or None def list(self, params=None) -> list[T]: query, root_alias = self.get_query() plan = self._plan(params, root_alias) query = self._apply_not_deleted(query, root_alias, params) query = self._apply_projection_load_only(query, root_alias, plan) query = self._apply_firsthop_strategies(query, root_alias, plan) if plan.filters: query = query.filter(*plan.filters) order_by = plan.order_by paginating = (plan.limit is not None) or (plan.offset not in (None, 0)) if paginating and not order_by and self.backend.requires_order_by_for_offset: order_by = self._default_order_by(root_alias) order_by = _dedupe_order_by(order_by) if order_by: query = query.order_by(*order_by) if plan.offset: query = query.offset(plan.offset) if plan.limit and plan.limit > 0: query = query.limit(plan.limit) query = self._apply_proj_opts(query, plan) rows = query.all() proj = self._projection_meta(plan) self._tag_projection(rows, proj) if log.isEnabledFor(logging.DEBUG): log.debug("QUERY: %s", str(query)) return rows # ---- write ops def create(self, data: dict, actor=None, *, commit: bool = True) -> T: session = self.session obj = self.model(**data) session.add(obj) session.flush() self._log_version("create", obj, actor, commit=commit) if commit: session.commit() return obj def update(self, id: int, data: dict, actor=None, *, commit: bool = True) -> T: session = self.session obj = session.get(self.model, id) if not obj: raise ValueError(f"{self.model.__name__} with ID {id} not found.") before = obj.as_dict() norm = normalize_payload(data, self.model) incoming = filter_to_columns(norm, self.model) desired = {**before, **incoming} proposed = deep_diff(before, desired, ignore_keys={"id", "created_at", "updated_at"}, list_mode="index") patch = diff_to_patch(proposed) if not patch: return obj for k, v in patch.items(): setattr(obj, k, v) dirty = any(inspect(obj).attrs[k].history.has_changes() for k in patch.keys()) if not dirty: return obj if commit: session.commit() after = obj.as_dict() actual = deep_diff(before, after, ignore_keys={"id", "created_at", "updated_at"}, list_mode="index") if not (actual["added"] or actual["removed"] or actual["changed"]): return obj self._log_version("update", obj, actor, metadata={"diff": actual, "patch": patch}, commit=commit) return obj def delete(self, id: int, hard: bool = False, actor=None, *, commit: bool = True): session = self.session obj = session.get(self.model, id) if not obj: return None if hard or not self.supports_soft_delete: session.delete(obj) else: cast(_SoftDeletable, obj).is_deleted = True if commit: session.commit() self._log_version("delete", obj, actor, commit=commit) return obj # ---- audit def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict | None = None, *, commit: bool = True): session = self.session try: try: snapshot = obj.as_dict() except Exception: snapshot = {"error": "serialize failed"} version = Version( model_name=self.model.__name__, object_id=obj.id, change_type=change_type, data=to_jsonable(snapshot), actor=str(actor) if actor else None, meta=to_jsonable(metadata) if metadata else None, ) session.add(version) if commit: session.commit() except Exception as e: log.warning(f"Version logging failed for {self.model.__name__} id={getattr(obj, 'id', '?')}: {str(e)}") session.rollback()