CRUDService overhaul.

This commit is contained in:
Yaro Kasear 2025-09-24 14:19:01 -05:00
parent a0ee1caeb7
commit 2a9fb389d7

View file

@ -1,10 +1,13 @@
from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
from sqlalchemy import and_, func, inspect, or_, text
from sqlalchemy import and_, func, inspect, or_, text, UnaryExpression
from sqlalchemy.engine import Engine, Connection
from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, RelationshipProperty
from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager, selectinload
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import ColumnElement
from crudkit.core.base import Version
from crudkit.core.spec import CRUDSpec
@ -15,13 +18,6 @@ from crudkit.projection import compile_projection
import logging
log = logging.getLogger("crudkit.service")
def _is_rel(model_cls, name: str) -> bool:
try:
prop = model_cls.__mapper__.relationships.get(name)
return isinstance(prop, RelationshipProperty)
except Exception:
return False
@runtime_checkable
class _HasID(Protocol):
id: int
@ -44,9 +40,65 @@ class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol):
T = TypeVar("T", bound=_CRUDModelProto)
def _belongs_to_alias(col: Any, alias: Any) -> bool:
# Try to detect if a column/expression ultimately comes from this alias.
# Works for most ORM columns; complex expressions may need more.
t = getattr(col, "table", None)
selectable = getattr(alias, "selectable", None)
return t is not None and selectable is not None and t is selectable
def _paths_needed_for_sql(order_by: Iterable[Any], filters: Iterable[Any], join_paths: tuple) -> set[tuple[str, ...]]:
paths: set[tuple[str, ...]] = set()
# Sort columns
for ob in order_by or []:
col = getattr(ob, "element", ob) # unwrap UnaryExpression
for path, _rel_attr, target_alias in join_paths:
if _belongs_to_alias(col, target_alias):
paths.add(tuple(path))
# Filter columns (best-effort)
# Walk simple binary expressions
def _extract_cols(expr: Any) -> Iterable[Any]:
if isinstance(expr, ColumnElement):
yield expr
for ch in getattr(expr, "get_children", lambda: [])():
yield from _extract_cols(ch)
elif hasattr(expr, "clauses"):
for ch in expr.clauses:
yield from _extract_cols(ch)
for flt in filters or []:
for col in _extract_cols(flt):
for path, _rel_attr, target_alias in join_paths:
if _belongs_to_alias(col, target_alias):
paths.add(tuple[path])
return paths
def _paths_from_fields(req_fields: list[str]) -> set[tuple[str, ...]]:
out: set[tuple[str, ...]] = set()
for f in req_fields:
if "." in f:
parts = tuple(f.split(".")[:-1])
if parts:
out.add(parts)
return out
def _is_truthy(val):
return str(val).lower() in ('1', 'true', 'yes', 'on')
def _normalize_fields_param(params: dict | None) -> list[str]:
if not params:
return []
raw = params.get("fields")
if isinstance(raw, (list, tuple)):
out: list[str] = []
for item in raw:
if isinstance(item, str):
out.extend([p for p in (s.strip() for s in item.split(",")) if p])
return out
if isinstance(raw, str):
return [p for p in (s.strip() for s in raw.split(",")) if p]
return []
class CRUDService(Generic[T]):
def __init__(
self,
@ -86,8 +138,6 @@ class CRUDService(Generic[T]):
Normalize order_by into (cols, desc_flags). Supports plain columns and
col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses.
"""
from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import UnaryExpression
given = self._stable_order_by(root_alias, given_order_by)
@ -109,7 +159,6 @@ class CRUDService(Generic[T]):
cols.append(col)
desc_flags.append(bool(is_desc))
from crudkit.core.types import OrderSpec
return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags))
def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool):
@ -156,49 +205,100 @@ class CRUDService(Generic[T]):
Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total.
"""
session = self.session
fields = list((params or {}).get("fields", []))
expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], [])
query, root_alias = self.get_query()
if proj_opts:
query = query.options(*proj_opts)
# Normalize requested fields and compile projection (may skip later to avoid conflicts)
fields = _normalize_fields_param(params)
expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], [])
spec = CRUDSpec(self.model, params or {}, root_alias)
filters = spec.parse_filters()
order_by = spec.parse_sort()
# Field parsing for root load_only fallback
root_fields, rel_field_names, root_field_names = spec.parse_fields()
# Soft delete filter
# if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")):
# query = query.filter(getattr(root_alias, "is_deleted") == False)
query = self._apply_not_deleted(query, root_alias, params)
# Parse filters first
# Apply filters first
if filters:
query = query.filter(*filters)
# Includes + joins (so relationship fields like brand.name, location.label work)
# Includes + join paths (dotted fields etc.)
spec.parse_includes()
join_paths = tuple(spec.get_join_paths())
join_paths = tuple(spec.get_join_paths()) # iterable of (path, relationship_attr, target_alias)
for _, relationship_attr, target_alias in spec.get_join_paths():
rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias)
query = query.join(target, rel_attr.of_type(target), isouter=True)
# Decide which relationship *names* are required for SQL (filters/sort) vs display-only
def _belongs_to_alias(col: Any, alias: Any) -> bool:
t = getattr(col, "table", None)
selectable = getattr(alias, "selectable", None)
return t is not None and selectable is not None and t is selectable
# Fields/projection: load_only for root columns, eager loads for relationships
# 1) which relationship aliases are referenced by sort/filter
sql_hops: set[str] = set()
for path, relationship_attr, target_alias in join_paths:
# If any ORDER BY column comes from this alias, mark it
for ob in (order_by or []):
col = getattr(ob, "element", ob) # unwrap UnaryExpression
if _belongs_to_alias(col, target_alias):
sql_hops.add(relationship_attr.key)
break
# If any filter expr touches this alias, mark it (best effort)
if relationship_attr.key not in sql_hops:
def _walk_cols(expr: Any):
# Primitive walker for ColumnElement trees
from sqlalchemy.sql.elements import ColumnElement
if isinstance(expr, ColumnElement):
yield expr
for ch in getattr(expr, "get_children", lambda: [])():
yield from _walk_cols(ch)
elif hasattr(expr, "clauses"):
for ch in expr.clauses:
yield from _walk_cols(ch)
for flt in (filters or []):
if any(_belongs_to_alias(c, target_alias) for c in _walk_cols(flt)):
sql_hops.add(relationship_attr.key)
break
# 2) first-hop relationship names implied by dotted projection fields
proj_hops: set[str] = {f.split(".", 1)[0] for f in fields if "." in f}
# Root column projection
from sqlalchemy.orm import Load # local import to match your style
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols))
# Relationship handling per path (avoid loader strategy conflicts)
used_contains_eager = False
for path, relationship_attr, target_alias in join_paths:
rel_attr = cast(InstrumentedAttribute, relationship_attr)
name = relationship_attr.key
if name in sql_hops:
# Needed for WHERE/ORDER BY: join + hydrate from that join
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True
elif name in proj_hops:
# Display-only: bulk-load efficiently, no join
query = query.options(selectinload(rel_attr))
else:
# Not needed
pass
# Apply projection loader options only if they won't conflict with contains_eager
if proj_opts and not used_contains_eager:
query = query.options(*proj_opts)
# Order + limit
order_spec = self._extract_order_spec(root_alias, order_by) # SA2-only helper
order_spec = self._extract_order_spec(root_alias, order_by) # SA 2.x helper
limit, _ = spec.parse_pagination()
if limit is None:
effective_limit = 50
elif limit == 0:
effective_limit = None
effective_limit = None # unlimited
else:
effective_limit = limit
@ -222,16 +322,17 @@ class CRUDService(Generic[T]):
query = query.limit(effective_limit)
items = list(reversed(query.all()))
# Tag projection so your renderer knows what fields were requested
if expanded_fields:
proj = list(expanded_fields)
if fields:
proj = list(dict.fromkeys(fields)) # dedupe, preserve order
if "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
else:
proj = []
if root_field_names:
proj.extend(root_field_names)
if root_fields:
proj.extend(c.key for c in root_fields)
proj.extend(c.key for c in root_fields if hasattr(c, "key"))
for path, names in (rel_field_names or {}).items():
prefix = ".".join(path)
for n in names:
@ -248,7 +349,7 @@ class CRUDService(Generic[T]):
# Boundary keys for cursor encoding in the API layer
first_key = self._pluck_key(items[0], order_spec) if items else None
last_key = self._pluck_key(items[-1], order_spec) if items else None
last_key = self._pluck_key(items[-1], order_spec) if items else None
# Optional total thats safe under JOINs (COUNT DISTINCT ids)
total = None
@ -257,10 +358,11 @@ class CRUDService(Generic[T]):
base = self._apply_not_deleted(base, root_alias, params)
if filters:
base = base.filter(*filters)
for _, relationship_attr, target_alias in join_paths: # reuse
rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias)
base = base.join(target, rel_attr.of_type(target), isouter=True)
# Mirror join structure for any SQL-needed relationships
for path, relationship_attr, target_alias in join_paths:
if relationship_attr.key in sql_hops:
rel_attr = cast(InstrumentedAttribute, relationship_attr)
base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
total = session.query(func.count()).select_from(
base.order_by(None).distinct().subquery()
).scalar() or 0
@ -270,7 +372,6 @@ class CRUDService(Generic[T]):
if log.isEnabledFor(logging.DEBUG):
log.debug("QUERY: %s", str(query))
from crudkit.core.types import SeekWindow # avoid circulars at module top
return SeekWindow(
items=items,
limit=window_limit_for_body,
@ -311,50 +412,81 @@ class CRUDService(Generic[T]):
return [*order_by, *pk_cols]
def get(self, id: int, params=None) -> T | None:
"""Fetch a single row by id with conflict-free eager loading and clean projection."""
query, root_alias = self.get_query()
include_deleted = False
root_fields = []
root_field_names = {}
rel_field_names = {}
# Defaults so we can build a projection even if params is None
root_fields: list[Any] = []
root_field_names: dict[str, str] = {}
rel_field_names: dict[tuple[str, ...], list[str]] = {}
req_fields: list[str] = _normalize_fields_param(params)
# Soft-delete guard
query = self._apply_not_deleted(query, root_alias, params)
spec = CRUDSpec(self.model, params or {}, root_alias)
if params:
if self.supports_soft_delete:
include_deleted = _is_truthy(params.get('include_deleted'))
if self.supports_soft_delete and not include_deleted:
query = query.filter(getattr(root_alias, "is_deleted") == False)
# Optional extra filters (in addition to id); keep parity with list()
filters = spec.parse_filters()
if filters:
query = query.filter(*filters)
# Always filter by id
query = query.filter(getattr(root_alias, "id") == id)
# Includes + join paths we may need
spec.parse_includes()
join_paths = tuple(spec.get_join_paths())
for _, relationship_attr, target_alias in spec.get_join_paths():
rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias)
query = query.join(target, rel_attr.of_type(target), isouter=True)
# Field parsing to enable root load_only
if params:
root_fields, rel_field_names, root_field_names = spec.parse_fields()
req_fields = list((params or {}).get("fields", []))
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
if proj_opts:
query = query.options(*proj_opts)
# Decide which relationship paths are needed for SQL vs display-only
# For get(), there is no ORDER BY; only filters might force SQL use.
sql_paths = _paths_needed_for_sql(order_by=None, filters=filters, join_paths=join_paths)
proj_paths = _paths_from_fields(req_fields)
# Root column projection
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols))
# Relationship handling per path: avoid loader strategy conflicts
used_contains_eager = False
for path, relationship_attr, target_alias in join_paths:
rel_attr = cast(InstrumentedAttribute, relationship_attr)
ptuple = tuple(path)
if ptuple in sql_paths:
# Needed in WHERE: join + hydrate from the join
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True
elif ptuple in proj_paths:
# Display-only: bulk-load efficiently
query = query.options(selectinload(rel_attr))
else:
pass
# Projection loader options compiled from requested fields.
# Skip if we used contains_eager to avoid strategy conflicts.
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
if proj_opts and not used_contains_eager:
query = query.options(*proj_opts)
obj = query.first()
if expanded_fields:
proj = list(expanded_fields)
# Emit exactly what the client requested (plus id), or a reasonable fallback
if req_fields:
proj = list(dict.fromkeys(req_fields)) # dedupe, preserve order
if "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
else:
proj = []
if root_field_names:
proj.extend(root_field_names)
if root_fields:
proj.extend(c.key for c in root_fields)
proj.extend(c.key for c in root_fields if hasattr(c, "key"))
for path, names in (rel_field_names or {}).items():
prefix = ".".join(path)
for n in names:
@ -374,40 +506,60 @@ class CRUDService(Generic[T]):
return obj or None
def list(self, params=None) -> list[T]:
"""Offset/limit listing with smart relationship loading and clean projection."""
query, root_alias = self.get_query()
root_fields = []
root_field_names = {}
rel_field_names = {}
# Defaults so we can reference them later even if params is None
root_fields: list[Any] = []
root_field_names: dict[str, str] = {}
rel_field_names: dict[tuple[str, ...], list[str]] = {}
req_fields: list[str] = _normalize_fields_param(params)
if params:
if self.supports_soft_delete:
include_deleted = _is_truthy(params.get('include_deleted'))
if not include_deleted:
query = query.filter(getattr(root_alias, "is_deleted") == False)
query = self._apply_not_deleted(query, root_alias, params)
spec = CRUDSpec(self.model, params or {}, root_alias)
filters = spec.parse_filters()
order_by = spec.parse_sort()
limit, offset = spec.parse_pagination()
# Includes + join paths we might need
spec.parse_includes()
join_paths = tuple(spec.get_join_paths())
for _, relationship_attr, target_alias in spec.get_join_paths():
rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias)
query = query.join(target, rel_attr.of_type(target), isouter=True)
if params:
root_fields, rel_field_names, root_field_names = spec.parse_fields()
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols))
# Field parsing for load_only on root columns
root_fields, rel_field_names, root_field_names = spec.parse_fields()
if filters:
query = query.filter(*filters)
# MSSQL: requires ORDER BY when using OFFSET (and SQLA will use OFFSET for limit+offset).
# Determine which relationship paths are needed for SQL vs display-only
sql_paths = _paths_needed_for_sql(order_by, filters, join_paths)
proj_paths = _paths_from_fields(req_fields)
# Root column projection
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols))
# Relationship handling per path
used_contains_eager = False
for path, relationship_attr, target_alias in join_paths:
rel_attr = cast(InstrumentedAttribute, relationship_attr)
ptuple = tuple(path)
if ptuple in sql_paths:
# Needed for WHERE/ORDER BY: join + hydrate from the join
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True
elif ptuple in proj_paths:
# Display-only: no join, bulk-load efficiently
query = query.options(selectinload(rel_attr))
else:
# Not needed at all; do nothing
pass
# MSSQL requires ORDER BY when OFFSET is used (SQLA uses OFFSET for limit/offset)
paginating = (limit is not None) or (offset is not None and offset != 0)
if paginating and not order_by and self.backend.requires_order_by_for_offset:
order_by = self._default_order_by(root_alias)
@ -415,27 +567,37 @@ class CRUDService(Generic[T]):
if order_by:
query = query.order_by(*order_by)
# Only apply offset/limit when not None.
# Only apply offset/limit when not None and not zero
if offset is not None and offset != 0:
query = query.offset(offset)
if limit is not None and limit > 0:
query = query.limit(limit)
req_fields = list((params or {}).get("fields", []))
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
if proj_opts:
query = query.options(*proj_opts)
# Projection loader options compiled from requested fields.
# Skip if we used contains_eager to avoid loader-strategy conflicts.
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
if proj_opts and not used_contains_eager:
query = query.options(*proj_opts)
else:
# No params means no filters/sorts/limits; still honor projection loaders if any
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
if proj_opts:
query = query.options(*proj_opts)
rows = query.all()
if expanded_fields:
proj = list(expanded_fields)
# Emit exactly what the client requested (plus id), or a reasonable fallback
if req_fields:
proj = list(dict.fromkeys(req_fields)) # dedupe while preserving order
if "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
else:
proj = []
if root_field_names:
proj.extend(root_field_names)
if root_fields:
proj.extend(c.key for c in root_fields)
proj.extend(c.key for c in root_fields if hasattr(c, "key"))
for path, names in (rel_field_names or {}).items():
prefix = ".".join(path)
for n in names: