from __future__ import annotations from collections.abc import Iterable from flask import current_app from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast from sqlalchemy import and_, func, inspect, or_, text from sqlalchemy.engine import Engine, Connection from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager, selectinload from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.sql import operators from sqlalchemy.sql.elements import UnaryExpression, ColumnElement from crudkit.core.base import Version from crudkit.core.spec import CRUDSpec from crudkit.core.types import OrderSpec, SeekWindow from crudkit.backend import BackendInfo, make_backend_info from crudkit.projection import compile_projection import logging log = logging.getLogger("crudkit.service") @runtime_checkable class _HasID(Protocol): id: int @runtime_checkable class _HasTable(Protocol): __table__: Any @runtime_checkable class _HasADict(Protocol): def as_dict(self) -> dict: ... @runtime_checkable class _SoftDeletable(Protocol): is_deleted: bool class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol): """Minimal surface that our CRUD service relies on. Soft-delete is optional.""" pass T = TypeVar("T", bound=_CRUDModelProto) def _hops_from_sort(params: dict | None) -> set[str]: """Extract first-hop relationship names from a sort spec like 'owner.first_name,-brand.name'.""" if not params: return set() raw = params.get("sort") tokens: list[str] = [] if isinstance(raw, str): tokens = [t.strip() for t in raw.split(",") if t.strip()] elif isinstance(raw, (list, tuple)): for item in raw: if isinstance(item, str): tokens.extend([t.strip() for t in item.split(",") if t.strip()]) hops: set[str] = set() for tok in tokens: tok = tok.lstrip("+-") if "." in tok: hops.add(tok.split(".", 1)[0]) return hops def _belongs_to_alias(col: Any, alias: Any) -> bool: # Try to detect if a column/expression ultimately comes from this alias. # Works for most ORM columns; complex expressions may need more. t = getattr(col, "table", None) selectable = getattr(alias, "selectable", None) return t is not None and selectable is not None and t is selectable def _paths_needed_for_sql(order_by: Iterable[Any], filters: Iterable[Any], join_paths: tuple) -> set[str]: hops: set[str] = set() paths: set[tuple[str, ...]] = set() # Sort columns for ob in order_by or []: col = getattr(ob, "element", ob) # unwrap UnaryExpression for _path, rel_attr, target_alias in join_paths: if _belongs_to_alias(col, target_alias): hops.add(rel_attr.key) # Filter columns (best-effort) # Walk simple binary expressions def _extract_cols(expr: Any) -> Iterable[Any]: if isinstance(expr, ColumnElement): yield expr for ch in getattr(expr, "get_children", lambda: [])(): yield from _extract_cols(ch) elif hasattr(expr, "clauses"): for ch in expr.clauses: yield from _extract_cols(ch) for flt in filters or []: for col in _extract_cols(flt): for _path, rel_attr, target_alias in join_paths: if _belongs_to_alias(col, target_alias): hops.add(rel_attr.key) return hops def _paths_from_fields(req_fields: list[str]) -> set[str]: out: set[str] = set() for f in req_fields: if "." in f: parent = f.split(".", 1)[0] if parent: out.add(parent) return out def _is_truthy(val): return str(val).lower() in ('1', 'true', 'yes', 'on') def _normalize_fields_param(params: dict | None) -> list[str]: if not params: return [] raw = params.get("fields") if isinstance(raw, (list, tuple)): out: list[str] = [] for item in raw: if isinstance(item, str): out.extend([p for p in (s.strip() for s in item.split(",")) if p]) return out if isinstance(raw, str): return [p for p in (s.strip() for s in raw.split(",")) if p] return [] class CRUDService(Generic[T]): def __init__( self, model: Type[T], session_factory: Callable[[], Session], polymorphic: bool = False, *, backend: Optional[BackendInfo] = None ): self.model = model self._session_factory = session_factory self.polymorphic = polymorphic self.supports_soft_delete = hasattr(model, 'is_deleted') self._backend: Optional[BackendInfo] = backend @property def session(self) -> Session: """Always return the Flask-scoped Session if available; otherwise the provided factory.""" try: sess = current_app.extensions["crudkit"]["Session"] return sess except Exception: return self._session_factory() @property def backend(self) -> BackendInfo: """Resolve backend info lazily against the active session's engine.""" if self._backend is None: bind = self.session.get_bind() eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind) self._backend = make_backend_info(eng) return self._backend def get_query(self): if self.polymorphic: poly = with_polymorphic(self.model, "*") return self.session.query(poly), poly return self.session.query(self.model), self.model def _debug_bind(self, where: str): try: bind = self.session.get_bind() eng = getattr(bind, "engine", bind) print(f"SERVICE BIND [{where}]: engine_id={id(eng)} url={getattr(eng, 'url', '?')} session={type(self.session).__name__}") except Exception as e: print(f"SERVICE BIND [{where}]: failed to introspect bind: {e}") def _apply_not_deleted(self, query, root_alias, params) -> Any: if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")): return query.filter(getattr(root_alias, "is_deleted") == False) return query def _extract_order_spec(self, root_alias, given_order_by): """ SQLAlchemy 2.x only: Normalize order_by into (cols, desc_flags). Supports plain columns and col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses. """ given = self._stable_order_by(root_alias, given_order_by) cols, desc_flags = [], [] for ob in given: # Unwrap column if this is a UnaryExpression produced by .asc()/.desc() elem = getattr(ob, "element", None) col = elem if elem is not None else ob # Detect direction in SA 2.x is_desc = False dir_attr = getattr(ob, "_direction", None) if dir_attr is not None: is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC") elif isinstance(ob, UnaryExpression): op = getattr(ob, "operator", None) is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") cols.append(col) desc_flags.append(bool(is_desc)) return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags)) def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool): if not key_vals: return None conds = [] for i, col in enumerate(spec.cols): # If NULLs possible, normalize for comparison. Example using coalesce to a sentinel: # sent_col = func.coalesce(col, literal("-∞")) sent_col = col ties = [spec.cols[j] == key_vals[j] for j in range(i)] is_desc = spec.desc[i] if not backward: op = (sent_col < key_vals[i]) if is_desc else (sent_col > key_vals[i]) else: op = (sent_col > key_vals[i]) if is_desc else (sent_col < key_vals[i]) conds.append(and_(*ties, op)) return or_(*conds) def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]: out = [] for c in spec.cols: # Only simple mapped columns supported for key pluck key = getattr(c, "key", None) or getattr(c, "name", None) if key is None or not hasattr(obj, key): raise ValueError("Order includes non-mapped or related column; cannot pluck cursor key from row object.") out.append(getattr(obj, key)) return out def seek_window( self, params: dict | None = None, *, key: list[Any] | None = None, backward: bool = False, include_total: bool = True, ) -> "SeekWindow[T]": """ Keyset pagination with relationship-safe filtering/sorting. Always JOIN all CRUDSpec-discovered paths first; then apply filters, sort, seek. """ session = self.session query, root_alias = self.get_query() # Requested fields → projection + optional loaders fields = _normalize_fields_param(params) expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], []) spec = CRUDSpec(self.model, params or {}, root_alias) # Parse all inputs so join_paths are populated filters = spec.parse_filters() order_by = spec.parse_sort() root_fields, rel_field_names, root_field_names = spec.parse_fields() spec.parse_includes() join_paths = tuple(spec.get_join_paths()) # Soft delete query = self._apply_not_deleted(query, root_alias, params) # Root column projection (load_only) only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) # JOIN all resolved paths, hydrate from the join used_contains_eager = False for _base_alias, rel_attr, target_alias in join_paths: query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.options(contains_eager(rel_attr, alias=target_alias)) used_contains_eager = True # Filters AFTER joins → no cartesian products if filters: query = query.filter(*filters) # Order spec (with PK tie-breakers for stability) order_spec = self._extract_order_spec(root_alias, order_by) limit, _ = spec.parse_pagination() if limit is None: effective_limit = 50 elif limit == 0: effective_limit = None # unlimited else: effective_limit = limit # Seek predicate from cursor key (if any) if key: pred = self._key_predicate(order_spec, key, backward) if pred is not None: query = query.filter(pred) # Apply ORDER and LIMIT. Backward is SQL-inverted + reverse in-memory. if not backward: clauses = [(c.desc() if is_desc else c.asc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)] query = query.order_by(*clauses) if effective_limit is not None: query = query.limit(effective_limit) items = query.all() else: inv_clauses = [(c.asc() if is_desc else c.desc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)] query = query.order_by(*inv_clauses) if effective_limit is not None: query = query.limit(effective_limit) items = list(reversed(query.all())) # Projection meta tag for renderers if fields: proj = list(dict.fromkeys(fields)) if "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") else: proj = [] if root_field_names: proj.extend(root_field_names) if root_fields: proj.extend(c.key for c in root_fields if hasattr(c, "key")) for path, names in (rel_field_names or {}).items(): prefix = ".".join(path) for n in names: proj.append(f"{prefix}.{n}") if proj and "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") if proj: for obj in items: try: setattr(obj, "__crudkit_projection__", tuple(proj)) except Exception: pass # Cursor key pluck: support related columns we hydrated via contains_eager def _pluck_key_from_obj(obj: Any) -> list[Any]: vals: list[Any] = [] alias_to_rel: dict[Any, str] = {} for _p, relationship_attr, target_alias in join_paths: sel = getattr(target_alias, "selectable", None) if sel is not None: alias_to_rel[sel] = relationship_attr.key for col in order_spec.cols: keyname = getattr(col, "key", None) or getattr(col, "name", None) if keyname and hasattr(obj, keyname): vals.append(getattr(obj, keyname)) continue table = getattr(col, "table", None) relname = alias_to_rel.get(table) if relname and keyname: relobj = getattr(obj, relname, None) if relobj is not None and hasattr(relobj, keyname): vals.append(getattr(relobj, keyname)) continue raise ValueError("unpluckable") return vals try: first_key = _pluck_key_from_obj(items[0]) if items else None last_key = _pluck_key_from_obj(items[-1]) if items else None except Exception: first_key = None last_key = None # Count DISTINCT ids with mirrored joins total = None if include_total: base = session.query(getattr(root_alias, "id")) base = self._apply_not_deleted(base, root_alias, params) # same joins as above for correctness for _base_alias, rel_attr, target_alias in join_paths: base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True) if filters: base = base.filter(*filters) total = session.query(func.count()).select_from( base.order_by(None).distinct().subquery() ).scalar() or 0 window_limit_for_body = 0 if effective_limit is None and (limit == 0) else (effective_limit or 50) if log.isEnabledFor(logging.DEBUG): log.debug("QUERY: %s", str(query)) return SeekWindow( items=items, limit=window_limit_for_body, first_key=first_key, last_key=last_key, order=order_spec, total=total, ) # Helper: default ORDER BY for MSSQL when paginating without explicit order def _default_order_by(self, root_alias): mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) cols = [] for col in mapper.primary_key: try: cols.append(getattr(root_alias, col.key)) except AttributeError: cols.append(col) return cols or [text("1")] def _stable_order_by(self, root_alias, given_order_by): """ Ensure deterministic ordering by appending PK columns as tiebreakers. If no order is provided, fall back to default primary-key order. """ order_by = list(given_order_by or []) if not order_by: return self._default_order_by(root_alias) mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) pk_cols = [] for col in mapper.primary_key: try: pk_cols.append(getattr(root_alias, col.key)) except AttributeError: pk_cols.append(col) return [*order_by, *pk_cols] def get(self, id: int, params=None) -> T | None: """ Fetch a single row by id with conflict-free eager loading and clean projection. Always JOIN any paths that CRUDSpec resolved for filters/fields/includes so related-column filters never create cartesian products. """ query, root_alias = self.get_query() # Defaults so we can build a projection even if params is None req_fields: list[str] = _normalize_fields_param(params) root_fields: list[Any] = [] root_field_names: dict[str, str] = {} rel_field_names: dict[tuple[str, ...], list[str]] = {} # Soft-delete guard first query = self._apply_not_deleted(query, root_alias, params) spec = CRUDSpec(self.model, params or {}, root_alias) # Parse everything so CRUDSpec records any join paths it needed to resolve filters = spec.parse_filters() # no ORDER BY for get() if params: root_fields, rel_field_names, root_field_names = spec.parse_fields() spec.parse_includes() join_paths = tuple(spec.get_join_paths()) # Root-column projection (load_only) only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) # JOIN all discovered paths up front; hydrate via contains_eager used_contains_eager = False for _base_alias, rel_attr, target_alias in join_paths: query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.options(contains_eager(rel_attr, alias=target_alias)) used_contains_eager = True # Apply filters (joins are in place → no cartesian products) if filters: query = query.filter(*filters) # And the id filter query = query.filter(getattr(root_alias, "id") == id) # Projection loader options compiled from requested fields. # Skip if we used contains_eager to avoid loader-strategy conflicts. expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) if proj_opts and not used_contains_eager: query = query.options(*proj_opts) obj = query.first() # Emit exactly what the client requested (plus id), or a reasonable fallback if req_fields: proj = list(dict.fromkeys(req_fields)) if "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") else: proj = [] if root_field_names: proj.extend(root_field_names) if root_fields: proj.extend(c.key for c in root_fields if hasattr(c, "key")) for path, names in (rel_field_names or {}).items(): prefix = ".".join(path) for n in names: proj.append(f"{prefix}.{n}") if proj and "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") if proj and obj is not None: try: setattr(obj, "__crudkit_projection__", tuple(proj)) except Exception: pass if log.isEnabledFor(logging.DEBUG): log.debug("QUERY: %s", str(query)) return obj or None def list(self, params=None) -> list[T]: """ Offset/limit listing with relationship-safe filtering. We always JOIN every CRUDSpec-discovered path before applying filters/sorts. """ query, root_alias = self.get_query() # Defaults so we can reference them later even if params is None req_fields: list[str] = _normalize_fields_param(params) root_fields: list[Any] = [] root_field_names: dict[str, str] = {} rel_field_names: dict[tuple[str, ...], list[str]] = {} if params: # Soft delete query = self._apply_not_deleted(query, root_alias, params) spec = CRUDSpec(self.model, params or {}, root_alias) filters = spec.parse_filters() order_by = spec.parse_sort() limit, offset = spec.parse_pagination() # Includes / fields (populates join_paths) root_fields, rel_field_names, root_field_names = spec.parse_fields() spec.parse_includes() join_paths = tuple(spec.get_join_paths()) # Root column projection (load_only) only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) # JOIN all paths we resolved and hydrate them from the join used_contains_eager = False for _base_alias, rel_attr, target_alias in join_paths: query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.options(contains_eager(rel_attr, alias=target_alias)) used_contains_eager = True # Filters AFTER joins → no cartesian products if filters: query = query.filter(*filters) # MSSQL requires ORDER BY when OFFSET is used; ensure stable PK tie-breakers paginating = (limit is not None) or (offset is not None and offset != 0) if paginating and not order_by and self.backend.requires_order_by_for_offset: order_by = self._default_order_by(root_alias) if order_by: query = query.order_by(*self._stable_order_by(root_alias, order_by)) # Offset/limit if offset is not None and offset != 0: query = query.offset(offset) if limit is not None and limit > 0: query = query.limit(limit) # Projection loaders only if we didn’t use contains_eager expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) if proj_opts and not used_contains_eager: query = query.options(*proj_opts) else: # No params; still honor projection loaders if any expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) if proj_opts: query = query.options(*proj_opts) rows = query.all() # Build projection meta for renderers if req_fields: proj = list(dict.fromkeys(req_fields)) if "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") else: proj = [] if root_field_names: proj.extend(root_field_names) if root_fields: proj.extend(c.key for c in root_fields if hasattr(c, "key")) for path, names in (rel_field_names or {}).items(): prefix = ".".join(path) for n in names: proj.append(f"{prefix}.{n}") if proj and "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") if proj: for obj in rows: try: setattr(obj, "__crudkit_projection__", tuple(proj)) except Exception: pass if log.isEnabledFor(logging.DEBUG): log.debug("QUERY: %s", str(query)) return rows def create(self, data: dict, actor=None) -> T: session = self.session obj = self.model(**data) session.add(obj) session.commit() self._log_version("create", obj, actor) return obj def update(self, id: int, data: dict, actor=None) -> T: session = self.session obj = session.get(self.model, id) if not obj: raise ValueError(f"{self.model.__name__} with ID {id} not found.") valid_fields = {c.name for c in self.model.__table__.columns} unknown = set(data) - valid_fields if unknown: raise ValueError(f"Unknown fields: {', '.join(sorted(unknown))}") for k, v in data.items(): if k in valid_fields: setattr(obj, k, v) session.commit() self._log_version("update", obj, actor) return obj def delete(self, id: int, hard: bool = False, actor = None): session = self.session obj = session.get(self.model, id) if not obj: return None if hard or not self.supports_soft_delete: session.delete(obj) else: soft = cast(_SoftDeletable, obj) soft.is_deleted = True session.commit() self._log_version("delete", obj, actor) return obj def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict | None = None): session = self.session try: data = obj.as_dict() except Exception: data = {"error": "Failed to serialize object."} version = Version( model_name=self.model.__name__, object_id=obj.id, change_type=change_type, data=data, actor=str(actor) if actor else None, meta=metadata ) session.add(version) session.commit()