from __future__ import annotations from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast from sqlalchemy import and_, func, inspect, or_, text from sqlalchemy.engine import Engine, Connection from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, RelationshipProperty from sqlalchemy.orm.attributes import InstrumentedAttribute from crudkit.core.base import Version from crudkit.core.spec import CRUDSpec from crudkit.core.types import OrderSpec, SeekWindow from crudkit.backend import BackendInfo, make_backend_info from crudkit.projection import compile_projection import logging log = logging.getLogger("crudkit.service") def _is_rel(model_cls, name: str) -> bool: try: prop = model_cls.__mapper__.relationships.get(name) return isinstance(prop, RelationshipProperty) except Exception: return False @runtime_checkable class _HasID(Protocol): id: int @runtime_checkable class _HasTable(Protocol): __table__: Any @runtime_checkable class _HasADict(Protocol): def as_dict(self) -> dict: ... @runtime_checkable class _SoftDeletable(Protocol): is_deleted: bool class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol): """Minimal surface that our CRUD service relies on. Soft-delete is optional.""" pass T = TypeVar("T", bound=_CRUDModelProto) def _is_truthy(val): return str(val).lower() in ('1', 'true', 'yes', 'on') class CRUDService(Generic[T]): def __init__( self, model: Type[T], session_factory: Callable[[], Session], polymorphic: bool = False, *, backend: Optional[BackendInfo] = None ): self.model = model self._session_factory = session_factory self.polymorphic = polymorphic self.supports_soft_delete = hasattr(model, 'is_deleted') # Cache backend info once. If not provided, derive from session bind. bind = session_factory().get_bind() eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind) self.backend = backend or make_backend_info(eng) @property def session(self) -> Session: return self._session_factory() def get_query(self): if self.polymorphic: poly = with_polymorphic(self.model, "*") return self.session.query(poly), poly return self.session.query(self.model), self.model def _apply_not_deleted(self, query, root_alias, params) -> Any: if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")): return query.filter(getattr(root_alias, "is_deleted") == False) return query def _extract_order_spec(self, root_alias, given_order_by): """ SQLAlchemy 2.x only: Normalize order_by into (cols, desc_flags). Supports plain columns and col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses. """ from sqlalchemy.sql import operators from sqlalchemy.sql.elements import UnaryExpression given = self._stable_order_by(root_alias, given_order_by) cols, desc_flags = [], [] for ob in given: # Unwrap column if this is a UnaryExpression produced by .asc()/.desc() elem = getattr(ob, "element", None) col = elem if elem is not None else ob # Detect direction in SA 2.x is_desc = False dir_attr = getattr(ob, "_direction", None) if dir_attr is not None: is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC") elif isinstance(ob, UnaryExpression): op = getattr(ob, "operator", None) is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") cols.append(col) desc_flags.append(bool(is_desc)) from crudkit.core.types import OrderSpec return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags)) def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool): if not key_vals: return None conds = [] for i, col in enumerate(spec.cols): # If NULLs possible, normalize for comparison. Example using coalesce to a sentinel: # sent_col = func.coalesce(col, literal("-∞")) sent_col = col ties = [spec.cols[j] == key_vals[j] for j in range(i)] is_desc = spec.desc[i] if not backward: op = (sent_col < key_vals[i]) if is_desc else (sent_col > key_vals[i]) else: op = (sent_col > key_vals[i]) if is_desc else (sent_col < key_vals[i]) conds.append(and_(*ties, op)) return or_(*conds) def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]: out = [] for c in spec.cols: # Only simple mapped columns supported for key pluck key = getattr(c, "key", None) or getattr(c, "name", None) if key is None or not hasattr(obj, key): raise ValueError("Order includes non-mapped or related column; cannot pluck cursor key from row object.") out.append(getattr(obj, key)) return out def seek_window( self, params: dict | None = None, *, key: list[Any] | None = None, backward: bool = False, include_total: bool = True, ) -> "SeekWindow[T]": """ Transport-agnostic keyset pagination that preserves all the goodies from `list()`: - filters, includes, joins, field projection, eager loading, soft-delete - deterministic ordering (user sort + PK tiebreakers) - forward/backward seek via `key` and `backward` Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total. """ session = self.session fields = list((params or {}).get("fields", [])) expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], []) query, root_alias = self.get_query() if proj_opts: query = query.options(*proj_opts) spec = CRUDSpec(self.model, params or {}, root_alias) filters = spec.parse_filters() order_by = spec.parse_sort() root_fields, rel_field_names, root_field_names = spec.parse_fields() # Soft delete filter # if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")): # query = query.filter(getattr(root_alias, "is_deleted") == False) query = self._apply_not_deleted(query, root_alias, params) # Parse filters first if filters: query = query.filter(*filters) # Includes + joins (so relationship fields like brand.name, location.label work) spec.parse_includes() join_paths = tuple(spec.get_join_paths()) for _, relationship_attr, target_alias in spec.get_join_paths(): rel_attr = cast(InstrumentedAttribute, relationship_attr) target = cast(Any, target_alias) query = query.join(target, rel_attr.of_type(target), isouter=True) # Fields/projection: load_only for root columns, eager loads for relationships only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) # Order + limit order_spec = self._extract_order_spec(root_alias, order_by) # SA2-only helper limit, _ = spec.parse_pagination() if limit is None: effective_limit = 50 elif limit == 0: effective_limit = None else: effective_limit = limit # Keyset predicate if key: pred = self._key_predicate(order_spec, key, backward) if pred is not None: query = query.filter(pred) # Apply ordering. For backward, invert SQL order then reverse in-memory for display. if not backward: clauses = [(c.desc() if is_desc else c.asc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)] query = query.order_by(*clauses) if effective_limit is not None: query = query.limit(effective_limit) items = query.all() else: inv_clauses = [(c.asc() if is_desc else c.desc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)] query = query.order_by(*inv_clauses) if effective_limit is not None: query = query.limit(effective_limit) items = list(reversed(query.all())) # Tag projection so your renderer knows what fields were requested if expanded_fields: proj = list(expanded_fields) else: proj = [] if root_field_names: proj.extend(root_field_names) if root_fields: proj.extend(c.key for c in root_fields) for path, names in (rel_field_names or {}).items(): prefix = ".".join(path) for n in names: proj.append(f"{prefix}.{n}") if proj and "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") if proj: for obj in items: try: setattr(obj, "__crudkit_projection__", tuple(proj)) except Exception: pass # Boundary keys for cursor encoding in the API layer first_key = self._pluck_key(items[0], order_spec) if items else None last_key = self._pluck_key(items[-1], order_spec) if items else None # Optional total that’s safe under JOINs (COUNT DISTINCT ids) total = None if include_total: base = session.query(getattr(root_alias, "id")) base = self._apply_not_deleted(base, root_alias, params) if filters: base = base.filter(*filters) for _, relationship_attr, target_alias in join_paths: # reuse rel_attr = cast(InstrumentedAttribute, relationship_attr) target = cast(Any, target_alias) base = base.join(target, rel_attr.of_type(target), isouter=True) total = session.query(func.count()).select_from( base.order_by(None).distinct().subquery() ).scalar() or 0 window_limit_for_body = 0 if effective_limit is None and limit == 0 else (effective_limit or 50) if log.isEnabledFor(logging.DEBUG): log.debug("QUERY: %s", str(query)) from crudkit.core.types import SeekWindow # avoid circulars at module top return SeekWindow( items=items, limit=window_limit_for_body, first_key=first_key, last_key=last_key, order=order_spec, total=total, ) # Helper: default ORDER BY for MSSQL when paginating without explicit order def _default_order_by(self, root_alias): mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) cols = [] for col in mapper.primary_key: try: cols.append(getattr(root_alias, col.key)) except AttributeError: cols.append(col) return cols or [text("1")] def _stable_order_by(self, root_alias, given_order_by): """ Ensure deterministic ordering by appending PK columns as tiebreakers. If no order is provided, fall back to default primary-key order. """ order_by = list(given_order_by or []) if not order_by: return self._default_order_by(root_alias) mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) pk_cols = [] for col in mapper.primary_key: try: pk_cols.append(getattr(root_alias, col.key)) except AttributeError: pk_cols.append(col) return [*order_by, *pk_cols] def get(self, id: int, params=None) -> T | None: query, root_alias = self.get_query() include_deleted = False root_fields = [] root_field_names = {} rel_field_names = {} spec = CRUDSpec(self.model, params or {}, root_alias) if params: if self.supports_soft_delete: include_deleted = _is_truthy(params.get('include_deleted')) if self.supports_soft_delete and not include_deleted: query = query.filter(getattr(root_alias, "is_deleted") == False) query = query.filter(getattr(root_alias, "id") == id) spec.parse_includes() for _, relationship_attr, target_alias in spec.get_join_paths(): rel_attr = cast(InstrumentedAttribute, relationship_attr) target = cast(Any, target_alias) query = query.join(target, rel_attr.of_type(target), isouter=True) if params: root_fields, rel_field_names, root_field_names = spec.parse_fields() req_fields = list((params or {}).get("fields", [])) expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) if proj_opts: query = query.options(*proj_opts) only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) obj = query.first() if expanded_fields: proj = list(expanded_fields) else: proj = [] if root_field_names: proj.extend(root_field_names) if root_fields: proj.extend(c.key for c in root_fields) for path, names in (rel_field_names or {}).items(): prefix = ".".join(path) for n in names: proj.append(f"{prefix}.{n}") if proj and "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") if proj and obj is not None: try: setattr(obj, "__crudkit_projection__", tuple(proj)) except Exception: pass if log.isEnabledFor(logging.DEBUG): log.debug("QUERY: %s", str(query)) return obj or None def list(self, params=None) -> list[T]: query, root_alias = self.get_query() root_fields = [] root_field_names = {} rel_field_names = {} if params: if self.supports_soft_delete: include_deleted = _is_truthy(params.get('include_deleted')) if not include_deleted: query = query.filter(getattr(root_alias, "is_deleted") == False) spec = CRUDSpec(self.model, params or {}, root_alias) filters = spec.parse_filters() order_by = spec.parse_sort() limit, offset = spec.parse_pagination() spec.parse_includes() for _, relationship_attr, target_alias in spec.get_join_paths(): rel_attr = cast(InstrumentedAttribute, relationship_attr) target = cast(Any, target_alias) query = query.join(target, rel_attr.of_type(target), isouter=True) if params: root_fields, rel_field_names, root_field_names = spec.parse_fields() only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) if filters: query = query.filter(*filters) # MSSQL: requires ORDER BY when using OFFSET (and SQLA will use OFFSET for limit+offset). paginating = (limit is not None) or (offset is not None and offset != 0) if paginating and not order_by and self.backend.requires_order_by_for_offset: order_by = self._default_order_by(root_alias) if order_by: query = query.order_by(*order_by) # Only apply offset/limit when not None. if offset is not None and offset != 0: query = query.offset(offset) if limit is not None and limit > 0: query = query.limit(limit) req_fields = list((params or {}).get("fields", [])) expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) if proj_opts: query = query.options(*proj_opts) rows = query.all() if expanded_fields: proj = list(expanded_fields) else: proj = [] if root_field_names: proj.extend(root_field_names) if root_fields: proj.extend(c.key for c in root_fields) for path, names in (rel_field_names or {}).items(): prefix = ".".join(path) for n in names: proj.append(f"{prefix}.{n}") if proj and "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") if proj: for obj in rows: try: setattr(obj, "__crudkit_projection__", tuple(proj)) except Exception: pass if log.isEnabledFor(logging.DEBUG): log.debug("QUERY: %s", str(query)) return rows def create(self, data: dict, actor=None) -> T: session = self.session obj = self.model(**data) session.add(obj) session.commit() self._log_version("create", obj, actor) return obj def update(self, id: int, data: dict, actor=None) -> T: session = self.session obj = self.get(id) if not obj: raise ValueError(f"{self.model.__name__} with ID {id} not found.") valid_fields = {c.name for c in self.model.__table__.columns} unknown = set(data) - valid_fields if unknown: raise ValueError(f"Unknown fields: {', '.join(sorted(unknown))}") for k, v in data.items(): if k in valid_fields: setattr(obj, k, v) session.commit() self._log_version("update", obj, actor) return obj def delete(self, id: int, hard: bool = False, actor = None): session = self.session obj = session.get(self.model, id) if not obj: return None if hard or not self.supports_soft_delete: session.delete(obj) else: soft = cast(_SoftDeletable, obj) soft.is_deleted = True session.commit() self._log_version("delete", obj, actor) return obj def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict | None = None): session = self.session try: data = obj.as_dict() except Exception: data = {"error": "Failed to serialize object."} version = Version( model_name=self.model.__name__, object_id=obj.id, change_type=change_type, data=data, actor=str(actor) if actor else None, meta=metadata ) session.add(version) session.commit()