inventory/crudkit/core/service.py
2025-09-11 10:29:47 -05:00

172 lines
5.9 KiB
Python

from typing import Type, TypeVar, Generic, Optional
from sqlalchemy.orm import Load, Session, raiseload, with_polymorphic
from sqlalchemy import inspect, text
from crudkit.core.base import Version
from crudkit.core.spec import CRUDSpec
from crudkit.backend import BackendInfo, make_backend_info
T = TypeVar("T")
def _is_truthy(val):
return str(val).lower() in ('1', 'true', 'yes', 'on')
class CRUDService(Generic[T]):
def __init__(
self,
model: Type[T],
session: Session,
polymorphic: bool = False,
*,
backend: Optional[BackendInfo] = None
):
self.model = model
self.session = session
self.polymorphic = polymorphic
self.supports_soft_delete = hasattr(model, 'is_deleted')
# Cache backend info once. If not provided, derive from session bind.
self.backend = backend or make_backend_info(self.session.get_bind())
def get_query(self):
if self.polymorphic:
poly = with_polymorphic(self.model, "*")
return self.session.query(poly), poly
return self.session.query(self.model), self.model
# Helper: default ORDER BY for MSSQL when paginating without explicit order
def _default_order_by(self, root_alias):
mapper = inspect(self.model)
cols = []
for col in mapper.primary_key:
try:
cols.append(getattr(root_alias, col.key))
except AttributeError:
cols.append(col)
return cols or [text("1")]
def get(self, id: int, include_deleted: bool = False) -> T | None:
query, root_alias = self.get_query()
if self.supports_soft_delete and not include_deleted:
query = query.filter(getattr(root_alias, "is_deleted") == False)
query = query.filter(getattr(root_alias, "id") == id)
obj = query.first()
return obj or None
def list(self, params=None) -> list[T]:
query, root_alias = self.get_query()
root_fields = []
rel_field_names = {}
if params:
if self.supports_soft_delete:
include_deleted = _is_truthy(params.get('include_deleted'))
if not include_deleted:
query = query.filter(getattr(root_alias, "is_deleted") == False)
spec = CRUDSpec(self.model, params or {}, root_alias)
filters = spec.parse_filters()
order_by = spec.parse_sort()
limit, offset = spec.parse_pagination()
spec.parse_includes()
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
query = query.join(
target_alias,
relationship_attr.of_type(target_alias),
isouter=True
)
if params:
root_fields, rel_field_names = spec.parse_fields()
if root_fields:
query = query.options(Load(root_alias).load_only(*root_fields))
for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
query = query.options(eager)
if filters:
query = query.filter(*filters)
# MSSQL: requires ORDER BY when using OFFSET (and SQLA will use OFFSET for limit+offset).
paginating = (limit is not None) or (offset is not None and offset != 0)
if paginating and not order_by and self.backend.requires_order_by_for_offset:
order_by = self._default_order_by(root_alias)
if order_by:
query = query.order_by(*order_by)
# Only apply offset/limit when not None.
if offset is not None and offset != 0:
query = query.offset(offset)
if limit is not None and limit > 0:
query = query.limit(limit)
rows = query.all()
proj = []
if root_fields:
proj.extend(c.key for c in root_fields)
for path, names in (rel_field_names or {}).items():
prefix = ".".join(path)
for n in names:
proj.append(f"{prefix}.{n}")
if proj and "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
if proj:
for obj in rows:
try:
setattr(obj, "__crudkit_projection__", tuple(proj))
except Exception:
pass
return rows
def create(self, data: dict, actor=None) -> T:
obj = self.model(**data)
self.session.add(obj)
self.session.commit()
self._log_version("create", obj, actor)
return obj
def update(self, id: int, data: dict, actor=None) -> T:
obj = self.get(id)
if not obj:
raise ValueError(f"{self.model.__name__} with ID {id} not found.")
valid_fields = {c.name for c in self.model.__table__.columns}
for k, v in data.items():
if k in valid_fields:
setattr(obj, k, v)
self.session.commit()
self._log_version("update", obj, actor)
return obj
def delete(self, id: int, hard: bool = False, actor = False):
obj = self.session.get(self.model, id)
if not obj:
return None
if hard or not self.supports_soft_delete:
self.session.delete(obj)
else:
obj.is_deleted = True
self.session.commit()
self._log_version("delete", obj, actor)
return obj
def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict = {}):
try:
data = obj.as_dict()
except Exception:
data = {"error": "Failed to serialize object."}
version = Version(
model_name=self.model.__name__,
object_id=obj.id,
change_type=change_type,
data=data,
actor=str(actor) if actor else None,
meta=metadata
)
self.session.add(version)
self.session.commit()