More fixes.

This commit is contained in:
Yaro Kasear 2025-09-24 15:04:00 -05:00
parent 2a9fb389d7
commit c6165af40e
2 changed files with 155 additions and 81 deletions

View file

@ -2,12 +2,12 @@ from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
from sqlalchemy import and_, func, inspect, or_, text, UnaryExpression from sqlalchemy import and_, func, inspect, or_, text
from sqlalchemy.engine import Engine, Connection from sqlalchemy.engine import Engine, Connection
from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager, selectinload from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager, selectinload
from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql import operators from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import UnaryExpression, ColumnElement
from crudkit.core.base import Version from crudkit.core.base import Version
from crudkit.core.spec import CRUDSpec from crudkit.core.spec import CRUDSpec
@ -40,6 +40,25 @@ class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol):
T = TypeVar("T", bound=_CRUDModelProto) T = TypeVar("T", bound=_CRUDModelProto)
def _hops_from_sort(params: dict | None) -> set[str]:
"""Extract first-hop relationship names from a sort spec like 'owner.first_name,-brand.name'."""
if not params:
return set()
raw = params.get("sort")
tokens: list[str] = []
if isinstance(raw, str):
tokens = [t.strip() for t in raw.split(",") if t.strip()]
elif isinstance(raw, (list, tuple)):
for item in raw:
if isinstance(item, str):
tokens.extend([t.strip() for t in item.split(",") if t.strip()])
hops: set[str] = set()
for tok in tokens:
tok = tok.lstrip("+-")
if "." in tok:
hops.add(tok.split(".", 1)[0])
return hops
def _belongs_to_alias(col: Any, alias: Any) -> bool: def _belongs_to_alias(col: Any, alias: Any) -> bool:
# Try to detect if a column/expression ultimately comes from this alias. # Try to detect if a column/expression ultimately comes from this alias.
# Works for most ORM columns; complex expressions may need more. # Works for most ORM columns; complex expressions may need more.
@ -47,14 +66,15 @@ def _belongs_to_alias(col: Any, alias: Any) -> bool:
selectable = getattr(alias, "selectable", None) selectable = getattr(alias, "selectable", None)
return t is not None and selectable is not None and t is selectable return t is not None and selectable is not None and t is selectable
def _paths_needed_for_sql(order_by: Iterable[Any], filters: Iterable[Any], join_paths: tuple) -> set[tuple[str, ...]]: def _paths_needed_for_sql(order_by: Iterable[Any], filters: Iterable[Any], join_paths: tuple) -> set[str]:
hops: set[str] = set()
paths: set[tuple[str, ...]] = set() paths: set[tuple[str, ...]] = set()
# Sort columns # Sort columns
for ob in order_by or []: for ob in order_by or []:
col = getattr(ob, "element", ob) # unwrap UnaryExpression col = getattr(ob, "element", ob) # unwrap UnaryExpression
for path, _rel_attr, target_alias in join_paths: for _path, rel_attr, target_alias in join_paths:
if _belongs_to_alias(col, target_alias): if _belongs_to_alias(col, target_alias):
paths.add(tuple(path)) hops.add(rel_attr.key)
# Filter columns (best-effort) # Filter columns (best-effort)
# Walk simple binary expressions # Walk simple binary expressions
def _extract_cols(expr: Any) -> Iterable[Any]: def _extract_cols(expr: Any) -> Iterable[Any]:
@ -68,18 +88,18 @@ def _paths_needed_for_sql(order_by: Iterable[Any], filters: Iterable[Any], join_
for flt in filters or []: for flt in filters or []:
for col in _extract_cols(flt): for col in _extract_cols(flt):
for path, _rel_attr, target_alias in join_paths: for _path, rel_attr, target_alias in join_paths:
if _belongs_to_alias(col, target_alias): if _belongs_to_alias(col, target_alias):
paths.add(tuple[path]) hops.add(rel_attr.key)
return paths return hops
def _paths_from_fields(req_fields: list[str]) -> set[tuple[str, ...]]: def _paths_from_fields(req_fields: list[str]) -> set[str]:
out: set[tuple[str, ...]] = set() out: set[str] = set()
for f in req_fields: for f in req_fields:
if "." in f: if "." in f:
parts = tuple(f.split(".")[:-1]) parent = f.split(".", 1)[0]
if parts: if parent:
out.add(parts) out.add(parent)
return out return out
def _is_truthy(val): def _is_truthy(val):
@ -230,50 +250,24 @@ class CRUDService(Generic[T]):
spec.parse_includes() spec.parse_includes()
join_paths = tuple(spec.get_join_paths()) # iterable of (path, relationship_attr, target_alias) join_paths = tuple(spec.get_join_paths()) # iterable of (path, relationship_attr, target_alias)
# Decide which relationship *names* are required for SQL (filters/sort) vs display-only # Relationship names required by ORDER BY / WHERE
def _belongs_to_alias(col: Any, alias: Any) -> bool: sql_hops: set[str] = _paths_needed_for_sql(order_by, filters, join_paths)
t = getattr(col, "table", None) # Also include relationships mentioned directly in the sort spec
selectable = getattr(alias, "selectable", None) sql_hops |= _hops_from_sort(params)
return t is not None and selectable is not None and t is selectable
# 1) which relationship aliases are referenced by sort/filter # First-hop relationship names implied by dotted projection fields
sql_hops: set[str] = set() proj_hops: set[str] = _paths_from_fields(fields)
for path, relationship_attr, target_alias in join_paths:
# If any ORDER BY column comes from this alias, mark it
for ob in (order_by or []):
col = getattr(ob, "element", ob) # unwrap UnaryExpression
if _belongs_to_alias(col, target_alias):
sql_hops.add(relationship_attr.key)
break
# If any filter expr touches this alias, mark it (best effort)
if relationship_attr.key not in sql_hops:
def _walk_cols(expr: Any):
# Primitive walker for ColumnElement trees
from sqlalchemy.sql.elements import ColumnElement
if isinstance(expr, ColumnElement):
yield expr
for ch in getattr(expr, "get_children", lambda: [])():
yield from _walk_cols(ch)
elif hasattr(expr, "clauses"):
for ch in expr.clauses:
yield from _walk_cols(ch)
for flt in (filters or []):
if any(_belongs_to_alias(c, target_alias) for c in _walk_cols(flt)):
sql_hops.add(relationship_attr.key)
break
# 2) first-hop relationship names implied by dotted projection fields
proj_hops: set[str] = {f.split(".", 1)[0] for f in fields if "." in f}
# Root column projection # Root column projection
from sqlalchemy.orm import Load # local import to match your style
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols: if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols)) query = query.options(Load(root_alias).load_only(*only_cols))
# Relationship handling per path (avoid loader strategy conflicts) # Relationship handling per path (avoid loader strategy conflicts)
used_contains_eager = False used_contains_eager = False
for path, relationship_attr, target_alias in join_paths: joined_names: set[str] = set()
for _path, relationship_attr, target_alias in join_paths:
rel_attr = cast(InstrumentedAttribute, relationship_attr) rel_attr = cast(InstrumentedAttribute, relationship_attr)
name = relationship_attr.key name = relationship_attr.key
if name in sql_hops: if name in sql_hops:
@ -281,12 +275,20 @@ class CRUDService(Generic[T]):
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias)) query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True used_contains_eager = True
joined_names.add(name)
elif name in proj_hops: elif name in proj_hops:
# Display-only: bulk-load efficiently, no join # Display-only: bulk-load efficiently, no join
query = query.options(selectinload(rel_attr)) query = query.options(selectinload(rel_attr))
else: joined_names.add(name)
# Not needed
pass # Force-join any SQL-needed relationships that weren't in join_paths
missing_sql = sql_hops - joined_names
for name in missing_sql:
rel_attr = cast(InstrumentedAttribute, getattr(root_alias, name))
query = query.join(rel_attr, isouter=True)
query = query.options(contains_eager(rel_attr))
used_contains_eager = True
joined_names.add(name)
# Apply projection loader options only if they won't conflict with contains_eager # Apply projection loader options only if they won't conflict with contains_eager
if proj_opts and not used_contains_eager: if proj_opts and not used_contains_eager:
@ -348,8 +350,43 @@ class CRUDService(Generic[T]):
pass pass
# Boundary keys for cursor encoding in the API layer # Boundary keys for cursor encoding in the API layer
first_key = self._pluck_key(items[0], order_spec) if items else None # When ORDER BY includes related columns (e.g., owner.first_name),
last_key = self._pluck_key(items[-1], order_spec) if items else None # pluck values from the related object we hydrated with contains_eager/selectinload.
def _pluck_key_from_obj(obj: Any) -> list[Any]:
vals: list[Any] = []
# Build a quick map: selectable -> relationship name
alias_to_rel: dict[Any, str] = {}
for _p, relationship_attr, target_alias in join_paths:
sel = getattr(target_alias, "selectable", None)
if sel is not None:
alias_to_rel[sel] = relationship_attr.key
for col in order_spec.cols:
key = getattr(col, "key", None) or getattr(col, "name", None)
# Try root attribute first
if key and hasattr(obj, key):
vals.append(getattr(obj, key))
continue
# Try relationship hop by matching the column's table/selectable
table = getattr(col, "table", None)
relname = alias_to_rel.get(table)
if relname and key:
relobj = getattr(obj, relname, None)
if relobj is not None and hasattr(relobj, key):
vals.append(getattr(relobj, key))
continue
# Give up: unsupported expression for cursor purposes
raise ValueError("unpluckable")
return vals
try:
first_key = _pluck_key_from_obj(items[0]) if items else None
last_key = _pluck_key_from_obj(items[-1]) if items else None
except Exception:
# If we can't derive cursor keys (e.g., ORDER BY expression/aggregate),
# disable cursors for this response rather than exploding.
first_key = None
last_key = None
# Optional total thats safe under JOINs (COUNT DISTINCT ids) # Optional total thats safe under JOINs (COUNT DISTINCT ids)
total = None total = None
@ -359,10 +396,15 @@ class CRUDService(Generic[T]):
if filters: if filters:
base = base.filter(*filters) base = base.filter(*filters)
# Mirror join structure for any SQL-needed relationships # Mirror join structure for any SQL-needed relationships
for path, relationship_attr, target_alias in join_paths: for _path, relationship_attr, target_alias in join_paths:
if relationship_attr.key in sql_hops: if relationship_attr.key in sql_hops:
rel_attr = cast(InstrumentedAttribute, relationship_attr) rel_attr = cast(InstrumentedAttribute, relationship_attr)
base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True) base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
# Also mirror any forced joins
for name in (sql_hops - {ra.key for _p, ra, _a in join_paths}):
rel_attr = cast(InstrumentedAttribute, getattr(root_alias, name))
base = base.join(rel_attr, isouter=True)
total = session.query(func.count()).select_from( total = session.query(func.count()).select_from(
base.order_by(None).distinct().subquery() base.order_by(None).distinct().subquery()
).scalar() or 0 ).scalar() or 0
@ -444,8 +486,8 @@ class CRUDService(Generic[T]):
# Decide which relationship paths are needed for SQL vs display-only # Decide which relationship paths are needed for SQL vs display-only
# For get(), there is no ORDER BY; only filters might force SQL use. # For get(), there is no ORDER BY; only filters might force SQL use.
sql_paths = _paths_needed_for_sql(order_by=None, filters=filters, join_paths=join_paths) sql_hops = _paths_needed_for_sql(order_by=None, filters=filters, join_paths=join_paths)
proj_paths = _paths_from_fields(req_fields) proj_hops = _paths_from_fields(req_fields)
# Root column projection # Root column projection
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
@ -454,15 +496,15 @@ class CRUDService(Generic[T]):
# Relationship handling per path: avoid loader strategy conflicts # Relationship handling per path: avoid loader strategy conflicts
used_contains_eager = False used_contains_eager = False
for path, relationship_attr, target_alias in join_paths: for _path, relationship_attr, target_alias in join_paths:
rel_attr = cast(InstrumentedAttribute, relationship_attr) rel_attr = cast(InstrumentedAttribute, relationship_attr)
ptuple = tuple(path) name = relationship_attr.key
if ptuple in sql_paths: if name in sql_hops:
# Needed in WHERE: join + hydrate from the join # Needed in WHERE: join + hydrate from the join
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias)) query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True used_contains_eager = True
elif ptuple in proj_paths: elif name in proj_hops:
# Display-only: bulk-load efficiently # Display-only: bulk-load efficiently
query = query.options(selectinload(rel_attr)) query = query.options(selectinload(rel_attr))
else: else:
@ -534,8 +576,9 @@ class CRUDService(Generic[T]):
query = query.filter(*filters) query = query.filter(*filters)
# Determine which relationship paths are needed for SQL vs display-only # Determine which relationship paths are needed for SQL vs display-only
sql_paths = _paths_needed_for_sql(order_by, filters, join_paths) sql_hops = _paths_needed_for_sql(order_by, filters, join_paths)
proj_paths = _paths_from_fields(req_fields) sql_hops |= _hops_from_sort(params) # ensure sort-driven joins exist
proj_hops = _paths_from_fields(req_fields)
# Root column projection # Root column projection
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
@ -544,20 +587,30 @@ class CRUDService(Generic[T]):
# Relationship handling per path # Relationship handling per path
used_contains_eager = False used_contains_eager = False
for path, relationship_attr, target_alias in join_paths: joined_names: set[str] = set()
for _path, relationship_attr, target_alias in join_paths:
rel_attr = cast(InstrumentedAttribute, relationship_attr) rel_attr = cast(InstrumentedAttribute, relationship_attr)
ptuple = tuple(path) name = relationship_attr.key
if ptuple in sql_paths: if name in sql_hops:
# Needed for WHERE/ORDER BY: join + hydrate from the join # Needed for WHERE/ORDER BY: join + hydrate from the join
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias)) query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True used_contains_eager = True
elif ptuple in proj_paths: joined_names.add(name)
elif name in proj_hops:
# Display-only: no join, bulk-load efficiently # Display-only: no join, bulk-load efficiently
query = query.options(selectinload(rel_attr)) query = query.options(selectinload(rel_attr))
else: joined_names.add(name)
# Not needed at all; do nothing
pass # Force-join any SQL-needed relationships that weren't in join_paths
missing_sql = sql_hops - joined_names
for name in missing_sql:
rel_attr = cast(InstrumentedAttribute, getattr(root_alias, name))
query = query.join(rel_attr, isouter=True)
query = query.options(contains_eager(rel_attr))
used_contains_eager = True
joined_names.add(name)
# MSSQL requires ORDER BY when OFFSET is used (SQLA uses OFFSET for limit/offset) # MSSQL requires ORDER BY when OFFSET is used (SQLA uses OFFSET for limit/offset)
paginating = (limit is not None) or (offset is not None and offset != 0) paginating = (limit is not None) or (offset is not None and offset != 0)
@ -617,6 +670,7 @@ class CRUDService(Generic[T]):
return rows return rows
def create(self, data: dict, actor=None) -> T: def create(self, data: dict, actor=None) -> T:
session = self.session session = self.session
obj = self.model(**data) obj = self.model(**data)
@ -627,7 +681,7 @@ class CRUDService(Generic[T]):
def update(self, id: int, data: dict, actor=None) -> T: def update(self, id: int, data: dict, actor=None) -> T:
session = self.session session = self.session
obj = self.get(id) obj = session.get(self.model, id)
if not obj: if not obj:
raise ValueError(f"{self.model.__name__} with ID {id} not found.") raise ValueError(f"{self.model.__name__} with ID {id} not found.")
valid_fields = {c.name for c in self.model.__table__.columns} valid_fields = {c.name for c in self.model.__table__.columns}

View file

@ -17,6 +17,14 @@ def init_listing_routes(app):
if cls is None: if cls is None:
abort(404) abort(404)
# read query args
limit = int(request.args.get("limit", 15))
sort = request.args.get("sort") # <- capture sort from URL
cursor = request.args.get("cursor")
# your decode returns (key, _desc, backward) in this project
key, _desc, backward = decode_cursor(cursor)
# base spec per model
spec = {} spec = {}
columns = [] columns = []
row_classes = [] row_classes = []
@ -42,7 +50,8 @@ def init_listing_routes(app):
{"field": "model"}, {"field": "model"},
{"field": "device_type.description", "label": "Device Type"}, {"field": "device_type.description", "label": "Device Type"},
{"field": "condition"}, {"field": "condition"},
{"field": "owner.label", "label": "Contact", "link": {"endpoint": "entry.entry", "params": {"id": "{owner.id}", "model": "user"}}}, {"field": "owner.label", "label": "Contact",
"link": {"endpoint": "entry.entry", "params": {"id": "{owner.id}", "model": "user"}}},
{"field": "location.label", "label": "Room"}, {"field": "location.label", "label": "Room"},
] ]
elif model.lower() == 'user': elif model.lower() == 'user':
@ -54,12 +63,13 @@ def init_listing_routes(app):
"robot.overlord", "robot.overlord",
"staff", "staff",
"active", "active",
], "sort": "first_name,last_name"} ], "sort": "first_name,last_name"} # default for users
columns = [ columns = [
{"field": "label", "label": "Full Name"}, {"field": "label", "label": "Full Name"},
{"field": "last_name"}, {"field": "last_name"},
{"field": "first_name"}, {"field": "first_name"},
{"field": "supervisor.label", "label": "Supervisor", "link": {"endpoint": "entry.entry", "params": {"id": "{supervisor.id}", "model": "user"}}}, {"field": "supervisor.label", "label": "Supervisor",
"link": {"endpoint": "entry.entry", "params": {"id": "{supervisor.id}", "model": "user"}}},
{"field": "staff", "format": "yesno"}, {"field": "staff", "format": "yesno"},
{"field": "active", "format": "yesno"}, {"field": "active", "format": "yesno"},
] ]
@ -79,8 +89,10 @@ def init_listing_routes(app):
"complete", "complete",
]} ]}
columns = [ columns = [
{"field": "work_item.label", "label": "Work Item", "link": {"endpoint": "entry.entry", "params": {"id": "{work_item.id}", "model": "inventory"}}}, {"field": "work_item.label", "label": "Work Item",
{"field": "contact.label", "label": "Contact", "link": {"endpoint": "entry.entry", "params": {"id": "{contact.id}", "model": "user"}}}, "link": {"endpoint": "entry.entry", "params": {"id": "{work_item.id}", "model": "inventory"}}},
{"field": "contact.label", "label": "Contact",
"link": {"endpoint": "entry.entry", "params": {"id": "{contact.id}", "model": "user"}}},
{"field": "start_time", "format": "datetime"}, {"field": "start_time", "format": "datetime"},
{"field": "end_time", "format": "datetime"}, {"field": "end_time", "format": "datetime"},
{"field": "complete", "format": "yesno"}, {"field": "complete", "format": "yesno"},
@ -89,19 +101,27 @@ def init_listing_routes(app):
{"when": {"field": "complete", "is": True}, "class": "table-success"}, {"when": {"field": "complete", "is": True}, "class": "table-success"},
{"when": {"field": "complete", "is": False}, "class": "table-danger"} {"when": {"field": "complete", "is": False}, "class": "table-danger"}
] ]
limit = int(request.args.get("limit", 15))
cursor = request.args.get("cursor") # overlay URL-provided sort if present
key, _desc, backward = decode_cursor(cursor) if sort:
spec["sort"] = sort
service = crudkit.crud.get_service(cls) service = crudkit.crud.get_service(cls)
# include limit and go
window = service.seek_window(spec | {"limit": limit}, key=key, backward=backward, include_total=True) window = service.seek_window(spec | {"limit": limit}, key=key, backward=backward, include_total=True)
table = render_table(window.items, columns=columns, opts={"object_class": model, "row_classes": row_classes}) table = render_table(window.items, columns=columns,
return render_template("listing.html", model=model, table=table, pagination={ opts={"object_class": model, "row_classes": row_classes})
# pass sort through so templates can preserve it on pager links, if they care
pagination_ctx = {
"limit": window.limit, "limit": window.limit,
"total": window.total, "total": window.total,
"next_cursor": encode_cursor(window.last_key, list(window.order.desc), backward=False), "next_cursor": encode_cursor(window.last_key, list(window.order.desc), backward=False),
"prev_cursor": encode_cursor(window.first_key, list(window.order.desc), backward=True), "prev_cursor": encode_cursor(window.first_key, list(window.order.desc), backward=True),
}) "sort": sort or spec.get("sort") # expose current sort to the template
}
return render_template("listing.html", model=model, table=table, pagination=pagination_ctx)
app.register_blueprint(bp_listing) app.register_blueprint(bp_listing)