CRUDService overhaul.

This commit is contained in:
Yaro Kasear 2025-09-24 14:19:01 -05:00
parent a0ee1caeb7
commit 2a9fb389d7

View file

@ -1,10 +1,13 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
from sqlalchemy import and_, func, inspect, or_, text from sqlalchemy import and_, func, inspect, or_, text, UnaryExpression
from sqlalchemy.engine import Engine, Connection from sqlalchemy.engine import Engine, Connection
from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, RelationshipProperty from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager, selectinload
from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import ColumnElement
from crudkit.core.base import Version from crudkit.core.base import Version
from crudkit.core.spec import CRUDSpec from crudkit.core.spec import CRUDSpec
@ -15,13 +18,6 @@ from crudkit.projection import compile_projection
import logging import logging
log = logging.getLogger("crudkit.service") log = logging.getLogger("crudkit.service")
def _is_rel(model_cls, name: str) -> bool:
try:
prop = model_cls.__mapper__.relationships.get(name)
return isinstance(prop, RelationshipProperty)
except Exception:
return False
@runtime_checkable @runtime_checkable
class _HasID(Protocol): class _HasID(Protocol):
id: int id: int
@ -44,9 +40,65 @@ class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol):
T = TypeVar("T", bound=_CRUDModelProto) T = TypeVar("T", bound=_CRUDModelProto)
def _belongs_to_alias(col: Any, alias: Any) -> bool:
# Try to detect if a column/expression ultimately comes from this alias.
# Works for most ORM columns; complex expressions may need more.
t = getattr(col, "table", None)
selectable = getattr(alias, "selectable", None)
return t is not None and selectable is not None and t is selectable
def _paths_needed_for_sql(order_by: Iterable[Any], filters: Iterable[Any], join_paths: tuple) -> set[tuple[str, ...]]:
paths: set[tuple[str, ...]] = set()
# Sort columns
for ob in order_by or []:
col = getattr(ob, "element", ob) # unwrap UnaryExpression
for path, _rel_attr, target_alias in join_paths:
if _belongs_to_alias(col, target_alias):
paths.add(tuple(path))
# Filter columns (best-effort)
# Walk simple binary expressions
def _extract_cols(expr: Any) -> Iterable[Any]:
if isinstance(expr, ColumnElement):
yield expr
for ch in getattr(expr, "get_children", lambda: [])():
yield from _extract_cols(ch)
elif hasattr(expr, "clauses"):
for ch in expr.clauses:
yield from _extract_cols(ch)
for flt in filters or []:
for col in _extract_cols(flt):
for path, _rel_attr, target_alias in join_paths:
if _belongs_to_alias(col, target_alias):
paths.add(tuple[path])
return paths
def _paths_from_fields(req_fields: list[str]) -> set[tuple[str, ...]]:
out: set[tuple[str, ...]] = set()
for f in req_fields:
if "." in f:
parts = tuple(f.split(".")[:-1])
if parts:
out.add(parts)
return out
def _is_truthy(val): def _is_truthy(val):
return str(val).lower() in ('1', 'true', 'yes', 'on') return str(val).lower() in ('1', 'true', 'yes', 'on')
def _normalize_fields_param(params: dict | None) -> list[str]:
if not params:
return []
raw = params.get("fields")
if isinstance(raw, (list, tuple)):
out: list[str] = []
for item in raw:
if isinstance(item, str):
out.extend([p for p in (s.strip() for s in item.split(",")) if p])
return out
if isinstance(raw, str):
return [p for p in (s.strip() for s in raw.split(",")) if p]
return []
class CRUDService(Generic[T]): class CRUDService(Generic[T]):
def __init__( def __init__(
self, self,
@ -86,8 +138,6 @@ class CRUDService(Generic[T]):
Normalize order_by into (cols, desc_flags). Supports plain columns and Normalize order_by into (cols, desc_flags). Supports plain columns and
col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses. col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses.
""" """
from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import UnaryExpression
given = self._stable_order_by(root_alias, given_order_by) given = self._stable_order_by(root_alias, given_order_by)
@ -109,7 +159,6 @@ class CRUDService(Generic[T]):
cols.append(col) cols.append(col)
desc_flags.append(bool(is_desc)) desc_flags.append(bool(is_desc))
from crudkit.core.types import OrderSpec
return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags)) return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags))
def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool): def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool):
@ -156,49 +205,100 @@ class CRUDService(Generic[T]):
Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total. Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total.
""" """
session = self.session session = self.session
fields = list((params or {}).get("fields", []))
expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], [])
query, root_alias = self.get_query() query, root_alias = self.get_query()
if proj_opts:
query = query.options(*proj_opts) # Normalize requested fields and compile projection (may skip later to avoid conflicts)
fields = _normalize_fields_param(params)
expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], [])
spec = CRUDSpec(self.model, params or {}, root_alias) spec = CRUDSpec(self.model, params or {}, root_alias)
filters = spec.parse_filters() filters = spec.parse_filters()
order_by = spec.parse_sort() order_by = spec.parse_sort()
# Field parsing for root load_only fallback
root_fields, rel_field_names, root_field_names = spec.parse_fields() root_fields, rel_field_names, root_field_names = spec.parse_fields()
# Soft delete filter # Soft delete filter
# if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")):
# query = query.filter(getattr(root_alias, "is_deleted") == False)
query = self._apply_not_deleted(query, root_alias, params) query = self._apply_not_deleted(query, root_alias, params)
# Parse filters first # Apply filters first
if filters: if filters:
query = query.filter(*filters) query = query.filter(*filters)
# Includes + joins (so relationship fields like brand.name, location.label work) # Includes + join paths (dotted fields etc.)
spec.parse_includes() spec.parse_includes()
join_paths = tuple(spec.get_join_paths()) join_paths = tuple(spec.get_join_paths()) # iterable of (path, relationship_attr, target_alias)
for _, relationship_attr, target_alias in spec.get_join_paths(): # Decide which relationship *names* are required for SQL (filters/sort) vs display-only
rel_attr = cast(InstrumentedAttribute, relationship_attr) def _belongs_to_alias(col: Any, alias: Any) -> bool:
target = cast(Any, target_alias) t = getattr(col, "table", None)
query = query.join(target, rel_attr.of_type(target), isouter=True) selectable = getattr(alias, "selectable", None)
return t is not None and selectable is not None and t is selectable
# Fields/projection: load_only for root columns, eager loads for relationships # 1) which relationship aliases are referenced by sort/filter
sql_hops: set[str] = set()
for path, relationship_attr, target_alias in join_paths:
# If any ORDER BY column comes from this alias, mark it
for ob in (order_by or []):
col = getattr(ob, "element", ob) # unwrap UnaryExpression
if _belongs_to_alias(col, target_alias):
sql_hops.add(relationship_attr.key)
break
# If any filter expr touches this alias, mark it (best effort)
if relationship_attr.key not in sql_hops:
def _walk_cols(expr: Any):
# Primitive walker for ColumnElement trees
from sqlalchemy.sql.elements import ColumnElement
if isinstance(expr, ColumnElement):
yield expr
for ch in getattr(expr, "get_children", lambda: [])():
yield from _walk_cols(ch)
elif hasattr(expr, "clauses"):
for ch in expr.clauses:
yield from _walk_cols(ch)
for flt in (filters or []):
if any(_belongs_to_alias(c, target_alias) for c in _walk_cols(flt)):
sql_hops.add(relationship_attr.key)
break
# 2) first-hop relationship names implied by dotted projection fields
proj_hops: set[str] = {f.split(".", 1)[0] for f in fields if "." in f}
# Root column projection
from sqlalchemy.orm import Load # local import to match your style
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols: if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols)) query = query.options(Load(root_alias).load_only(*only_cols))
# Relationship handling per path (avoid loader strategy conflicts)
used_contains_eager = False
for path, relationship_attr, target_alias in join_paths:
rel_attr = cast(InstrumentedAttribute, relationship_attr)
name = relationship_attr.key
if name in sql_hops:
# Needed for WHERE/ORDER BY: join + hydrate from that join
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True
elif name in proj_hops:
# Display-only: bulk-load efficiently, no join
query = query.options(selectinload(rel_attr))
else:
# Not needed
pass
# Apply projection loader options only if they won't conflict with contains_eager
if proj_opts and not used_contains_eager:
query = query.options(*proj_opts)
# Order + limit # Order + limit
order_spec = self._extract_order_spec(root_alias, order_by) # SA2-only helper order_spec = self._extract_order_spec(root_alias, order_by) # SA 2.x helper
limit, _ = spec.parse_pagination() limit, _ = spec.parse_pagination()
if limit is None: if limit is None:
effective_limit = 50 effective_limit = 50
elif limit == 0: elif limit == 0:
effective_limit = None effective_limit = None # unlimited
else: else:
effective_limit = limit effective_limit = limit
@ -222,16 +322,17 @@ class CRUDService(Generic[T]):
query = query.limit(effective_limit) query = query.limit(effective_limit)
items = list(reversed(query.all())) items = list(reversed(query.all()))
# Tag projection so your renderer knows what fields were requested # Tag projection so your renderer knows what fields were requested
if expanded_fields: if fields:
proj = list(expanded_fields) proj = list(dict.fromkeys(fields)) # dedupe, preserve order
if "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
else: else:
proj = [] proj = []
if root_field_names: if root_field_names:
proj.extend(root_field_names) proj.extend(root_field_names)
if root_fields: if root_fields:
proj.extend(c.key for c in root_fields) proj.extend(c.key for c in root_fields if hasattr(c, "key"))
for path, names in (rel_field_names or {}).items(): for path, names in (rel_field_names or {}).items():
prefix = ".".join(path) prefix = ".".join(path)
for n in names: for n in names:
@ -257,10 +358,11 @@ class CRUDService(Generic[T]):
base = self._apply_not_deleted(base, root_alias, params) base = self._apply_not_deleted(base, root_alias, params)
if filters: if filters:
base = base.filter(*filters) base = base.filter(*filters)
for _, relationship_attr, target_alias in join_paths: # reuse # Mirror join structure for any SQL-needed relationships
for path, relationship_attr, target_alias in join_paths:
if relationship_attr.key in sql_hops:
rel_attr = cast(InstrumentedAttribute, relationship_attr) rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias) base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
base = base.join(target, rel_attr.of_type(target), isouter=True)
total = session.query(func.count()).select_from( total = session.query(func.count()).select_from(
base.order_by(None).distinct().subquery() base.order_by(None).distinct().subquery()
).scalar() or 0 ).scalar() or 0
@ -270,7 +372,6 @@ class CRUDService(Generic[T]):
if log.isEnabledFor(logging.DEBUG): if log.isEnabledFor(logging.DEBUG):
log.debug("QUERY: %s", str(query)) log.debug("QUERY: %s", str(query))
from crudkit.core.types import SeekWindow # avoid circulars at module top
return SeekWindow( return SeekWindow(
items=items, items=items,
limit=window_limit_for_body, limit=window_limit_for_body,
@ -311,50 +412,81 @@ class CRUDService(Generic[T]):
return [*order_by, *pk_cols] return [*order_by, *pk_cols]
def get(self, id: int, params=None) -> T | None: def get(self, id: int, params=None) -> T | None:
"""Fetch a single row by id with conflict-free eager loading and clean projection."""
query, root_alias = self.get_query() query, root_alias = self.get_query()
include_deleted = False # Defaults so we can build a projection even if params is None
root_fields = [] root_fields: list[Any] = []
root_field_names = {} root_field_names: dict[str, str] = {}
rel_field_names = {} rel_field_names: dict[tuple[str, ...], list[str]] = {}
req_fields: list[str] = _normalize_fields_param(params)
# Soft-delete guard
query = self._apply_not_deleted(query, root_alias, params)
spec = CRUDSpec(self.model, params or {}, root_alias) spec = CRUDSpec(self.model, params or {}, root_alias)
if params:
if self.supports_soft_delete: # Optional extra filters (in addition to id); keep parity with list()
include_deleted = _is_truthy(params.get('include_deleted')) filters = spec.parse_filters()
if self.supports_soft_delete and not include_deleted: if filters:
query = query.filter(getattr(root_alias, "is_deleted") == False) query = query.filter(*filters)
# Always filter by id
query = query.filter(getattr(root_alias, "id") == id) query = query.filter(getattr(root_alias, "id") == id)
# Includes + join paths we may need
spec.parse_includes() spec.parse_includes()
join_paths = tuple(spec.get_join_paths())
for _, relationship_attr, target_alias in spec.get_join_paths(): # Field parsing to enable root load_only
rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias)
query = query.join(target, rel_attr.of_type(target), isouter=True)
if params: if params:
root_fields, rel_field_names, root_field_names = spec.parse_fields() root_fields, rel_field_names, root_field_names = spec.parse_fields()
req_fields = list((params or {}).get("fields", [])) # Decide which relationship paths are needed for SQL vs display-only
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) # For get(), there is no ORDER BY; only filters might force SQL use.
if proj_opts: sql_paths = _paths_needed_for_sql(order_by=None, filters=filters, join_paths=join_paths)
query = query.options(*proj_opts) proj_paths = _paths_from_fields(req_fields)
# Root column projection
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols: if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols)) query = query.options(Load(root_alias).load_only(*only_cols))
# Relationship handling per path: avoid loader strategy conflicts
used_contains_eager = False
for path, relationship_attr, target_alias in join_paths:
rel_attr = cast(InstrumentedAttribute, relationship_attr)
ptuple = tuple(path)
if ptuple in sql_paths:
# Needed in WHERE: join + hydrate from the join
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True
elif ptuple in proj_paths:
# Display-only: bulk-load efficiently
query = query.options(selectinload(rel_attr))
else:
pass
# Projection loader options compiled from requested fields.
# Skip if we used contains_eager to avoid strategy conflicts.
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
if proj_opts and not used_contains_eager:
query = query.options(*proj_opts)
obj = query.first() obj = query.first()
if expanded_fields: # Emit exactly what the client requested (plus id), or a reasonable fallback
proj = list(expanded_fields) if req_fields:
proj = list(dict.fromkeys(req_fields)) # dedupe, preserve order
if "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
else: else:
proj = [] proj = []
if root_field_names: if root_field_names:
proj.extend(root_field_names) proj.extend(root_field_names)
if root_fields: if root_fields:
proj.extend(c.key for c in root_fields) proj.extend(c.key for c in root_fields if hasattr(c, "key"))
for path, names in (rel_field_names or {}).items(): for path, names in (rel_field_names or {}).items():
prefix = ".".join(path) prefix = ".".join(path)
for n in names: for n in names:
@ -374,40 +506,60 @@ class CRUDService(Generic[T]):
return obj or None return obj or None
def list(self, params=None) -> list[T]: def list(self, params=None) -> list[T]:
"""Offset/limit listing with smart relationship loading and clean projection."""
query, root_alias = self.get_query() query, root_alias = self.get_query()
root_fields = [] # Defaults so we can reference them later even if params is None
root_field_names = {} root_fields: list[Any] = []
rel_field_names = {} root_field_names: dict[str, str] = {}
rel_field_names: dict[tuple[str, ...], list[str]] = {}
req_fields: list[str] = _normalize_fields_param(params)
if params: if params:
if self.supports_soft_delete: query = self._apply_not_deleted(query, root_alias, params)
include_deleted = _is_truthy(params.get('include_deleted'))
if not include_deleted:
query = query.filter(getattr(root_alias, "is_deleted") == False)
spec = CRUDSpec(self.model, params or {}, root_alias) spec = CRUDSpec(self.model, params or {}, root_alias)
filters = spec.parse_filters() filters = spec.parse_filters()
order_by = spec.parse_sort() order_by = spec.parse_sort()
limit, offset = spec.parse_pagination() limit, offset = spec.parse_pagination()
# Includes + join paths we might need
spec.parse_includes() spec.parse_includes()
join_paths = tuple(spec.get_join_paths())
for _, relationship_attr, target_alias in spec.get_join_paths(): # Field parsing for load_only on root columns
rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias)
query = query.join(target, rel_attr.of_type(target), isouter=True)
if params:
root_fields, rel_field_names, root_field_names = spec.parse_fields() root_fields, rel_field_names, root_field_names = spec.parse_fields()
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols))
if filters: if filters:
query = query.filter(*filters) query = query.filter(*filters)
# MSSQL: requires ORDER BY when using OFFSET (and SQLA will use OFFSET for limit+offset). # Determine which relationship paths are needed for SQL vs display-only
sql_paths = _paths_needed_for_sql(order_by, filters, join_paths)
proj_paths = _paths_from_fields(req_fields)
# Root column projection
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols))
# Relationship handling per path
used_contains_eager = False
for path, relationship_attr, target_alias in join_paths:
rel_attr = cast(InstrumentedAttribute, relationship_attr)
ptuple = tuple(path)
if ptuple in sql_paths:
# Needed for WHERE/ORDER BY: join + hydrate from the join
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True
elif ptuple in proj_paths:
# Display-only: no join, bulk-load efficiently
query = query.options(selectinload(rel_attr))
else:
# Not needed at all; do nothing
pass
# MSSQL requires ORDER BY when OFFSET is used (SQLA uses OFFSET for limit/offset)
paginating = (limit is not None) or (offset is not None and offset != 0) paginating = (limit is not None) or (offset is not None and offset != 0)
if paginating and not order_by and self.backend.requires_order_by_for_offset: if paginating and not order_by and self.backend.requires_order_by_for_offset:
order_by = self._default_order_by(root_alias) order_by = self._default_order_by(root_alias)
@ -415,27 +567,37 @@ class CRUDService(Generic[T]):
if order_by: if order_by:
query = query.order_by(*order_by) query = query.order_by(*order_by)
# Only apply offset/limit when not None. # Only apply offset/limit when not None and not zero
if offset is not None and offset != 0: if offset is not None and offset != 0:
query = query.offset(offset) query = query.offset(offset)
if limit is not None and limit > 0: if limit is not None and limit > 0:
query = query.limit(limit) query = query.limit(limit)
req_fields = list((params or {}).get("fields", [])) # Projection loader options compiled from requested fields.
# Skip if we used contains_eager to avoid loader-strategy conflicts.
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
if proj_opts and not used_contains_eager:
query = query.options(*proj_opts)
else:
# No params means no filters/sorts/limits; still honor projection loaders if any
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
if proj_opts: if proj_opts:
query = query.options(*proj_opts) query = query.options(*proj_opts)
rows = query.all() rows = query.all()
if expanded_fields: # Emit exactly what the client requested (plus id), or a reasonable fallback
proj = list(expanded_fields) if req_fields:
proj = list(dict.fromkeys(req_fields)) # dedupe while preserving order
if "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
else: else:
proj = [] proj = []
if root_field_names: if root_field_names:
proj.extend(root_field_names) proj.extend(root_field_names)
if root_fields: if root_fields:
proj.extend(c.key for c in root_fields) proj.extend(c.key for c in root_fields if hasattr(c, "key"))
for path, names in (rel_field_names or {}).items(): for path, names in (rel_field_names or {}).items():
prefix = ".".join(path) prefix = ".".join(path)
for n in names: for n in names: