from __future__ import annotations from dataclasses import dataclass from flask import current_app from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast from sqlalchemy import and_, func, inspect, or_, text from sqlalchemy.engine import Engine, Connection from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, selectinload, with_loader_criteria from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.sql import operators, visitors from sqlalchemy.sql.elements import UnaryExpression, ColumnElement from crudkit.core import to_jsonable, deep_diff, diff_to_patch, filter_to_columns, normalize_payload from crudkit.core.base import Version from crudkit.core.meta import rel_map, column_names_for_model from crudkit.core.params import is_truthy, normalize_fields_param from crudkit.core.spec import CRUDSpec, CollPred from crudkit.core.types import OrderSpec, SeekWindow from crudkit.backend import BackendInfo, make_backend_info from crudkit.projection import compile_projection import logging log = logging.getLogger("crudkit.service") # logging.getLogger("crudkit.service").setLevel(logging.DEBUG) # Ensure our debug actually prints even if the app/root logger is WARNING+ # if not log.handlers: # _h = logging.StreamHandler() # _h.setLevel(logging.DEBUG) # _h.setFormatter(logging.Formatter( # "%(asctime)s %(levelname)s %(name)s: %(message)s" # )) # log.addHandler(_h) # # log.setLevel(logging.DEBUG) # log.propagate = False @runtime_checkable class _HasID(Protocol): id: int @runtime_checkable class _HasTable(Protocol): __table__: Any @runtime_checkable class _HasADict(Protocol): def as_dict(self) -> dict: ... @runtime_checkable class _SoftDeletable(Protocol): is_deleted: bool class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol): """Surface expected by CRUDService.""" pass T = TypeVar("T", bound=_CRUDModelProto) # ---------------------------- utilities ---------------------------- def _collect_tables_from_filters(filters) -> set: seen = set() stack = list(filters or []) while stack: node = stack.pop() tbl = getattr(node, "table", None) if tbl is not None: cur = tbl while cur is not None and cur not in seen: seen.add(cur) cur = getattr(cur, "element", None) # follow only the common attributes; no generic visitor left = getattr(node, "left", None) if left is not None: stack.append(left) right = getattr(node, "right", None) if right is not None: stack.append(right) elem = getattr(node, "element", None) if elem is not None: stack.append(elem) clause = getattr(node, "clause", None) if clause is not None: stack.append(clause) clauses = getattr(node, "clauses", None) if clauses is not None: try: stack.extend(list(clauses)) except TypeError: pass return seen def _selectable_keys(sel) -> set[str]: """ Return a set of stable string keys for a selectable/alias and its base, so we can match when when different alias objects are used. """ keys: set[str] = set() cur = sel while cur is not None: k = getattr(cur, "key", None) or getattr(cur, "name", None) if isinstance(k, str) and k: keys.add(k) cur = getattr(cur, "element", None) return keys def _unwrap_ob(ob): elem = getattr(ob, "element", None) col = elem if elem is not None else ob d = getattr(ob, "_direction", None) if d is not None: is_desc = (d is operators.desc_op) or (getattr(d, "name", "").upper() == "DESC") elif isinstance(ob, UnaryExpression): op = getattr(ob, "operator", None) is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") else: is_desc = False return col, bool(is_desc) def _order_identity(col: ColumnElement): table = getattr(col, "table", None) table_key = getattr(table, "key", None) or id(table) col_key = getattr(col, "key", None) or getattr(col, "name", None) return (table_key, col_key) def _dedupe_order_by(order_by): if not order_by: return [] seen, out = set(), [] for ob in order_by: col, _ = _unwrap_ob(ob) ident = _order_identity(col) if ident in seen: continue seen.add(ident) out.append(ob) return out # ---------------------------- CRUD service ---------------------------- class CRUDService(Generic[T]): def __init__( self, model: Type[T], session_factory: Callable[[], Session], polymorphic: bool = False, *, backend: Optional[BackendInfo] = None ): self.model = model self._session_factory = session_factory self.polymorphic = polymorphic self.supports_soft_delete = hasattr(model, 'is_deleted') self._backend: Optional[BackendInfo] = backend # ---- infra @property def session(self) -> Session: try: return current_app.extensions["crudkit"]["Session"] except Exception: return self._session_factory() @property def backend(self) -> BackendInfo: if self._backend is None: bind = self.session.get_bind() eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind) self._backend = make_backend_info(eng) return self._backend def get_query(self): if self.polymorphic: poly = with_polymorphic(self.model, "*") return self.session.query(poly), poly return self.session.query(self.model), self.model # ---- common building blocks def _apply_soft_delete_criteria_for_children(self, query, plan: "CRUDService._Plan", params): # Skip if caller explicitly asked for deleted if is_truthy((params or {}).get("include_deleted")): return query seen = set() for _base_alias, rel_attr, _target_alias in plan.join_paths: prop = getattr(rel_attr, "property", None) if not prop: continue target_cls = getattr(prop.mapper, "class_", None) if not target_cls or target_cls in seen: continue seen.add(target_cls) # Only apply to models that support soft delete if hasattr(target_cls, "is_deleted"): query = query.options( with_loader_criteria( target_cls, lambda cls: cls.is_deleted == False, include_aliases=True, propagate_to_loaders=True, ) ) return query def _order_clauses(self, order_spec, invert: bool = False): clauses = [] for c, is_desc in zip(order_spec.cols, order_spec.desc): d = not is_desc if invert else is_desc clauses.append(c.desc() if d else c.asc()) return clauses def _anchor_key_for_page(self, params, per_page: int, page: int): """Return the keyset tuple for the last row of the previous page, or None for page 1.""" if page <= 1: return None query, root_alias = self.get_query() query = self._apply_not_deleted(query, root_alias, params) plan = self._plan(params, root_alias) # Make sure joins/filters match the real query query = self._apply_firsthop_strategies(query, root_alias, plan) if plan.filters: filters = self._final_filters(root_alias, plan) if filters: query = query.filter(*filters) order_spec = self._extract_order_spec(root_alias, plan.order_by) # Inner subquery ordered exactly like the real query inner = query.order_by(*self._order_clauses(order_spec, invert=False)) # Label each order-by column with deterministic, unique names labeled_cols = [] for idx, col in enumerate(order_spec.cols): base = getattr(col, "key", None) or getattr(col, "name", None) name = f"ord_{idx}_{base}" if base else f"ord_{idx}" labeled_cols.append(col.label(name)) subq = inner.with_entities(*labeled_cols).subquery() cols_on_subq = [getattr(subq.c, c.key) for c in labeled_cols] # Now the outer anchor query orders and offsets on the subquery columns anchor_q = ( self.session .query(*cols_on_subq) .select_from(subq) .order_by(*[ (c.desc() if is_desc else c.asc()) for c, is_desc in zip(cols_on_subq, order_spec.desc) ]) ) offset = max(0, (page - 1) * per_page - 1) row = anchor_q.offset(offset).limit(1).first() if not row: return None return list(row) # tuple-like -> list for _key_predicate def _apply_not_deleted(self, query, root_alias, params): if self.supports_soft_delete and not is_truthy((params or {}).get("include_deleted")): return query.filter(getattr(root_alias, "is_deleted") == False) return query def _default_order_by(self, root_alias): mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) cols = [] for col in mapper.primary_key: try: cols.append(getattr(root_alias, col.key)) except AttributeError: cols.append(col) return cols or [text("1")] def _stable_order_by(self, root_alias, given_order_by): order_by = list(given_order_by or []) if not order_by: # Safe default: primary key(s) only. No unlabeled expressions. return _dedupe_order_by(self._default_order_by(root_alias)) # Dedupe what the user gave us, then ensure PK tie-breakers exist order_by = _dedupe_order_by(order_by) mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) present = {_order_identity(_unwrap_ob(ob)[0]) for ob in order_by} for pk in mapper.primary_key: try: pk_col = getattr(root_alias, pk.key) except AttributeError: pk_col = pk ident = _order_identity(pk_col) if ident not in present: order_by.append(pk_col.asc()) present.add(ident) return order_by def _extract_order_spec(self, root_alias, given_order_by): given = self._stable_order_by(root_alias, given_order_by) cols, desc_flags = [], [] for ob in given: elem = getattr(ob, "element", None) col = elem if elem is not None else ob is_desc = False dir_attr = getattr(ob, "_direction", None) if dir_attr is not None: is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC") elif isinstance(ob, UnaryExpression): op = getattr(ob, "operator", None) is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") cols.append(col) desc_flags.append(bool(is_desc)) return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags)) def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool): if not key_vals: return None conds = [] for i, col in enumerate(spec.cols): ties = [spec.cols[j] == key_vals[j] for j in range(i)] is_desc = spec.desc[i] op = (col < key_vals[i]) if is_desc ^ backward else (col > key_vals[i]) conds.append(and_(*ties, op)) return or_(*conds) # ---- planning and application @dataclass(slots=True) class _Plan: spec: Any filters: Any order_by: Any limit: Any offset: Any root_fields: Any rel_field_names: Any root_field_names: Any collection_field_names: Any join_paths: Any filter_tables: Any filter_table_keys: Any req_fields: Any proj_opts: Any def _plan(self, params, root_alias) -> _Plan: req_fields = normalize_fields_param(params) spec = CRUDSpec(self.model, params or {}, root_alias) filters = spec.parse_filters() order_by = spec.parse_sort() limit, offset = spec.parse_pagination() root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() if params else ([], {}, {}, {}) spec.parse_includes() join_paths = tuple(spec.get_join_paths()) filter_tables = _collect_tables_from_filters(filters) fkeys = set() # Build projection opts only if there are true scalar columns requested. # Bare relationship fields like "owner" should not force root column pruning. column_names = set(column_names_for_model(self.model)) has_scalar_column_tokens = any( (("." not in f) and (f in column_names)) for f in (req_fields or []) ) _, proj_opts = (compile_projection(self.model, req_fields) if (req_fields and has_scalar_column_tokens) else ([], [])) # filter_tables = () # fkeys = set() return self._Plan( spec=spec, filters=filters, order_by=order_by, limit=limit, offset=offset, root_fields=root_fields, rel_field_names=rel_field_names, root_field_names=root_field_names, collection_field_names=collection_field_names, join_paths=join_paths, filter_tables=filter_tables, filter_table_keys=fkeys, req_fields=req_fields, proj_opts=proj_opts ) def _apply_projection_load_only(self, query, root_alias, plan: _Plan): only_cols = [c for c in plan.root_fields if isinstance(c, InstrumentedAttribute)] return query.options(Load(root_alias).load_only(*only_cols)) if only_cols else query def _apply_firsthop_strategies(self, query, root_alias, plan: _Plan): joined_rel_keys: set[str] = set() rels = rel_map(self.model) # {name: RelInfo} # Eager join to-one relationships requested as bare fields (e.g., fields=owner) requested_scalars = set(plan.root_field_names or []) for key in requested_scalars: info = rels.get(key) if info and not info.uselist and key not in joined_rel_keys: query = query.join(getattr(root_alias, key), isouter=True) joined_rel_keys.add(key) # 1) Join to-one relationships explicitly requested as bare fields requested_scalars = set(plan.root_field_names or []) # names like "owner", "supervisor" for key in requested_scalars: info = rels.get(key) if info and not info.uselist and key not in joined_rel_keys: query = query.join(getattr(root_alias, key), isouter=True) joined_rel_keys.add(key) # 2) Join to-one relationships from parsed join_paths for base_alias, rel_attr, target_alias in plan.join_paths: if base_alias is not root_alias: continue prop = getattr(rel_attr, "property", None) if prop and not prop.uselist: query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) joined_rel_keys.add(prop.key if prop is not None else rel_attr.key) # 3) Ensure to-one touched by filters is joined if plan.filter_tables: for key, info in rels.items(): if info.uselist or not info.target_cls: continue target_tbl = getattr(info.target_cls, "__table__", None) if target_tbl is not None and target_tbl in plan.filter_tables and key not in joined_rel_keys: query = query.join(getattr(root_alias, key), isouter=True) joined_rel_keys.add(key) # 4) Collections via selectinload, optionally load_only for requested child columns for base_alias, rel_attr, _target_alias in plan.join_paths: if base_alias is not root_alias: continue prop = getattr(rel_attr, "property", None) if not prop or not prop.uselist: continue opt = selectinload(rel_attr) child_names = (plan.collection_field_names or {}).get(prop.key, []) if child_names: target_cls = prop.mapper.class_ cols = [getattr(target_cls, n, None) for n in child_names] cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] if cols: opt = opt.load_only(*cols) query = query.options(opt) for path, _names in (plan.rel_field_names or {}).items(): if not path: continue # Build a chained selectinload for each relationship segment in the path first = path[0] info = rels.get(first) if not info or info.target_cls is None: continue # Start with selectinload on the first hop opt = selectinload(getattr(root_alias, first)) # Walk deeper segments cur_cls = info.target_cls for seg in path[1:]: sub = rel_map(cur_cls).get(seg) if not sub or sub.target_cls is None: # if segment isn't a relationship, we stop the chain break opt = opt.selectinload(getattr(cur_cls, seg)) cur_cls = sub.target_cls query = query.options(opt) return query def _apply_proj_opts(self, query, plan: _Plan): if not plan.proj_opts: return query try: return query.options(*plan.proj_opts) except KeyError as e: # Seen "KeyError: 'col'" when alias-column remapping meets unlabeled exprs. log.debug("Projection options disabled due to %r; proceeding without them.", e) return query except Exception as e: log.debug("Projection options failed (%r); proceeding without them.", e) return query def _projection_meta(self, plan: _Plan): if plan.req_fields: proj = list(dict.fromkeys(plan.req_fields)) return ["id"] + proj if "id" not in proj and hasattr(self.model, "id") else proj proj: list[str] = [] if plan.root_field_names: proj.extend(plan.root_field_names) if plan.root_fields: proj.extend(c.key for c in plan.root_fields if hasattr(c, "key")) for path, names in (plan.rel_field_names or {}).items(): prefix = ".".join(path) for n in names: proj.append(f"{prefix}.{n}") if proj and "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") return proj def _tag_projection(self, items, proj): if not proj: return for obj in items if isinstance(items, list) else [items]: try: setattr(obj, "__crudkit_projection__", tuple(proj)) except Exception: pass def _rebind_filters_to_firsthop_aliases(self, filters, root_alias, plan): """Make filter expressions use the exact same alias objects as our JOINs.""" if not filters: return filters # Map first-hop target selectable keysets -> the exact selectable object we JOINed with alias_map = {} for base_alias, _rel_attr, target_alias in plan.join_paths: if base_alias is not root_alias: continue sel = getattr(target_alias, "selectable", None) if sel is not None: alias_map[frozenset(_selectable_keys(sel))] = sel if not alias_map: return filters def replace(elem): tbl = getattr(elem, "table", None) if tbl is None: return elem keyset = frozenset(_selectable_keys(tbl)) new_sel = alias_map.get(keyset) if new_sel is None or new_sel is tbl: return elem colkey = getattr(elem, "key", None) or getattr(elem, "name", None) if not colkey: return elem try: return getattr(new_sel.c, colkey) except Exception: return elem return [visitors.replacement_traverse(f, {}, replace) for f in filters] def _final_filters(self, root_alias, plan): """ Return filters where: - root/to-one predicates are kept as SQLAlchemy expressions. - first-hop collection predicates (CollPred) are rebuilt into a single EXISTS via rel.any(...) with one alias per collection table. """ filters = list(plan.filters or []) if not filters: return [] # 1) Build a map of first-hop relationships: TABLE -> (rel_attr, target_cls) coll_map = {} for base_alias, rel_attr, target_alias in plan.join_paths: if base_alias is not root_alias: continue prop = getattr(rel_attr, "property", None) if not prop or not getattr(prop, "uselist", False): continue target_cls = prop.mapper.class_ tbl = getattr(target_cls, "__table__", None) if tbl is not None: coll_map[tbl] = (rel_attr, target_cls) # 2) Split raw filters into normal SQLA and CollPreds (by target table) normal_filters = [] by_table: dict[Any, list[CollPred]] = {} for f in filters: if isinstance(f, CollPred): by_table.setdefault(f.table, []).append(f) else: normal_filters.append(f) # 3) Rebuild each table group into ONE .any(...) using one alias from sqlalchemy.orm import aliased from sqlalchemy import and_ exists_filters = [] for tbl, preds in by_table.items(): if tbl not in coll_map: # Safety: if it's not a first-hop collection, ignore or raise continue rel_attr, target_cls = coll_map[tbl] ta = aliased(target_cls) built = [] for p in preds: col = getattr(ta, p.col_key) op = p.op val = p.value if op == 'icontains': built.append(col.ilike(f"%{val}%")) elif op == 'eq': built.append(col == val) elif op == 'ne': built.append(col != val) elif op == 'in': vs = val if isinstance(val, (list, tuple, set)) else [val] built.append(col.in_(vs)) elif op == 'nin': vs = val if isinstance(val, (list, tuple, set)) else [val] built.append(~col.in_(vs)) elif op == 'lt': built.append(col < val) elif op == 'lte': built.append(col <= val) elif op == 'gt': built.append(col > val) elif op == 'gte': built.append(col >= val) else: # unknown op — skip or raise continue # enforce child soft delete inside the EXISTS if hasattr(target_cls, "is_deleted"): built.append(ta.is_deleted == False) crit = and_(*built) if built else None exists_filters.append(rel_attr.of_type(ta).any(crit) if crit is not None else rel_attr.of_type(ta).any()) # 4) Final filter list = normal SQLA filters + all EXISTS filters return normal_filters + exists_filters # ---- public read ops def page(self, params=None, *, page: int = 1, per_page: int = 50, include_total: bool = True): # Ensure seek_window uses `per_page` params = dict(params or {}) params["limit"] = per_page anchor_key = self._anchor_key_for_page(params, per_page, page) win = self.seek_window(params, key=anchor_key, backward=False, include_total=include_total) pages = None if include_total and win.total is not None and per_page: # class ceil(total / per_page) // per_page pages = (win.total + per_page - 1) // per_page return { "items": win.items, "page": page, "per_page": per_page, "total": win.total, "pages": pages, "order": [str(c) for c in win.order.cols], } def seek_window( self, params: dict | None = None, *, key: list[Any] | None = None, backward: bool = False, include_total: bool = True, ) -> "SeekWindow[T]": session = self.session query, root_alias = self.get_query() query = self._apply_not_deleted(query, root_alias, params) plan = self._plan(params, root_alias) query = self._apply_projection_load_only(query, root_alias, plan) query = self._apply_firsthop_strategies(query, root_alias, plan) query = self._apply_soft_delete_criteria_for_children(query, plan, params) if plan.filters: filters = self._final_filters(root_alias, plan) if filters: query = query.filter(*filters) order_spec = self._extract_order_spec(root_alias, plan.order_by) limit = 50 if plan.limit is None else (None if plan.limit == 0 else plan.limit) if key: pred = self._key_predicate(order_spec, key, backward) if pred is not None: query = query.filter(pred) clauses = [(c.desc() if d else c.asc()) for c, d in zip(order_spec.cols, order_spec.desc)] if backward: clauses = [(c.asc() if d else c.desc()) for c, d in zip(order_spec.cols, order_spec.desc)] query = query.order_by(*clauses) if limit is not None: query = query.limit(limit) query = self._apply_proj_opts(query, plan) rows = query.all() items = list(reversed(rows)) if backward else rows proj = self._projection_meta(plan) self._tag_projection(items, proj) # cursor keys def pluck(obj): vals = [] alias_to_rel = {} for _p, rel_attr, target_alias in plan.join_paths: sel = getattr(target_alias, "selectable", None) if sel is not None: alias_to_rel[sel] = rel_attr.key for col in order_spec.cols: keyname = getattr(col, "key", None) or getattr(col, "name", None) if keyname and hasattr(obj, keyname): vals.append(getattr(obj, keyname)); continue table = getattr(col, "table", None) relname = alias_to_rel.get(table) if relname and keyname: relobj = getattr(obj, relname, None) if relobj is not None and hasattr(relobj, keyname): vals.append(getattr(relobj, keyname)); continue raise ValueError("unpluckable") return vals try: first_key = pluck(items[0]) if items else None last_key = pluck(items[-1]) if items else None except Exception: first_key = last_key = None total = None if include_total: base = session.query(getattr(root_alias, "id")) base = self._apply_not_deleted(base, root_alias, params) for _b, rel_attr, target_alias in plan.join_paths: if not bool(getattr(getattr(rel_attr, "property", None), "uselist", False)): base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True) if plan.filters: filters = self._final_filters(root_alias, plan) if filters: base = base.filter(*filters) # <-- use base, not query total = session.query(func.count()).select_from( base.order_by(None).distinct().subquery() ).scalar() or 0 if log.isEnabledFor(logging.DEBUG): log.debug("QUERY: %s", str(query)) window_limit_for_body = 0 if limit is None and (plan.limit == 0) else (limit or 50) return SeekWindow( items=items, limit=window_limit_for_body, first_key=first_key, last_key=last_key, order=order_spec, total=total, ) def get(self, id: int, params=None) -> T | None: query, root_alias = self.get_query() query = self._apply_not_deleted(query, root_alias, params) plan = self._plan(params, root_alias) query = self._apply_projection_load_only(query, root_alias, plan) query = self._apply_firsthop_strategies(query, root_alias, plan) query = self._apply_soft_delete_criteria_for_children(query, plan, params) if plan.filters: filters = self._final_filters(root_alias, plan) if filters: query = query.filter(*filters) query = query.filter(getattr(root_alias, "id") == id) query = self._apply_proj_opts(query, plan) obj = query.first() proj = self._projection_meta(plan) if obj: self._tag_projection(obj, proj) if log.isEnabledFor(logging.DEBUG): log.debug("QUERY: %s", str(query)) return obj or None def list(self, params=None) -> list[T]: query, root_alias = self.get_query() plan = self._plan(params, root_alias) query = self._apply_not_deleted(query, root_alias, params) query = self._apply_projection_load_only(query, root_alias, plan) query = self._apply_firsthop_strategies(query, root_alias, plan) query = self._apply_soft_delete_criteria_for_children(query, plan, params) if plan.filters: filters = self._final_filters(root_alias, plan) if filters: query = query.filter(*filters) order_by = plan.order_by paginating = (plan.limit is not None) or (plan.offset not in (None, 0)) if paginating and not order_by and self.backend.requires_order_by_for_offset: order_by = self._default_order_by(root_alias) order_by = _dedupe_order_by(order_by) if order_by: query = query.order_by(*order_by) default_cap = getattr(current_app.config, "CRUDKIT_DEFAULT_LIST_LIMIT", 200) if plan.offset: query = query.offset(plan.offset) if plan.limit and plan.limit > 0: query = query.limit(plan.limit) elif plan.limit is None and default_cap: query = query.limit(default_cap) query = self._apply_proj_opts(query, plan) rows = query.all() proj = self._projection_meta(plan) self._tag_projection(rows, proj) if log.isEnabledFor(logging.DEBUG): log.debug("QUERY: %s", str(query)) return rows # ---- write ops def create(self, data: dict, actor=None, *, commit: bool = True) -> T: session = self.session obj = self.model(**data) session.add(obj) session.flush() self._log_version("create", obj, actor, commit=commit) if commit: session.commit() return obj def update(self, id: int, data: dict, actor=None, *, commit: bool = True) -> T: session = self.session obj = session.get(self.model, id) if not obj: raise ValueError(f"{self.model.__name__} with ID {id} not found.") before = obj.as_dict() norm = normalize_payload(data, self.model) incoming = filter_to_columns(norm, self.model) desired = {**before, **incoming} proposed = deep_diff(before, desired, ignore_keys={"id", "created_at", "updated_at"}, list_mode="index") patch = diff_to_patch(proposed) if not patch: return obj for k, v in patch.items(): setattr(obj, k, v) dirty = any(inspect(obj).attrs[k].history.has_changes() for k in patch.keys()) if not dirty: return obj if commit: session.commit() after = obj.as_dict() actual = deep_diff(before, after, ignore_keys={"id", "created_at", "updated_at"}, list_mode="index") if not (actual["added"] or actual["removed"] or actual["changed"]): return obj self._log_version("update", obj, actor, metadata={"diff": actual, "patch": patch}, commit=commit) return obj def delete(self, id: int, hard: bool = False, actor=None, *, commit: bool = True): session = self.session obj = session.get(self.model, id) if not obj: return None if hard or not self.supports_soft_delete: session.delete(obj) else: cast(_SoftDeletable, obj).is_deleted = True if commit: session.commit() self._log_version("delete", obj, actor, commit=commit) return obj # ---- audit def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict | None = None, *, commit: bool = True): session = self.session try: try: snapshot = obj.as_dict() except Exception: snapshot = {"error": "serialize failed"} version = Version( model_name=self.model.__name__, object_id=obj.id, change_type=change_type, data=to_jsonable(snapshot), actor=str(actor) if actor else None, meta=to_jsonable(metadata) if metadata else None, ) session.add(version) if commit: session.commit() except Exception as e: log.warning(f"Version logging failed for {self.model.__name__} id={getattr(obj, 'id', '?')}: {str(e)}") session.rollback()