644 lines
24 KiB
Python
644 lines
24 KiB
Python
from typing import Any, Callable, Dict, Iterable, List, Tuple, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
|
||
from sqlalchemy import and_, func, inspect, or_, text
|
||
from sqlalchemy.engine import Engine, Connection
|
||
from sqlalchemy.orm import Load, Session, raiseload, selectinload, with_polymorphic, Mapper, RelationshipProperty, class_mapper, ColumnProperty
|
||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||
from sqlalchemy.orm.base import NO_VALUE
|
||
from sqlalchemy.orm.util import AliasedClass
|
||
from sqlalchemy.sql import operators
|
||
from sqlalchemy.sql.elements import UnaryExpression
|
||
|
||
from crudkit.core.base import Version
|
||
from crudkit.core.spec import CRUDSpec
|
||
from crudkit.core.types import OrderSpec, SeekWindow
|
||
from crudkit.backend import BackendInfo, make_backend_info
|
||
|
||
def _expand_requires(model_cls, fields):
|
||
out, seen = [], set()
|
||
def add(f):
|
||
if f not in seen:
|
||
seen.add(f); out.append(f)
|
||
|
||
for f in fields:
|
||
add(f)
|
||
parts = f.split(".")
|
||
cur_cls = model_cls
|
||
prefix = []
|
||
|
||
for p in parts[:-1]:
|
||
rel = getattr(cur_cls.__mapper__.relationships, 'get', lambda _: None)(p)
|
||
if not rel:
|
||
cur_cls = None
|
||
break
|
||
cur_cls = rel.mapper.class_
|
||
prefix.append(p)
|
||
|
||
if cur_cls is None:
|
||
continue
|
||
|
||
leaf = parts[-1]
|
||
deps = (getattr(cur_cls, "__crudkit_field_requires__", {}) or {}).get(leaf)
|
||
if not deps:
|
||
continue
|
||
|
||
pre = ".".join(prefix)
|
||
for dep in deps:
|
||
add(f"{pre + '.' if pre else ''}{dep}")
|
||
return out
|
||
|
||
def _is_rel(model_cls, name: str) -> bool:
|
||
try:
|
||
prop = model_cls.__mapper__.relationships.get(name)
|
||
return isinstance(prop, RelationshipProperty)
|
||
except Exception:
|
||
return False
|
||
|
||
def _is_instrumented_column(attr) -> bool:
|
||
try:
|
||
return hasattr(attr, "property") and isinstance(attr.property, ColumnProperty)
|
||
except Exception:
|
||
return False
|
||
|
||
def _loader_options_for_fields(root_alias, model_cls, fields: list[str]) -> list[Load]:
|
||
"""
|
||
For bare MANYTOONE names in fields (e.g. "location"), selectinload the relationship
|
||
and only fetch the related PK. This is enough for preselecting <select> inputs
|
||
without projecting the FK column on the root model.
|
||
"""
|
||
opts: list[Load] = []
|
||
if not fields:
|
||
return opts
|
||
|
||
mapper = class_mapper(model_cls)
|
||
for name in fields:
|
||
prop = mapper.relationships.get(name)
|
||
if not isinstance(prop, RelationshipProperty):
|
||
continue
|
||
if prop.direction.name != "MANYTOONE":
|
||
continue
|
||
|
||
rel_attr = getattr(root_alias, name)
|
||
target_cls = prop.mapper.class_
|
||
# load_only PK if present; else just selectinload
|
||
id_attr = getattr(target_cls, "id", None)
|
||
if id_attr is not None:
|
||
opts.append(selectinload(rel_attr).load_only(id_attr))
|
||
else:
|
||
opts.append(selectinload(rel_attr))
|
||
|
||
return opts
|
||
|
||
@runtime_checkable
|
||
class _HasID(Protocol):
|
||
id: int
|
||
|
||
@runtime_checkable
|
||
class _HasTable(Protocol):
|
||
__table__: Any
|
||
|
||
@runtime_checkable
|
||
class _HasADict(Protocol):
|
||
def as_dict(self) -> dict: ...
|
||
|
||
@runtime_checkable
|
||
class _SoftDeletable(Protocol):
|
||
is_deleted: bool
|
||
|
||
class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol):
|
||
"""Minimal surface that our CRUD service relies on. Soft-delete is optional."""
|
||
pass
|
||
|
||
T = TypeVar("T", bound=_CRUDModelProto)
|
||
|
||
def _is_truthy(val):
|
||
return str(val).lower() in ('1', 'true', 'yes', 'on')
|
||
|
||
class CRUDService(Generic[T]):
|
||
def __init__(
|
||
self,
|
||
model: Type[T],
|
||
session_factory: Callable[[], Session],
|
||
polymorphic: bool = False,
|
||
*,
|
||
backend: Optional[BackendInfo] = None
|
||
):
|
||
self.model = model
|
||
self._session_factory = session_factory
|
||
self.polymorphic = polymorphic
|
||
self.supports_soft_delete = hasattr(model, 'is_deleted')
|
||
# Cache backend info once. If not provided, derive from session bind.
|
||
bind = self.session.get_bind()
|
||
eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind)
|
||
self.backend = backend or make_backend_info(eng)
|
||
|
||
@property
|
||
def session(self) -> Session:
|
||
return self._session_factory()
|
||
|
||
def get_query(self):
|
||
if self.polymorphic:
|
||
poly = with_polymorphic(self.model, "*")
|
||
return self.session.query(poly), poly
|
||
return self.session.query(self.model), self.model
|
||
|
||
def _resolve_required_includes(self, root_alias: Any, rel_field_names: Dict[Tuple[str, ...], List[str]]) -> List[Any]:
|
||
"""
|
||
For each dotted path like ("location"), -> ["label"], look up the target
|
||
model's __crudkit_field_requires__ for the terminal field and produce
|
||
selectinload options prefixed with the relationship path, e.g.:
|
||
Room.__crudkit_field_requires__['label'] = ['room_function']
|
||
=> selectinload(root.location).selectinload(Room.room_function)
|
||
"""
|
||
opts: List[Any] = []
|
||
root_mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
|
||
|
||
for path, names in (rel_field_names or {}).items():
|
||
if not path:
|
||
continue
|
||
|
||
current_mapper = root_mapper
|
||
rel_props: List[RelationshipProperty] = []
|
||
|
||
valid = True
|
||
for step in path:
|
||
rel = current_mapper.relationships.get(step)
|
||
if not isinstance(rel, RelationshipProperty):
|
||
valid = False
|
||
break
|
||
rel_props.append(rel)
|
||
current_mapper = cast(Mapper[Any], inspect(rel.entity.entity))
|
||
if not valid or not rel_props:
|
||
continue
|
||
|
||
first = rel_props[0]
|
||
base_loader = selectinload(getattr(root_alias, first.key))
|
||
for i in range(1, len(rel_props)):
|
||
prev_target_cls = rel_props[i - 1].mapper.class_
|
||
hop_attr = getattr(prev_target_cls, rel_props[i].key)
|
||
base_loader = base_loader.selectinload(hop_attr)
|
||
|
||
target_cls = rel_props[-1].mapper.class_
|
||
|
||
requires = getattr(target_cls, "__crudkit_field_requires__", None)
|
||
if not isinstance(requires, dict):
|
||
continue
|
||
|
||
for field_name in names:
|
||
needed: Iterable[str] = requires.get(field_name, []) or []
|
||
for rel_need in needed:
|
||
rel_prop2 = target_cls.__mapper__.relationships.get(rel_need)
|
||
if not isinstance(rel_prop2, RelationshipProperty):
|
||
continue
|
||
dep_attr = getattr(target_cls, rel_prop2.key)
|
||
opts.append(base_loader.selectinload(dep_attr))
|
||
|
||
return opts
|
||
|
||
def _extract_order_spec(self, root_alias, given_order_by):
|
||
"""
|
||
SQLAlchemy 2.x only:
|
||
Normalize order_by into (cols, desc_flags). Supports plain columns and
|
||
col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses.
|
||
"""
|
||
from sqlalchemy.sql import operators
|
||
from sqlalchemy.sql.elements import UnaryExpression
|
||
|
||
given = self._stable_order_by(root_alias, given_order_by)
|
||
|
||
cols, desc_flags = [], []
|
||
for ob in given:
|
||
# Unwrap column if this is a UnaryExpression produced by .asc()/.desc()
|
||
elem = getattr(ob, "element", None)
|
||
col = elem if elem is not None else ob # don't use "or" with SA expressions
|
||
|
||
# Detect direction in SA 2.x
|
||
is_desc = False
|
||
dir_attr = getattr(ob, "_direction", None)
|
||
if dir_attr is not None:
|
||
is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC")
|
||
elif isinstance(ob, UnaryExpression):
|
||
op = getattr(ob, "operator", None)
|
||
is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC")
|
||
|
||
cols.append(col)
|
||
desc_flags.append(bool(is_desc))
|
||
|
||
from crudkit.core.types import OrderSpec
|
||
return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags))
|
||
|
||
def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool):
|
||
"""
|
||
Build lexicographic predicate for keyset seek.
|
||
For backward traversal, import comparisons.
|
||
"""
|
||
if not key_vals:
|
||
return None
|
||
conds = []
|
||
for i, col in enumerate(spec.cols):
|
||
ties = [spec.cols[j] == key_vals[j] for j in range(i)]
|
||
is_desc = spec.desc[i]
|
||
if not backward:
|
||
op = col < key_vals[i] if is_desc else col > key_vals[i]
|
||
else:
|
||
op = col > key_vals[i] if is_desc else col < key_vals[i]
|
||
conds.append(and_(*ties, op))
|
||
return or_(*conds)
|
||
|
||
def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]:
|
||
out = []
|
||
for c in spec.cols:
|
||
key = getattr(c, "key", None) or getattr(c, "name", None)
|
||
out.append(getattr(obj, key))
|
||
return out
|
||
|
||
def seek_window(
|
||
self,
|
||
params: dict | None = None,
|
||
*,
|
||
key: list[Any] | None = None,
|
||
backward: bool = False,
|
||
include_total: bool = True,
|
||
) -> "SeekWindow[T]":
|
||
"""
|
||
Transport-agnostic keyset pagination that preserves all the goodies from `list()`:
|
||
- filters, includes, joins, field projection, eager loading, soft-delete
|
||
- deterministic ordering (user sort + PK tiebreakers)
|
||
- forward/backward seek via `key` and `backward`
|
||
Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total.
|
||
"""
|
||
fields = list(params.get("fields", []))
|
||
if fields:
|
||
fields = _expand_requires(self.model, fields)
|
||
params = {**params, "fields": fields}
|
||
query, root_alias = self.get_query()
|
||
|
||
spec = CRUDSpec(self.model, params or {}, root_alias)
|
||
|
||
filters = spec.parse_filters()
|
||
order_by = spec.parse_sort()
|
||
|
||
root_fields, rel_field_names, root_field_names = spec.parse_fields()
|
||
|
||
seen_rel_roots = set()
|
||
for path, names in (rel_field_names or {}).items():
|
||
if not path:
|
||
continue
|
||
rel_name = path[0]
|
||
if rel_name in seen_rel_roots:
|
||
continue
|
||
if _is_rel(self.model, rel_name):
|
||
rel_attr = getattr(root_alias, rel_name, None)
|
||
if rel_attr is not None:
|
||
query = query.options(selectinload(rel_attr))
|
||
seen_rel_roots.add(rel_name)
|
||
|
||
# Soft delete filter
|
||
if self.supports_soft_delete and not _is_truthy(params.get("include_deleted")):
|
||
query = query.filter(getattr(root_alias, "is_deleted") == False)
|
||
|
||
# Parse filters first
|
||
if filters:
|
||
query = query.filter(*filters)
|
||
|
||
# Includes + joins (so relationship fields like brand.name, location.label work)
|
||
spec.parse_includes()
|
||
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
|
||
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
||
target = cast(Any, target_alias)
|
||
query = query.join(target, rel_attr.of_type(target), isouter=True)
|
||
|
||
# Fields/projection: load_only for root columns, eager loads for relationships
|
||
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
|
||
if only_cols:
|
||
query = query.options(Load(root_alias).load_only(*only_cols))
|
||
# for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
|
||
# query = query.options(eager)
|
||
|
||
for opt in self._resolve_required_includes(root_alias, rel_field_names):
|
||
query = query.options(opt)
|
||
|
||
# Order + limit
|
||
order_spec = self._extract_order_spec(root_alias, order_by) # SA2-only helper
|
||
limit, _ = spec.parse_pagination()
|
||
if not limit or limit <= 0:
|
||
limit = 50 # sensible default
|
||
|
||
# Keyset predicate
|
||
if key:
|
||
pred = self._key_predicate(order_spec, key, backward)
|
||
if pred is not None:
|
||
query = query.filter(pred)
|
||
|
||
# Apply ordering. For backward, invert SQL order then reverse in-memory for display.
|
||
if not backward:
|
||
clauses = []
|
||
for col, is_desc in zip(order_spec.cols, order_spec.desc):
|
||
clauses.append(col.desc() if is_desc else col.asc())
|
||
query = query.order_by(*clauses).limit(limit)
|
||
items = query.all()
|
||
else:
|
||
inv_clauses = []
|
||
for col, is_desc in zip(order_spec.cols, order_spec.desc):
|
||
inv_clauses.append(col.asc() if is_desc else col.desc())
|
||
query = query.order_by(*inv_clauses).limit(limit)
|
||
items = list(reversed(query.all()))
|
||
|
||
# Tag projection so your renderer knows what fields were requested
|
||
proj = []
|
||
if root_field_names:
|
||
proj.extend(root_field_names)
|
||
if root_fields:
|
||
proj.extend(c.key for c in root_fields)
|
||
for path, names in (rel_field_names or {}).items():
|
||
prefix = ".".join(path)
|
||
for n in names:
|
||
proj.append(f"{prefix}.{n}")
|
||
if proj and "id" not in proj and hasattr(self.model, "id"):
|
||
proj.insert(0, "id")
|
||
if proj:
|
||
for obj in items:
|
||
try:
|
||
setattr(obj, "__crudkit_projection__", tuple(proj))
|
||
except Exception:
|
||
pass
|
||
|
||
# Boundary keys for cursor encoding in the API layer
|
||
first_key = self._pluck_key(items[0], order_spec) if items else None
|
||
last_key = self._pluck_key(items[-1], order_spec) if items else None
|
||
|
||
# Optional total that’s safe under JOINs (COUNT DISTINCT ids)
|
||
total = None
|
||
if include_total:
|
||
base = self.session.query(getattr(root_alias, "id"))
|
||
if self.supports_soft_delete and not _is_truthy(params.get("include_deleted")):
|
||
base = base.filter(getattr(root_alias, "is_deleted") == False)
|
||
if filters:
|
||
base = base.filter(*filters)
|
||
# replicate the same joins used above
|
||
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
|
||
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
||
target = cast(Any, target_alias)
|
||
base = base.join(target, rel_attr.of_type(target), isouter=True)
|
||
total = self.session.query(func.count()).select_from(base.order_by(None).distinct().subquery()).scalar() or 0
|
||
|
||
from crudkit.core.types import SeekWindow # avoid circulars at module top
|
||
return SeekWindow(
|
||
items=items,
|
||
limit=limit,
|
||
first_key=first_key,
|
||
last_key=last_key,
|
||
order=order_spec,
|
||
total=total,
|
||
)
|
||
|
||
# Helper: default ORDER BY for MSSQL when paginating without explicit order
|
||
def _default_order_by(self, root_alias):
|
||
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
|
||
cols = []
|
||
for col in mapper.primary_key:
|
||
try:
|
||
cols.append(getattr(root_alias, col.key))
|
||
except AttributeError:
|
||
cols.append(col)
|
||
return cols or [text("1")]
|
||
|
||
def _stable_order_by(self, root_alias, given_order_by):
|
||
"""
|
||
Ensure deterministic ordering by appending PK columns as tiebreakers.
|
||
If no order is provided, fall back to default primary-key order.
|
||
"""
|
||
order_by = list(given_order_by or [])
|
||
if not order_by:
|
||
return self._default_order_by(root_alias)
|
||
|
||
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
|
||
pk_cols = []
|
||
for col in mapper.primary_key:
|
||
try:
|
||
pk_cols.append(getattr(root_alias, col.key))
|
||
except ArithmeticError:
|
||
pk_cols.append(col)
|
||
|
||
return [*order_by, *pk_cols]
|
||
|
||
def get(self, id: int, params=None) -> T | None:
|
||
query, root_alias = self.get_query()
|
||
|
||
include_deleted = False
|
||
root_fields = []
|
||
root_field_names = {}
|
||
rel_field_names = {}
|
||
|
||
spec = CRUDSpec(self.model, params or {}, root_alias)
|
||
if params:
|
||
if self.supports_soft_delete:
|
||
include_deleted = _is_truthy(params.get('include_deleted'))
|
||
if self.supports_soft_delete and not include_deleted:
|
||
query = query.filter(getattr(root_alias, "is_deleted") == False)
|
||
query = query.filter(getattr(root_alias, "id") == id)
|
||
|
||
spec.parse_includes()
|
||
|
||
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
|
||
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
||
target = cast(Any, target_alias)
|
||
query = query.join(target, rel_attr.of_type(target), isouter=True)
|
||
|
||
if params:
|
||
root_fields, rel_field_names, root_field_names = spec.parse_fields()
|
||
|
||
if rel_field_names:
|
||
seen_rel_roots = set()
|
||
for path, names in rel_field_names.items():
|
||
if not path:
|
||
continue
|
||
rel_name = path[0]
|
||
if rel_name in seen_rel_roots:
|
||
continue
|
||
if _is_rel(self.model, rel_name):
|
||
rel_attr = getattr(root_alias, rel_name, None)
|
||
if rel_attr is not None:
|
||
query = query.options(selectinload(rel_attr))
|
||
seen_rel_roots.add(rel_name)
|
||
|
||
fields = (params or {}).get("fields") if isinstance(params, dict) else None
|
||
if fields:
|
||
for opt in _loader_options_for_fields(root_alias, self.model, fields):
|
||
query = query.options(opt)
|
||
|
||
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
|
||
if only_cols:
|
||
query = query.options(Load(root_alias).load_only(*only_cols))
|
||
|
||
# for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
|
||
# query = query.options(eager)
|
||
|
||
if params:
|
||
fields = params.get("fields") or []
|
||
for opt in _loader_options_for_fields(root_alias, self.model, fields):
|
||
query = query.options(opt)
|
||
|
||
obj = query.first()
|
||
|
||
proj = []
|
||
if root_field_names:
|
||
proj.extend(root_field_names)
|
||
if root_fields:
|
||
proj.extend(c.key for c in root_fields)
|
||
for path, names in (rel_field_names or {}).items():
|
||
prefix = ".".join(path)
|
||
for n in names:
|
||
proj.append(f"{prefix}.{n}")
|
||
|
||
if proj and "id" not in proj and hasattr(self.model, "id"):
|
||
proj.insert(0, "id")
|
||
|
||
if proj:
|
||
try:
|
||
setattr(obj, "__crudkit_projection__", tuple(proj))
|
||
except Exception:
|
||
pass
|
||
|
||
return obj or None
|
||
|
||
def list(self, params=None) -> list[T]:
|
||
query, root_alias = self.get_query()
|
||
|
||
root_fields = []
|
||
root_field_names = {}
|
||
rel_field_names = {}
|
||
|
||
if params:
|
||
if self.supports_soft_delete:
|
||
include_deleted = _is_truthy(params.get('include_deleted'))
|
||
if not include_deleted:
|
||
query = query.filter(getattr(root_alias, "is_deleted") == False)
|
||
|
||
spec = CRUDSpec(self.model, params or {}, root_alias)
|
||
filters = spec.parse_filters()
|
||
order_by = spec.parse_sort()
|
||
limit, offset = spec.parse_pagination()
|
||
spec.parse_includes()
|
||
|
||
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
|
||
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
||
target = cast(Any, target_alias)
|
||
query = query.join(target, rel_attr.of_type(target), isouter=True)
|
||
|
||
if params:
|
||
root_fields, rel_field_names, root_field_names = spec.parse_fields()
|
||
|
||
if rel_field_names:
|
||
seen_rel_roots = set()
|
||
for path, names in rel_field_names.items():
|
||
if not path:
|
||
continue
|
||
rel_name = path[0]
|
||
if rel_name in seen_rel_roots:
|
||
continue
|
||
if _is_rel(self.model, rel_name):
|
||
rel_attr = getattr(root_alias, rel_name, None)
|
||
if rel_attr is not None:
|
||
query = query.options(selectinload(rel_attr))
|
||
seen_rel_roots.add(rel_name)
|
||
|
||
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
|
||
if only_cols:
|
||
query = query.options(Load(root_alias).load_only(*only_cols))
|
||
|
||
# for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
|
||
# query = query.options(eager)
|
||
|
||
if params:
|
||
fields = params.get("fields") or []
|
||
for opt in _loader_options_for_fields(root_alias, self.model, fields):
|
||
query = query.options(opt)
|
||
|
||
if filters:
|
||
query = query.filter(*filters)
|
||
|
||
# MSSQL: requires ORDER BY when using OFFSET (and SQLA will use OFFSET for limit+offset).
|
||
paginating = (limit is not None) or (offset is not None and offset != 0)
|
||
if paginating and not order_by and self.backend.requires_order_by_for_offset:
|
||
order_by = self._default_order_by(root_alias)
|
||
|
||
if order_by:
|
||
query = query.order_by(*order_by)
|
||
|
||
# Only apply offset/limit when not None.
|
||
if offset is not None and offset != 0:
|
||
query = query.offset(offset)
|
||
if limit is not None and limit > 0:
|
||
query = query.limit(limit)
|
||
|
||
rows = query.all()
|
||
|
||
proj = []
|
||
if root_field_names:
|
||
proj.extend(root_field_names)
|
||
if root_fields:
|
||
proj.extend(c.key for c in root_fields)
|
||
for path, names in (rel_field_names or {}).items():
|
||
prefix = ".".join(path)
|
||
for n in names:
|
||
proj.append(f"{prefix}.{n}")
|
||
|
||
if proj and "id" not in proj and hasattr(self.model, "id"):
|
||
proj.insert(0, "id")
|
||
|
||
if proj:
|
||
for obj in rows:
|
||
try:
|
||
setattr(obj, "__crudkit_projection__", tuple(proj))
|
||
except Exception:
|
||
pass
|
||
|
||
return rows
|
||
|
||
def create(self, data: dict, actor=None) -> T:
|
||
obj = self.model(**data)
|
||
self.session.add(obj)
|
||
self.session.commit()
|
||
self._log_version("create", obj, actor)
|
||
return obj
|
||
|
||
def update(self, id: int, data: dict, actor=None) -> T:
|
||
obj = self.get(id)
|
||
if not obj:
|
||
raise ValueError(f"{self.model.__name__} with ID {id} not found.")
|
||
valid_fields = {c.name for c in self.model.__table__.columns}
|
||
for k, v in data.items():
|
||
if k in valid_fields:
|
||
setattr(obj, k, v)
|
||
self.session.commit()
|
||
self._log_version("update", obj, actor)
|
||
return obj
|
||
|
||
def delete(self, id: int, hard: bool = False, actor = False):
|
||
obj = self.session.get(self.model, id)
|
||
if not obj:
|
||
return None
|
||
if hard or not self.supports_soft_delete:
|
||
self.session.delete(obj)
|
||
else:
|
||
soft = cast(_SoftDeletable, obj)
|
||
soft.is_deleted = True
|
||
self.session.commit()
|
||
self._log_version("delete", obj, actor)
|
||
return obj
|
||
|
||
def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict = {}):
|
||
try:
|
||
data = obj.as_dict()
|
||
except Exception:
|
||
data = {"error": "Failed to serialize object."}
|
||
version = Version(
|
||
model_name=self.model.__name__,
|
||
object_id=obj.id,
|
||
change_type=change_type,
|
||
data=data,
|
||
actor=str(actor) if actor else None,
|
||
meta=metadata
|
||
)
|
||
self.session.add(version)
|
||
self.session.commit()
|