Lots of downstream changes.

This commit is contained in:
Yaro Kasear 2025-09-26 15:55:02 -05:00
parent d34654834b
commit d4e51affd5
5 changed files with 521 additions and 309 deletions

View file

@ -1,5 +1,6 @@
from typing import Any, Dict, Iterable, List, Tuple, Set
from sqlalchemy import Column, Integer, DateTime, Boolean, String, JSON, func, inspect from sqlalchemy import Column, Integer, DateTime, Boolean, String, JSON, func, inspect
from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE, RelationshipProperty
Base = declarative_base() Base = declarative_base()
@ -16,45 +17,296 @@ def _safe_get_loaded_attr(obj, name):
except Exception: except Exception:
return None return None
def _identity_key(obj) -> Tuple[type, Any]:
try:
st = inspect(obj)
return (type(obj), st.identity_key[1][0] if st.identity_key else id(obj))
except Exception:
return (type(obj), id(obj))
def _is_collection_rel(prop: RelationshipProperty) -> bool:
try:
return prop.uselist is True
except Exception:
return False
def _serialize_simple_obj(obj) -> Dict[str, Any]:
"""Columns only (no relationships)."""
out: Dict[str, Any] = {}
for cls in obj.__class__.__mro__:
if hasattr(cls, "__table__"):
for col in cls.__table__.columns:
name = col.name
try:
out[name] = getattr(obj, name)
except Exception:
out[name] = None
return out
def _serialize_loaded_rel(obj, name, *, depth: int, seen: Set[Tuple[type, Any]], embed: Set[str]) -> Any:
"""
Serialize relationship 'name' already loaded on obj.
- If in 'embed' (or depth > 0 for depth-based walk), recurse.
- Else, return None (dont lazy-load).
"""
val = _safe_get_loaded_attr(obj, name)
if val is None:
return None
# Decide whether to recurse into this relationship
should_recurse = (depth > 0) or (name in embed)
if isinstance(val, list):
if not should_recurse:
# Emit a light list of child primary data (id + a couple columns) without recursion.
return [_serialize_simple_obj(child) for child in val]
out = []
for child in val:
ik = _identity_key(child)
if ik in seen: # cycle guard
out.append({"id": getattr(child, "id", None)})
continue
seen.add(ik)
out.append(child.as_tree(embed_depth=max(depth - 1, 0), embed=embed, _seen=seen))
return out
# Scalar relationship
child = val
if not should_recurse:
return _serialize_simple_obj(child)
ik = _identity_key(child)
if ik in seen:
return {"id": getattr(child, "id", None)}
seen.add(ik)
return child.as_tree(embed_depth=max(depth - 1, 0), embed=embed, _seen=seen)
def _split_field_tokens(fields: Iterable[str]) -> Tuple[List[str], Dict[str, List[str]]]:
"""
Split requested fields into:
- scalars: ["label", "name"]
- collections: {"updates": ["id", "timestamp","content"], "owner": ["label"]}
Any dotted token "root.rest.of.path" becomes collections[root].append("rest.of.path").
Bare tokens ("foo") land in scalars.
"""
scalars: List[str] = []
groups: Dict[str, List[str]] = {}
for raw in fields:
f = str(raw).strip()
if not f:
continue
# bare token -> scalar
if "." not in f:
scalars.append(f)
continue
# dotted token -> group under root
root, tail = f.split(".", 1)
if not root or not tail:
continue
groups.setdefault(root, []).append(tail)
return scalars, groups
def _deep_get_loaded(obj: Any, dotted: str) -> Any:
"""
Deep get with no lazy loads:
- For all but the final hop, use _safe_get_loaded_attr (mapped-only, no getattr).
- For the final hop, try _safe_get_loaded_attr first; if None, fall back to getattr()
to allow computed properties/hybrids that rely on already-loaded columns.
"""
parts = dotted.split(".")
if not parts:
return None
cur = obj
# Traverse up to the parent of the last token safely
for part in parts[:-1]:
if cur is None:
return None
cur = _safe_get_loaded_attr(cur, part)
if cur is None:
return None
last = parts[-1]
# Try safe fetch on the last hop first
val = _safe_get_loaded_attr(cur, last)
if val is not None:
return val
# Fall back to getattr for computed/hybrid attributes on an already-loaded object
try:
return getattr(cur, last, None)
except Exception:
return None
def _serialize_leaf(obj: Any) -> Any:
"""
Lead serialization for values we put into as_dict():
- If object has as_dict(), call as_dict() with no args (caller controls field shapes).
- Else return value as-is (Flask/JSON encoder will handle datetimes, etc., via app config).
"""
if obj is None:
return None
ad = getattr(obj, "as_dict", None)
if callable(ad):
try:
return ad(None)
except Exception:
return str(obj)
return obj
def _serialize_collection(items: Iterable[Any], requested_tails: List[str]) -> List[Dict[str, Any]]:
"""
Turn a collection of ORM objects into list[dict] with exactly requested_tails,
where each tail can be dotted again (e.g., "author.label"). We do NOT lazy-load.
"""
out: List[Dict[str, Any]] = []
# Deduplicate while preserving order
uniq_tails = list(dict.fromkeys(requested_tails))
for child in (items or []):
row: Dict[str, Any] = {}
for tail in uniq_tails:
row[tail] = _deep_get_loaded(child, tail)
# ensure id present if exists and not already requested
try:
if "id" not in row and hasattr(child, "id"):
row["id"] = getattr(child, "id")
except Exception:
pass
out.append(row)
return out
@declarative_mixin @declarative_mixin
class CRUDMixin: class CRUDMixin:
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
created_at = Column(DateTime, default=func.now(), nullable=False) created_at = Column(DateTime, default=func.now(), nullable=False)
updated_at = Column(DateTime, default=func.now(), nullable=False, onupdate=func.now()) updated_at = Column(DateTime, default=func.now(), nullable=False, onupdate=func.now())
def as_tree(
self,
*,
embed_depth: int = 0,
embed: Iterable[str] | None = None,
_seen: Set[Tuple[type, Any]] | None = None,
) -> Dict[str, Any]:
"""
Recursive, NON-LAZY serializer.
- Always includes mapped columns.
- For relationships: only serializes those ALREADY LOADED.
- Recurses either up to embed_depth or for specific names in 'embed'.
- Keeps *_id columns alongside embedded objects.
- Cycle-safe via _seen.
"""
seen = _seen or set()
ik = _identity_key(self)
if ik in seen:
return {"id": getattr(self, "id", None)}
seen.add(ik)
data = _serialize_simple_obj(self)
# Determine which relationships to consider
try:
st = inspect(self)
mapper = st.mapper
embed_set = set(str(x).split(".", 1)[0] for x in (embed or [])) # top-level names
for name, prop in mapper.relationships.items():
# Only touch relationships that are already loaded; never lazy-load here.
rel_loaded = st.attrs.get(name)
if rel_loaded is None or rel_loaded.loaded_value is NO_VALUE:
continue
data[name] = _serialize_loaded_rel(
self, name, depth=embed_depth, seen=seen, embed=embed_set
)
except Exception:
# If inspection fails, we just return columns.
pass
return data
def as_dict(self, fields: list[str] | None = None): def as_dict(self, fields: list[str] | None = None):
""" """
Serialize the instance. Serialize the instance.
- If 'fields' (possibly dotted) is provided, emit exactly those keys.
- Else, if '__crudkit_projection__' is set on the instance, emit those keys.
- Else, fall back to all mapped columns on this class hierarchy.
Always includes 'id' when present unless explicitly excluded.
"""
if fields is None:
fields = getattr(self, "__crudkit_projection__", None)
if fields: Behavior:
out = {} - If 'fields' (possibly dotted) is provided, emit exactly those keys.
if "id" not in fields and hasattr(self, "id"): * Bare tokens (e.g., "label", "owner") return the current loaded value.
out["id"] = getattr(self, "id") * Dotted tokens for one-to-many (e.g., "updates.id","updates.timestamp")
for f in fields: produce a single "updates" key containing a list of dicts with the requested child keys.
cur = self * Dotted tokens for many-to-one/one-to-one (e.g., "owner.label") emit the scalar under "owner.label".
for part in f.split("."): - Else, if '__crudkit_projection__' is set on the instance, use that.
if cur is None: - Else, fall back to all mapped columns on this class hierarchy.
break
cur = getattr(cur, part, None) Always includes 'id' when present unless explicitly excluded (i.e., fields explicitly provided without id).
out[f] = cur """
req = fields if fields is not None else getattr(self, "__crudkit_projection__", None)
if req:
# Normalize and split into (scalars, groups of dotted by root)
req_list = [p for p in (str(x).strip() for x in req) if p]
scalars, groups = _split_field_tokens(req_list)
out: Dict[str, Any] = {}
# Always include id unless the caller explicitly listed fields containing id
if "id" not in req_list and hasattr(self, "id"):
try:
out["id"] = getattr(self, "id")
except Exception:
pass
# Handle scalar tokens (may be columns, hybrids/properties, or relationships)
for name in scalars:
# Try loaded value first (never lazy-load)
val = _safe_get_loaded_attr(self, name)
# Final-hop getattr for root scalars (hybrids/@property) so they can compute.
if val is None:
try:
val = getattr(self, name)
except Exception:
val = None
# If it's a scalar ORM object (relationship), serialize its columns
try:
st = inspect(val) # will raise if not an ORM object
if getattr(st, "mapper", None) is not None:
out[name] = _serialize_simple_obj(val)
continue
except Exception:
pass
# If it's a collection and no subfields were requested, emit a light list
if isinstance(val, (list, tuple)):
out[name] = [_serialize_leaf(v) for v in val]
else:
out[name] = val
# Handle dotted groups: root -> [tails]
for root, tails in groups.items():
root_val = _safe_get_loaded_attr(self, root)
if isinstance(root_val, (list, tuple)):
# one-to-many collection → list of dicts with the requested tails
out[root] = _serialize_collection(root_val, tails)
else:
# many-to-one or scalar dotted; place each full dotted path as key
for tail in tails:
dotted = f"{root}.{tail}"
out[dotted] = _deep_get_loaded(self, dotted)
# ← This was the placeholder before. We return the dict we just built.
return out return out
result = {} # Fallback: all mapped columns on this class hierarchy
result: Dict[str, Any] = {}
for cls in self.__class__.__mro__: for cls in self.__class__.__mro__:
if hasattr(cls, "__table__"): if hasattr(cls, "__table__"):
for column in cls.__table__.columns: for column in cls.__table__.columns:
name = column.name name = column.name
result[name] = getattr(self, name) try:
result[name] = getattr(self, name)
except Exception:
result[name] = None
return result return result
class Version(Base): class Version(Base):
__tablename__ = "versions" __tablename__ = "versions"

View file

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
from flask import current_app
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
from sqlalchemy import and_, func, inspect, or_, text from sqlalchemy import and_, func, inspect, or_, text
from sqlalchemy.engine import Engine, Connection from sqlalchemy.engine import Engine, Connection
@ -40,66 +41,43 @@ class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol):
T = TypeVar("T", bound=_CRUDModelProto) T = TypeVar("T", bound=_CRUDModelProto)
def _hops_from_sort(params: dict | None) -> set[str]: def _unwrap_ob(ob):
"""Extract first-hop relationship names from a sort spec like 'owner.first_name,-brand.name'.""" """Return (col, is_desc) from an ORDER BY element (handles .asc()/.desc())."""
if not params: col = getattr(ob, "element", None)
return set() if col is None:
raw = params.get("sort") col = ob
tokens: list[str] = [] is_desc = False
if isinstance(raw, str): dir_attr = getattr(ob, "_direction", None)
tokens = [t.strip() for t in raw.split(",") if t.strip()] if dir_attr is not None:
elif isinstance(raw, (list, tuple)): is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC")
for item in raw: elif isinstance(ob, UnaryExpression):
if isinstance(item, str): op = getattr(ob, "operator", None)
tokens.extend([t.strip() for t in item.split(",") if t.strip()]) is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC")
hops: set[str] = set() return col, bool(is_desc)
for tok in tokens:
tok = tok.lstrip("+-")
if "." in tok:
hops.add(tok.split(".", 1)[0])
return hops
def _belongs_to_alias(col: Any, alias: Any) -> bool: def _order_identity(col: ColumnElement):
# Try to detect if a column/expression ultimately comes from this alias. """
# Works for most ORM columns; complex expressions may need more. Build a stable identity for a column suitable for deduping.
t = getattr(col, "table", None) We ignore direction here. Duplicates are duplicates regardless of ASC/DESC.
selectable = getattr(alias, "selectable", None) """
return t is not None and selectable is not None and t is selectable table = getattr(col, "table", None)
table_key = getattr(table, "key", None) or id(table)
col_key = getattr(col, "key", None) or getattr(col, "name", None)
return (table_key, col_key)
def _paths_needed_for_sql(order_by: Iterable[Any], filters: Iterable[Any], join_paths: tuple) -> set[str]: def _dedupe_order_by(order_by):
hops: set[str] = set() """Remove duplicate ORDER BY entries (by column identity, ignoring direction)."""
paths: set[tuple[str, ...]] = set() if not order_by:
# Sort columns return []
for ob in order_by or []: seen = set()
col = getattr(ob, "element", ob) # unwrap UnaryExpression out = []
for _path, rel_attr, target_alias in join_paths: for ob in order_by:
if _belongs_to_alias(col, target_alias): col, _ = _unwrap_ob(ob)
hops.add(rel_attr.key) ident = _order_identity(col)
# Filter columns (best-effort) if ident in seen:
# Walk simple binary expressions continue
def _extract_cols(expr: Any) -> Iterable[Any]: seen.add(ident)
if isinstance(expr, ColumnElement): out.append(ob)
yield expr
for ch in getattr(expr, "get_children", lambda: [])():
yield from _extract_cols(ch)
elif hasattr(expr, "clauses"):
for ch in expr.clauses:
yield from _extract_cols(ch)
for flt in filters or []:
for col in _extract_cols(flt):
for _path, rel_attr, target_alias in join_paths:
if _belongs_to_alias(col, target_alias):
hops.add(rel_attr.key)
return hops
def _paths_from_fields(req_fields: list[str]) -> set[str]:
out: set[str] = set()
for f in req_fields:
if "." in f:
parent = f.split(".", 1)[0]
if parent:
out.add(parent)
return out return out
def _is_truthy(val): def _is_truthy(val):
@ -133,24 +111,25 @@ class CRUDService(Generic[T]):
self.polymorphic = polymorphic self.polymorphic = polymorphic
self.supports_soft_delete = hasattr(model, 'is_deleted') self.supports_soft_delete = hasattr(model, 'is_deleted')
# Derive engine WITHOUT leaking a session/connection self._backend: Optional[BackendInfo] = backend
bind = getattr(session_factory, "bind", None)
if bind is None:
tmp_sess = session_factory()
try:
bind = tmp_sess.get_bind()
finally:
try:
tmp_sess.close()
except Exception:
pass
eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind)
self.backend = backend or make_backend_info(eng)
@property @property
def session(self) -> Session: def session(self) -> Session:
return self._session_factory() """Always return the Flask-scoped Session if available; otherwise the provided factory."""
try:
sess = current_app.extensions["crudkit"]["Session"]
return sess
except Exception:
return self._session_factory()
@property
def backend(self) -> BackendInfo:
"""Resolve backend info lazily against the active session's engine."""
if self._backend is None:
bind = self.session.get_bind()
eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind)
self._backend = make_backend_info(eng)
return self._backend
def get_query(self): def get_query(self):
if self.polymorphic: if self.polymorphic:
@ -237,85 +216,69 @@ class CRUDService(Generic[T]):
include_total: bool = True, include_total: bool = True,
) -> "SeekWindow[T]": ) -> "SeekWindow[T]":
""" """
Transport-agnostic keyset pagination that preserves all the goodies from `list()`: Keyset pagination with relationship-safe filtering/sorting.
- filters, includes, joins, field projection, eager loading, soft-delete Always JOIN all CRUDSpec-discovered paths first; then apply filters, sort, seek.
- deterministic ordering (user sort + PK tiebreakers)
- forward/backward seek via `key` and `backward`
Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total.
""" """
self._debug_bind("seek_window")
session = self.session session = self.session
query, root_alias = self.get_query() query, root_alias = self.get_query()
# Normalize requested fields and compile projection (may skip later to avoid conflicts) # Requested fields → projection + optional loaders
fields = _normalize_fields_param(params) fields = _normalize_fields_param(params)
expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], []) expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], [])
spec = CRUDSpec(self.model, params or {}, root_alias) spec = CRUDSpec(self.model, params or {}, root_alias)
# Parse all inputs so join_paths are populated
filters = spec.parse_filters() filters = spec.parse_filters()
order_by = spec.parse_sort() order_by = spec.parse_sort()
root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields()
spec.parse_includes()
join_paths = tuple(spec.get_join_paths())
# Field parsing for root load_only fallback # Soft delete
root_fields, rel_field_names, root_field_names = spec.parse_fields()
# Soft delete filter
query = self._apply_not_deleted(query, root_alias, params) query = self._apply_not_deleted(query, root_alias, params)
# Apply filters first # Root column projection (load_only)
if filters:
query = query.filter(*filters)
# Includes + join paths (dotted fields etc.)
spec.parse_includes()
join_paths = tuple(spec.get_join_paths()) # iterable of (path, relationship_attr, target_alias)
# Relationship names required by ORDER BY / WHERE
sql_hops: set[str] = _paths_needed_for_sql(order_by, filters, join_paths)
# Also include relationships mentioned directly in the sort spec
sql_hops |= _hops_from_sort(params)
# First-hop relationship names implied by dotted projection fields
proj_hops: set[str] = _paths_from_fields(fields)
# Root column projection
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols: if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols)) query = query.options(Load(root_alias).load_only(*only_cols))
# Relationship handling per path (avoid loader strategy conflicts) # Detect first hops that have deeper, nested tails requested (e.g. "contact.supervisor")
used_contains_eager = False nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 }
joined_names: set[str] = set()
for _path, relationship_attr, target_alias in join_paths: used_contains_eager = False
rel_attr = cast(InstrumentedAttribute, relationship_attr) for base_alias, rel_attr, target_alias in join_paths:
name = relationship_attr.key is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False))
if name in sql_hops: is_nested_firsthop = rel_attr.key in nested_first_hops
# Needed for WHERE/ORDER BY: join + hydrate from that join
if is_collection or is_nested_firsthop:
# Use selectinload so deeper hops can chain cleanly (and to avoid
# contains_eager/loader conflicts on nested paths).
opt = selectinload(rel_attr)
# Narrow columns for collections if we know child scalar names
if is_collection:
child_names = (collection_field_names or {}).get(rel_attr.key, [])
if child_names:
target_cls = rel_attr.property.mapper.class_
cols = [getattr(target_cls, n, None) for n in child_names]
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
if cols:
opt = opt.load_only(*cols)
query = query.options(opt)
else:
# Simple first-hop scalar rel with no deeper tails: safe to join + contains_eager
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias)) query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True used_contains_eager = True
joined_names.add(name)
elif name in proj_hops:
# Display-only: bulk-load efficiently, no join
query = query.options(selectinload(rel_attr))
joined_names.add(name)
# Force-join any SQL-needed relationships that weren't in join_paths # Filters AFTER joins → no cartesian products
missing_sql = sql_hops - joined_names if filters:
for name in missing_sql: query = query.filter(*filters)
rel_attr = cast(InstrumentedAttribute, getattr(root_alias, name))
query = query.join(rel_attr, isouter=True)
query = query.options(contains_eager(rel_attr))
used_contains_eager = True
joined_names.add(name)
# Apply projection loader options only if they won't conflict with contains_eager # Order spec (with PK tie-breakers for stability)
if proj_opts and not used_contains_eager: order_spec = self._extract_order_spec(root_alias, order_by)
query = query.options(*proj_opts)
# Order + limit
order_spec = self._extract_order_spec(root_alias, order_by) # SA 2.x helper
limit, _ = spec.parse_pagination() limit, _ = spec.parse_pagination()
if limit is None: if limit is None:
effective_limit = 50 effective_limit = 50
@ -324,13 +287,13 @@ class CRUDService(Generic[T]):
else: else:
effective_limit = limit effective_limit = limit
# Keyset predicate # Seek predicate from cursor key (if any)
if key: if key:
pred = self._key_predicate(order_spec, key, backward) pred = self._key_predicate(order_spec, key, backward)
if pred is not None: if pred is not None:
query = query.filter(pred) query = query.filter(pred)
# Apply ordering. For backward, invert SQL order then reverse in-memory for display. # Apply ORDER and LIMIT. Backward is SQL-inverted + reverse in-memory.
if not backward: if not backward:
clauses = [(c.desc() if is_desc else c.asc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)] clauses = [(c.desc() if is_desc else c.asc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)]
query = query.order_by(*clauses) query = query.order_by(*clauses)
@ -344,9 +307,9 @@ class CRUDService(Generic[T]):
query = query.limit(effective_limit) query = query.limit(effective_limit)
items = list(reversed(query.all())) items = list(reversed(query.all()))
# Tag projection so your renderer knows what fields were requested # Projection meta tag for renderers
if fields: if fields:
proj = list(dict.fromkeys(fields)) # dedupe, preserve order proj = list(dict.fromkeys(fields))
if "id" not in proj and hasattr(self.model, "id"): if "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id") proj.insert(0, "id")
else: else:
@ -369,12 +332,9 @@ class CRUDService(Generic[T]):
except Exception: except Exception:
pass pass
# Boundary keys for cursor encoding in the API layer # Cursor key pluck: support related columns we hydrated via contains_eager
# When ORDER BY includes related columns (e.g., owner.first_name),
# pluck values from the related object we hydrated with contains_eager/selectinload.
def _pluck_key_from_obj(obj: Any) -> list[Any]: def _pluck_key_from_obj(obj: Any) -> list[Any]:
vals: list[Any] = [] vals: list[Any] = []
# Build a quick map: selectable -> relationship name
alias_to_rel: dict[Any, str] = {} alias_to_rel: dict[Any, str] = {}
for _p, relationship_attr, target_alias in join_paths: for _p, relationship_attr, target_alias in join_paths:
sel = getattr(target_alias, "selectable", None) sel = getattr(target_alias, "selectable", None)
@ -382,20 +342,17 @@ class CRUDService(Generic[T]):
alias_to_rel[sel] = relationship_attr.key alias_to_rel[sel] = relationship_attr.key
for col in order_spec.cols: for col in order_spec.cols:
key = getattr(col, "key", None) or getattr(col, "name", None) keyname = getattr(col, "key", None) or getattr(col, "name", None)
# Try root attribute first if keyname and hasattr(obj, keyname):
if key and hasattr(obj, key): vals.append(getattr(obj, keyname))
vals.append(getattr(obj, key))
continue continue
# Try relationship hop by matching the column's table/selectable
table = getattr(col, "table", None) table = getattr(col, "table", None)
relname = alias_to_rel.get(table) relname = alias_to_rel.get(table)
if relname and key: if relname and keyname:
relobj = getattr(obj, relname, None) relobj = getattr(obj, relname, None)
if relobj is not None and hasattr(relobj, key): if relobj is not None and hasattr(relobj, keyname):
vals.append(getattr(relobj, key)) vals.append(getattr(relobj, keyname))
continue continue
# Give up: unsupported expression for cursor purposes
raise ValueError("unpluckable") raise ValueError("unpluckable")
return vals return vals
@ -403,33 +360,26 @@ class CRUDService(Generic[T]):
first_key = _pluck_key_from_obj(items[0]) if items else None first_key = _pluck_key_from_obj(items[0]) if items else None
last_key = _pluck_key_from_obj(items[-1]) if items else None last_key = _pluck_key_from_obj(items[-1]) if items else None
except Exception: except Exception:
# If we can't derive cursor keys (e.g., ORDER BY expression/aggregate),
# disable cursors for this response rather than exploding.
first_key = None first_key = None
last_key = None last_key = None
# Optional total thats safe under JOINs (COUNT DISTINCT ids) # Count DISTINCT ids with mirrored joins
total = None total = None
if include_total: if include_total:
base = session.query(getattr(root_alias, "id")) base = session.query(getattr(root_alias, "id"))
base = self._apply_not_deleted(base, root_alias, params) base = self._apply_not_deleted(base, root_alias, params)
# same joins as above for correctness
for base_alias, rel_attr, target_alias in join_paths:
# do not join collections for COUNT mirror
if not bool(getattr(getattr(rel_attr, "property", None), "uselist", False)):
base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
if filters: if filters:
base = base.filter(*filters) base = base.filter(*filters)
# Mirror join structure for any SQL-needed relationships
for _path, relationship_attr, target_alias in join_paths:
if relationship_attr.key in sql_hops:
rel_attr = cast(InstrumentedAttribute, relationship_attr)
base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
# Also mirror any forced joins
for name in (sql_hops - {ra.key for _p, ra, _a in join_paths}):
rel_attr = cast(InstrumentedAttribute, getattr(root_alias, name))
base = base.join(rel_attr, isouter=True)
total = session.query(func.count()).select_from( total = session.query(func.count()).select_from(
base.order_by(None).distinct().subquery() base.order_by(None).distinct().subquery()
).scalar() or 0 ).scalar() or 0
window_limit_for_body = 0 if effective_limit is None and limit == 0 else (effective_limit or 50) window_limit_for_body = 0 if effective_limit is None and (limit == 0) else (effective_limit or 50)
if log.isEnabledFor(logging.DEBUG): if log.isEnabledFor(logging.DEBUG):
log.debug("QUERY: %s", str(query)) log.debug("QUERY: %s", str(query))
@ -458,81 +408,93 @@ class CRUDService(Generic[T]):
""" """
Ensure deterministic ordering by appending PK columns as tiebreakers. Ensure deterministic ordering by appending PK columns as tiebreakers.
If no order is provided, fall back to default primary-key order. If no order is provided, fall back to default primary-key order.
Never duplicates columns in ORDER BY (SQL Server requires uniqueness).
""" """
order_by = list(given_order_by or []) order_by = list(given_order_by or [])
if not order_by: if not order_by:
return self._default_order_by(root_alias) return _dedupe_order_by(self._default_order_by(root_alias))
order_by = _dedupe_order_by(order_by)
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
pk_cols = [] present = {_order_identity(_unwrap_ob(ob)[0]) for ob in order_by}
for col in mapper.primary_key:
try:
pk_cols.append(getattr(root_alias, col.key))
except AttributeError:
pk_cols.append(col)
return [*order_by, *pk_cols] for pk in mapper.primary_key:
try:
pk_col = getattr(root_alias, pk.key)
except AttributeError:
pk_col = pk
if _order_identity(pk_col) not in present:
order_by.append(pk_col.asc())
present.add(_order_identity(pk_col))
return order_by
def get(self, id: int, params=None) -> T | None: def get(self, id: int, params=None) -> T | None:
"""Fetch a single row by id with conflict-free eager loading and clean projection.""" """
self._debug_bind("get") Fetch a single row by id with conflict-free eager loading and clean projection.
Always JOIN any paths that CRUDSpec resolved for filters/fields/includes so
related-column filters never create cartesian products.
"""
query, root_alias = self.get_query() query, root_alias = self.get_query()
# Defaults so we can build a projection even if params is None # Defaults so we can build a projection even if params is None
req_fields: list[str] = _normalize_fields_param(params)
root_fields: list[Any] = [] root_fields: list[Any] = []
root_field_names: dict[str, str] = {} root_field_names: dict[str, str] = {}
rel_field_names: dict[tuple[str, ...], list[str]] = {} rel_field_names: dict[tuple[str, ...], list[str]] = {}
req_fields: list[str] = _normalize_fields_param(params)
# Soft-delete guard # Soft-delete guard first
query = self._apply_not_deleted(query, root_alias, params) query = self._apply_not_deleted(query, root_alias, params)
spec = CRUDSpec(self.model, params or {}, root_alias) spec = CRUDSpec(self.model, params or {}, root_alias)
# Optional extra filters (in addition to id); keep parity with list() # Parse everything so CRUDSpec records any join paths it needed to resolve
filters = spec.parse_filters() filters = spec.parse_filters()
if filters: # no ORDER BY for get()
query = query.filter(*filters) if params:
root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields()
# Always filter by id
query = query.filter(getattr(root_alias, "id") == id)
# Includes + join paths we may need
spec.parse_includes() spec.parse_includes()
join_paths = tuple(spec.get_join_paths()) join_paths = tuple(spec.get_join_paths())
# Field parsing to enable root load_only # Root-column projection (load_only)
if params:
root_fields, rel_field_names, root_field_names = spec.parse_fields()
# Decide which relationship paths are needed for SQL vs display-only
# For get(), there is no ORDER BY; only filters might force SQL use.
sql_hops = _paths_needed_for_sql(order_by=None, filters=filters, join_paths=join_paths)
proj_hops = _paths_from_fields(req_fields)
# Root column projection
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols: if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols)) query = query.options(Load(root_alias).load_only(*only_cols))
# Relationship handling per path: avoid loader strategy conflicts nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 }
used_contains_eager = False used_contains_eager = False
for _path, relationship_attr, target_alias in join_paths: for base_alias, rel_attr, target_alias in join_paths:
rel_attr = cast(InstrumentedAttribute, relationship_attr) is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False))
name = relationship_attr.key is_nested_firsthop = rel_attr.key in nested_first_hops
if name in sql_hops:
# Needed in WHERE: join + hydrate from the join if is_collection or is_nested_firsthop:
opt = selectinload(rel_attr)
if is_collection:
child_names = (collection_field_names or {}).get(rel_attr.key, [])
if child_names:
target_cls = rel_attr.property.mapper.class_
cols = [getattr(target_cls, n, None) for n in child_names]
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
if cols:
opt = opt.load_only(*cols)
query = query.options(opt)
else:
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias)) query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True used_contains_eager = True
elif name in proj_hops:
# Display-only: bulk-load efficiently # Apply filters (joins are in place → no cartesian products)
query = query.options(selectinload(rel_attr)) if filters:
else: query = query.filter(*filters)
pass
# And the id filter
query = query.filter(getattr(root_alias, "id") == id)
# Projection loader options compiled from requested fields. # Projection loader options compiled from requested fields.
# Skip if we used contains_eager to avoid strategy conflicts. # Skip if we used contains_eager to avoid loader-strategy conflicts.
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
if proj_opts and not used_contains_eager: if proj_opts and not used_contains_eager:
query = query.options(*proj_opts) query = query.options(*proj_opts)
@ -541,7 +503,7 @@ class CRUDService(Generic[T]):
# Emit exactly what the client requested (plus id), or a reasonable fallback # Emit exactly what the client requested (plus id), or a reasonable fallback
if req_fields: if req_fields:
proj = list(dict.fromkeys(req_fields)) # dedupe, preserve order proj = list(dict.fromkeys(req_fields))
if "id" not in proj and hasattr(self.model, "id"): if "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id") proj.insert(0, "id")
else: else:
@ -569,17 +531,20 @@ class CRUDService(Generic[T]):
return obj or None return obj or None
def list(self, params=None) -> list[T]: def list(self, params=None) -> list[T]:
"""Offset/limit listing with smart relationship loading and clean projection.""" """
self._debug_bind("list") Offset/limit listing with relationship-safe filtering.
We always JOIN every CRUDSpec-discovered path before applying filters/sorts.
"""
query, root_alias = self.get_query() query, root_alias = self.get_query()
# Defaults so we can reference them later even if params is None # Defaults so we can reference them later even if params is None
req_fields: list[str] = _normalize_fields_param(params)
root_fields: list[Any] = [] root_fields: list[Any] = []
root_field_names: dict[str, str] = {} root_field_names: dict[str, str] = {}
rel_field_names: dict[tuple[str, ...], list[str]] = {} rel_field_names: dict[tuple[str, ...], list[str]] = {}
req_fields: list[str] = _normalize_fields_param(params)
if params: if params:
# Soft delete
query = self._apply_not_deleted(query, root_alias, params) query = self._apply_not_deleted(query, root_alias, params)
spec = CRUDSpec(self.model, params or {}, root_alias) spec = CRUDSpec(self.model, params or {}, root_alias)
@ -587,84 +552,75 @@ class CRUDService(Generic[T]):
order_by = spec.parse_sort() order_by = spec.parse_sort()
limit, offset = spec.parse_pagination() limit, offset = spec.parse_pagination()
# Includes + join paths we might need # Includes / fields (populates join_paths)
root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields()
spec.parse_includes() spec.parse_includes()
join_paths = tuple(spec.get_join_paths()) join_paths = tuple(spec.get_join_paths())
# Field parsing for load_only on root columns # Root column projection (load_only)
root_fields, rel_field_names, root_field_names = spec.parse_fields()
if filters:
query = query.filter(*filters)
# Determine which relationship paths are needed for SQL vs display-only
sql_hops = _paths_needed_for_sql(order_by, filters, join_paths)
sql_hops |= _hops_from_sort(params) # ensure sort-driven joins exist
proj_hops = _paths_from_fields(req_fields)
# Root column projection
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols: if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols)) query = query.options(Load(root_alias).load_only(*only_cols))
# Relationship handling per path nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 }
used_contains_eager = False
joined_names: set[str] = set()
for _path, relationship_attr, target_alias in join_paths: used_contains_eager = False
rel_attr = cast(InstrumentedAttribute, relationship_attr) for _base_alias, rel_attr, target_alias in join_paths:
name = relationship_attr.key is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False))
if name in sql_hops: is_nested_firsthop = rel_attr.key in nested_first_hops
# Needed for WHERE/ORDER BY: join + hydrate from the join
if is_collection or is_nested_firsthop:
opt = selectinload(rel_attr)
if is_collection:
child_names = (collection_field_names or {}).get(rel_attr.key, [])
if child_names:
target_cls = rel_attr.property.mapper.class_
cols = [getattr(target_cls, n, None) for n in child_names]
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
if cols:
opt = opt.load_only(*cols)
query = query.options(opt)
else:
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias)) query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True used_contains_eager = True
joined_names.add(name)
elif name in proj_hops:
# Display-only: no join, bulk-load efficiently
query = query.options(selectinload(rel_attr))
joined_names.add(name)
# Force-join any SQL-needed relationships that weren't in join_paths # Filters AFTER joins → no cartesian products
missing_sql = sql_hops - joined_names if filters:
for name in missing_sql: query = query.filter(*filters)
rel_attr = cast(InstrumentedAttribute, getattr(root_alias, name))
query = query.join(rel_attr, isouter=True)
query = query.options(contains_eager(rel_attr))
used_contains_eager = True
joined_names.add(name)
# MSSQL requires ORDER BY when OFFSET is used (SQLA uses OFFSET for limit/offset) # MSSQL requires ORDER BY when OFFSET is used; ensure stable PK tie-breakers
paginating = (limit is not None) or (offset is not None and offset != 0) paginating = (limit is not None) or (offset is not None and offset != 0)
if paginating and not order_by and self.backend.requires_order_by_for_offset: if paginating and not order_by and self.backend.requires_order_by_for_offset:
order_by = self._default_order_by(root_alias) order_by = self._default_order_by(root_alias)
order_by = _dedupe_order_by(order_by)
if order_by: if order_by:
query = query.order_by(*order_by) query = query.order_by(*order_by)
# Only apply offset/limit when not None and not zero # Offset/limit
if offset is not None and offset != 0: if offset is not None and offset != 0:
query = query.offset(offset) query = query.offset(offset)
if limit is not None and limit > 0: if limit is not None and limit > 0:
query = query.limit(limit) query = query.limit(limit)
# Projection loader options compiled from requested fields. # Projection loaders only if we didnt use contains_eager
# Skip if we used contains_eager to avoid loader-strategy conflicts.
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
if proj_opts and not used_contains_eager: if proj_opts and not used_contains_eager:
query = query.options(*proj_opts) query = query.options(*proj_opts)
else: else:
# No params means no filters/sorts/limits; still honor projection loaders if any # No params; still honor projection loaders if any
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
if proj_opts: if proj_opts:
query = query.options(*proj_opts) query = query.options(*proj_opts)
rows = query.all() rows = query.all()
# Emit exactly what the client requested (plus id), or a reasonable fallback # Build projection meta for renderers
if req_fields: if req_fields:
proj = list(dict.fromkeys(req_fields)) # dedupe while preserving order proj = list(dict.fromkeys(req_fields))
if "id" not in proj and hasattr(self.model, "id"): if "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id") proj.insert(0, "id")
else: else:
@ -692,7 +648,6 @@ class CRUDService(Generic[T]):
return rows return rows
def create(self, data: dict, actor=None) -> T: def create(self, data: dict, actor=None) -> T:
session = self.session session = self.session
obj = self.model(**data) obj = self.model(**data)

View file

@ -1,4 +1,4 @@
from typing import List, Tuple, Set, Dict, Optional from typing import List, Tuple, Set, Dict, Optional, Iterable
from sqlalchemy import asc, desc from sqlalchemy import asc, desc
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import aliased, selectinload from sqlalchemy.orm import aliased, selectinload
@ -20,10 +20,14 @@ class CRUDSpec:
self.params = params self.params = params
self.root_alias = root_alias self.root_alias = root_alias
self.eager_paths: Set[Tuple[str, ...]] = set() self.eager_paths: Set[Tuple[str, ...]] = set()
# (parent_alias. relationship_attr, alias_for_target)
self.join_paths: List[Tuple[object, InstrumentedAttribute, object]] = [] self.join_paths: List[Tuple[object, InstrumentedAttribute, object]] = []
self.alias_map: Dict[Tuple[str, ...], object] = {} self.alias_map: Dict[Tuple[str, ...], object] = {}
self._root_fields: List[InstrumentedAttribute] = [] self._root_fields: List[InstrumentedAttribute] = []
self._rel_field_names: Dict[Tuple[str, ...], object] = {} # dotted non-collection fields (MANYTOONE etc)
self._rel_field_names: Dict[Tuple[str, ...], List[str]] = {}
# dotted collection fields (ONETOMANY)
self._collection_field_names: Dict[str, List[str]] = {}
self.include_paths: Set[Tuple[str, ...]] = set() self.include_paths: Set[Tuple[str, ...]] = set()
def _resolve_column(self, path: str): def _resolve_column(self, path: str):
@ -117,11 +121,12 @@ class CRUDSpec:
Parse ?fields=colA,colB,rel1.colC,rel1.rel2.colD Parse ?fields=colA,colB,rel1.colC,rel1.rel2.colD
- Root fields become InstrumentedAttributes bound to root_alias. - Root fields become InstrumentedAttributes bound to root_alias.
- Related fields store attribute NAMES; we'll resolve them on the target class when building loader options. - Related fields store attribute NAMES; we'll resolve them on the target class when building loader options.
Returns (root_fields, rel_field_names). - Collection (uselist=True) relationships record child names by relationship key.
Returns (root_fields, rel_field_names, root_field_names, collection_field_names_by_rel).
""" """
raw = self.params.get('fields') raw = self.params.get('fields')
if not raw: if not raw:
return [], {}, {} return [], {}, {}, {}
if isinstance(raw, list): if isinstance(raw, list):
tokens = [] tokens = []
@ -133,14 +138,36 @@ class CRUDSpec:
root_fields: List[InstrumentedAttribute] = [] root_fields: List[InstrumentedAttribute] = []
root_field_names: list[str] = [] root_field_names: list[str] = []
rel_field_names: Dict[Tuple[str, ...], List[str]] = {} rel_field_names: Dict[Tuple[str, ...], List[str]] = {}
collection_field_names: Dict[str, List[str]] = {}
for token in tokens: for token in tokens:
col, join_path = self._resolve_column(token) col, join_path = self._resolve_column(token)
if not col: if not col:
continue continue
if join_path: if join_path:
rel_field_names.setdefault(join_path, []).append(col.key) # rel_field_names.setdefault(join_path, []).append(col.key)
self.eager_paths.add(join_path) # self.eager_paths.add(join_path)
try:
cur_cls = self.model
names = list(join_path)
last_name = names[-1]
for nm in names:
rel_attr = getattr(cur_cls, nm)
cur_cls = rel_attr.property.mapper.class_
is_collection = bool(getattr(getattr(self.model, last_name), "property", None) and getattr(getattr(self.model, last_name).property, "uselist", False))
except Exception:
# Fallback: inspect the InstrumentedAttribute we recorded on join_paths
is_collection = False
for _pa, rel_attr, _al in self.join_paths:
if rel_attr.key == (join_path[-1] if join_path else ""):
is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False))
break
if is_collection:
collection_field_names.setdefault(join_path[-1], []).append(col.key)
else:
rel_field_names.setdefault(join_path, []).append(col.key)
self.eager_paths.add(join_path)
else: else:
root_fields.append(col) root_fields.append(col)
root_field_names.append(getattr(col, "key", token)) root_field_names.append(getattr(col, "key", token))
@ -153,7 +180,11 @@ class CRUDSpec:
self._root_fields = root_fields self._root_fields = root_fields
self._rel_field_names = rel_field_names self._rel_field_names = rel_field_names
return root_fields, rel_field_names, root_field_names # return root_fields, rel_field_names, root_field_names
for r, names in collection_field_names.items():
seen3 = set()
collection_field_names[r] = [n for n in names if not (n in seen3 or seen3.add(n))]
return root_field_names, rel_field_names, root_field_names, collection_field_names
def get_eager_loads(self, root_alias, *, fields_map=None): def get_eager_loads(self, root_alias, *, fields_map=None):
loads = [] loads = []

View file

@ -26,32 +26,6 @@ def init_app(app: Flask, *, runtime: CRUDKitRuntime | None = None, config: type[
try: try:
bound_engine = getattr(SessionFactory, "bind", None) or getattr(SessionFactory, "kw", {}).get("bind") or engine bound_engine = getattr(SessionFactory, "bind", None) or getattr(SessionFactory, "kw", {}).get("bind") or engine
pool = bound_engine.pool pool = bound_engine.pool
from sqlalchemy import event
@event.listens_for(pool, "checkout")
def _on_checkout(dbapi_conn, conn_record, conn_proxy):
sz = pool.size()
chk = pool.checkedout()
try:
conns_in_pool = pool.checkedin()
except Exception:
conns_in_pool = "?"
print(f"POOL CHECKOUT: Pool size: {sz} Connections in pool: {conns_in_pool} "
f"Current Overflow: {pool.overflow()} Current Checked out connections: {chk} "
f"engine id= {id(bound_engine)}")
@event.listens_for(pool, "checkin")
def _on_checkin(dbapi_conn, conn_record):
sz = pool.size()
chk = pool.checkedout()
try:
conns_in_pool = pool.checkedin()
except Exception:
conns_in_pool = "?"
print(f"POOL CHECKIN: Pool size: {sz} Connections in pool: {conns_in_pool} "
f"Current Overflow: {pool.overflow()} Current Checked out connections: {chk} "
f"engine id= {id(bound_engine)}")
except Exception as e: except Exception as e:
print(f"[crudkit.init_app] Failed to attach pool listeners: {e}") print(f"[crudkit.init_app] Failed to attach pool listeners: {e}")

View file

@ -1153,15 +1153,15 @@ def render_form(
field["wrap"] = _sanitize_attrs(field["wrap"]) field["wrap"] = _sanitize_attrs(field["wrap"])
fields.append(field) fields.append(field)
if submit_attrs: if submit_attrs:
submit_attrs = _sanitize_attrs(submit_attrs) submit_attrs = _sanitize_attrs(submit_attrs)
common_ctx = {"values": values_map, "instance": instance, "model_cls": model_cls, "session": session} common_ctx = {"values": values_map, "instance": instance, "model_cls": model_cls, "session": session}
for f in fields: for f in fields:
if f.get("type") == "template": if f.get("type") == "template":
base = dict(common_ctx) base = dict(common_ctx)
base.update(f.get("template_ctx") or {}) base.update(f.get("template_ctx") or {})
f["template_ctx"] = base f["template_ctx"] = base
for f in fields: for f in fields:
# existing FK label resolution # existing FK label resolution