Compare commits

..

2 commits

Author SHA1 Message Date
Yaro Kasear
2ad327fcd9 Refactored the service to be less painful and redundant. 2025-10-06 14:36:08 -05:00
Yaro Kasear
5ad652d372 Trying in vain to fix filtering. 2025-10-06 14:28:24 -05:00
2 changed files with 317 additions and 495 deletions

View file

@ -1,11 +1,12 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import dataclass
from flask import current_app from flask import current_app
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
from sqlalchemy import and_, func, inspect, or_, text from sqlalchemy import and_, func, inspect, or_, text
from sqlalchemy.engine import Engine, Connection from sqlalchemy.engine import Engine, Connection
from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager, selectinload from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, selectinload
from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql import operators from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import UnaryExpression, ColumnElement from sqlalchemy.sql.elements import UnaryExpression, ColumnElement
@ -37,41 +38,68 @@ class _SoftDeletable(Protocol):
is_deleted: bool is_deleted: bool
class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol): class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol):
"""Minimal surface that our CRUD service relies on. Soft-delete is optional.""" """Surface expected by CRUDService."""
pass pass
T = TypeVar("T", bound=_CRUDModelProto) T = TypeVar("T", bound=_CRUDModelProto)
# ---------------------------- utilities ----------------------------
def _collect_tables_from_filters(filters) -> set:
"""Walk SQLA expressions to collect Table/Alias objects that appear in filters."""
seen = set()
def visit(node):
if node is None:
return
tbl = getattr(node, "table", None)
if tbl is not None:
cur = tbl
while cur is not None:
seen.add(cur)
cur = getattr(cur, "element", None)
for attr in ("get_children",):
fn = getattr(node, attr, None)
if fn:
for ch in fn():
visit(ch)
for attr in ("left", "right", "element", "clause", "clauses"):
val = getattr(node, attr, None)
if val is None:
continue
if isinstance(val, (list, tuple)):
for v in val: visit(v)
else:
visit(val)
for f in (filters or []):
visit(f)
return seen
def _unwrap_ob(ob): def _unwrap_ob(ob):
"""Return (col, is_desc) from an ORDER BY element (handles .asc()/.desc()).""" elem = getattr(ob, "element", None)
col = getattr(ob, "element", None) col = elem if elem is not None else ob
if col is None:
col = ob d = getattr(ob, "_direction", None)
is_desc = False if d is not None:
dir_attr = getattr(ob, "_direction", None) is_desc = (d is operators.desc_op) or (getattr(d, "name", "").upper() == "DESC")
if dir_attr is not None:
is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC")
elif isinstance(ob, UnaryExpression): elif isinstance(ob, UnaryExpression):
op = getattr(ob, "operator", None) op = getattr(ob, "operator", None)
is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC")
else:
is_desc = False
return col, bool(is_desc) return col, bool(is_desc)
def _order_identity(col: ColumnElement): def _order_identity(col: ColumnElement):
"""
Build a stable identity for a column suitable for deduping.
We ignore direction here. Duplicates are duplicates regardless of ASC/DESC.
"""
table = getattr(col, "table", None) table = getattr(col, "table", None)
table_key = getattr(table, "key", None) or id(table) table_key = getattr(table, "key", None) or id(table)
col_key = getattr(col, "key", None) or getattr(col, "name", None) col_key = getattr(col, "key", None) or getattr(col, "name", None)
return (table_key, col_key) return (table_key, col_key)
def _dedupe_order_by(order_by): def _dedupe_order_by(order_by):
"""Remove duplicate ORDER BY entries (by column identity, ignoring direction)."""
if not order_by: if not order_by:
return [] return []
seen = set() seen, out = set(), []
out = []
for ob in order_by: for ob in order_by:
col, _ = _unwrap_ob(ob) col, _ = _unwrap_ob(ob)
ident = _order_identity(col) ident = _order_identity(col)
@ -98,6 +126,8 @@ def _normalize_fields_param(params: dict | None) -> list[str]:
return [p for p in (s.strip() for s in raw.split(",")) if p] return [p for p in (s.strip() for s in raw.split(",")) if p]
return [] return []
# ---------------------------- CRUD service ----------------------------
class CRUDService(Generic[T]): class CRUDService(Generic[T]):
def __init__( def __init__(
self, self,
@ -111,21 +141,19 @@ class CRUDService(Generic[T]):
self._session_factory = session_factory self._session_factory = session_factory
self.polymorphic = polymorphic self.polymorphic = polymorphic
self.supports_soft_delete = hasattr(model, 'is_deleted') self.supports_soft_delete = hasattr(model, 'is_deleted')
self._backend: Optional[BackendInfo] = backend self._backend: Optional[BackendInfo] = backend
# ---- infra
@property @property
def session(self) -> Session: def session(self) -> Session:
"""Always return the Flask-scoped Session if available; otherwise the provided factory."""
try: try:
sess = current_app.extensions["crudkit"]["Session"] return current_app.extensions["crudkit"]["Session"]
return sess
except Exception: except Exception:
return self._session_factory() return self._session_factory()
@property @property
def backend(self) -> BackendInfo: def backend(self) -> BackendInfo:
"""Resolve backend info lazily against the active session's engine."""
if self._backend is None: if self._backend is None:
bind = self.session.get_bind() bind = self.session.get_bind()
eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind) eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind)
@ -138,265 +166,13 @@ class CRUDService(Generic[T]):
return self.session.query(poly), poly return self.session.query(poly), poly
return self.session.query(self.model), self.model return self.session.query(self.model), self.model
def _debug_bind(self, where: str): # ---- common building blocks
try:
bind = self.session.get_bind()
eng = getattr(bind, "engine", bind)
print(f"SERVICE BIND [{where}]: engine_id={id(eng)} url={getattr(eng, 'url', '?')} session={type(self.session).__name__}")
except Exception as e:
print(f"SERVICE BIND [{where}]: failed to introspect bind: {e}")
def _apply_not_deleted(self, query, root_alias, params) -> Any: def _apply_not_deleted(self, query, root_alias, params):
if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")): if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")):
return query.filter(getattr(root_alias, "is_deleted") == False) return query.filter(getattr(root_alias, "is_deleted") == False)
return query return query
def _extract_order_spec(self, root_alias, given_order_by):
"""
SQLAlchemy 2.x only:
Normalize order_by into (cols, desc_flags). Supports plain columns and
col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses.
"""
given = self._stable_order_by(root_alias, given_order_by)
cols, desc_flags = [], []
for ob in given:
# Unwrap column if this is a UnaryExpression produced by .asc()/.desc()
elem = getattr(ob, "element", None)
col = elem if elem is not None else ob
# Detect direction in SA 2.x
is_desc = False
dir_attr = getattr(ob, "_direction", None)
if dir_attr is not None:
is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC")
elif isinstance(ob, UnaryExpression):
op = getattr(ob, "operator", None)
is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC")
cols.append(col)
desc_flags.append(bool(is_desc))
return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags))
def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool):
if not key_vals:
return None
conds = []
for i, col in enumerate(spec.cols):
# If NULLs possible, normalize for comparison. Example using coalesce to a sentinel:
# sent_col = func.coalesce(col, literal("-∞"))
sent_col = col
ties = [spec.cols[j] == key_vals[j] for j in range(i)]
is_desc = spec.desc[i]
if not backward:
op = (sent_col < key_vals[i]) if is_desc else (sent_col > key_vals[i])
else:
op = (sent_col > key_vals[i]) if is_desc else (sent_col < key_vals[i])
conds.append(and_(*ties, op))
return or_(*conds)
def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]:
out = []
for c in spec.cols:
# Only simple mapped columns supported for key pluck
key = getattr(c, "key", None) or getattr(c, "name", None)
if key is None or not hasattr(obj, key):
raise ValueError("Order includes non-mapped or related column; cannot pluck cursor key from row object.")
out.append(getattr(obj, key))
return out
def seek_window(
self,
params: dict | None = None,
*,
key: list[Any] | None = None,
backward: bool = False,
include_total: bool = True,
) -> "SeekWindow[T]":
"""
Keyset pagination with relationship-safe filtering/sorting.
Always JOIN all CRUDSpec-discovered paths first; then apply filters, sort, seek.
"""
session = self.session
query, root_alias = self.get_query()
# Requested fields → projection + optional loaders
fields = _normalize_fields_param(params)
expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], [])
spec = CRUDSpec(self.model, params or {}, root_alias)
# Parse all inputs so join_paths are populated
filters = spec.parse_filters()
order_by = spec.parse_sort()
root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields()
spec.parse_includes()
join_paths = tuple(spec.get_join_paths())
# Soft delete
query = self._apply_not_deleted(query, root_alias, params)
# Root column projection (load_only)
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols))
# Detect first hops that have deeper, nested tails requested (e.g. "contact.supervisor")
nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 }
# IMPORTANT:
# - Only attach loader options for first-hop relations from the root.
# - Always use selectinload here (avoid contains_eager joins).
# - Let compile_projections() supply deep chained options.
for base_alias, rel_attr, target_alias in join_paths:
is_firsthop_from_root = (base_alias is root_alias)
if not is_firsthop_from_root:
# Deeper hops are handled by proj_opts below
continue
prop = getattr(rel_attr, "property", None)
is_collection = bool(getattr(prop, "uselist", False))
is_nested_firsthop = rel_attr.key in nested_first_hops
opt = selectinload(rel_attr)
# Optional narrowng for collections
if is_collection:
child_names = (collection_field_names or {}).get(rel_attr.key, [])
if child_names:
target_cls = prop.mapper.class_
cols = [getattr(target_cls, n, None) for n in child_names]
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
if cols:
opt = opt.load_only(*cols)
query = query.options(opt)
# Filters AFTER joins → no cartesian products
if filters:
query = query.filter(*filters)
# Order spec (with PK tie-breakers for stability)
order_spec = self._extract_order_spec(root_alias, order_by)
limit, _ = spec.parse_pagination()
if limit is None:
effective_limit = 50
elif limit == 0:
effective_limit = None # unlimited
else:
effective_limit = limit
# Seek predicate from cursor key (if any)
if key:
pred = self._key_predicate(order_spec, key, backward)
if pred is not None:
query = query.filter(pred)
# Apply ORDER and LIMIT. Backward is SQL-inverted + reverse in-memory.
if not backward:
clauses = [(c.desc() if is_desc else c.asc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)]
query = query.order_by(*clauses)
if effective_limit is not None:
query = query.limit(effective_limit)
items = query.all()
else:
inv_clauses = [(c.asc() if is_desc else c.desc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)]
query = query.order_by(*inv_clauses)
if effective_limit is not None:
query = query.limit(effective_limit)
items = list(reversed(query.all()))
# Projection meta tag for renderers
if fields:
proj = list(dict.fromkeys(fields))
if "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
else:
proj = []
if root_field_names:
proj.extend(root_field_names)
if root_fields:
proj.extend(c.key for c in root_fields if hasattr(c, "key"))
for path, names in (rel_field_names or {}).items():
prefix = ".".join(path)
for n in names:
proj.append(f"{prefix}.{n}")
if proj and "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
if proj:
for obj in items:
try:
setattr(obj, "__crudkit_projection__", tuple(proj))
except Exception:
pass
# Cursor key pluck: support related columns we hydrated via contains_eager
def _pluck_key_from_obj(obj: Any) -> list[Any]:
vals: list[Any] = []
alias_to_rel: dict[Any, str] = {}
for _p, relationship_attr, target_alias in join_paths:
sel = getattr(target_alias, "selectable", None)
if sel is not None:
alias_to_rel[sel] = relationship_attr.key
for col in order_spec.cols:
keyname = getattr(col, "key", None) or getattr(col, "name", None)
if keyname and hasattr(obj, keyname):
vals.append(getattr(obj, keyname))
continue
table = getattr(col, "table", None)
relname = alias_to_rel.get(table)
if relname and keyname:
relobj = getattr(obj, relname, None)
if relobj is not None and hasattr(relobj, keyname):
vals.append(getattr(relobj, keyname))
continue
raise ValueError("unpluckable")
return vals
try:
first_key = _pluck_key_from_obj(items[0]) if items else None
last_key = _pluck_key_from_obj(items[-1]) if items else None
except Exception:
first_key = None
last_key = None
# Count DISTINCT ids with mirrored joins
# Apply deep projection loader options (safe: we avoided contains_eager)
if proj_opts:
query = query.options(*proj_opts)
total = None
if include_total:
base = session.query(getattr(root_alias, "id"))
base = self._apply_not_deleted(base, root_alias, params)
# same joins as above for correctness
for base_alias, rel_attr, target_alias in join_paths:
# do not join collections for COUNT mirror
if not bool(getattr(getattr(rel_attr, "property", None), "uselist", False)):
base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
if filters:
base = base.filter(*filters)
total = session.query(func.count()).select_from(
base.order_by(None).distinct().subquery()
).scalar() or 0
window_limit_for_body = 0 if effective_limit is None and (limit == 0) else (effective_limit or 50)
if log.isEnabledFor(logging.DEBUG):
log.debug("QUERY: %s", str(query))
return SeekWindow(
items=items,
limit=window_limit_for_body,
first_key=first_key,
last_key=last_key,
order=order_spec,
total=total,
)
# Helper: default ORDER BY for MSSQL when paginating without explicit order
def _default_order_by(self, root_alias): def _default_order_by(self, root_alias):
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
cols = [] cols = []
@ -408,176 +184,114 @@ class CRUDService(Generic[T]):
return cols or [text("1")] return cols or [text("1")]
def _stable_order_by(self, root_alias, given_order_by): def _stable_order_by(self, root_alias, given_order_by):
"""
Ensure deterministic ordering by appending PK columns as tiebreakers.
If no order is provided, fall back to default primary-key order.
Never duplicates columns in ORDER BY (SQL Server requires uniqueness).
"""
order_by = list(given_order_by or []) order_by = list(given_order_by or [])
if not order_by: if not order_by:
return _dedupe_order_by(self._default_order_by(root_alias)) return _dedupe_order_by(self._default_order_by(root_alias))
order_by = _dedupe_order_by(order_by) order_by = _dedupe_order_by(order_by)
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
present = {_order_identity(_unwrap_ob(ob)[0]) for ob in order_by} present = {_order_identity(_unwrap_ob(ob)[0]) for ob in order_by}
for pk in mapper.primary_key: for pk in mapper.primary_key:
try: try:
pk_col = getattr(root_alias, pk.key) pk_col = getattr(root_alias, pk.key)
except AttributeError: except AttributeError:
pk_col = pk pk_col = pk
if _order_identity(pk_col) not in present: ident = _order_identity(pk_col)
if ident not in present:
order_by.append(pk_col.asc()) order_by.append(pk_col.asc())
present.add(_order_identity(pk_col)) present.add(ident)
return order_by return order_by
def get(self, id: int, params=None) -> T | None: def _extract_order_spec(self, root_alias, given_order_by):
""" given = self._stable_order_by(root_alias, given_order_by)
Fetch a single row by id with conflict-free eager loading and clean projection. cols, desc_flags = [], []
Always JOIN any paths that CRUDSpec resolved for filters/fields/includes so for ob in given:
related-column filters never create cartesian products. elem = getattr(ob, "element", None)
""" col = elem if elem is not None else ob
query, root_alias = self.get_query() is_desc = False
dir_attr = getattr(ob, "_direction", None)
if dir_attr is not None:
is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC")
elif isinstance(ob, UnaryExpression):
op = getattr(ob, "operator", None)
is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC")
cols.append(col)
desc_flags.append(bool(is_desc))
return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags))
# Defaults so we can build a projection even if params is None def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool):
req_fields: list[str] = _normalize_fields_param(params) if not key_vals:
root_fields: list[Any] = [] return None
root_field_names: dict[str, str] = {} conds = []
rel_field_names: dict[tuple[str, ...], list[str]] = {} for i, col in enumerate(spec.cols):
ties = [spec.cols[j] == key_vals[j] for j in range(i)]
is_desc = spec.desc[i]
op = (col < key_vals[i]) if is_desc ^ backward else (col > key_vals[i])
conds.append(and_(*ties, op))
return or_(*conds)
# Soft-delete guard first # ---- planning and application
query = self._apply_not_deleted(query, root_alias, params)
@dataclass(slots=True)
class _Plan:
spec: Any
filters: Any
order_by: Any
limit: Any
offset: Any
root_fields: Any
rel_field_names: Any
root_field_names: Any
collection_field_names: Any
join_paths: Any
filter_tables: Any
req_fields: Any
proj_opts: Any
def _plan(self, params, root_alias) -> _Plan:
req_fields = _normalize_fields_param(params)
spec = CRUDSpec(self.model, params or {}, root_alias) spec = CRUDSpec(self.model, params or {}, root_alias)
# Parse everything so CRUDSpec records any join paths it needed to resolve
filters = spec.parse_filters() filters = spec.parse_filters()
# no ORDER BY for get() order_by = spec.parse_sort()
if params: limit, offset = spec.parse_pagination()
root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() if params else ([], {}, {}, {})
spec.parse_includes() spec.parse_includes()
join_paths = tuple(spec.get_join_paths()) join_paths = tuple(spec.get_join_paths())
filter_tables = _collect_tables_from_filters(filters)
_, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
# Root-column projection (load_only) return self._Plan(
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] spec=spec, filters=filters, order_by=order_by, limit=limit, offset=offset,
if only_cols: root_fields=root_fields, rel_field_names=rel_field_names,
query = query.options(Load(root_alias).load_only(*only_cols)) root_field_names=root_field_names, collection_field_names=collection_field_names,
join_paths=join_paths, filter_tables=filter_tables,
req_fields=req_fields, proj_opts=proj_opts
)
nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } def _apply_projection_load_only(self, query, root_alias, plan: _Plan):
only_cols = [c for c in plan.root_fields if isinstance(c, InstrumentedAttribute)]
return query.options(Load(root_alias).load_only(*only_cols)) if only_cols else query
# First-hop only; use selectinload (no contains_eager) def _apply_firsthop_strategies(self, query, root_alias, plan: _Plan):
for base_alias, rel_attr, target_alias in join_paths: nested_first_hops = { p[0] for p in (plan.rel_field_names or {}).keys() if len(p) > 1 }
is_firsthop_from_root = (base_alias is root_alias) for base_alias, rel_attr, target_alias in plan.join_paths:
if not is_firsthop_from_root: if base_alias is not root_alias:
continue continue
prop = getattr(rel_attr, "property", None) prop = getattr(rel_attr, "property", None)
is_collection = bool(getattr(prop, "uselist", False)) is_collection = bool(getattr(prop, "uselist", False))
_is_nested_firsthop = rel_attr.key in nested_first_hops
opt = selectinload(rel_attr) sel = getattr(target_alias, "selectable", None)
if is_collection: sel_elem = getattr(sel, "element", None)
child_names = (collection_field_names or {}).get(rel_attr.key, []) base_sel = sel_elem if sel_elem is not None else sel
if child_names:
target_cls = prop.mapper.class_
cols = [getattr(target_cls, n, None) for n in child_names]
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
if cols:
opt = opt.load_only(*cols)
query = query.options(opt)
# Apply filters (joins are in place → no cartesian products) needed_for_filter = (sel in plan.filter_tables) or (base_sel in plan.filter_tables)
if filters:
query = query.filter(*filters)
# And the id filter
query = query.filter(getattr(root_alias, "id") == id)
# Projection loader options compiled from requested fields.
# Skip if we used contains_eager to avoid loader-strategy conflicts.
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
if proj_opts:
query = query.options(*proj_opts)
obj = query.first()
# Emit exactly what the client requested (plus id), or a reasonable fallback
if req_fields:
proj = list(dict.fromkeys(req_fields))
if "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
else:
proj = []
if root_field_names:
proj.extend(root_field_names)
if root_fields:
proj.extend(c.key for c in root_fields if hasattr(c, "key"))
for path, names in (rel_field_names or {}).items():
prefix = ".".join(path)
for n in names:
proj.append(f"{prefix}.{n}")
if proj and "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
if proj and obj is not None:
try:
setattr(obj, "__crudkit_projection__", tuple(proj))
except Exception:
pass
if log.isEnabledFor(logging.DEBUG):
log.debug("QUERY: %s", str(query))
return obj or None
def list(self, params=None) -> list[T]:
"""
Offset/limit listing with relationship-safe filtering.
We always JOIN every CRUDSpec-discovered path before applying filters/sorts.
"""
query, root_alias = self.get_query()
# Defaults so we can reference them later even if params is None
req_fields: list[str] = _normalize_fields_param(params)
root_fields: list[Any] = []
root_field_names: dict[str, str] = {}
rel_field_names: dict[tuple[str, ...], list[str]] = {}
if params:
# Soft delete
query = self._apply_not_deleted(query, root_alias, params)
spec = CRUDSpec(self.model, params or {}, root_alias)
filters = spec.parse_filters()
order_by = spec.parse_sort()
limit, offset = spec.parse_pagination()
# Includes / fields (populates join_paths)
root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields()
spec.parse_includes()
join_paths = tuple(spec.get_join_paths())
# Root column projection (load_only)
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols))
nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 }
# First-hop only; use selectinload
for base_alias, rel_attr, target_alias in join_paths:
is_firsthop_from_root = (base_alias is root_alias)
if not is_firsthop_from_root:
continue
prop = getattr(rel_attr, "property", None)
is_collection = bool(getattr(prop, "uselist", False))
_is_nested_firsthop = rel_attr.key in nested_first_hops
if needed_for_filter and not is_collection:
query = query.join(rel_attr, isouter=True)
else:
opt = selectinload(rel_attr) opt = selectinload(rel_attr)
if is_collection: if is_collection:
child_names = (collection_field_names or {}).get(rel_attr.key, []) child_names = (plan.collection_field_names or {}).get(rel_attr.key, [])
if child_names: if child_names:
target_cls = prop.mapper.class_ target_cls = prop.mapper.class_
cols = [getattr(target_cls, n, None) for n in child_names] cols = [getattr(target_cls, n, None) for n in child_names]
@ -585,79 +299,194 @@ class CRUDService(Generic[T]):
if cols: if cols:
opt = opt.load_only(*cols) opt = opt.load_only(*cols)
query = query.options(opt) query = query.options(opt)
return query
# Filters AFTER joins → no cartesian products def _apply_proj_opts(self, query, plan: _Plan):
if filters: return query.options(*plan.proj_opts) if plan.proj_opts else query
query = query.filter(*filters)
# MSSQL requires ORDER BY when OFFSET is used; ensure stable PK tie-breakers def _projection_meta(self, plan: _Plan):
paginating = (limit is not None) or (offset is not None and offset != 0) if plan.req_fields:
if paginating and not order_by and self.backend.requires_order_by_for_offset: proj = list(dict.fromkeys(plan.req_fields))
order_by = self._default_order_by(root_alias) return ["id"] + proj if "id" not in proj and hasattr(self.model, "id") else proj
order_by = _dedupe_order_by(order_by) proj: list[str] = []
if plan.root_field_names:
proj.extend(plan.root_field_names)
if plan.root_fields:
proj.extend(c.key for c in plan.root_fields if hasattr(c, "key"))
for path, names in (plan.rel_field_names or {}).items():
prefix = ".".join(path)
for n in names:
proj.append(f"{prefix}.{n}")
if proj and "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
return proj
if order_by: def _tag_projection(self, items, proj):
query = query.order_by(*order_by) if not proj:
return
for obj in items if isinstance(items, list) else [items]:
try:
setattr(obj, "__crudkit_projection__", tuple(proj))
except Exception:
pass
# Offset/limit # ---- public read ops
if offset is not None and offset != 0:
query = query.offset(offset)
if limit is not None and limit > 0:
query = query.limit(limit)
# Projection loaders only if we didnt use contains_eager def seek_window(
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) self,
if proj_opts: params: dict | None = None,
query = query.options(*proj_opts) *,
key: list[Any] | None = None,
backward: bool = False,
include_total: bool = True,
) -> "SeekWindow[T]":
session = self.session
query, root_alias = self.get_query()
query = self._apply_not_deleted(query, root_alias, params)
else: plan = self._plan(params, root_alias)
# No params; still honor projection loaders if any query = self._apply_projection_load_only(query, root_alias, plan)
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) query = self._apply_firsthop_strategies(query, root_alias, plan)
if proj_opts: if plan.filters:
query = query.options(*proj_opts) query = query.filter(*plan.filters)
order_spec = self._extract_order_spec(root_alias, plan.order_by)
limit = 50 if plan.limit is None else (None if plan.limit == 0 else plan.limit)
if key:
pred = self._key_predicate(order_spec, key, backward)
if pred is not None:
query = query.filter(pred)
clauses = [(c.desc() if d else c.asc()) for c, d in zip(order_spec.cols, order_spec.desc)]
if backward:
clauses = [(c.asc() if d else c.desc()) for c, d in zip(order_spec.cols, order_spec.desc)]
query = query.order_by(*clauses)
if limit is not None:
query = query.limit(limit)
query = self._apply_proj_opts(query, plan)
rows = query.all() rows = query.all()
items = list(reversed(rows)) if backward else rows
# Build projection meta for renderers proj = self._projection_meta(plan)
if req_fields: self._tag_projection(items, proj)
proj = list(dict.fromkeys(req_fields))
if "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
else:
proj = []
if root_field_names:
proj.extend(root_field_names)
if root_fields:
proj.extend(c.key for c in root_fields if hasattr(c, "key"))
for path, names in (rel_field_names or {}).items():
prefix = ".".join(path)
for n in names:
proj.append(f"{prefix}.{n}")
if proj and "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
if proj: # cursor keys
for obj in rows: def pluck(obj):
try: vals = []
setattr(obj, "__crudkit_projection__", tuple(proj)) alias_to_rel = {}
except Exception: for _p, rel_attr, target_alias in plan.join_paths:
pass sel = getattr(target_alias, "selectable", None)
if sel is not None:
alias_to_rel[sel] = rel_attr.key
for col in order_spec.cols:
keyname = getattr(col, "key", None) or getattr(col, "name", None)
if keyname and hasattr(obj, keyname):
vals.append(getattr(obj, keyname)); continue
table = getattr(col, "table", None)
relname = alias_to_rel.get(table)
if relname and keyname:
relobj = getattr(obj, relname, None)
if relobj is not None and hasattr(relobj, keyname):
vals.append(getattr(relobj, keyname)); continue
raise ValueError("unpluckable")
return vals
try:
first_key = pluck(items[0]) if items else None
last_key = pluck(items[-1]) if items else None
except Exception:
first_key = last_key = None
total = None
if include_total:
base = session.query(getattr(root_alias, "id"))
base = self._apply_not_deleted(base, root_alias, params)
for _b, rel_attr, target_alias in plan.join_paths:
if not bool(getattr(getattr(rel_attr, "property", None), "uselist", False)):
base = base.join(rel_attr, isouter=True)
if plan.filters:
base = base.filter(*plan.filters)
total = session.query(func.count()).select_from(
base.order_by(None).distinct().subquery()
).scalar() or 0
if log.isEnabledFor(logging.DEBUG): if log.isEnabledFor(logging.DEBUG):
log.debug("QUERY: %s", str(query)) log.debug("QUERY: %s", str(query))
window_limit_for_body = 0 if limit is None and (plan.limit == 0) else (limit or 50)
return SeekWindow(
items=items,
limit=window_limit_for_body,
first_key=first_key,
last_key=last_key,
order=order_spec,
total=total,
)
def get(self, id: int, params=None) -> T | None:
query, root_alias = self.get_query()
query = self._apply_not_deleted(query, root_alias, params)
plan = self._plan(params, root_alias)
query = self._apply_projection_load_only(query, root_alias, plan)
query = self._apply_firsthop_strategies(query, root_alias, plan)
if plan.filters:
query = query.filter(*plan.filters)
query = query.filter(getattr(root_alias, "id") == id)
query = self._apply_proj_opts(query, plan)
obj = query.first()
proj = self._projection_meta(plan)
if obj:
self._tag_projection(obj, proj)
if log.isEnabledFor(logging.DEBUG):
log.debug("QUERY: %s", str(query))
return obj or None
def list(self, params=None) -> list[T]:
query, root_alias = self.get_query()
plan = self._plan(params, root_alias)
query = self._apply_not_deleted(query, root_alias, params)
query = self._apply_projection_load_only(query, root_alias, plan)
query = self._apply_firsthop_strategies(query, root_alias, plan)
if plan.filters:
query = query.filter(*plan.filters)
order_by = plan.order_by
paginating = (plan.limit is not None) or (plan.offset not in (None, 0))
if paginating and not order_by and self.backend.requires_order_by_for_offset:
order_by = self._default_order_by(root_alias)
order_by = _dedupe_order_by(order_by)
if order_by:
query = query.order_by(*order_by)
if plan.offset: query = query.offset(plan.offset)
if plan.limit and plan.limit > 0: query = query.limit(plan.limit)
query = self._apply_proj_opts(query, plan)
rows = query.all()
proj = self._projection_meta(plan)
self._tag_projection(rows, proj)
if log.isEnabledFor(logging.DEBUG):
log.debug("QUERY: %s", str(query))
return rows return rows
# ---- write ops
def create(self, data: dict, actor=None, *, commit: bool = True) -> T: def create(self, data: dict, actor=None, *, commit: bool = True) -> T:
session = self.session session = self.session
obj = self.model(**data) obj = self.model(**data)
session.add(obj) session.add(obj)
session.flush() session.flush()
self._log_version("create", obj, actor, commit=commit) self._log_version("create", obj, actor, commit=commit)
if commit: if commit:
session.commit() session.commit()
return obj return obj
@ -669,59 +498,34 @@ class CRUDService(Generic[T]):
raise ValueError(f"{self.model.__name__} with ID {id} not found.") raise ValueError(f"{self.model.__name__} with ID {id} not found.")
before = obj.as_dict() before = obj.as_dict()
# Normalize and restrict payload to real columns
norm = normalize_payload(data, self.model) norm = normalize_payload(data, self.model)
incoming = filter_to_columns(norm, self.model) incoming = filter_to_columns(norm, self.model)
# Build a synthetic "desired" state for top-level columns
desired = {**before, **incoming} desired = {**before, **incoming}
# Compute intended change set (before vs intended) proposed = deep_diff(before, desired, ignore_keys={"id", "created_at", "updated_at"}, list_mode="index")
proposed = deep_diff(
before, desired,
ignore_keys={"id", "created_at", "updated_at"},
list_mode="index",
)
patch = diff_to_patch(proposed) patch = diff_to_patch(proposed)
# Nothing to do
if not patch: if not patch:
return obj return obj
# Apply only what actually changes
for k, v in patch.items(): for k, v in patch.items():
setattr(obj, k, v) setattr(obj, k, v)
# Optional: skip commit if ORM says no real change (paranoid check)
# Note: is_modified can lie if attrs are expired; use history for certainty.
dirty = any(inspect(obj).attrs[k].history.has_changes() for k in patch.keys()) dirty = any(inspect(obj).attrs[k].history.has_changes() for k in patch.keys())
if not dirty: if not dirty:
return obj return obj
# Commit atomically
if commit: if commit:
session.commit() session.commit()
# AFTER snapshot for audit
after = obj.as_dict() after = obj.as_dict()
actual = deep_diff(before, after, ignore_keys={"id", "created_at", "updated_at"}, list_mode="index")
# Actual diff (captures triggers/defaults, still ignoring noisy keys)
actual = deep_diff(
before, after,
ignore_keys={"id", "created_at", "updated_at"},
list_mode="index",
)
# If truly nothing changed post-commit (rare), skip version spam
if not (actual["added"] or actual["removed"] or actual["changed"]): if not (actual["added"] or actual["removed"] or actual["changed"]):
return obj return obj
# Log both what we *intended* and what *actually* happened
self._log_version("update", obj, actor, metadata={"diff": actual, "patch": patch}, commit=commit) self._log_version("update", obj, actor, metadata={"diff": actual, "patch": patch}, commit=commit)
return obj return obj
def delete(self, id: int, hard: bool = False, actor = None, *, commit: bool = True): def delete(self, id: int, hard: bool = False, actor=None, *, commit: bool = True):
session = self.session session = self.session
obj = session.get(self.model, id) obj = session.get(self.model, id)
if not obj: if not obj:
@ -729,22 +533,21 @@ class CRUDService(Generic[T]):
if hard or not self.supports_soft_delete: if hard or not self.supports_soft_delete:
session.delete(obj) session.delete(obj)
else: else:
soft = cast(_SoftDeletable, obj) cast(_SoftDeletable, obj).is_deleted = True
soft.is_deleted = True
if commit: if commit:
session.commit() session.commit()
self._log_version("delete", obj, actor, commit=commit) self._log_version("delete", obj, actor, commit=commit)
return obj return obj
# ---- audit
def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict | None = None, *, commit: bool = True): def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict | None = None, *, commit: bool = True):
session = self.session session = self.session
try: try:
snapshot = {}
try: try:
snapshot = obj.as_dict() snapshot = obj.as_dict()
except Exception: except Exception:
snapshot = {"error": "serialize failed"} snapshot = {"error": "serialize failed"}
version = Version( version = Version(
model_name=self.model.__name__, model_name=self.model.__name__,
object_id=obj.id, object_id=obj.id,

View file

@ -25,6 +25,15 @@
{% endfor %} {% endfor %}
</ul> </ul>
<div class="mt-3">
<button type="button" class="btn btn-outline-primary btn-sm" onclick="addNewUpdate()">
Add update
</button>
</div>
<input type="hidden" name="updates" id="updatesPayload">
<input type="hidden" name="delete_update_ids" id="deleteUpdatesPayload">
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/github-markdown-css@5/github-markdown.min.css"> <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/github-markdown-css@5/github-markdown.min.css">
<style> <style>
textarea.auto-md { textarea.auto-md {
@ -45,6 +54,16 @@
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script> <script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/dompurify/dist/purify.min.js"></script> <script src="https://cdn.jsdelivr.net/npm/dompurify/dist/purify.min.js"></script>
<script> <script>
// Track deletions for existing notes
const deletedUpdateIds = new Set();
function addNewUpdate() {
// Create a temporary client-only id so we can manage the DOM before saving
const tempId = `new_${Date.now()}`;
const li = document.createElement('li');
li.className = 'list-group-item';
}
// Initial render // Initial render
document.addEventListener('DOMContentLoaded', () => { document.addEventListener('DOMContentLoaded', () => {
const ids = [ {% for n in items %} {{ n.id }}{% if not loop.last %}, {% endif %}{% endfor %} ]; const ids = [ {% for n in items %} {{ n.id }}{% if not loop.last %}, {% endif %}{% endfor %} ];