More fixes.

This commit is contained in:
Yaro Kasear 2025-09-24 15:04:00 -05:00
parent 2a9fb389d7
commit c6165af40e
2 changed files with 155 additions and 81 deletions

View file

@ -2,12 +2,12 @@ from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
from sqlalchemy import and_, func, inspect, or_, text, UnaryExpression
from sqlalchemy import and_, func, inspect, or_, text
from sqlalchemy.engine import Engine, Connection
from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager, selectinload
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import UnaryExpression, ColumnElement
from crudkit.core.base import Version
from crudkit.core.spec import CRUDSpec
@ -40,6 +40,25 @@ class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol):
T = TypeVar("T", bound=_CRUDModelProto)
def _hops_from_sort(params: dict | None) -> set[str]:
"""Extract first-hop relationship names from a sort spec like 'owner.first_name,-brand.name'."""
if not params:
return set()
raw = params.get("sort")
tokens: list[str] = []
if isinstance(raw, str):
tokens = [t.strip() for t in raw.split(",") if t.strip()]
elif isinstance(raw, (list, tuple)):
for item in raw:
if isinstance(item, str):
tokens.extend([t.strip() for t in item.split(",") if t.strip()])
hops: set[str] = set()
for tok in tokens:
tok = tok.lstrip("+-")
if "." in tok:
hops.add(tok.split(".", 1)[0])
return hops
def _belongs_to_alias(col: Any, alias: Any) -> bool:
# Try to detect if a column/expression ultimately comes from this alias.
# Works for most ORM columns; complex expressions may need more.
@ -47,14 +66,15 @@ def _belongs_to_alias(col: Any, alias: Any) -> bool:
selectable = getattr(alias, "selectable", None)
return t is not None and selectable is not None and t is selectable
def _paths_needed_for_sql(order_by: Iterable[Any], filters: Iterable[Any], join_paths: tuple) -> set[tuple[str, ...]]:
def _paths_needed_for_sql(order_by: Iterable[Any], filters: Iterable[Any], join_paths: tuple) -> set[str]:
hops: set[str] = set()
paths: set[tuple[str, ...]] = set()
# Sort columns
for ob in order_by or []:
col = getattr(ob, "element", ob) # unwrap UnaryExpression
for path, _rel_attr, target_alias in join_paths:
for _path, rel_attr, target_alias in join_paths:
if _belongs_to_alias(col, target_alias):
paths.add(tuple(path))
hops.add(rel_attr.key)
# Filter columns (best-effort)
# Walk simple binary expressions
def _extract_cols(expr: Any) -> Iterable[Any]:
@ -68,18 +88,18 @@ def _paths_needed_for_sql(order_by: Iterable[Any], filters: Iterable[Any], join_
for flt in filters or []:
for col in _extract_cols(flt):
for path, _rel_attr, target_alias in join_paths:
for _path, rel_attr, target_alias in join_paths:
if _belongs_to_alias(col, target_alias):
paths.add(tuple[path])
return paths
hops.add(rel_attr.key)
return hops
def _paths_from_fields(req_fields: list[str]) -> set[tuple[str, ...]]:
out: set[tuple[str, ...]] = set()
def _paths_from_fields(req_fields: list[str]) -> set[str]:
out: set[str] = set()
for f in req_fields:
if "." in f:
parts = tuple(f.split(".")[:-1])
if parts:
out.add(parts)
parent = f.split(".", 1)[0]
if parent:
out.add(parent)
return out
def _is_truthy(val):
@ -230,50 +250,24 @@ class CRUDService(Generic[T]):
spec.parse_includes()
join_paths = tuple(spec.get_join_paths()) # iterable of (path, relationship_attr, target_alias)
# Decide which relationship *names* are required for SQL (filters/sort) vs display-only
def _belongs_to_alias(col: Any, alias: Any) -> bool:
t = getattr(col, "table", None)
selectable = getattr(alias, "selectable", None)
return t is not None and selectable is not None and t is selectable
# Relationship names required by ORDER BY / WHERE
sql_hops: set[str] = _paths_needed_for_sql(order_by, filters, join_paths)
# Also include relationships mentioned directly in the sort spec
sql_hops |= _hops_from_sort(params)
# 1) which relationship aliases are referenced by sort/filter
sql_hops: set[str] = set()
for path, relationship_attr, target_alias in join_paths:
# If any ORDER BY column comes from this alias, mark it
for ob in (order_by or []):
col = getattr(ob, "element", ob) # unwrap UnaryExpression
if _belongs_to_alias(col, target_alias):
sql_hops.add(relationship_attr.key)
break
# If any filter expr touches this alias, mark it (best effort)
if relationship_attr.key not in sql_hops:
def _walk_cols(expr: Any):
# Primitive walker for ColumnElement trees
from sqlalchemy.sql.elements import ColumnElement
if isinstance(expr, ColumnElement):
yield expr
for ch in getattr(expr, "get_children", lambda: [])():
yield from _walk_cols(ch)
elif hasattr(expr, "clauses"):
for ch in expr.clauses:
yield from _walk_cols(ch)
for flt in (filters or []):
if any(_belongs_to_alias(c, target_alias) for c in _walk_cols(flt)):
sql_hops.add(relationship_attr.key)
break
# 2) first-hop relationship names implied by dotted projection fields
proj_hops: set[str] = {f.split(".", 1)[0] for f in fields if "." in f}
# First-hop relationship names implied by dotted projection fields
proj_hops: set[str] = _paths_from_fields(fields)
# Root column projection
from sqlalchemy.orm import Load # local import to match your style
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols))
# Relationship handling per path (avoid loader strategy conflicts)
used_contains_eager = False
for path, relationship_attr, target_alias in join_paths:
joined_names: set[str] = set()
for _path, relationship_attr, target_alias in join_paths:
rel_attr = cast(InstrumentedAttribute, relationship_attr)
name = relationship_attr.key
if name in sql_hops:
@ -281,12 +275,20 @@ class CRUDService(Generic[T]):
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True
joined_names.add(name)
elif name in proj_hops:
# Display-only: bulk-load efficiently, no join
query = query.options(selectinload(rel_attr))
else:
# Not needed
pass
joined_names.add(name)
# Force-join any SQL-needed relationships that weren't in join_paths
missing_sql = sql_hops - joined_names
for name in missing_sql:
rel_attr = cast(InstrumentedAttribute, getattr(root_alias, name))
query = query.join(rel_attr, isouter=True)
query = query.options(contains_eager(rel_attr))
used_contains_eager = True
joined_names.add(name)
# Apply projection loader options only if they won't conflict with contains_eager
if proj_opts and not used_contains_eager:
@ -348,8 +350,43 @@ class CRUDService(Generic[T]):
pass
# Boundary keys for cursor encoding in the API layer
first_key = self._pluck_key(items[0], order_spec) if items else None
last_key = self._pluck_key(items[-1], order_spec) if items else None
# When ORDER BY includes related columns (e.g., owner.first_name),
# pluck values from the related object we hydrated with contains_eager/selectinload.
def _pluck_key_from_obj(obj: Any) -> list[Any]:
vals: list[Any] = []
# Build a quick map: selectable -> relationship name
alias_to_rel: dict[Any, str] = {}
for _p, relationship_attr, target_alias in join_paths:
sel = getattr(target_alias, "selectable", None)
if sel is not None:
alias_to_rel[sel] = relationship_attr.key
for col in order_spec.cols:
key = getattr(col, "key", None) or getattr(col, "name", None)
# Try root attribute first
if key and hasattr(obj, key):
vals.append(getattr(obj, key))
continue
# Try relationship hop by matching the column's table/selectable
table = getattr(col, "table", None)
relname = alias_to_rel.get(table)
if relname and key:
relobj = getattr(obj, relname, None)
if relobj is not None and hasattr(relobj, key):
vals.append(getattr(relobj, key))
continue
# Give up: unsupported expression for cursor purposes
raise ValueError("unpluckable")
return vals
try:
first_key = _pluck_key_from_obj(items[0]) if items else None
last_key = _pluck_key_from_obj(items[-1]) if items else None
except Exception:
# If we can't derive cursor keys (e.g., ORDER BY expression/aggregate),
# disable cursors for this response rather than exploding.
first_key = None
last_key = None
# Optional total thats safe under JOINs (COUNT DISTINCT ids)
total = None
@ -359,10 +396,15 @@ class CRUDService(Generic[T]):
if filters:
base = base.filter(*filters)
# Mirror join structure for any SQL-needed relationships
for path, relationship_attr, target_alias in join_paths:
for _path, relationship_attr, target_alias in join_paths:
if relationship_attr.key in sql_hops:
rel_attr = cast(InstrumentedAttribute, relationship_attr)
base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
# Also mirror any forced joins
for name in (sql_hops - {ra.key for _p, ra, _a in join_paths}):
rel_attr = cast(InstrumentedAttribute, getattr(root_alias, name))
base = base.join(rel_attr, isouter=True)
total = session.query(func.count()).select_from(
base.order_by(None).distinct().subquery()
).scalar() or 0
@ -444,8 +486,8 @@ class CRUDService(Generic[T]):
# Decide which relationship paths are needed for SQL vs display-only
# For get(), there is no ORDER BY; only filters might force SQL use.
sql_paths = _paths_needed_for_sql(order_by=None, filters=filters, join_paths=join_paths)
proj_paths = _paths_from_fields(req_fields)
sql_hops = _paths_needed_for_sql(order_by=None, filters=filters, join_paths=join_paths)
proj_hops = _paths_from_fields(req_fields)
# Root column projection
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
@ -454,15 +496,15 @@ class CRUDService(Generic[T]):
# Relationship handling per path: avoid loader strategy conflicts
used_contains_eager = False
for path, relationship_attr, target_alias in join_paths:
for _path, relationship_attr, target_alias in join_paths:
rel_attr = cast(InstrumentedAttribute, relationship_attr)
ptuple = tuple(path)
if ptuple in sql_paths:
name = relationship_attr.key
if name in sql_hops:
# Needed in WHERE: join + hydrate from the join
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True
elif ptuple in proj_paths:
elif name in proj_hops:
# Display-only: bulk-load efficiently
query = query.options(selectinload(rel_attr))
else:
@ -534,8 +576,9 @@ class CRUDService(Generic[T]):
query = query.filter(*filters)
# Determine which relationship paths are needed for SQL vs display-only
sql_paths = _paths_needed_for_sql(order_by, filters, join_paths)
proj_paths = _paths_from_fields(req_fields)
sql_hops = _paths_needed_for_sql(order_by, filters, join_paths)
sql_hops |= _hops_from_sort(params) # ensure sort-driven joins exist
proj_hops = _paths_from_fields(req_fields)
# Root column projection
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
@ -544,20 +587,30 @@ class CRUDService(Generic[T]):
# Relationship handling per path
used_contains_eager = False
for path, relationship_attr, target_alias in join_paths:
joined_names: set[str] = set()
for _path, relationship_attr, target_alias in join_paths:
rel_attr = cast(InstrumentedAttribute, relationship_attr)
ptuple = tuple(path)
if ptuple in sql_paths:
name = relationship_attr.key
if name in sql_hops:
# Needed for WHERE/ORDER BY: join + hydrate from the join
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True
elif ptuple in proj_paths:
joined_names.add(name)
elif name in proj_hops:
# Display-only: no join, bulk-load efficiently
query = query.options(selectinload(rel_attr))
else:
# Not needed at all; do nothing
pass
joined_names.add(name)
# Force-join any SQL-needed relationships that weren't in join_paths
missing_sql = sql_hops - joined_names
for name in missing_sql:
rel_attr = cast(InstrumentedAttribute, getattr(root_alias, name))
query = query.join(rel_attr, isouter=True)
query = query.options(contains_eager(rel_attr))
used_contains_eager = True
joined_names.add(name)
# MSSQL requires ORDER BY when OFFSET is used (SQLA uses OFFSET for limit/offset)
paginating = (limit is not None) or (offset is not None and offset != 0)
@ -617,6 +670,7 @@ class CRUDService(Generic[T]):
return rows
def create(self, data: dict, actor=None) -> T:
session = self.session
obj = self.model(**data)
@ -627,7 +681,7 @@ class CRUDService(Generic[T]):
def update(self, id: int, data: dict, actor=None) -> T:
session = self.session
obj = self.get(id)
obj = session.get(self.model, id)
if not obj:
raise ValueError(f"{self.model.__name__} with ID {id} not found.")
valid_fields = {c.name for c in self.model.__table__.columns}