from __future__ import annotations from collections.abc import Iterable from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast from sqlalchemy import and_, func, inspect, or_, text, UnaryExpression from sqlalchemy.engine import Engine, Connection from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager, selectinload from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.sql import operators from sqlalchemy.sql.elements import ColumnElement from crudkit.core.base import Version from crudkit.core.spec import CRUDSpec from crudkit.core.types import OrderSpec, SeekWindow from crudkit.backend import BackendInfo, make_backend_info from crudkit.projection import compile_projection import logging log = logging.getLogger("crudkit.service") @runtime_checkable class _HasID(Protocol): id: int @runtime_checkable class _HasTable(Protocol): __table__: Any @runtime_checkable class _HasADict(Protocol): def as_dict(self) -> dict: ... @runtime_checkable class _SoftDeletable(Protocol): is_deleted: bool class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol): """Minimal surface that our CRUD service relies on. Soft-delete is optional.""" pass T = TypeVar("T", bound=_CRUDModelProto) def _belongs_to_alias(col: Any, alias: Any) -> bool: # Try to detect if a column/expression ultimately comes from this alias. # Works for most ORM columns; complex expressions may need more. t = getattr(col, "table", None) selectable = getattr(alias, "selectable", None) return t is not None and selectable is not None and t is selectable def _paths_needed_for_sql(order_by: Iterable[Any], filters: Iterable[Any], join_paths: tuple) -> set[tuple[str, ...]]: paths: set[tuple[str, ...]] = set() # Sort columns for ob in order_by or []: col = getattr(ob, "element", ob) # unwrap UnaryExpression for path, _rel_attr, target_alias in join_paths: if _belongs_to_alias(col, target_alias): paths.add(tuple(path)) # Filter columns (best-effort) # Walk simple binary expressions def _extract_cols(expr: Any) -> Iterable[Any]: if isinstance(expr, ColumnElement): yield expr for ch in getattr(expr, "get_children", lambda: [])(): yield from _extract_cols(ch) elif hasattr(expr, "clauses"): for ch in expr.clauses: yield from _extract_cols(ch) for flt in filters or []: for col in _extract_cols(flt): for path, _rel_attr, target_alias in join_paths: if _belongs_to_alias(col, target_alias): paths.add(tuple[path]) return paths def _paths_from_fields(req_fields: list[str]) -> set[tuple[str, ...]]: out: set[tuple[str, ...]] = set() for f in req_fields: if "." in f: parts = tuple(f.split(".")[:-1]) if parts: out.add(parts) return out def _is_truthy(val): return str(val).lower() in ('1', 'true', 'yes', 'on') def _normalize_fields_param(params: dict | None) -> list[str]: if not params: return [] raw = params.get("fields") if isinstance(raw, (list, tuple)): out: list[str] = [] for item in raw: if isinstance(item, str): out.extend([p for p in (s.strip() for s in item.split(",")) if p]) return out if isinstance(raw, str): return [p for p in (s.strip() for s in raw.split(",")) if p] return [] class CRUDService(Generic[T]): def __init__( self, model: Type[T], session_factory: Callable[[], Session], polymorphic: bool = False, *, backend: Optional[BackendInfo] = None ): self.model = model self._session_factory = session_factory self.polymorphic = polymorphic self.supports_soft_delete = hasattr(model, 'is_deleted') # Cache backend info once. If not provided, derive from session bind. bind = session_factory().get_bind() eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind) self.backend = backend or make_backend_info(eng) @property def session(self) -> Session: return self._session_factory() def get_query(self): if self.polymorphic: poly = with_polymorphic(self.model, "*") return self.session.query(poly), poly return self.session.query(self.model), self.model def _apply_not_deleted(self, query, root_alias, params) -> Any: if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")): return query.filter(getattr(root_alias, "is_deleted") == False) return query def _extract_order_spec(self, root_alias, given_order_by): """ SQLAlchemy 2.x only: Normalize order_by into (cols, desc_flags). Supports plain columns and col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses. """ given = self._stable_order_by(root_alias, given_order_by) cols, desc_flags = [], [] for ob in given: # Unwrap column if this is a UnaryExpression produced by .asc()/.desc() elem = getattr(ob, "element", None) col = elem if elem is not None else ob # Detect direction in SA 2.x is_desc = False dir_attr = getattr(ob, "_direction", None) if dir_attr is not None: is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC") elif isinstance(ob, UnaryExpression): op = getattr(ob, "operator", None) is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") cols.append(col) desc_flags.append(bool(is_desc)) return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags)) def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool): if not key_vals: return None conds = [] for i, col in enumerate(spec.cols): # If NULLs possible, normalize for comparison. Example using coalesce to a sentinel: # sent_col = func.coalesce(col, literal("-∞")) sent_col = col ties = [spec.cols[j] == key_vals[j] for j in range(i)] is_desc = spec.desc[i] if not backward: op = (sent_col < key_vals[i]) if is_desc else (sent_col > key_vals[i]) else: op = (sent_col > key_vals[i]) if is_desc else (sent_col < key_vals[i]) conds.append(and_(*ties, op)) return or_(*conds) def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]: out = [] for c in spec.cols: # Only simple mapped columns supported for key pluck key = getattr(c, "key", None) or getattr(c, "name", None) if key is None or not hasattr(obj, key): raise ValueError("Order includes non-mapped or related column; cannot pluck cursor key from row object.") out.append(getattr(obj, key)) return out def seek_window( self, params: dict | None = None, *, key: list[Any] | None = None, backward: bool = False, include_total: bool = True, ) -> "SeekWindow[T]": """ Transport-agnostic keyset pagination that preserves all the goodies from `list()`: - filters, includes, joins, field projection, eager loading, soft-delete - deterministic ordering (user sort + PK tiebreakers) - forward/backward seek via `key` and `backward` Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total. """ session = self.session query, root_alias = self.get_query() # Normalize requested fields and compile projection (may skip later to avoid conflicts) fields = _normalize_fields_param(params) expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], []) spec = CRUDSpec(self.model, params or {}, root_alias) filters = spec.parse_filters() order_by = spec.parse_sort() # Field parsing for root load_only fallback root_fields, rel_field_names, root_field_names = spec.parse_fields() # Soft delete filter query = self._apply_not_deleted(query, root_alias, params) # Apply filters first if filters: query = query.filter(*filters) # Includes + join paths (dotted fields etc.) spec.parse_includes() join_paths = tuple(spec.get_join_paths()) # iterable of (path, relationship_attr, target_alias) # Decide which relationship *names* are required for SQL (filters/sort) vs display-only def _belongs_to_alias(col: Any, alias: Any) -> bool: t = getattr(col, "table", None) selectable = getattr(alias, "selectable", None) return t is not None and selectable is not None and t is selectable # 1) which relationship aliases are referenced by sort/filter sql_hops: set[str] = set() for path, relationship_attr, target_alias in join_paths: # If any ORDER BY column comes from this alias, mark it for ob in (order_by or []): col = getattr(ob, "element", ob) # unwrap UnaryExpression if _belongs_to_alias(col, target_alias): sql_hops.add(relationship_attr.key) break # If any filter expr touches this alias, mark it (best effort) if relationship_attr.key not in sql_hops: def _walk_cols(expr: Any): # Primitive walker for ColumnElement trees from sqlalchemy.sql.elements import ColumnElement if isinstance(expr, ColumnElement): yield expr for ch in getattr(expr, "get_children", lambda: [])(): yield from _walk_cols(ch) elif hasattr(expr, "clauses"): for ch in expr.clauses: yield from _walk_cols(ch) for flt in (filters or []): if any(_belongs_to_alias(c, target_alias) for c in _walk_cols(flt)): sql_hops.add(relationship_attr.key) break # 2) first-hop relationship names implied by dotted projection fields proj_hops: set[str] = {f.split(".", 1)[0] for f in fields if "." in f} # Root column projection from sqlalchemy.orm import Load # local import to match your style only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) # Relationship handling per path (avoid loader strategy conflicts) used_contains_eager = False for path, relationship_attr, target_alias in join_paths: rel_attr = cast(InstrumentedAttribute, relationship_attr) name = relationship_attr.key if name in sql_hops: # Needed for WHERE/ORDER BY: join + hydrate from that join query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.options(contains_eager(rel_attr, alias=target_alias)) used_contains_eager = True elif name in proj_hops: # Display-only: bulk-load efficiently, no join query = query.options(selectinload(rel_attr)) else: # Not needed pass # Apply projection loader options only if they won't conflict with contains_eager if proj_opts and not used_contains_eager: query = query.options(*proj_opts) # Order + limit order_spec = self._extract_order_spec(root_alias, order_by) # SA 2.x helper limit, _ = spec.parse_pagination() if limit is None: effective_limit = 50 elif limit == 0: effective_limit = None # unlimited else: effective_limit = limit # Keyset predicate if key: pred = self._key_predicate(order_spec, key, backward) if pred is not None: query = query.filter(pred) # Apply ordering. For backward, invert SQL order then reverse in-memory for display. if not backward: clauses = [(c.desc() if is_desc else c.asc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)] query = query.order_by(*clauses) if effective_limit is not None: query = query.limit(effective_limit) items = query.all() else: inv_clauses = [(c.asc() if is_desc else c.desc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)] query = query.order_by(*inv_clauses) if effective_limit is not None: query = query.limit(effective_limit) items = list(reversed(query.all())) # Tag projection so your renderer knows what fields were requested if fields: proj = list(dict.fromkeys(fields)) # dedupe, preserve order if "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") else: proj = [] if root_field_names: proj.extend(root_field_names) if root_fields: proj.extend(c.key for c in root_fields if hasattr(c, "key")) for path, names in (rel_field_names or {}).items(): prefix = ".".join(path) for n in names: proj.append(f"{prefix}.{n}") if proj and "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") if proj: for obj in items: try: setattr(obj, "__crudkit_projection__", tuple(proj)) except Exception: pass # Boundary keys for cursor encoding in the API layer first_key = self._pluck_key(items[0], order_spec) if items else None last_key = self._pluck_key(items[-1], order_spec) if items else None # Optional total that’s safe under JOINs (COUNT DISTINCT ids) total = None if include_total: base = session.query(getattr(root_alias, "id")) base = self._apply_not_deleted(base, root_alias, params) if filters: base = base.filter(*filters) # Mirror join structure for any SQL-needed relationships for path, relationship_attr, target_alias in join_paths: if relationship_attr.key in sql_hops: rel_attr = cast(InstrumentedAttribute, relationship_attr) base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True) total = session.query(func.count()).select_from( base.order_by(None).distinct().subquery() ).scalar() or 0 window_limit_for_body = 0 if effective_limit is None and limit == 0 else (effective_limit or 50) if log.isEnabledFor(logging.DEBUG): log.debug("QUERY: %s", str(query)) return SeekWindow( items=items, limit=window_limit_for_body, first_key=first_key, last_key=last_key, order=order_spec, total=total, ) # Helper: default ORDER BY for MSSQL when paginating without explicit order def _default_order_by(self, root_alias): mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) cols = [] for col in mapper.primary_key: try: cols.append(getattr(root_alias, col.key)) except AttributeError: cols.append(col) return cols or [text("1")] def _stable_order_by(self, root_alias, given_order_by): """ Ensure deterministic ordering by appending PK columns as tiebreakers. If no order is provided, fall back to default primary-key order. """ order_by = list(given_order_by or []) if not order_by: return self._default_order_by(root_alias) mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) pk_cols = [] for col in mapper.primary_key: try: pk_cols.append(getattr(root_alias, col.key)) except AttributeError: pk_cols.append(col) return [*order_by, *pk_cols] def get(self, id: int, params=None) -> T | None: """Fetch a single row by id with conflict-free eager loading and clean projection.""" query, root_alias = self.get_query() # Defaults so we can build a projection even if params is None root_fields: list[Any] = [] root_field_names: dict[str, str] = {} rel_field_names: dict[tuple[str, ...], list[str]] = {} req_fields: list[str] = _normalize_fields_param(params) # Soft-delete guard query = self._apply_not_deleted(query, root_alias, params) spec = CRUDSpec(self.model, params or {}, root_alias) # Optional extra filters (in addition to id); keep parity with list() filters = spec.parse_filters() if filters: query = query.filter(*filters) # Always filter by id query = query.filter(getattr(root_alias, "id") == id) # Includes + join paths we may need spec.parse_includes() join_paths = tuple(spec.get_join_paths()) # Field parsing to enable root load_only if params: root_fields, rel_field_names, root_field_names = spec.parse_fields() # Decide which relationship paths are needed for SQL vs display-only # For get(), there is no ORDER BY; only filters might force SQL use. sql_paths = _paths_needed_for_sql(order_by=None, filters=filters, join_paths=join_paths) proj_paths = _paths_from_fields(req_fields) # Root column projection only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) # Relationship handling per path: avoid loader strategy conflicts used_contains_eager = False for path, relationship_attr, target_alias in join_paths: rel_attr = cast(InstrumentedAttribute, relationship_attr) ptuple = tuple(path) if ptuple in sql_paths: # Needed in WHERE: join + hydrate from the join query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.options(contains_eager(rel_attr, alias=target_alias)) used_contains_eager = True elif ptuple in proj_paths: # Display-only: bulk-load efficiently query = query.options(selectinload(rel_attr)) else: pass # Projection loader options compiled from requested fields. # Skip if we used contains_eager to avoid strategy conflicts. expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) if proj_opts and not used_contains_eager: query = query.options(*proj_opts) obj = query.first() # Emit exactly what the client requested (plus id), or a reasonable fallback if req_fields: proj = list(dict.fromkeys(req_fields)) # dedupe, preserve order if "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") else: proj = [] if root_field_names: proj.extend(root_field_names) if root_fields: proj.extend(c.key for c in root_fields if hasattr(c, "key")) for path, names in (rel_field_names or {}).items(): prefix = ".".join(path) for n in names: proj.append(f"{prefix}.{n}") if proj and "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") if proj and obj is not None: try: setattr(obj, "__crudkit_projection__", tuple(proj)) except Exception: pass if log.isEnabledFor(logging.DEBUG): log.debug("QUERY: %s", str(query)) return obj or None def list(self, params=None) -> list[T]: """Offset/limit listing with smart relationship loading and clean projection.""" query, root_alias = self.get_query() # Defaults so we can reference them later even if params is None root_fields: list[Any] = [] root_field_names: dict[str, str] = {} rel_field_names: dict[tuple[str, ...], list[str]] = {} req_fields: list[str] = _normalize_fields_param(params) if params: query = self._apply_not_deleted(query, root_alias, params) spec = CRUDSpec(self.model, params or {}, root_alias) filters = spec.parse_filters() order_by = spec.parse_sort() limit, offset = spec.parse_pagination() # Includes + join paths we might need spec.parse_includes() join_paths = tuple(spec.get_join_paths()) # Field parsing for load_only on root columns root_fields, rel_field_names, root_field_names = spec.parse_fields() if filters: query = query.filter(*filters) # Determine which relationship paths are needed for SQL vs display-only sql_paths = _paths_needed_for_sql(order_by, filters, join_paths) proj_paths = _paths_from_fields(req_fields) # Root column projection only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) # Relationship handling per path used_contains_eager = False for path, relationship_attr, target_alias in join_paths: rel_attr = cast(InstrumentedAttribute, relationship_attr) ptuple = tuple(path) if ptuple in sql_paths: # Needed for WHERE/ORDER BY: join + hydrate from the join query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.options(contains_eager(rel_attr, alias=target_alias)) used_contains_eager = True elif ptuple in proj_paths: # Display-only: no join, bulk-load efficiently query = query.options(selectinload(rel_attr)) else: # Not needed at all; do nothing pass # MSSQL requires ORDER BY when OFFSET is used (SQLA uses OFFSET for limit/offset) paginating = (limit is not None) or (offset is not None and offset != 0) if paginating and not order_by and self.backend.requires_order_by_for_offset: order_by = self._default_order_by(root_alias) if order_by: query = query.order_by(*order_by) # Only apply offset/limit when not None and not zero if offset is not None and offset != 0: query = query.offset(offset) if limit is not None and limit > 0: query = query.limit(limit) # Projection loader options compiled from requested fields. # Skip if we used contains_eager to avoid loader-strategy conflicts. expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) if proj_opts and not used_contains_eager: query = query.options(*proj_opts) else: # No params means no filters/sorts/limits; still honor projection loaders if any expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) if proj_opts: query = query.options(*proj_opts) rows = query.all() # Emit exactly what the client requested (plus id), or a reasonable fallback if req_fields: proj = list(dict.fromkeys(req_fields)) # dedupe while preserving order if "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") else: proj = [] if root_field_names: proj.extend(root_field_names) if root_fields: proj.extend(c.key for c in root_fields if hasattr(c, "key")) for path, names in (rel_field_names or {}).items(): prefix = ".".join(path) for n in names: proj.append(f"{prefix}.{n}") if proj and "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") if proj: for obj in rows: try: setattr(obj, "__crudkit_projection__", tuple(proj)) except Exception: pass if log.isEnabledFor(logging.DEBUG): log.debug("QUERY: %s", str(query)) return rows def create(self, data: dict, actor=None) -> T: session = self.session obj = self.model(**data) session.add(obj) session.commit() self._log_version("create", obj, actor) return obj def update(self, id: int, data: dict, actor=None) -> T: session = self.session obj = self.get(id) if not obj: raise ValueError(f"{self.model.__name__} with ID {id} not found.") valid_fields = {c.name for c in self.model.__table__.columns} unknown = set(data) - valid_fields if unknown: raise ValueError(f"Unknown fields: {', '.join(sorted(unknown))}") for k, v in data.items(): if k in valid_fields: setattr(obj, k, v) session.commit() self._log_version("update", obj, actor) return obj def delete(self, id: int, hard: bool = False, actor = None): session = self.session obj = session.get(self.model, id) if not obj: return None if hard or not self.supports_soft_delete: session.delete(obj) else: soft = cast(_SoftDeletable, obj) soft.is_deleted = True session.commit() self._log_version("delete", obj, actor) return obj def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict | None = None): session = self.session try: data = obj.as_dict() except Exception: data = {"error": "Failed to serialize object."} version = Version( model_name=self.model.__name__, object_id=obj.id, change_type=change_type, data=data, actor=str(actor) if actor else None, meta=metadata ) session.add(version) session.commit()