from typing import Type, TypeVar, Generic, Optional from sqlalchemy.orm import Load, Session, raiseload, with_polymorphic from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy import inspect, text from crudkit.core.base import Version from crudkit.core.spec import CRUDSpec from crudkit.backend import BackendInfo, make_backend_info T = TypeVar("T") def _is_truthy(val): return str(val).lower() in ('1', 'true', 'yes', 'on') class CRUDService(Generic[T]): def __init__( self, model: Type[T], session: Session, polymorphic: bool = False, *, backend: Optional[BackendInfo] = None ): self.model = model self.session = session self.polymorphic = polymorphic self.supports_soft_delete = hasattr(model, 'is_deleted') # Cache backend info once. If not provided, derive from session bind. self.backend = backend or make_backend_info(self.session.get_bind()) def get_query(self): if self.polymorphic: poly = with_polymorphic(self.model, "*") return self.session.query(poly), poly return self.session.query(self.model), self.model # Helper: default ORDER BY for MSSQL when paginating without explicit order def _default_order_by(self, root_alias): mapper = inspect(self.model) cols = [] for col in mapper.primary_key: try: cols.append(getattr(root_alias, col.key)) except AttributeError: cols.append(col) return cols or [text("1")] def get(self, id: int, params=None) -> T | None: print(f"I AM GETTING A THING! A THINGS! {params}") query, root_alias = self.get_query() include_deleted = False root_fields = [] root_field_names = {} rel_field_names = {} spec = CRUDSpec(self.model, params or {}, root_alias) if params: if self.supports_soft_delete: include_deleted = _is_truthy(params.get('include_deleted')) if self.supports_soft_delete and not include_deleted: query = query.filter(getattr(root_alias, "is_deleted") == False) query = query.filter(getattr(root_alias, "id") == id) spec.parse_includes() for parent_alias, relationship_attr, target_alias in spec.get_join_paths(): query = query.join( target_alias, relationship_attr.of_type(target_alias), isouter=True ) if params: root_fields, rel_field_names, root_field_names = spec.parse_fields() only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names): query = query.options(eager) obj = query.first() proj = [] if root_field_names: proj.extend(root_field_names) if root_fields: proj.extend(c.key for c in root_fields) for path, names in (rel_field_names or {}).items(): prefix = ".".join(path) for n in names: proj.append(f"{prefix}.{n}") if proj and "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") if proj: try: setattr(obj, "__crudkit_projection__", tuple(proj)) except Exception: pass return obj or None def list(self, params=None) -> list[T]: query, root_alias = self.get_query() root_fields = [] root_field_names = {} rel_field_names = {} if params: if self.supports_soft_delete: include_deleted = _is_truthy(params.get('include_deleted')) if not include_deleted: query = query.filter(getattr(root_alias, "is_deleted") == False) spec = CRUDSpec(self.model, params or {}, root_alias) filters = spec.parse_filters() order_by = spec.parse_sort() limit, offset = spec.parse_pagination() spec.parse_includes() for parent_alias, relationship_attr, target_alias in spec.get_join_paths(): query = query.join( target_alias, relationship_attr.of_type(target_alias), isouter=True ) if params: root_fields, rel_field_names, root_field_names = spec.parse_fields() only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names): query = query.options(eager) if filters: query = query.filter(*filters) # MSSQL: requires ORDER BY when using OFFSET (and SQLA will use OFFSET for limit+offset). paginating = (limit is not None) or (offset is not None and offset != 0) if paginating and not order_by and self.backend.requires_order_by_for_offset: order_by = self._default_order_by(root_alias) if order_by: query = query.order_by(*order_by) # Only apply offset/limit when not None. if offset is not None and offset != 0: query = query.offset(offset) if limit is not None and limit > 0: query = query.limit(limit) rows = query.all() proj = [] if root_field_names: proj.extend(root_field_names) if root_fields: proj.extend(c.key for c in root_fields) for path, names in (rel_field_names or {}).items(): prefix = ".".join(path) for n in names: proj.append(f"{prefix}.{n}") if proj and "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") if proj: for obj in rows: try: setattr(obj, "__crudkit_projection__", tuple(proj)) except Exception: pass return rows def create(self, data: dict, actor=None) -> T: obj = self.model(**data) self.session.add(obj) self.session.commit() self._log_version("create", obj, actor) return obj def update(self, id: int, data: dict, actor=None) -> T: obj = self.get(id) if not obj: raise ValueError(f"{self.model.__name__} with ID {id} not found.") valid_fields = {c.name for c in self.model.__table__.columns} for k, v in data.items(): if k in valid_fields: setattr(obj, k, v) self.session.commit() self._log_version("update", obj, actor) return obj def delete(self, id: int, hard: bool = False, actor = False): obj = self.session.get(self.model, id) if not obj: return None if hard or not self.supports_soft_delete: self.session.delete(obj) else: obj.is_deleted = True self.session.commit() self._log_version("delete", obj, actor) return obj def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict = {}): try: data = obj.as_dict() except Exception: data = {"error": "Failed to serialize object."} version = Version( model_name=self.model.__name__, object_id=obj.id, change_type=change_type, data=data, actor=str(actor) if actor else None, meta=metadata ) self.session.add(version) self.session.commit()