inventory/crudkit/core/service.py

644 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Any, Callable, Dict, Iterable, List, Tuple, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
from sqlalchemy import and_, func, inspect, or_, text
from sqlalchemy.engine import Engine, Connection
from sqlalchemy.orm import Load, Session, raiseload, selectinload, with_polymorphic, Mapper, RelationshipProperty, class_mapper, ColumnProperty
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.base import NO_VALUE
from sqlalchemy.orm.util import AliasedClass
from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import UnaryExpression
from crudkit.core.base import Version
from crudkit.core.spec import CRUDSpec
from crudkit.core.types import OrderSpec, SeekWindow
from crudkit.backend import BackendInfo, make_backend_info
def _expand_requires(model_cls, fields):
out, seen = [], set()
def add(f):
if f not in seen:
seen.add(f); out.append(f)
for f in fields:
add(f)
parts = f.split(".")
cur_cls = model_cls
prefix = []
for p in parts[:-1]:
rel = getattr(cur_cls.__mapper__.relationships, 'get', lambda _: None)(p)
if not rel:
cur_cls = None
break
cur_cls = rel.mapper.class_
prefix.append(p)
if cur_cls is None:
continue
leaf = parts[-1]
deps = (getattr(cur_cls, "__crudkit_field_requires__", {}) or {}).get(leaf)
if not deps:
continue
pre = ".".join(prefix)
for dep in deps:
add(f"{pre + '.' if pre else ''}{dep}")
return out
def _is_rel(model_cls, name: str) -> bool:
try:
prop = model_cls.__mapper__.relationships.get(name)
return isinstance(prop, RelationshipProperty)
except Exception:
return False
def _is_instrumented_column(attr) -> bool:
try:
return hasattr(attr, "property") and isinstance(attr.property, ColumnProperty)
except Exception:
return False
def _loader_options_for_fields(root_alias, model_cls, fields: list[str]) -> list[Load]:
"""
For bare MANYTOONE names in fields (e.g. "location"), selectinload the relationship
and only fetch the related PK. This is enough for preselecting <select> inputs
without projecting the FK column on the root model.
"""
opts: list[Load] = []
if not fields:
return opts
mapper = class_mapper(model_cls)
for name in fields:
prop = mapper.relationships.get(name)
if not isinstance(prop, RelationshipProperty):
continue
if prop.direction.name != "MANYTOONE":
continue
rel_attr = getattr(root_alias, name)
target_cls = prop.mapper.class_
# load_only PK if present; else just selectinload
id_attr = getattr(target_cls, "id", None)
if id_attr is not None:
opts.append(selectinload(rel_attr).load_only(id_attr))
else:
opts.append(selectinload(rel_attr))
return opts
@runtime_checkable
class _HasID(Protocol):
id: int
@runtime_checkable
class _HasTable(Protocol):
__table__: Any
@runtime_checkable
class _HasADict(Protocol):
def as_dict(self) -> dict: ...
@runtime_checkable
class _SoftDeletable(Protocol):
is_deleted: bool
class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol):
"""Minimal surface that our CRUD service relies on. Soft-delete is optional."""
pass
T = TypeVar("T", bound=_CRUDModelProto)
def _is_truthy(val):
return str(val).lower() in ('1', 'true', 'yes', 'on')
class CRUDService(Generic[T]):
def __init__(
self,
model: Type[T],
session_factory: Callable[[], Session],
polymorphic: bool = False,
*,
backend: Optional[BackendInfo] = None
):
self.model = model
self._session_factory = session_factory
self.polymorphic = polymorphic
self.supports_soft_delete = hasattr(model, 'is_deleted')
# Cache backend info once. If not provided, derive from session bind.
bind = self.session.get_bind()
eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind)
self.backend = backend or make_backend_info(eng)
@property
def session(self) -> Session:
return self._session_factory()
def get_query(self):
if self.polymorphic:
poly = with_polymorphic(self.model, "*")
return self.session.query(poly), poly
return self.session.query(self.model), self.model
def _resolve_required_includes(self, root_alias: Any, rel_field_names: Dict[Tuple[str, ...], List[str]]) -> List[Any]:
"""
For each dotted path like ("location"), -> ["label"], look up the target
model's __crudkit_field_requires__ for the terminal field and produce
selectinload options prefixed with the relationship path, e.g.:
Room.__crudkit_field_requires__['label'] = ['room_function']
=> selectinload(root.location).selectinload(Room.room_function)
"""
opts: List[Any] = []
root_mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
for path, names in (rel_field_names or {}).items():
if not path:
continue
current_mapper = root_mapper
rel_props: List[RelationshipProperty] = []
valid = True
for step in path:
rel = current_mapper.relationships.get(step)
if not isinstance(rel, RelationshipProperty):
valid = False
break
rel_props.append(rel)
current_mapper = cast(Mapper[Any], inspect(rel.entity.entity))
if not valid or not rel_props:
continue
first = rel_props[0]
base_loader = selectinload(getattr(root_alias, first.key))
for i in range(1, len(rel_props)):
prev_target_cls = rel_props[i - 1].mapper.class_
hop_attr = getattr(prev_target_cls, rel_props[i].key)
base_loader = base_loader.selectinload(hop_attr)
target_cls = rel_props[-1].mapper.class_
requires = getattr(target_cls, "__crudkit_field_requires__", None)
if not isinstance(requires, dict):
continue
for field_name in names:
needed: Iterable[str] = requires.get(field_name, []) or []
for rel_need in needed:
rel_prop2 = target_cls.__mapper__.relationships.get(rel_need)
if not isinstance(rel_prop2, RelationshipProperty):
continue
dep_attr = getattr(target_cls, rel_prop2.key)
opts.append(base_loader.selectinload(dep_attr))
return opts
def _extract_order_spec(self, root_alias, given_order_by):
"""
SQLAlchemy 2.x only:
Normalize order_by into (cols, desc_flags). Supports plain columns and
col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses.
"""
from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import UnaryExpression
given = self._stable_order_by(root_alias, given_order_by)
cols, desc_flags = [], []
for ob in given:
# Unwrap column if this is a UnaryExpression produced by .asc()/.desc()
elem = getattr(ob, "element", None)
col = elem if elem is not None else ob # don't use "or" with SA expressions
# Detect direction in SA 2.x
is_desc = False
dir_attr = getattr(ob, "_direction", None)
if dir_attr is not None:
is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC")
elif isinstance(ob, UnaryExpression):
op = getattr(ob, "operator", None)
is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC")
cols.append(col)
desc_flags.append(bool(is_desc))
from crudkit.core.types import OrderSpec
return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags))
def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool):
"""
Build lexicographic predicate for keyset seek.
For backward traversal, import comparisons.
"""
if not key_vals:
return None
conds = []
for i, col in enumerate(spec.cols):
ties = [spec.cols[j] == key_vals[j] for j in range(i)]
is_desc = spec.desc[i]
if not backward:
op = col < key_vals[i] if is_desc else col > key_vals[i]
else:
op = col > key_vals[i] if is_desc else col < key_vals[i]
conds.append(and_(*ties, op))
return or_(*conds)
def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]:
out = []
for c in spec.cols:
key = getattr(c, "key", None) or getattr(c, "name", None)
out.append(getattr(obj, key))
return out
def seek_window(
self,
params: dict | None = None,
*,
key: list[Any] | None = None,
backward: bool = False,
include_total: bool = True,
) -> "SeekWindow[T]":
"""
Transport-agnostic keyset pagination that preserves all the goodies from `list()`:
- filters, includes, joins, field projection, eager loading, soft-delete
- deterministic ordering (user sort + PK tiebreakers)
- forward/backward seek via `key` and `backward`
Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total.
"""
fields = list(params.get("fields", []))
if fields:
fields = _expand_requires(self.model, fields)
params = {**params, "fields": fields}
query, root_alias = self.get_query()
spec = CRUDSpec(self.model, params or {}, root_alias)
filters = spec.parse_filters()
order_by = spec.parse_sort()
root_fields, rel_field_names, root_field_names = spec.parse_fields()
seen_rel_roots = set()
for path, names in (rel_field_names or {}).items():
if not path:
continue
rel_name = path[0]
if rel_name in seen_rel_roots:
continue
if _is_rel(self.model, rel_name):
rel_attr = getattr(root_alias, rel_name, None)
if rel_attr is not None:
query = query.options(selectinload(rel_attr))
seen_rel_roots.add(rel_name)
# Soft delete filter
if self.supports_soft_delete and not _is_truthy(params.get("include_deleted")):
query = query.filter(getattr(root_alias, "is_deleted") == False)
# Parse filters first
if filters:
query = query.filter(*filters)
# Includes + joins (so relationship fields like brand.name, location.label work)
spec.parse_includes()
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias)
query = query.join(target, rel_attr.of_type(target), isouter=True)
# Fields/projection: load_only for root columns, eager loads for relationships
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols))
# for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
# query = query.options(eager)
for opt in self._resolve_required_includes(root_alias, rel_field_names):
query = query.options(opt)
# Order + limit
order_spec = self._extract_order_spec(root_alias, order_by) # SA2-only helper
limit, _ = spec.parse_pagination()
if not limit or limit <= 0:
limit = 50 # sensible default
# Keyset predicate
if key:
pred = self._key_predicate(order_spec, key, backward)
if pred is not None:
query = query.filter(pred)
# Apply ordering. For backward, invert SQL order then reverse in-memory for display.
if not backward:
clauses = []
for col, is_desc in zip(order_spec.cols, order_spec.desc):
clauses.append(col.desc() if is_desc else col.asc())
query = query.order_by(*clauses).limit(limit)
items = query.all()
else:
inv_clauses = []
for col, is_desc in zip(order_spec.cols, order_spec.desc):
inv_clauses.append(col.asc() if is_desc else col.desc())
query = query.order_by(*inv_clauses).limit(limit)
items = list(reversed(query.all()))
# Tag projection so your renderer knows what fields were requested
proj = []
if root_field_names:
proj.extend(root_field_names)
if root_fields:
proj.extend(c.key for c in root_fields)
for path, names in (rel_field_names or {}).items():
prefix = ".".join(path)
for n in names:
proj.append(f"{prefix}.{n}")
if proj and "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
if proj:
for obj in items:
try:
setattr(obj, "__crudkit_projection__", tuple(proj))
except Exception:
pass
# Boundary keys for cursor encoding in the API layer
first_key = self._pluck_key(items[0], order_spec) if items else None
last_key = self._pluck_key(items[-1], order_spec) if items else None
# Optional total thats safe under JOINs (COUNT DISTINCT ids)
total = None
if include_total:
base = self.session.query(getattr(root_alias, "id"))
if self.supports_soft_delete and not _is_truthy(params.get("include_deleted")):
base = base.filter(getattr(root_alias, "is_deleted") == False)
if filters:
base = base.filter(*filters)
# replicate the same joins used above
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias)
base = base.join(target, rel_attr.of_type(target), isouter=True)
total = self.session.query(func.count()).select_from(base.order_by(None).distinct().subquery()).scalar() or 0
from crudkit.core.types import SeekWindow # avoid circulars at module top
return SeekWindow(
items=items,
limit=limit,
first_key=first_key,
last_key=last_key,
order=order_spec,
total=total,
)
# Helper: default ORDER BY for MSSQL when paginating without explicit order
def _default_order_by(self, root_alias):
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
cols = []
for col in mapper.primary_key:
try:
cols.append(getattr(root_alias, col.key))
except AttributeError:
cols.append(col)
return cols or [text("1")]
def _stable_order_by(self, root_alias, given_order_by):
"""
Ensure deterministic ordering by appending PK columns as tiebreakers.
If no order is provided, fall back to default primary-key order.
"""
order_by = list(given_order_by or [])
if not order_by:
return self._default_order_by(root_alias)
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
pk_cols = []
for col in mapper.primary_key:
try:
pk_cols.append(getattr(root_alias, col.key))
except ArithmeticError:
pk_cols.append(col)
return [*order_by, *pk_cols]
def get(self, id: int, params=None) -> T | None:
query, root_alias = self.get_query()
include_deleted = False
root_fields = []
root_field_names = {}
rel_field_names = {}
spec = CRUDSpec(self.model, params or {}, root_alias)
if params:
if self.supports_soft_delete:
include_deleted = _is_truthy(params.get('include_deleted'))
if self.supports_soft_delete and not include_deleted:
query = query.filter(getattr(root_alias, "is_deleted") == False)
query = query.filter(getattr(root_alias, "id") == id)
spec.parse_includes()
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias)
query = query.join(target, rel_attr.of_type(target), isouter=True)
if params:
root_fields, rel_field_names, root_field_names = spec.parse_fields()
if rel_field_names:
seen_rel_roots = set()
for path, names in rel_field_names.items():
if not path:
continue
rel_name = path[0]
if rel_name in seen_rel_roots:
continue
if _is_rel(self.model, rel_name):
rel_attr = getattr(root_alias, rel_name, None)
if rel_attr is not None:
query = query.options(selectinload(rel_attr))
seen_rel_roots.add(rel_name)
fields = (params or {}).get("fields") if isinstance(params, dict) else None
if fields:
for opt in _loader_options_for_fields(root_alias, self.model, fields):
query = query.options(opt)
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols))
# for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
# query = query.options(eager)
if params:
fields = params.get("fields") or []
for opt in _loader_options_for_fields(root_alias, self.model, fields):
query = query.options(opt)
obj = query.first()
proj = []
if root_field_names:
proj.extend(root_field_names)
if root_fields:
proj.extend(c.key for c in root_fields)
for path, names in (rel_field_names or {}).items():
prefix = ".".join(path)
for n in names:
proj.append(f"{prefix}.{n}")
if proj and "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
if proj:
try:
setattr(obj, "__crudkit_projection__", tuple(proj))
except Exception:
pass
return obj or None
def list(self, params=None) -> list[T]:
query, root_alias = self.get_query()
root_fields = []
root_field_names = {}
rel_field_names = {}
if params:
if self.supports_soft_delete:
include_deleted = _is_truthy(params.get('include_deleted'))
if not include_deleted:
query = query.filter(getattr(root_alias, "is_deleted") == False)
spec = CRUDSpec(self.model, params or {}, root_alias)
filters = spec.parse_filters()
order_by = spec.parse_sort()
limit, offset = spec.parse_pagination()
spec.parse_includes()
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias)
query = query.join(target, rel_attr.of_type(target), isouter=True)
if params:
root_fields, rel_field_names, root_field_names = spec.parse_fields()
if rel_field_names:
seen_rel_roots = set()
for path, names in rel_field_names.items():
if not path:
continue
rel_name = path[0]
if rel_name in seen_rel_roots:
continue
if _is_rel(self.model, rel_name):
rel_attr = getattr(root_alias, rel_name, None)
if rel_attr is not None:
query = query.options(selectinload(rel_attr))
seen_rel_roots.add(rel_name)
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols))
# for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
# query = query.options(eager)
if params:
fields = params.get("fields") or []
for opt in _loader_options_for_fields(root_alias, self.model, fields):
query = query.options(opt)
if filters:
query = query.filter(*filters)
# MSSQL: requires ORDER BY when using OFFSET (and SQLA will use OFFSET for limit+offset).
paginating = (limit is not None) or (offset is not None and offset != 0)
if paginating and not order_by and self.backend.requires_order_by_for_offset:
order_by = self._default_order_by(root_alias)
if order_by:
query = query.order_by(*order_by)
# Only apply offset/limit when not None.
if offset is not None and offset != 0:
query = query.offset(offset)
if limit is not None and limit > 0:
query = query.limit(limit)
rows = query.all()
proj = []
if root_field_names:
proj.extend(root_field_names)
if root_fields:
proj.extend(c.key for c in root_fields)
for path, names in (rel_field_names or {}).items():
prefix = ".".join(path)
for n in names:
proj.append(f"{prefix}.{n}")
if proj and "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
if proj:
for obj in rows:
try:
setattr(obj, "__crudkit_projection__", tuple(proj))
except Exception:
pass
return rows
def create(self, data: dict, actor=None) -> T:
obj = self.model(**data)
self.session.add(obj)
self.session.commit()
self._log_version("create", obj, actor)
return obj
def update(self, id: int, data: dict, actor=None) -> T:
obj = self.get(id)
if not obj:
raise ValueError(f"{self.model.__name__} with ID {id} not found.")
valid_fields = {c.name for c in self.model.__table__.columns}
for k, v in data.items():
if k in valid_fields:
setattr(obj, k, v)
self.session.commit()
self._log_version("update", obj, actor)
return obj
def delete(self, id: int, hard: bool = False, actor = False):
obj = self.session.get(self.model, id)
if not obj:
return None
if hard or not self.supports_soft_delete:
self.session.delete(obj)
else:
soft = cast(_SoftDeletable, obj)
soft.is_deleted = True
self.session.commit()
self._log_version("delete", obj, actor)
return obj
def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict = {}):
try:
data = obj.as_dict()
except Exception:
data = {"error": "Failed to serialize object."}
version = Version(
model_name=self.model.__name__,
object_id=obj.id,
change_type=change_type,
data=data,
actor=str(actor) if actor else None,
meta=metadata
)
self.session.add(version)
self.session.commit()