252 lines
8.7 KiB
Python
252 lines
8.7 KiB
Python
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
|
|
from sqlalchemy.orm import Load, Session, raiseload, with_polymorphic, Mapper
|
|
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
|
from sqlalchemy.orm.util import AliasedClass
|
|
from sqlalchemy.engine import Engine, Connection
|
|
from sqlalchemy import inspect, text
|
|
from crudkit.core.base import Version
|
|
from crudkit.core.spec import CRUDSpec
|
|
from crudkit.backend import BackendInfo, make_backend_info
|
|
|
|
@runtime_checkable
|
|
class _HasID(Protocol):
|
|
id: int
|
|
|
|
@runtime_checkable
|
|
class _HasTable(Protocol):
|
|
__table__: Any
|
|
|
|
@runtime_checkable
|
|
class _HasADict(Protocol):
|
|
def as_dict(self) -> dict: ...
|
|
|
|
@runtime_checkable
|
|
class _SoftDeletable(Protocol):
|
|
is_deleted: bool
|
|
|
|
class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol):
|
|
"""Minimal surface that our CRUD service relies on. Soft-delete is optional."""
|
|
pass
|
|
|
|
T = TypeVar("T", bound=_CRUDModelProto)
|
|
|
|
def _is_truthy(val):
|
|
return str(val).lower() in ('1', 'true', 'yes', 'on')
|
|
|
|
class CRUDService(Generic[T]):
|
|
def __init__(
|
|
self,
|
|
model: Type[T],
|
|
session_factory: Callable[[], Session],
|
|
polymorphic: bool = False,
|
|
*,
|
|
backend: Optional[BackendInfo] = None
|
|
):
|
|
self.model = model
|
|
self._session_factory = session_factory
|
|
self.polymorphic = polymorphic
|
|
self.supports_soft_delete = hasattr(model, 'is_deleted')
|
|
# Cache backend info once. If not provided, derive from session bind.
|
|
bind = self.session.get_bind()
|
|
eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind)
|
|
self.backend = backend or make_backend_info(eng)
|
|
|
|
@property
|
|
def session(self) -> Session:
|
|
return self._session_factory()
|
|
|
|
def get_query(self):
|
|
if self.polymorphic:
|
|
poly = with_polymorphic(self.model, "*")
|
|
return self.session.query(poly), poly
|
|
return self.session.query(self.model), self.model
|
|
|
|
# Helper: default ORDER BY for MSSQL when paginating without explicit order
|
|
def _default_order_by(self, root_alias):
|
|
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
|
|
cols = []
|
|
for col in mapper.primary_key:
|
|
try:
|
|
cols.append(getattr(root_alias, col.key))
|
|
except AttributeError:
|
|
cols.append(col)
|
|
return cols or [text("1")]
|
|
|
|
def get(self, id: int, params=None) -> T | None:
|
|
query, root_alias = self.get_query()
|
|
|
|
include_deleted = False
|
|
root_fields = []
|
|
root_field_names = {}
|
|
rel_field_names = {}
|
|
|
|
spec = CRUDSpec(self.model, params or {}, root_alias)
|
|
if params:
|
|
if self.supports_soft_delete:
|
|
include_deleted = _is_truthy(params.get('include_deleted'))
|
|
if self.supports_soft_delete and not include_deleted:
|
|
query = query.filter(getattr(root_alias, "is_deleted") == False)
|
|
query = query.filter(getattr(root_alias, "id") == id)
|
|
|
|
spec.parse_includes()
|
|
|
|
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
|
|
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
|
target = cast(Any, target_alias)
|
|
query = query.join(target, rel_attr.of_type(target), isouter=True)
|
|
|
|
if params:
|
|
root_fields, rel_field_names, root_field_names = spec.parse_fields()
|
|
|
|
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
|
|
if only_cols:
|
|
query = query.options(Load(root_alias).load_only(*only_cols))
|
|
|
|
for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
|
|
query = query.options(eager)
|
|
|
|
obj = query.first()
|
|
|
|
proj = []
|
|
if root_field_names:
|
|
proj.extend(root_field_names)
|
|
if root_fields:
|
|
proj.extend(c.key for c in root_fields)
|
|
for path, names in (rel_field_names or {}).items():
|
|
prefix = ".".join(path)
|
|
for n in names:
|
|
proj.append(f"{prefix}.{n}")
|
|
|
|
if proj and "id" not in proj and hasattr(self.model, "id"):
|
|
proj.insert(0, "id")
|
|
|
|
if proj:
|
|
try:
|
|
setattr(obj, "__crudkit_projection__", tuple(proj))
|
|
except Exception:
|
|
pass
|
|
|
|
return obj or None
|
|
|
|
def list(self, params=None) -> list[T]:
|
|
query, root_alias = self.get_query()
|
|
|
|
root_fields = []
|
|
root_field_names = {}
|
|
rel_field_names = {}
|
|
|
|
if params:
|
|
if self.supports_soft_delete:
|
|
include_deleted = _is_truthy(params.get('include_deleted'))
|
|
if not include_deleted:
|
|
query = query.filter(getattr(root_alias, "is_deleted") == False)
|
|
|
|
spec = CRUDSpec(self.model, params or {}, root_alias)
|
|
filters = spec.parse_filters()
|
|
order_by = spec.parse_sort()
|
|
limit, offset = spec.parse_pagination()
|
|
spec.parse_includes()
|
|
|
|
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
|
|
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
|
target = cast(Any, target_alias)
|
|
query = query.join(target, rel_attr.of_type(target), isouter=True)
|
|
|
|
if params:
|
|
root_fields, rel_field_names, root_field_names = spec.parse_fields()
|
|
|
|
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
|
|
if only_cols:
|
|
query = query.options(Load(root_alias).load_only(*only_cols))
|
|
|
|
for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
|
|
query = query.options(eager)
|
|
|
|
if filters:
|
|
query = query.filter(*filters)
|
|
|
|
# MSSQL: requires ORDER BY when using OFFSET (and SQLA will use OFFSET for limit+offset).
|
|
paginating = (limit is not None) or (offset is not None and offset != 0)
|
|
if paginating and not order_by and self.backend.requires_order_by_for_offset:
|
|
order_by = self._default_order_by(root_alias)
|
|
|
|
if order_by:
|
|
query = query.order_by(*order_by)
|
|
|
|
# Only apply offset/limit when not None.
|
|
if offset is not None and offset != 0:
|
|
query = query.offset(offset)
|
|
if limit is not None and limit > 0:
|
|
query = query.limit(limit)
|
|
|
|
rows = query.all()
|
|
|
|
proj = []
|
|
if root_field_names:
|
|
proj.extend(root_field_names)
|
|
if root_fields:
|
|
proj.extend(c.key for c in root_fields)
|
|
for path, names in (rel_field_names or {}).items():
|
|
prefix = ".".join(path)
|
|
for n in names:
|
|
proj.append(f"{prefix}.{n}")
|
|
|
|
if proj and "id" not in proj and hasattr(self.model, "id"):
|
|
proj.insert(0, "id")
|
|
|
|
if proj:
|
|
for obj in rows:
|
|
try:
|
|
setattr(obj, "__crudkit_projection__", tuple(proj))
|
|
except Exception:
|
|
pass
|
|
|
|
return rows
|
|
|
|
def create(self, data: dict, actor=None) -> T:
|
|
obj = self.model(**data)
|
|
self.session.add(obj)
|
|
self.session.commit()
|
|
self._log_version("create", obj, actor)
|
|
return obj
|
|
|
|
def update(self, id: int, data: dict, actor=None) -> T:
|
|
obj = self.get(id)
|
|
if not obj:
|
|
raise ValueError(f"{self.model.__name__} with ID {id} not found.")
|
|
valid_fields = {c.name for c in self.model.__table__.columns}
|
|
for k, v in data.items():
|
|
if k in valid_fields:
|
|
setattr(obj, k, v)
|
|
self.session.commit()
|
|
self._log_version("update", obj, actor)
|
|
return obj
|
|
|
|
def delete(self, id: int, hard: bool = False, actor = False):
|
|
obj = self.session.get(self.model, id)
|
|
if not obj:
|
|
return None
|
|
if hard or not self.supports_soft_delete:
|
|
self.session.delete(obj)
|
|
else:
|
|
soft = cast(_SoftDeletable, obj)
|
|
soft.is_deleted = True
|
|
self.session.commit()
|
|
self._log_version("delete", obj, actor)
|
|
return obj
|
|
|
|
def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict = {}):
|
|
try:
|
|
data = obj.as_dict()
|
|
except Exception:
|
|
data = {"error": "Failed to serialize object."}
|
|
version = Version(
|
|
model_name=self.model.__name__,
|
|
object_id=obj.id,
|
|
change_type=change_type,
|
|
data=data,
|
|
actor=str(actor) if actor else None,
|
|
meta=metadata
|
|
)
|
|
self.session.add(version)
|
|
self.session.commit()
|