from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast from sqlalchemy import and_, func, inspect, or_, text from sqlalchemy.engine import Engine, Connection from sqlalchemy.orm import Load, Session, raiseload, with_polymorphic, Mapper from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.util import AliasedClass from sqlalchemy.sql import operators from sqlalchemy.sql.elements import UnaryExpression from crudkit.core.base import Version from crudkit.core.spec import CRUDSpec from crudkit.core.types import OrderSpec, SeekWindow from crudkit.backend import BackendInfo, make_backend_info @runtime_checkable class _HasID(Protocol): id: int @runtime_checkable class _HasTable(Protocol): __table__: Any @runtime_checkable class _HasADict(Protocol): def as_dict(self) -> dict: ... @runtime_checkable class _SoftDeletable(Protocol): is_deleted: bool class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol): """Minimal surface that our CRUD service relies on. Soft-delete is optional.""" pass T = TypeVar("T", bound=_CRUDModelProto) def _is_truthy(val): return str(val).lower() in ('1', 'true', 'yes', 'on') class CRUDService(Generic[T]): def __init__( self, model: Type[T], session_factory: Callable[[], Session], polymorphic: bool = False, *, backend: Optional[BackendInfo] = None ): self.model = model self._session_factory = session_factory self.polymorphic = polymorphic self.supports_soft_delete = hasattr(model, 'is_deleted') # Cache backend info once. If not provided, derive from session bind. bind = self.session.get_bind() eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind) self.backend = backend or make_backend_info(eng) @property def session(self) -> Session: return self._session_factory() def get_query(self): if self.polymorphic: poly = with_polymorphic(self.model, "*") return self.session.query(poly), poly return self.session.query(self.model), self.model def _extract_order_spec(self, root_alias, given_order_by): """ SQLAlchemy 2.x only: Normalize order_by into (cols, desc_flags). Supports plain columns and col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses. """ from sqlalchemy.sql import operators from sqlalchemy.sql.elements import UnaryExpression given = self._stable_order_by(root_alias, given_order_by) cols, desc_flags = [], [] for ob in given: # Unwrap column if this is a UnaryExpression produced by .asc()/.desc() elem = getattr(ob, "element", None) col = elem if elem is not None else ob # don't use "or" with SA expressions # Detect direction in SA 2.x is_desc = False dir_attr = getattr(ob, "_direction", None) if dir_attr is not None: is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC") elif isinstance(ob, UnaryExpression): op = getattr(ob, "operator", None) is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") cols.append(col) desc_flags.append(bool(is_desc)) from crudkit.core.types import OrderSpec return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags)) def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool): """ Build lexicographic predicate for keyset seek. For backward traversal, import comparisons. """ if not key_vals: return None conds = [] for i, col in enumerate(spec.cols): ties = [spec.cols[j] == key_vals[j] for j in range(i)] is_desc = spec.desc[i] if not backward: op = col < key_vals[i] if is_desc else col > key_vals[i] else: op = col > key_vals[i] if is_desc else col < key_vals[i] conds.append(and_(*ties, op)) return or_(*conds) def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]: out = [] for c in spec.cols: key = getattr(c, "key", None) or getattr(c, "name", None) out.append(getattr(obj, key)) return out def seek_window( self, params: dict | None = None, *, key: list[Any] | None = None, backward: bool = False, include_total: bool = True, ) -> "SeekWindow[T]": """ Transport-agnostic keyset pagination that preserves all the goodies from `list()`: - filters, includes, joins, field projection, eager loading, soft-delete - deterministic ordering (user sort + PK tiebreakers) - forward/backward seek via `key` and `backward` Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total. """ params = params or {} query, root_alias = self.get_query() spec = CRUDSpec(self.model, params, root_alias) # Soft delete filter if self.supports_soft_delete and not _is_truthy(params.get("include_deleted")): query = query.filter(getattr(root_alias, "is_deleted") == False) # Parse filters first filters = spec.parse_filters() if filters: query = query.filter(*filters) # Includes + joins (so relationship fields like brand.name, location.label work) spec.parse_includes() for parent_alias, relationship_attr, target_alias in spec.get_join_paths(): rel_attr = cast(InstrumentedAttribute, relationship_attr) target = cast(Any, target_alias) query = query.join(target, rel_attr.of_type(target), isouter=True) # Fields/projection: load_only for root columns, eager loads for relationships root_fields, rel_field_names, root_field_names = spec.parse_fields() only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names): query = query.options(eager) # Order + limit order_by = spec.parse_sort() order_spec = self._extract_order_spec(root_alias, order_by) # SA2-only helper limit, _ = spec.parse_pagination() if not limit or limit <= 0: limit = 50 # sensible default # Keyset predicate if key: pred = self._key_predicate(order_spec, key, backward) if pred is not None: query = query.filter(pred) # Apply ordering. For backward, invert SQL order then reverse in-memory for display. if not backward: clauses = [] for col, is_desc in zip(order_spec.cols, order_spec.desc): clauses.append(col.desc() if is_desc else col.asc()) query = query.order_by(*clauses).limit(limit) items = query.all() else: inv_clauses = [] for col, is_desc in zip(order_spec.cols, order_spec.desc): inv_clauses.append(col.asc() if is_desc else col.desc()) query = query.order_by(*inv_clauses).limit(limit) items = list(reversed(query.all())) # Tag projection so your renderer knows what fields were requested proj = [] if root_field_names: proj.extend(root_field_names) if root_fields: proj.extend(c.key for c in root_fields) for path, names in (rel_field_names or {}).items(): prefix = ".".join(path) for n in names: proj.append(f"{prefix}.{n}") if proj and "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") if proj: for obj in items: try: setattr(obj, "__crudkit_projection__", tuple(proj)) except Exception: pass # Boundary keys for cursor encoding in the API layer first_key = self._pluck_key(items[0], order_spec) if items else None last_key = self._pluck_key(items[-1], order_spec) if items else None # Optional total that’s safe under JOINs (COUNT DISTINCT ids) total = None if include_total: base = self.session.query(getattr(root_alias, "id")) if self.supports_soft_delete and not _is_truthy(params.get("include_deleted")): base = base.filter(getattr(root_alias, "is_deleted") == False) if filters: base = base.filter(*filters) # replicate the same joins used above for parent_alias, relationship_attr, target_alias in spec.get_join_paths(): rel_attr = cast(InstrumentedAttribute, relationship_attr) target = cast(Any, target_alias) base = base.join(target, rel_attr.of_type(target), isouter=True) total = self.session.query(func.count()).select_from(base.order_by(None).distinct().subquery()).scalar() or 0 from crudkit.core.types import SeekWindow # avoid circulars at module top return SeekWindow( items=items, limit=limit, first_key=first_key, last_key=last_key, order=order_spec, total=total, ) # Helper: default ORDER BY for MSSQL when paginating without explicit order def _default_order_by(self, root_alias): mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) cols = [] for col in mapper.primary_key: try: cols.append(getattr(root_alias, col.key)) except AttributeError: cols.append(col) return cols or [text("1")] def _stable_order_by(self, root_alias, given_order_by): """ Ensure deterministic ordering by appending PK columns as tiebreakers. If no order is provided, fall back to default primary-key order. """ order_by = list(given_order_by or []) if not order_by: return self._default_order_by(root_alias) mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) pk_cols = [] for col in mapper.primary_key: try: pk_cols.append(getattr(root_alias, col.key)) except ArithmeticError: pk_cols.append(col) return [*order_by, *pk_cols] def get(self, id: int, params=None) -> T | None: query, root_alias = self.get_query() include_deleted = False root_fields = [] root_field_names = {} rel_field_names = {} spec = CRUDSpec(self.model, params or {}, root_alias) if params: if self.supports_soft_delete: include_deleted = _is_truthy(params.get('include_deleted')) if self.supports_soft_delete and not include_deleted: query = query.filter(getattr(root_alias, "is_deleted") == False) query = query.filter(getattr(root_alias, "id") == id) spec.parse_includes() for parent_alias, relationship_attr, target_alias in spec.get_join_paths(): rel_attr = cast(InstrumentedAttribute, relationship_attr) target = cast(Any, target_alias) query = query.join(target, rel_attr.of_type(target), isouter=True) if params: root_fields, rel_field_names, root_field_names = spec.parse_fields() only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names): query = query.options(eager) obj = query.first() proj = [] if root_field_names: proj.extend(root_field_names) if root_fields: proj.extend(c.key for c in root_fields) for path, names in (rel_field_names or {}).items(): prefix = ".".join(path) for n in names: proj.append(f"{prefix}.{n}") if proj and "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") if proj: try: setattr(obj, "__crudkit_projection__", tuple(proj)) except Exception: pass return obj or None def list(self, params=None) -> list[T]: query, root_alias = self.get_query() root_fields = [] root_field_names = {} rel_field_names = {} if params: if self.supports_soft_delete: include_deleted = _is_truthy(params.get('include_deleted')) if not include_deleted: query = query.filter(getattr(root_alias, "is_deleted") == False) spec = CRUDSpec(self.model, params or {}, root_alias) filters = spec.parse_filters() order_by = spec.parse_sort() limit, offset = spec.parse_pagination() spec.parse_includes() for parent_alias, relationship_attr, target_alias in spec.get_join_paths(): rel_attr = cast(InstrumentedAttribute, relationship_attr) target = cast(Any, target_alias) query = query.join(target, rel_attr.of_type(target), isouter=True) if params: root_fields, rel_field_names, root_field_names = spec.parse_fields() only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names): query = query.options(eager) if filters: query = query.filter(*filters) # MSSQL: requires ORDER BY when using OFFSET (and SQLA will use OFFSET for limit+offset). paginating = (limit is not None) or (offset is not None and offset != 0) if paginating and not order_by and self.backend.requires_order_by_for_offset: order_by = self._default_order_by(root_alias) if order_by: query = query.order_by(*order_by) # Only apply offset/limit when not None. if offset is not None and offset != 0: query = query.offset(offset) if limit is not None and limit > 0: query = query.limit(limit) rows = query.all() proj = [] if root_field_names: proj.extend(root_field_names) if root_fields: proj.extend(c.key for c in root_fields) for path, names in (rel_field_names or {}).items(): prefix = ".".join(path) for n in names: proj.append(f"{prefix}.{n}") if proj and "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") if proj: for obj in rows: try: setattr(obj, "__crudkit_projection__", tuple(proj)) except Exception: pass return rows def create(self, data: dict, actor=None) -> T: obj = self.model(**data) self.session.add(obj) self.session.commit() self._log_version("create", obj, actor) return obj def update(self, id: int, data: dict, actor=None) -> T: obj = self.get(id) if not obj: raise ValueError(f"{self.model.__name__} with ID {id} not found.") valid_fields = {c.name for c in self.model.__table__.columns} for k, v in data.items(): if k in valid_fields: setattr(obj, k, v) self.session.commit() self._log_version("update", obj, actor) return obj def delete(self, id: int, hard: bool = False, actor = False): obj = self.session.get(self.model, id) if not obj: return None if hard or not self.supports_soft_delete: self.session.delete(obj) else: soft = cast(_SoftDeletable, obj) soft.is_deleted = True self.session.commit() self._log_version("delete", obj, actor) return obj def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict = {}): try: data = obj.as_dict() except Exception: data = {"error": "Failed to serialize object."} version = Version( model_name=self.model.__name__, object_id=obj.id, change_type=change_type, data=data, actor=str(actor) if actor else None, meta=metadata ) self.session.add(version) self.session.commit()