870 lines
33 KiB
Python
870 lines
33 KiB
Python
from __future__ import annotations
|
||
|
||
from collections.abc import Iterable
|
||
from dataclasses import dataclass
|
||
from flask import current_app
|
||
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
|
||
from sqlalchemy import and_, func, inspect, or_, text, select, literal
|
||
from sqlalchemy.engine import Engine, Connection
|
||
from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, selectinload, with_loader_criteria, aliased, with_parent
|
||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||
from sqlalchemy.sql import operators, visitors
|
||
from sqlalchemy.sql.elements import UnaryExpression, ColumnElement
|
||
|
||
from crudkit.core import to_jsonable, deep_diff, diff_to_patch, filter_to_columns, normalize_payload
|
||
from crudkit.core.base import Version
|
||
from crudkit.core.spec import CRUDSpec, CollPred
|
||
from crudkit.core.types import OrderSpec, SeekWindow
|
||
from crudkit.backend import BackendInfo, make_backend_info
|
||
from crudkit.projection import compile_projection
|
||
|
||
import logging
|
||
log = logging.getLogger("crudkit.service")
|
||
# logging.getLogger("crudkit.service").setLevel(logging.DEBUG)
|
||
# Ensure our debug actually prints even if the app/root logger is WARNING+
|
||
# if not log.handlers:
|
||
# _h = logging.StreamHandler()
|
||
# _h.setLevel(logging.DEBUG)
|
||
# _h.setFormatter(logging.Formatter(
|
||
# "%(asctime)s %(levelname)s %(name)s: %(message)s"
|
||
# ))
|
||
# log.addHandler(_h)
|
||
#
|
||
# log.setLevel(logging.DEBUG)
|
||
# log.propagate = False
|
||
|
||
@runtime_checkable
|
||
class _HasID(Protocol):
|
||
id: int
|
||
|
||
@runtime_checkable
|
||
class _HasTable(Protocol):
|
||
__table__: Any
|
||
|
||
@runtime_checkable
|
||
class _HasADict(Protocol):
|
||
def as_dict(self) -> dict: ...
|
||
|
||
@runtime_checkable
|
||
class _SoftDeletable(Protocol):
|
||
is_deleted: bool
|
||
|
||
class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol):
|
||
"""Surface expected by CRUDService."""
|
||
pass
|
||
|
||
T = TypeVar("T", bound=_CRUDModelProto)
|
||
|
||
# ---------------------------- utilities ----------------------------
|
||
|
||
def _collect_tables_from_filters(filters) -> set:
|
||
seen = set()
|
||
stack = list(filters or [])
|
||
while stack:
|
||
node = stack.pop()
|
||
|
||
tbl = getattr(node, "table", None)
|
||
if tbl is not None:
|
||
cur = tbl
|
||
while cur is not None and cur not in seen:
|
||
seen.add(cur)
|
||
cur = getattr(cur, "element", None)
|
||
|
||
# follow only the common attributes; no generic visitor
|
||
left = getattr(node, "left", None)
|
||
if left is not None:
|
||
stack.append(left)
|
||
right = getattr(node, "right", None)
|
||
if right is not None:
|
||
stack.append(right)
|
||
elem = getattr(node, "element", None)
|
||
if elem is not None:
|
||
stack.append(elem)
|
||
clause = getattr(node, "clause", None)
|
||
if clause is not None:
|
||
stack.append(clause)
|
||
clauses = getattr(node, "clauses", None)
|
||
if clauses is not None:
|
||
try:
|
||
stack.extend(list(clauses))
|
||
except TypeError:
|
||
pass
|
||
|
||
return seen
|
||
|
||
def _selectable_keys(sel) -> set[str]:
|
||
"""
|
||
Return a set of stable string keys for a selectable/alias and its base,
|
||
so we can match when when different alias objects are used.
|
||
"""
|
||
keys: set[str] = set()
|
||
cur = sel
|
||
while cur is not None:
|
||
k = getattr(cur, "key", None) or getattr(cur, "name", None)
|
||
if isinstance(k, str) and k:
|
||
keys.add(k)
|
||
cur = getattr(cur, "element", None)
|
||
return keys
|
||
|
||
def _unwrap_ob(ob):
|
||
elem = getattr(ob, "element", None)
|
||
col = elem if elem is not None else ob
|
||
|
||
d = getattr(ob, "_direction", None)
|
||
if d is not None:
|
||
is_desc = (d is operators.desc_op) or (getattr(d, "name", "").upper() == "DESC")
|
||
elif isinstance(ob, UnaryExpression):
|
||
op = getattr(ob, "operator", None)
|
||
is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC")
|
||
else:
|
||
is_desc = False
|
||
return col, bool(is_desc)
|
||
|
||
def _order_identity(col: ColumnElement):
|
||
table = getattr(col, "table", None)
|
||
table_key = getattr(table, "key", None) or id(table)
|
||
col_key = getattr(col, "key", None) or getattr(col, "name", None)
|
||
return (table_key, col_key)
|
||
|
||
def _dedupe_order_by(order_by):
|
||
if not order_by:
|
||
return []
|
||
seen, out = set(), []
|
||
for ob in order_by:
|
||
col, _ = _unwrap_ob(ob)
|
||
ident = _order_identity(col)
|
||
if ident in seen:
|
||
continue
|
||
seen.add(ident)
|
||
out.append(ob)
|
||
return out
|
||
|
||
def _is_truthy(val):
|
||
return str(val).lower() in ('1', 'true', 'yes', 'on')
|
||
|
||
def _normalize_fields_param(params: dict | None) -> list[str]:
|
||
if not params:
|
||
return []
|
||
raw = params.get("fields")
|
||
if isinstance(raw, (list, tuple)):
|
||
out: list[str] = []
|
||
for item in raw:
|
||
if isinstance(item, str):
|
||
out.extend([p for p in (s.strip() for s in item.split(",")) if p])
|
||
return out
|
||
if isinstance(raw, str):
|
||
return [p for p in (s.strip() for s in raw.split(",")) if p]
|
||
return []
|
||
|
||
# ---------------------------- CRUD service ----------------------------
|
||
|
||
class CRUDService(Generic[T]):
|
||
def __init__(
|
||
self,
|
||
model: Type[T],
|
||
session_factory: Callable[[], Session],
|
||
polymorphic: bool = False,
|
||
*,
|
||
backend: Optional[BackendInfo] = None
|
||
):
|
||
self.model = model
|
||
self._session_factory = session_factory
|
||
self.polymorphic = polymorphic
|
||
self.supports_soft_delete = hasattr(model, 'is_deleted')
|
||
self._backend: Optional[BackendInfo] = backend
|
||
|
||
# ---- infra
|
||
|
||
@property
|
||
def session(self) -> Session:
|
||
try:
|
||
return current_app.extensions["crudkit"]["Session"]
|
||
except Exception:
|
||
return self._session_factory()
|
||
|
||
@property
|
||
def backend(self) -> BackendInfo:
|
||
if self._backend is None:
|
||
bind = self.session.get_bind()
|
||
eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind)
|
||
self._backend = make_backend_info(eng)
|
||
return self._backend
|
||
|
||
def get_query(self):
|
||
if self.polymorphic:
|
||
poly = with_polymorphic(self.model, "*")
|
||
return self.session.query(poly), poly
|
||
return self.session.query(self.model), self.model
|
||
|
||
# ---- common building blocks
|
||
|
||
def _apply_soft_delete_criteria_for_children(self, query, plan: "CRUDService._Plan", params):
|
||
# Skip if caller explicitly asked for deleted
|
||
if _is_truthy((params or {}).get("include_deleted")):
|
||
return query
|
||
|
||
seen = set()
|
||
for _base_alias, rel_attr, _target_alias in plan.join_paths:
|
||
prop = getattr(rel_attr, "property", None)
|
||
if not prop:
|
||
continue
|
||
target_cls = getattr(prop.mapper, "class_", None)
|
||
if not target_cls or target_cls in seen:
|
||
continue
|
||
seen.add(target_cls)
|
||
# Only apply to models that support soft delete
|
||
if hasattr(target_cls, "is_deleted"):
|
||
query = query.options(
|
||
with_loader_criteria(
|
||
target_cls,
|
||
lambda cls: cls.is_deleted == False,
|
||
include_aliases=True
|
||
)
|
||
)
|
||
return query
|
||
|
||
def _order_clauses(self, order_spec, invert: bool = False):
|
||
clauses = []
|
||
for c, is_desc in zip(order_spec.cols, order_spec.desc):
|
||
d = not is_desc if invert else is_desc
|
||
clauses.append(c.desc() if d else c.asc())
|
||
return clauses
|
||
|
||
def _anchor_key_for_page(self, params, per_page: int, page: int):
|
||
"""Return the keyset tuple for the last row of the previous page, or None for page 1."""
|
||
if page <= 1:
|
||
return None
|
||
|
||
query, root_alias = self.get_query()
|
||
query = self._apply_not_deleted(query, root_alias, params)
|
||
|
||
plan = self._plan(params, root_alias)
|
||
# Make sure joins/filters match the real query
|
||
query = self._apply_firsthop_strategies(query, root_alias, plan)
|
||
if plan.filters:
|
||
filters = self._final_filters(root_alias, plan)
|
||
if filters:
|
||
query = query.filter(*filters)
|
||
|
||
order_spec = self._extract_order_spec(root_alias, plan.order_by)
|
||
|
||
# Inner subquery must be ordered exactly like the real query
|
||
inner = query.order_by(*self._order_clauses(order_spec, invert=False))
|
||
|
||
# IMPORTANT: Build subquery that actually exposes the order-by columns
|
||
# under predictable names, then select FROM that and reference subq.c[...]
|
||
subq = inner.with_entities(*order_spec.cols).subquery()
|
||
|
||
# Map the order columns to the subquery columns by key/name
|
||
cols_on_subq = []
|
||
for col in order_spec.cols:
|
||
key = getattr(col, "key", None) or getattr(col, "name", None)
|
||
if not key:
|
||
# Fallback, but frankly your order cols should have names
|
||
raise ValueError("Order-by column is missing a key/name")
|
||
cols_on_subq.append(getattr(subq.c, key))
|
||
|
||
# Now the outer anchor query orders and offsets on the subquery columns
|
||
anchor_q = (
|
||
self.session
|
||
.query(*cols_on_subq)
|
||
.select_from(subq)
|
||
.order_by(*[
|
||
(c.desc() if is_desc else c.asc())
|
||
for c, is_desc in zip(cols_on_subq, order_spec.desc)
|
||
])
|
||
)
|
||
|
||
offset = max(0, (page - 1) * per_page - 1)
|
||
row = anchor_q.offset(offset).limit(1).first()
|
||
if not row:
|
||
return None
|
||
return list(row) # tuple-like -> list for _key_predicate
|
||
|
||
def _apply_not_deleted(self, query, root_alias, params):
|
||
if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")):
|
||
return query.filter(getattr(root_alias, "is_deleted") == False)
|
||
return query
|
||
|
||
def _default_order_by(self, root_alias):
|
||
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
|
||
cols = []
|
||
for col in mapper.primary_key:
|
||
try:
|
||
cols.append(getattr(root_alias, col.key))
|
||
except AttributeError:
|
||
cols.append(col)
|
||
return cols or [text("1")]
|
||
|
||
def _stable_order_by(self, root_alias, given_order_by):
|
||
order_by = list(given_order_by or [])
|
||
if not order_by:
|
||
return _dedupe_order_by(self._default_order_by(root_alias))
|
||
order_by = _dedupe_order_by(order_by)
|
||
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
|
||
present = {_order_identity(_unwrap_ob(ob)[0]) for ob in order_by}
|
||
for pk in mapper.primary_key:
|
||
try:
|
||
pk_col = getattr(root_alias, pk.key)
|
||
except AttributeError:
|
||
pk_col = pk
|
||
ident = _order_identity(pk_col)
|
||
if ident not in present:
|
||
order_by.append(pk_col.asc())
|
||
present.add(ident)
|
||
return order_by
|
||
|
||
def _extract_order_spec(self, root_alias, given_order_by):
|
||
given = self._stable_order_by(root_alias, given_order_by)
|
||
cols, desc_flags = [], []
|
||
for ob in given:
|
||
elem = getattr(ob, "element", None)
|
||
col = elem if elem is not None else ob
|
||
is_desc = False
|
||
dir_attr = getattr(ob, "_direction", None)
|
||
if dir_attr is not None:
|
||
is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC")
|
||
elif isinstance(ob, UnaryExpression):
|
||
op = getattr(ob, "operator", None)
|
||
is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC")
|
||
cols.append(col)
|
||
desc_flags.append(bool(is_desc))
|
||
return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags))
|
||
|
||
def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool):
|
||
if not key_vals:
|
||
return None
|
||
conds = []
|
||
for i, col in enumerate(spec.cols):
|
||
ties = [spec.cols[j] == key_vals[j] for j in range(i)]
|
||
is_desc = spec.desc[i]
|
||
op = (col < key_vals[i]) if is_desc ^ backward else (col > key_vals[i])
|
||
conds.append(and_(*ties, op))
|
||
return or_(*conds)
|
||
|
||
# ---- planning and application
|
||
|
||
@dataclass(slots=True)
|
||
class _Plan:
|
||
spec: Any
|
||
filters: Any
|
||
order_by: Any
|
||
limit: Any
|
||
offset: Any
|
||
root_fields: Any
|
||
rel_field_names: Any
|
||
root_field_names: Any
|
||
collection_field_names: Any
|
||
join_paths: Any
|
||
filter_tables: Any
|
||
filter_table_keys: Any
|
||
req_fields: Any
|
||
proj_opts: Any
|
||
|
||
def _plan(self, params, root_alias) -> _Plan:
|
||
req_fields = _normalize_fields_param(params)
|
||
spec = CRUDSpec(self.model, params or {}, root_alias)
|
||
|
||
filters = spec.parse_filters()
|
||
order_by = spec.parse_sort()
|
||
limit, offset = spec.parse_pagination()
|
||
root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() if params else ([], {}, {}, {})
|
||
spec.parse_includes()
|
||
join_paths = tuple(spec.get_join_paths())
|
||
filter_tables = _collect_tables_from_filters(filters)
|
||
fkeys = set()
|
||
_, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
|
||
|
||
# filter_tables = ()
|
||
# fkeys = set()
|
||
|
||
return self._Plan(
|
||
spec=spec, filters=filters, order_by=order_by, limit=limit, offset=offset,
|
||
root_fields=root_fields, rel_field_names=rel_field_names,
|
||
root_field_names=root_field_names, collection_field_names=collection_field_names,
|
||
join_paths=join_paths, filter_tables=filter_tables, filter_table_keys=fkeys,
|
||
req_fields=req_fields, proj_opts=proj_opts
|
||
)
|
||
|
||
def _apply_projection_load_only(self, query, root_alias, plan: _Plan):
|
||
only_cols = [c for c in plan.root_fields if isinstance(c, InstrumentedAttribute)]
|
||
return query.options(Load(root_alias).load_only(*only_cols)) if only_cols else query
|
||
|
||
def _apply_firsthop_strategies(self, query, root_alias, plan: _Plan):
|
||
nested_first_hops = { p[0] for p in (plan.rel_field_names or {}).keys() if len(p) > 1 }
|
||
joined_rel_keys = set()
|
||
|
||
# Existing behavior: join everything in join_paths (to-one), selectinload collections
|
||
for base_alias, rel_attr, target_alias in plan.join_paths:
|
||
if base_alias is not root_alias:
|
||
continue
|
||
prop = getattr(rel_attr, "property", None)
|
||
is_collection = bool(getattr(prop, "uselist", False))
|
||
|
||
if not is_collection:
|
||
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
|
||
joined_rel_keys.add(prop.key if prop is not None else rel_attr.key)
|
||
else:
|
||
opt = selectinload(rel_attr)
|
||
child_names = (plan.collection_field_names or {}).get(rel_attr.key, [])
|
||
if child_names:
|
||
target_cls = prop.mapper.class_
|
||
cols = [getattr(target_cls, n, None) for n in child_names]
|
||
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
|
||
if cols:
|
||
opt = opt.load_only(*cols)
|
||
query = query.options(opt)
|
||
|
||
# NEW: if a first-hop to-one relationship’s target table is present in filter expressions,
|
||
# make sure we actually JOIN it (outer) so filters don’t create a cartesian product.
|
||
if plan.filter_tables:
|
||
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
|
||
for rel in mapper.relationships:
|
||
if rel.uselist:
|
||
continue # only first-hop to-one here
|
||
target_tbl = getattr(rel.mapper.class_, "__table__", None)
|
||
if target_tbl is None:
|
||
continue
|
||
if target_tbl in plan.filter_tables:
|
||
if rel.key in joined_rel_keys:
|
||
continue # already joined via join_paths
|
||
query = query.join(getattr(root_alias, rel.key), isouter=True)
|
||
joined_rel_keys.add(rel.key)
|
||
if log.isEnabledFor(logging.DEBUG):
|
||
info = []
|
||
for base_alias, rel_attr, target_alias in plan.join_paths:
|
||
if base_alias is not root_alias:
|
||
continue
|
||
prop = getattr(rel_attr, "property", None)
|
||
sel = getattr(target_alias, "selectable", None)
|
||
info.append({
|
||
"rel": (getattr(prop, "key", getattr(rel_attr, "key", "?"))),
|
||
"collection": bool(getattr(prop, "uselist", False)),
|
||
"target_keys": list(_selectable_keys(sel)) if sel is not None else [],
|
||
"joined": (getattr(prop, "key", None) in joined_rel_keys),
|
||
})
|
||
log.debug("FIRSTHOP: %s.%s first-hop paths: %s",
|
||
self.model.__name__, getattr(root_alias, "__table__", type(root_alias)).key,
|
||
info)
|
||
|
||
return query
|
||
|
||
def _apply_proj_opts(self, query, plan: _Plan):
|
||
return query.options(*plan.proj_opts) if plan.proj_opts else query
|
||
|
||
def _projection_meta(self, plan: _Plan):
|
||
if plan.req_fields:
|
||
proj = list(dict.fromkeys(plan.req_fields))
|
||
return ["id"] + proj if "id" not in proj and hasattr(self.model, "id") else proj
|
||
|
||
proj: list[str] = []
|
||
if plan.root_field_names:
|
||
proj.extend(plan.root_field_names)
|
||
if plan.root_fields:
|
||
proj.extend(c.key for c in plan.root_fields if hasattr(c, "key"))
|
||
for path, names in (plan.rel_field_names or {}).items():
|
||
prefix = ".".join(path)
|
||
for n in names:
|
||
proj.append(f"{prefix}.{n}")
|
||
if proj and "id" not in proj and hasattr(self.model, "id"):
|
||
proj.insert(0, "id")
|
||
return proj
|
||
|
||
def _tag_projection(self, items, proj):
|
||
if not proj:
|
||
return
|
||
for obj in items if isinstance(items, list) else [items]:
|
||
try:
|
||
setattr(obj, "__crudkit_projection__", tuple(proj))
|
||
except Exception:
|
||
pass
|
||
|
||
def _rebind_filters_to_firsthop_aliases(self, filters, root_alias, plan):
|
||
"""Make filter expressions use the exact same alias objects as our JOINs."""
|
||
if not filters:
|
||
return filters
|
||
|
||
# Map first-hop target selectable keysets -> the exact selectable object we JOINed with
|
||
alias_map = {}
|
||
for base_alias, _rel_attr, target_alias in plan.join_paths:
|
||
if base_alias is not root_alias:
|
||
continue
|
||
sel = getattr(target_alias, "selectable", None)
|
||
if sel is not None:
|
||
alias_map[frozenset(_selectable_keys(sel))] = sel
|
||
|
||
if not alias_map:
|
||
return filters
|
||
|
||
def replace(elem):
|
||
tbl = getattr(elem, "table", None)
|
||
if tbl is None:
|
||
return elem
|
||
keyset = frozenset(_selectable_keys(tbl))
|
||
new_sel = alias_map.get(keyset)
|
||
if new_sel is None or new_sel is tbl:
|
||
return elem
|
||
|
||
colkey = getattr(elem, "key", None) or getattr(elem, "name", None)
|
||
if not colkey:
|
||
return elem
|
||
try:
|
||
return getattr(new_sel.c, colkey)
|
||
except Exception:
|
||
return elem
|
||
|
||
return [visitors.replacement_traverse(f, {}, replace) for f in filters]
|
||
|
||
def _final_filters(self, root_alias, plan):
|
||
"""
|
||
Return filters where:
|
||
- root/to-one predicates are kept as SQLAlchemy expressions.
|
||
- first-hop collection predicates (CollPred) are rebuilt into a single
|
||
EXISTS via rel.any(...) with one alias per collection table.
|
||
"""
|
||
filters = list(plan.filters or [])
|
||
if not filters:
|
||
return []
|
||
|
||
# 1) Build a map of first-hop relationships: TABLE -> (rel_attr, target_cls)
|
||
coll_map = {}
|
||
for base_alias, rel_attr, target_alias in plan.join_paths:
|
||
if base_alias is not root_alias:
|
||
continue
|
||
prop = getattr(rel_attr, "property", None)
|
||
if not prop or not getattr(prop, "uselist", False):
|
||
continue
|
||
target_cls = prop.mapper.class_
|
||
tbl = getattr(target_cls, "__table__", None)
|
||
if tbl is not None:
|
||
coll_map[tbl] = (rel_attr, target_cls)
|
||
|
||
# 2) Split raw filters into normal SQLA and CollPreds (by target table)
|
||
normal_filters = []
|
||
by_table: dict[Any, list[CollPred]] = {}
|
||
for f in filters:
|
||
if isinstance(f, CollPred):
|
||
by_table.setdefault(f.table, []).append(f)
|
||
else:
|
||
normal_filters.append(f)
|
||
|
||
# 3) Rebuild each table group into ONE .any(...) using one alias
|
||
from sqlalchemy.orm import aliased
|
||
from sqlalchemy import and_
|
||
|
||
exists_filters = []
|
||
for tbl, preds in by_table.items():
|
||
if tbl not in coll_map:
|
||
# Safety: if it's not a first-hop collection, ignore or raise
|
||
continue
|
||
rel_attr, target_cls = coll_map[tbl]
|
||
ta = aliased(target_cls)
|
||
|
||
built = []
|
||
for p in preds:
|
||
col = getattr(ta, p.col_key)
|
||
op = p.op
|
||
val = p.value
|
||
if op == 'icontains':
|
||
built.append(col.ilike(f"%{val}%"))
|
||
elif op == 'eq':
|
||
built.append(col == val)
|
||
elif op == 'ne':
|
||
built.append(col != val)
|
||
elif op == 'in':
|
||
vs = val if isinstance(val, (list, tuple, set)) else [val]
|
||
built.append(col.in_(vs))
|
||
elif op == 'nin':
|
||
vs = val if isinstance(val, (list, tuple, set)) else [val]
|
||
built.append(~col.in_(vs))
|
||
elif op == 'lt':
|
||
built.append(col < val)
|
||
elif op == 'lte':
|
||
built.append(col <= val)
|
||
elif op == 'gt':
|
||
built.append(col > val)
|
||
elif op == 'gte':
|
||
built.append(col >= val)
|
||
else:
|
||
# unknown op — skip or raise
|
||
continue
|
||
|
||
# enforce child soft delete inside the EXISTS
|
||
if hasattr(target_cls, "is_deleted"):
|
||
built.append(ta.is_deleted == False)
|
||
|
||
crit = and_(*built) if built else None
|
||
exists_filters.append(rel_attr.of_type(ta).any(crit) if crit is not None
|
||
else rel_attr.of_type(ta).any())
|
||
|
||
# 4) Final filter list = normal SQLA filters + all EXISTS filters
|
||
return normal_filters + exists_filters
|
||
|
||
# ---- public read ops
|
||
|
||
def page(self, params=None, *, page: int = 1, per_page: int = 50, include_total: bool = True):
|
||
# Ensure seek_window uses `per_page`
|
||
params = dict(params or {})
|
||
params["limit"] = per_page
|
||
|
||
anchor_key = self._anchor_key_for_page(params, per_page, page)
|
||
win = self.seek_window(params, key=anchor_key, backward=False, include_total=include_total)
|
||
|
||
pages = None
|
||
if include_total and win.total is not None and per_page:
|
||
# class ceil(total / per_page) // per_page
|
||
pages = (win.total + per_page - 1) // per_page
|
||
|
||
return {
|
||
"items": win.items,
|
||
"page": page,
|
||
"per_page": per_page,
|
||
"total": win.total,
|
||
"pages": pages,
|
||
"order": [str(c) for c in win.order.cols],
|
||
}
|
||
|
||
def seek_window(
|
||
self,
|
||
params: dict | None = None,
|
||
*,
|
||
key: list[Any] | None = None,
|
||
backward: bool = False,
|
||
include_total: bool = True,
|
||
) -> "SeekWindow[T]":
|
||
session = self.session
|
||
query, root_alias = self.get_query()
|
||
query = self._apply_not_deleted(query, root_alias, params)
|
||
|
||
plan = self._plan(params, root_alias)
|
||
query = self._apply_projection_load_only(query, root_alias, plan)
|
||
query = self._apply_firsthop_strategies(query, root_alias, plan)
|
||
query = self._apply_soft_delete_criteria_for_children(query, plan, params)
|
||
if plan.filters:
|
||
filters = self._final_filters(root_alias, plan)
|
||
if filters:
|
||
query = query.filter(*filters)
|
||
|
||
order_spec = self._extract_order_spec(root_alias, plan.order_by)
|
||
limit = 50 if plan.limit is None else (None if plan.limit == 0 else plan.limit)
|
||
|
||
if key:
|
||
pred = self._key_predicate(order_spec, key, backward)
|
||
if pred is not None:
|
||
query = query.filter(pred)
|
||
|
||
clauses = [(c.desc() if d else c.asc()) for c, d in zip(order_spec.cols, order_spec.desc)]
|
||
if backward:
|
||
clauses = [(c.asc() if d else c.desc()) for c, d in zip(order_spec.cols, order_spec.desc)]
|
||
query = query.order_by(*clauses)
|
||
if limit is not None:
|
||
query = query.limit(limit)
|
||
|
||
query = self._apply_proj_opts(query, plan)
|
||
rows = query.all()
|
||
items = list(reversed(rows)) if backward else rows
|
||
|
||
proj = self._projection_meta(plan)
|
||
self._tag_projection(items, proj)
|
||
|
||
# cursor keys
|
||
def pluck(obj):
|
||
vals = []
|
||
alias_to_rel = {}
|
||
for _p, rel_attr, target_alias in plan.join_paths:
|
||
sel = getattr(target_alias, "selectable", None)
|
||
if sel is not None:
|
||
alias_to_rel[sel] = rel_attr.key
|
||
|
||
for col in order_spec.cols:
|
||
keyname = getattr(col, "key", None) or getattr(col, "name", None)
|
||
if keyname and hasattr(obj, keyname):
|
||
vals.append(getattr(obj, keyname)); continue
|
||
table = getattr(col, "table", None)
|
||
relname = alias_to_rel.get(table)
|
||
if relname and keyname:
|
||
relobj = getattr(obj, relname, None)
|
||
if relobj is not None and hasattr(relobj, keyname):
|
||
vals.append(getattr(relobj, keyname)); continue
|
||
raise ValueError("unpluckable")
|
||
return vals
|
||
|
||
try:
|
||
first_key = pluck(items[0]) if items else None
|
||
last_key = pluck(items[-1]) if items else None
|
||
except Exception:
|
||
first_key = last_key = None
|
||
|
||
total = None
|
||
if include_total:
|
||
base = session.query(getattr(root_alias, "id"))
|
||
base = self._apply_not_deleted(base, root_alias, params)
|
||
for _b, rel_attr, target_alias in plan.join_paths:
|
||
if not bool(getattr(getattr(rel_attr, "property", None), "uselist", False)):
|
||
base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
|
||
if plan.filters:
|
||
filters = self._final_filters(root_alias, plan)
|
||
if filters:
|
||
base = base.filter(*filters) # <-- use base, not query
|
||
total = session.query(func.count()).select_from(
|
||
base.order_by(None).distinct().subquery()
|
||
).scalar() or 0
|
||
|
||
if log.isEnabledFor(logging.DEBUG):
|
||
log.debug("QUERY: %s", str(query))
|
||
|
||
window_limit_for_body = 0 if limit is None and (plan.limit == 0) else (limit or 50)
|
||
return SeekWindow(
|
||
items=items,
|
||
limit=window_limit_for_body,
|
||
first_key=first_key,
|
||
last_key=last_key,
|
||
order=order_spec,
|
||
total=total,
|
||
)
|
||
|
||
def get(self, id: int, params=None) -> T | None:
|
||
query, root_alias = self.get_query()
|
||
query = self._apply_not_deleted(query, root_alias, params)
|
||
|
||
plan = self._plan(params, root_alias)
|
||
query = self._apply_projection_load_only(query, root_alias, plan)
|
||
query = self._apply_firsthop_strategies(query, root_alias, plan)
|
||
query = self._apply_soft_delete_criteria_for_children(query, plan, params)
|
||
if plan.filters:
|
||
filters = self._final_filters(root_alias, plan)
|
||
if filters:
|
||
query = query.filter(*filters)
|
||
query = query.filter(getattr(root_alias, "id") == id)
|
||
query = self._apply_proj_opts(query, plan)
|
||
|
||
obj = query.first()
|
||
proj = self._projection_meta(plan)
|
||
if obj:
|
||
self._tag_projection(obj, proj)
|
||
|
||
if log.isEnabledFor(logging.DEBUG):
|
||
log.debug("QUERY: %s", str(query))
|
||
return obj or None
|
||
|
||
def list(self, params=None) -> list[T]:
|
||
query, root_alias = self.get_query()
|
||
plan = self._plan(params, root_alias)
|
||
query = self._apply_not_deleted(query, root_alias, params)
|
||
query = self._apply_projection_load_only(query, root_alias, plan)
|
||
query = self._apply_firsthop_strategies(query, root_alias, plan)
|
||
query = self._apply_soft_delete_criteria_for_children(query, plan, params)
|
||
if plan.filters:
|
||
filters = self._final_filters(root_alias, plan)
|
||
if filters:
|
||
query = query.filter(*filters)
|
||
|
||
order_by = plan.order_by
|
||
paginating = (plan.limit is not None) or (plan.offset not in (None, 0))
|
||
if paginating and not order_by and self.backend.requires_order_by_for_offset:
|
||
order_by = self._default_order_by(root_alias)
|
||
order_by = _dedupe_order_by(order_by)
|
||
if order_by:
|
||
query = query.order_by(*order_by)
|
||
|
||
default_cap = getattr(current_app.config, "CRUDKIT_DEFAULT_LIST_LIMIT", 200)
|
||
if plan.offset:
|
||
query = query.offset(plan.offset)
|
||
if plan.limit and plan.limit > 0:
|
||
query = query.limit(plan.limit)
|
||
elif plan.limit is None and default_cap:
|
||
query = query.limit(default_cap)
|
||
|
||
query = self._apply_proj_opts(query, plan)
|
||
rows = query.all()
|
||
|
||
proj = self._projection_meta(plan)
|
||
self._tag_projection(rows, proj)
|
||
|
||
if log.isEnabledFor(logging.DEBUG):
|
||
log.debug("QUERY: %s", str(query))
|
||
return rows
|
||
|
||
# ---- write ops
|
||
|
||
def create(self, data: dict, actor=None, *, commit: bool = True) -> T:
|
||
session = self.session
|
||
obj = self.model(**data)
|
||
session.add(obj)
|
||
session.flush()
|
||
self._log_version("create", obj, actor, commit=commit)
|
||
if commit:
|
||
session.commit()
|
||
return obj
|
||
|
||
def update(self, id: int, data: dict, actor=None, *, commit: bool = True) -> T:
|
||
session = self.session
|
||
obj = session.get(self.model, id)
|
||
if not obj:
|
||
raise ValueError(f"{self.model.__name__} with ID {id} not found.")
|
||
|
||
before = obj.as_dict()
|
||
norm = normalize_payload(data, self.model)
|
||
incoming = filter_to_columns(norm, self.model)
|
||
desired = {**before, **incoming}
|
||
|
||
proposed = deep_diff(before, desired, ignore_keys={"id", "created_at", "updated_at"}, list_mode="index")
|
||
patch = diff_to_patch(proposed)
|
||
if not patch:
|
||
return obj
|
||
|
||
for k, v in patch.items():
|
||
setattr(obj, k, v)
|
||
|
||
dirty = any(inspect(obj).attrs[k].history.has_changes() for k in patch.keys())
|
||
if not dirty:
|
||
return obj
|
||
|
||
if commit:
|
||
session.commit()
|
||
|
||
after = obj.as_dict()
|
||
actual = deep_diff(before, after, ignore_keys={"id", "created_at", "updated_at"}, list_mode="index")
|
||
if not (actual["added"] or actual["removed"] or actual["changed"]):
|
||
return obj
|
||
|
||
self._log_version("update", obj, actor, metadata={"diff": actual, "patch": patch}, commit=commit)
|
||
return obj
|
||
|
||
def delete(self, id: int, hard: bool = False, actor=None, *, commit: bool = True):
|
||
session = self.session
|
||
obj = session.get(self.model, id)
|
||
if not obj:
|
||
return None
|
||
if hard or not self.supports_soft_delete:
|
||
session.delete(obj)
|
||
else:
|
||
cast(_SoftDeletable, obj).is_deleted = True
|
||
if commit:
|
||
session.commit()
|
||
self._log_version("delete", obj, actor, commit=commit)
|
||
return obj
|
||
|
||
# ---- audit
|
||
|
||
def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict | None = None, *, commit: bool = True):
|
||
session = self.session
|
||
try:
|
||
try:
|
||
snapshot = obj.as_dict()
|
||
except Exception:
|
||
snapshot = {"error": "serialize failed"}
|
||
version = Version(
|
||
model_name=self.model.__name__,
|
||
object_id=obj.id,
|
||
change_type=change_type,
|
||
data=to_jsonable(snapshot),
|
||
actor=str(actor) if actor else None,
|
||
meta=to_jsonable(metadata) if metadata else None,
|
||
)
|
||
session.add(version)
|
||
if commit:
|
||
session.commit()
|
||
except Exception as e:
|
||
log.warning(f"Version logging failed for {self.model.__name__} id={getattr(obj, 'id', '?')}: {str(e)}")
|
||
session.rollback()
|