inventory/crudkit/core/service.py

704 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
from collections.abc import Iterable
from flask import current_app
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
from sqlalchemy import and_, func, inspect, or_, text
from sqlalchemy.engine import Engine, Connection
from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager, selectinload
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import UnaryExpression, ColumnElement
from crudkit.core.base import Version
from crudkit.core.spec import CRUDSpec
from crudkit.core.types import OrderSpec, SeekWindow
from crudkit.backend import BackendInfo, make_backend_info
from crudkit.projection import compile_projection
import logging
log = logging.getLogger("crudkit.service")
@runtime_checkable
class _HasID(Protocol):
id: int
@runtime_checkable
class _HasTable(Protocol):
__table__: Any
@runtime_checkable
class _HasADict(Protocol):
def as_dict(self) -> dict: ...
@runtime_checkable
class _SoftDeletable(Protocol):
is_deleted: bool
class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol):
"""Minimal surface that our CRUD service relies on. Soft-delete is optional."""
pass
T = TypeVar("T", bound=_CRUDModelProto)
def _unwrap_ob(ob):
"""Return (col, is_desc) from an ORDER BY element (handles .asc()/.desc())."""
col = getattr(ob, "element", None)
if col is None:
col = ob
is_desc = False
dir_attr = getattr(ob, "_direction", None)
if dir_attr is not None:
is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC")
elif isinstance(ob, UnaryExpression):
op = getattr(ob, "operator", None)
is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC")
return col, bool(is_desc)
def _order_identity(col: ColumnElement):
"""
Build a stable identity for a column suitable for deduping.
We ignore direction here. Duplicates are duplicates regardless of ASC/DESC.
"""
table = getattr(col, "table", None)
table_key = getattr(table, "key", None) or id(table)
col_key = getattr(col, "key", None) or getattr(col, "name", None)
return (table_key, col_key)
def _dedupe_order_by(order_by):
"""Remove duplicate ORDER BY entries (by column identity, ignoring direction)."""
if not order_by:
return []
seen = set()
out = []
for ob in order_by:
col, _ = _unwrap_ob(ob)
ident = _order_identity(col)
if ident in seen:
continue
seen.add(ident)
out.append(ob)
return out
def _is_truthy(val):
return str(val).lower() in ('1', 'true', 'yes', 'on')
def _normalize_fields_param(params: dict | None) -> list[str]:
if not params:
return []
raw = params.get("fields")
if isinstance(raw, (list, tuple)):
out: list[str] = []
for item in raw:
if isinstance(item, str):
out.extend([p for p in (s.strip() for s in item.split(",")) if p])
return out
if isinstance(raw, str):
return [p for p in (s.strip() for s in raw.split(",")) if p]
return []
class CRUDService(Generic[T]):
def __init__(
self,
model: Type[T],
session_factory: Callable[[], Session],
polymorphic: bool = False,
*,
backend: Optional[BackendInfo] = None
):
self.model = model
self._session_factory = session_factory
self.polymorphic = polymorphic
self.supports_soft_delete = hasattr(model, 'is_deleted')
self._backend: Optional[BackendInfo] = backend
@property
def session(self) -> Session:
"""Always return the Flask-scoped Session if available; otherwise the provided factory."""
try:
sess = current_app.extensions["crudkit"]["Session"]
return sess
except Exception:
return self._session_factory()
@property
def backend(self) -> BackendInfo:
"""Resolve backend info lazily against the active session's engine."""
if self._backend is None:
bind = self.session.get_bind()
eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind)
self._backend = make_backend_info(eng)
return self._backend
def get_query(self):
if self.polymorphic:
poly = with_polymorphic(self.model, "*")
return self.session.query(poly), poly
return self.session.query(self.model), self.model
def _debug_bind(self, where: str):
try:
bind = self.session.get_bind()
eng = getattr(bind, "engine", bind)
print(f"SERVICE BIND [{where}]: engine_id={id(eng)} url={getattr(eng, 'url', '?')} session={type(self.session).__name__}")
except Exception as e:
print(f"SERVICE BIND [{where}]: failed to introspect bind: {e}")
def _apply_not_deleted(self, query, root_alias, params) -> Any:
if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")):
return query.filter(getattr(root_alias, "is_deleted") == False)
return query
def _extract_order_spec(self, root_alias, given_order_by):
"""
SQLAlchemy 2.x only:
Normalize order_by into (cols, desc_flags). Supports plain columns and
col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses.
"""
given = self._stable_order_by(root_alias, given_order_by)
cols, desc_flags = [], []
for ob in given:
# Unwrap column if this is a UnaryExpression produced by .asc()/.desc()
elem = getattr(ob, "element", None)
col = elem if elem is not None else ob
# Detect direction in SA 2.x
is_desc = False
dir_attr = getattr(ob, "_direction", None)
if dir_attr is not None:
is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC")
elif isinstance(ob, UnaryExpression):
op = getattr(ob, "operator", None)
is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC")
cols.append(col)
desc_flags.append(bool(is_desc))
return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags))
def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool):
if not key_vals:
return None
conds = []
for i, col in enumerate(spec.cols):
# If NULLs possible, normalize for comparison. Example using coalesce to a sentinel:
# sent_col = func.coalesce(col, literal("-∞"))
sent_col = col
ties = [spec.cols[j] == key_vals[j] for j in range(i)]
is_desc = spec.desc[i]
if not backward:
op = (sent_col < key_vals[i]) if is_desc else (sent_col > key_vals[i])
else:
op = (sent_col > key_vals[i]) if is_desc else (sent_col < key_vals[i])
conds.append(and_(*ties, op))
return or_(*conds)
def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]:
out = []
for c in spec.cols:
# Only simple mapped columns supported for key pluck
key = getattr(c, "key", None) or getattr(c, "name", None)
if key is None or not hasattr(obj, key):
raise ValueError("Order includes non-mapped or related column; cannot pluck cursor key from row object.")
out.append(getattr(obj, key))
return out
def seek_window(
self,
params: dict | None = None,
*,
key: list[Any] | None = None,
backward: bool = False,
include_total: bool = True,
) -> "SeekWindow[T]":
"""
Keyset pagination with relationship-safe filtering/sorting.
Always JOIN all CRUDSpec-discovered paths first; then apply filters, sort, seek.
"""
session = self.session
query, root_alias = self.get_query()
# Requested fields → projection + optional loaders
fields = _normalize_fields_param(params)
expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], [])
spec = CRUDSpec(self.model, params or {}, root_alias)
# Parse all inputs so join_paths are populated
filters = spec.parse_filters()
order_by = spec.parse_sort()
root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields()
spec.parse_includes()
join_paths = tuple(spec.get_join_paths())
# Soft delete
query = self._apply_not_deleted(query, root_alias, params)
# Root column projection (load_only)
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols))
# Detect first hops that have deeper, nested tails requested (e.g. "contact.supervisor")
nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 }
used_contains_eager = False
for base_alias, rel_attr, target_alias in join_paths:
is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False))
is_nested_firsthop = rel_attr.key in nested_first_hops
if is_collection or is_nested_firsthop:
# Use selectinload so deeper hops can chain cleanly (and to avoid
# contains_eager/loader conflicts on nested paths).
opt = selectinload(rel_attr)
# Narrow columns for collections if we know child scalar names
if is_collection:
child_names = (collection_field_names or {}).get(rel_attr.key, [])
if child_names:
target_cls = rel_attr.property.mapper.class_
cols = [getattr(target_cls, n, None) for n in child_names]
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
if cols:
opt = opt.load_only(*cols)
query = query.options(opt)
else:
# Simple first-hop scalar rel with no deeper tails: safe to join + contains_eager
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True
# Filters AFTER joins → no cartesian products
if filters:
query = query.filter(*filters)
# Order spec (with PK tie-breakers for stability)
order_spec = self._extract_order_spec(root_alias, order_by)
limit, _ = spec.parse_pagination()
if limit is None:
effective_limit = 50
elif limit == 0:
effective_limit = None # unlimited
else:
effective_limit = limit
# Seek predicate from cursor key (if any)
if key:
pred = self._key_predicate(order_spec, key, backward)
if pred is not None:
query = query.filter(pred)
# Apply ORDER and LIMIT. Backward is SQL-inverted + reverse in-memory.
if not backward:
clauses = [(c.desc() if is_desc else c.asc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)]
query = query.order_by(*clauses)
if effective_limit is not None:
query = query.limit(effective_limit)
items = query.all()
else:
inv_clauses = [(c.asc() if is_desc else c.desc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)]
query = query.order_by(*inv_clauses)
if effective_limit is not None:
query = query.limit(effective_limit)
items = list(reversed(query.all()))
# Projection meta tag for renderers
if fields:
proj = list(dict.fromkeys(fields))
if "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
else:
proj = []
if root_field_names:
proj.extend(root_field_names)
if root_fields:
proj.extend(c.key for c in root_fields if hasattr(c, "key"))
for path, names in (rel_field_names or {}).items():
prefix = ".".join(path)
for n in names:
proj.append(f"{prefix}.{n}")
if proj and "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
if proj:
for obj in items:
try:
setattr(obj, "__crudkit_projection__", tuple(proj))
except Exception:
pass
# Cursor key pluck: support related columns we hydrated via contains_eager
def _pluck_key_from_obj(obj: Any) -> list[Any]:
vals: list[Any] = []
alias_to_rel: dict[Any, str] = {}
for _p, relationship_attr, target_alias in join_paths:
sel = getattr(target_alias, "selectable", None)
if sel is not None:
alias_to_rel[sel] = relationship_attr.key
for col in order_spec.cols:
keyname = getattr(col, "key", None) or getattr(col, "name", None)
if keyname and hasattr(obj, keyname):
vals.append(getattr(obj, keyname))
continue
table = getattr(col, "table", None)
relname = alias_to_rel.get(table)
if relname and keyname:
relobj = getattr(obj, relname, None)
if relobj is not None and hasattr(relobj, keyname):
vals.append(getattr(relobj, keyname))
continue
raise ValueError("unpluckable")
return vals
try:
first_key = _pluck_key_from_obj(items[0]) if items else None
last_key = _pluck_key_from_obj(items[-1]) if items else None
except Exception:
first_key = None
last_key = None
# Count DISTINCT ids with mirrored joins
total = None
if include_total:
base = session.query(getattr(root_alias, "id"))
base = self._apply_not_deleted(base, root_alias, params)
# same joins as above for correctness
for base_alias, rel_attr, target_alias in join_paths:
# do not join collections for COUNT mirror
if not bool(getattr(getattr(rel_attr, "property", None), "uselist", False)):
base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
if filters:
base = base.filter(*filters)
total = session.query(func.count()).select_from(
base.order_by(None).distinct().subquery()
).scalar() or 0
window_limit_for_body = 0 if effective_limit is None and (limit == 0) else (effective_limit or 50)
if log.isEnabledFor(logging.DEBUG):
log.debug("QUERY: %s", str(query))
return SeekWindow(
items=items,
limit=window_limit_for_body,
first_key=first_key,
last_key=last_key,
order=order_spec,
total=total,
)
# Helper: default ORDER BY for MSSQL when paginating without explicit order
def _default_order_by(self, root_alias):
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
cols = []
for col in mapper.primary_key:
try:
cols.append(getattr(root_alias, col.key))
except AttributeError:
cols.append(col)
return cols or [text("1")]
def _stable_order_by(self, root_alias, given_order_by):
"""
Ensure deterministic ordering by appending PK columns as tiebreakers.
If no order is provided, fall back to default primary-key order.
Never duplicates columns in ORDER BY (SQL Server requires uniqueness).
"""
order_by = list(given_order_by or [])
if not order_by:
return _dedupe_order_by(self._default_order_by(root_alias))
order_by = _dedupe_order_by(order_by)
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
present = {_order_identity(_unwrap_ob(ob)[0]) for ob in order_by}
for pk in mapper.primary_key:
try:
pk_col = getattr(root_alias, pk.key)
except AttributeError:
pk_col = pk
if _order_identity(pk_col) not in present:
order_by.append(pk_col.asc())
present.add(_order_identity(pk_col))
return order_by
def get(self, id: int, params=None) -> T | None:
"""
Fetch a single row by id with conflict-free eager loading and clean projection.
Always JOIN any paths that CRUDSpec resolved for filters/fields/includes so
related-column filters never create cartesian products.
"""
query, root_alias = self.get_query()
# Defaults so we can build a projection even if params is None
req_fields: list[str] = _normalize_fields_param(params)
root_fields: list[Any] = []
root_field_names: dict[str, str] = {}
rel_field_names: dict[tuple[str, ...], list[str]] = {}
# Soft-delete guard first
query = self._apply_not_deleted(query, root_alias, params)
spec = CRUDSpec(self.model, params or {}, root_alias)
# Parse everything so CRUDSpec records any join paths it needed to resolve
filters = spec.parse_filters()
# no ORDER BY for get()
if params:
root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields()
spec.parse_includes()
join_paths = tuple(spec.get_join_paths())
# Root-column projection (load_only)
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols))
nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 }
used_contains_eager = False
for base_alias, rel_attr, target_alias in join_paths:
is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False))
is_nested_firsthop = rel_attr.key in nested_first_hops
if is_collection or is_nested_firsthop:
opt = selectinload(rel_attr)
if is_collection:
child_names = (collection_field_names or {}).get(rel_attr.key, [])
if child_names:
target_cls = rel_attr.property.mapper.class_
cols = [getattr(target_cls, n, None) for n in child_names]
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
if cols:
opt = opt.load_only(*cols)
query = query.options(opt)
else:
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True
# Apply filters (joins are in place → no cartesian products)
if filters:
query = query.filter(*filters)
# And the id filter
query = query.filter(getattr(root_alias, "id") == id)
# Projection loader options compiled from requested fields.
# Skip if we used contains_eager to avoid loader-strategy conflicts.
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
if proj_opts and not used_contains_eager:
query = query.options(*proj_opts)
obj = query.first()
# Emit exactly what the client requested (plus id), or a reasonable fallback
if req_fields:
proj = list(dict.fromkeys(req_fields))
if "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
else:
proj = []
if root_field_names:
proj.extend(root_field_names)
if root_fields:
proj.extend(c.key for c in root_fields if hasattr(c, "key"))
for path, names in (rel_field_names or {}).items():
prefix = ".".join(path)
for n in names:
proj.append(f"{prefix}.{n}")
if proj and "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
if proj and obj is not None:
try:
setattr(obj, "__crudkit_projection__", tuple(proj))
except Exception:
pass
if log.isEnabledFor(logging.DEBUG):
log.debug("QUERY: %s", str(query))
return obj or None
def list(self, params=None) -> list[T]:
"""
Offset/limit listing with relationship-safe filtering.
We always JOIN every CRUDSpec-discovered path before applying filters/sorts.
"""
query, root_alias = self.get_query()
# Defaults so we can reference them later even if params is None
req_fields: list[str] = _normalize_fields_param(params)
root_fields: list[Any] = []
root_field_names: dict[str, str] = {}
rel_field_names: dict[tuple[str, ...], list[str]] = {}
if params:
# Soft delete
query = self._apply_not_deleted(query, root_alias, params)
spec = CRUDSpec(self.model, params or {}, root_alias)
filters = spec.parse_filters()
order_by = spec.parse_sort()
limit, offset = spec.parse_pagination()
# Includes / fields (populates join_paths)
root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields()
spec.parse_includes()
join_paths = tuple(spec.get_join_paths())
# Root column projection (load_only)
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols))
nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 }
used_contains_eager = False
for _base_alias, rel_attr, target_alias in join_paths:
is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False))
is_nested_firsthop = rel_attr.key in nested_first_hops
if is_collection or is_nested_firsthop:
opt = selectinload(rel_attr)
if is_collection:
child_names = (collection_field_names or {}).get(rel_attr.key, [])
if child_names:
target_cls = rel_attr.property.mapper.class_
cols = [getattr(target_cls, n, None) for n in child_names]
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
if cols:
opt = opt.load_only(*cols)
query = query.options(opt)
else:
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True
# Filters AFTER joins → no cartesian products
if filters:
query = query.filter(*filters)
# MSSQL requires ORDER BY when OFFSET is used; ensure stable PK tie-breakers
paginating = (limit is not None) or (offset is not None and offset != 0)
if paginating and not order_by and self.backend.requires_order_by_for_offset:
order_by = self._default_order_by(root_alias)
order_by = _dedupe_order_by(order_by)
if order_by:
query = query.order_by(*order_by)
# Offset/limit
if offset is not None and offset != 0:
query = query.offset(offset)
if limit is not None and limit > 0:
query = query.limit(limit)
# Projection loaders only if we didnt use contains_eager
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
if proj_opts and not used_contains_eager:
query = query.options(*proj_opts)
else:
# No params; still honor projection loaders if any
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
if proj_opts:
query = query.options(*proj_opts)
rows = query.all()
# Build projection meta for renderers
if req_fields:
proj = list(dict.fromkeys(req_fields))
if "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
else:
proj = []
if root_field_names:
proj.extend(root_field_names)
if root_fields:
proj.extend(c.key for c in root_fields if hasattr(c, "key"))
for path, names in (rel_field_names or {}).items():
prefix = ".".join(path)
for n in names:
proj.append(f"{prefix}.{n}")
if proj and "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
if proj:
for obj in rows:
try:
setattr(obj, "__crudkit_projection__", tuple(proj))
except Exception:
pass
if log.isEnabledFor(logging.DEBUG):
log.debug("QUERY: %s", str(query))
return rows
def create(self, data: dict, actor=None) -> T:
session = self.session
obj = self.model(**data)
session.add(obj)
session.commit()
self._log_version("create", obj, actor)
return obj
def update(self, id: int, data: dict, actor=None) -> T:
session = self.session
obj = session.get(self.model, id)
if not obj:
raise ValueError(f"{self.model.__name__} with ID {id} not found.")
valid_fields = {c.name for c in self.model.__table__.columns}
unknown = set(data) - valid_fields
if unknown:
raise ValueError(f"Unknown fields: {', '.join(sorted(unknown))}")
for k, v in data.items():
if k in valid_fields:
setattr(obj, k, v)
session.commit()
self._log_version("update", obj, actor)
return obj
def delete(self, id: int, hard: bool = False, actor = None):
session = self.session
obj = session.get(self.model, id)
if not obj:
return None
if hard or not self.supports_soft_delete:
session.delete(obj)
else:
soft = cast(_SoftDeletable, obj)
soft.is_deleted = True
session.commit()
self._log_version("delete", obj, actor)
return obj
def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict | None = None):
session = self.session
try:
data = obj.as_dict()
except Exception:
data = {"error": "Failed to serialize object."}
version = Version(
model_name=self.model.__name__,
object_id=obj.id,
change_type=change_type,
data=data,
actor=str(actor) if actor else None,
meta=metadata
)
session.add(version)
session.commit()