from typing import Any, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast from sqlalchemy.orm import Load, Session, raiseload, with_polymorphic, Mapper from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.util import AliasedClass from sqlalchemy.engine import Engine, Connection from sqlalchemy import inspect, text from crudkit.core.base import Version from crudkit.core.spec import CRUDSpec from crudkit.backend import BackendInfo, make_backend_info @runtime_checkable class _HasID(Protocol): id: int @runtime_checkable class _HasTable(Protocol): __table__: Any @runtime_checkable class _HasADict(Protocol): def as_dict(self) -> dict: ... @runtime_checkable class _SoftDeletable(Protocol): is_deleted: bool class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol): """Minimal surface that our CRUD service relies on. Soft-delete is optional.""" pass T = TypeVar("T", bound=_CRUDModelProto) def _is_truthy(val): return str(val).lower() in ('1', 'true', 'yes', 'on') class CRUDService(Generic[T]): def __init__( self, model: Type[T], session: Session, polymorphic: bool = False, *, backend: Optional[BackendInfo] = None ): self.model = model self.session = session self.polymorphic = polymorphic self.supports_soft_delete = hasattr(model, 'is_deleted') # Cache backend info once. If not provided, derive from session bind. bind = self.session.get_bind() eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind) self.backend = backend or make_backend_info(eng) def get_query(self): if self.polymorphic: poly = with_polymorphic(self.model, "*") return self.session.query(poly), poly return self.session.query(self.model), self.model # Helper: default ORDER BY for MSSQL when paginating without explicit order def _default_order_by(self, root_alias): mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) cols = [] for col in mapper.primary_key: try: cols.append(getattr(root_alias, col.key)) except AttributeError: cols.append(col) return cols or [text("1")] def get(self, id: int, params=None) -> T | None: print(f"I AM GETTING A THING! A THINGS! {params}") query, root_alias = self.get_query() include_deleted = False root_fields = [] root_field_names = {} rel_field_names = {} spec = CRUDSpec(self.model, params or {}, root_alias) if params: if self.supports_soft_delete: include_deleted = _is_truthy(params.get('include_deleted')) if self.supports_soft_delete and not include_deleted: query = query.filter(getattr(root_alias, "is_deleted") == False) query = query.filter(getattr(root_alias, "id") == id) spec.parse_includes() for parent_alias, relationship_attr, target_alias in spec.get_join_paths(): rel_attr = cast(InstrumentedAttribute, relationship_attr) target = cast(Any, target_alias) query = query.join(target, rel_attr.of_type(target), isouter=True) if params: root_fields, rel_field_names, root_field_names = spec.parse_fields() only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names): query = query.options(eager) obj = query.first() proj = [] if root_field_names: proj.extend(root_field_names) if root_fields: proj.extend(c.key for c in root_fields) for path, names in (rel_field_names or {}).items(): prefix = ".".join(path) for n in names: proj.append(f"{prefix}.{n}") if proj and "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") if proj: try: setattr(obj, "__crudkit_projection__", tuple(proj)) except Exception: pass return obj or None def list(self, params=None) -> list[T]: query, root_alias = self.get_query() root_fields = [] root_field_names = {} rel_field_names = {} if params: if self.supports_soft_delete: include_deleted = _is_truthy(params.get('include_deleted')) if not include_deleted: query = query.filter(getattr(root_alias, "is_deleted") == False) spec = CRUDSpec(self.model, params or {}, root_alias) filters = spec.parse_filters() order_by = spec.parse_sort() limit, offset = spec.parse_pagination() spec.parse_includes() for parent_alias, relationship_attr, target_alias in spec.get_join_paths(): rel_attr = cast(InstrumentedAttribute, relationship_attr) target = cast(Any, target_alias) query = query.join(target, rel_attr.of_type(target), isouter=True) if params: root_fields, rel_field_names, root_field_names = spec.parse_fields() only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names): query = query.options(eager) if filters: query = query.filter(*filters) # MSSQL: requires ORDER BY when using OFFSET (and SQLA will use OFFSET for limit+offset). paginating = (limit is not None) or (offset is not None and offset != 0) if paginating and not order_by and self.backend.requires_order_by_for_offset: order_by = self._default_order_by(root_alias) if order_by: query = query.order_by(*order_by) # Only apply offset/limit when not None. if offset is not None and offset != 0: query = query.offset(offset) if limit is not None and limit > 0: query = query.limit(limit) rows = query.all() proj = [] if root_field_names: proj.extend(root_field_names) if root_fields: proj.extend(c.key for c in root_fields) for path, names in (rel_field_names or {}).items(): prefix = ".".join(path) for n in names: proj.append(f"{prefix}.{n}") if proj and "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") if proj: for obj in rows: try: setattr(obj, "__crudkit_projection__", tuple(proj)) except Exception: pass return rows def create(self, data: dict, actor=None) -> T: obj = self.model(**data) self.session.add(obj) self.session.commit() self._log_version("create", obj, actor) return obj def update(self, id: int, data: dict, actor=None) -> T: obj = self.get(id) if not obj: raise ValueError(f"{self.model.__name__} with ID {id} not found.") valid_fields = {c.name for c in self.model.__table__.columns} for k, v in data.items(): if k in valid_fields: setattr(obj, k, v) self.session.commit() self._log_version("update", obj, actor) return obj def delete(self, id: int, hard: bool = False, actor = False): obj = self.session.get(self.model, id) if not obj: return None if hard or not self.supports_soft_delete: self.session.delete(obj) else: soft = cast(_SoftDeletable, obj) soft.is_deleted = True self.session.commit() self._log_version("delete", obj, actor) return obj def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict = {}): try: data = obj.as_dict() except Exception: data = {"error": "Failed to serialize object."} version = Version( model_name=self.model.__name__, object_id=obj.id, change_type=change_type, data=data, actor=str(actor) if actor else None, meta=metadata ) self.session.add(version) self.session.commit()