More fixes.
This commit is contained in:
parent
2a9fb389d7
commit
c6165af40e
2 changed files with 155 additions and 81 deletions
|
|
@ -2,12 +2,12 @@ from __future__ import annotations
|
|||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
|
||||
from sqlalchemy import and_, func, inspect, or_, text, UnaryExpression
|
||||
from sqlalchemy import and_, func, inspect, or_, text
|
||||
from sqlalchemy.engine import Engine, Connection
|
||||
from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager, selectinload
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
from sqlalchemy.sql import operators
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlalchemy.sql.elements import UnaryExpression, ColumnElement
|
||||
|
||||
from crudkit.core.base import Version
|
||||
from crudkit.core.spec import CRUDSpec
|
||||
|
|
@ -40,6 +40,25 @@ class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol):
|
|||
|
||||
T = TypeVar("T", bound=_CRUDModelProto)
|
||||
|
||||
def _hops_from_sort(params: dict | None) -> set[str]:
|
||||
"""Extract first-hop relationship names from a sort spec like 'owner.first_name,-brand.name'."""
|
||||
if not params:
|
||||
return set()
|
||||
raw = params.get("sort")
|
||||
tokens: list[str] = []
|
||||
if isinstance(raw, str):
|
||||
tokens = [t.strip() for t in raw.split(",") if t.strip()]
|
||||
elif isinstance(raw, (list, tuple)):
|
||||
for item in raw:
|
||||
if isinstance(item, str):
|
||||
tokens.extend([t.strip() for t in item.split(",") if t.strip()])
|
||||
hops: set[str] = set()
|
||||
for tok in tokens:
|
||||
tok = tok.lstrip("+-")
|
||||
if "." in tok:
|
||||
hops.add(tok.split(".", 1)[0])
|
||||
return hops
|
||||
|
||||
def _belongs_to_alias(col: Any, alias: Any) -> bool:
|
||||
# Try to detect if a column/expression ultimately comes from this alias.
|
||||
# Works for most ORM columns; complex expressions may need more.
|
||||
|
|
@ -47,14 +66,15 @@ def _belongs_to_alias(col: Any, alias: Any) -> bool:
|
|||
selectable = getattr(alias, "selectable", None)
|
||||
return t is not None and selectable is not None and t is selectable
|
||||
|
||||
def _paths_needed_for_sql(order_by: Iterable[Any], filters: Iterable[Any], join_paths: tuple) -> set[tuple[str, ...]]:
|
||||
def _paths_needed_for_sql(order_by: Iterable[Any], filters: Iterable[Any], join_paths: tuple) -> set[str]:
|
||||
hops: set[str] = set()
|
||||
paths: set[tuple[str, ...]] = set()
|
||||
# Sort columns
|
||||
for ob in order_by or []:
|
||||
col = getattr(ob, "element", ob) # unwrap UnaryExpression
|
||||
for path, _rel_attr, target_alias in join_paths:
|
||||
for _path, rel_attr, target_alias in join_paths:
|
||||
if _belongs_to_alias(col, target_alias):
|
||||
paths.add(tuple(path))
|
||||
hops.add(rel_attr.key)
|
||||
# Filter columns (best-effort)
|
||||
# Walk simple binary expressions
|
||||
def _extract_cols(expr: Any) -> Iterable[Any]:
|
||||
|
|
@ -68,18 +88,18 @@ def _paths_needed_for_sql(order_by: Iterable[Any], filters: Iterable[Any], join_
|
|||
|
||||
for flt in filters or []:
|
||||
for col in _extract_cols(flt):
|
||||
for path, _rel_attr, target_alias in join_paths:
|
||||
for _path, rel_attr, target_alias in join_paths:
|
||||
if _belongs_to_alias(col, target_alias):
|
||||
paths.add(tuple[path])
|
||||
return paths
|
||||
hops.add(rel_attr.key)
|
||||
return hops
|
||||
|
||||
def _paths_from_fields(req_fields: list[str]) -> set[tuple[str, ...]]:
|
||||
out: set[tuple[str, ...]] = set()
|
||||
def _paths_from_fields(req_fields: list[str]) -> set[str]:
|
||||
out: set[str] = set()
|
||||
for f in req_fields:
|
||||
if "." in f:
|
||||
parts = tuple(f.split(".")[:-1])
|
||||
if parts:
|
||||
out.add(parts)
|
||||
parent = f.split(".", 1)[0]
|
||||
if parent:
|
||||
out.add(parent)
|
||||
return out
|
||||
|
||||
def _is_truthy(val):
|
||||
|
|
@ -230,50 +250,24 @@ class CRUDService(Generic[T]):
|
|||
spec.parse_includes()
|
||||
join_paths = tuple(spec.get_join_paths()) # iterable of (path, relationship_attr, target_alias)
|
||||
|
||||
# Decide which relationship *names* are required for SQL (filters/sort) vs display-only
|
||||
def _belongs_to_alias(col: Any, alias: Any) -> bool:
|
||||
t = getattr(col, "table", None)
|
||||
selectable = getattr(alias, "selectable", None)
|
||||
return t is not None and selectable is not None and t is selectable
|
||||
# Relationship names required by ORDER BY / WHERE
|
||||
sql_hops: set[str] = _paths_needed_for_sql(order_by, filters, join_paths)
|
||||
# Also include relationships mentioned directly in the sort spec
|
||||
sql_hops |= _hops_from_sort(params)
|
||||
|
||||
# 1) which relationship aliases are referenced by sort/filter
|
||||
sql_hops: set[str] = set()
|
||||
for path, relationship_attr, target_alias in join_paths:
|
||||
# If any ORDER BY column comes from this alias, mark it
|
||||
for ob in (order_by or []):
|
||||
col = getattr(ob, "element", ob) # unwrap UnaryExpression
|
||||
if _belongs_to_alias(col, target_alias):
|
||||
sql_hops.add(relationship_attr.key)
|
||||
break
|
||||
# If any filter expr touches this alias, mark it (best effort)
|
||||
if relationship_attr.key not in sql_hops:
|
||||
def _walk_cols(expr: Any):
|
||||
# Primitive walker for ColumnElement trees
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
if isinstance(expr, ColumnElement):
|
||||
yield expr
|
||||
for ch in getattr(expr, "get_children", lambda: [])():
|
||||
yield from _walk_cols(ch)
|
||||
elif hasattr(expr, "clauses"):
|
||||
for ch in expr.clauses:
|
||||
yield from _walk_cols(ch)
|
||||
for flt in (filters or []):
|
||||
if any(_belongs_to_alias(c, target_alias) for c in _walk_cols(flt)):
|
||||
sql_hops.add(relationship_attr.key)
|
||||
break
|
||||
|
||||
# 2) first-hop relationship names implied by dotted projection fields
|
||||
proj_hops: set[str] = {f.split(".", 1)[0] for f in fields if "." in f}
|
||||
# First-hop relationship names implied by dotted projection fields
|
||||
proj_hops: set[str] = _paths_from_fields(fields)
|
||||
|
||||
# Root column projection
|
||||
from sqlalchemy.orm import Load # local import to match your style
|
||||
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
|
||||
if only_cols:
|
||||
query = query.options(Load(root_alias).load_only(*only_cols))
|
||||
|
||||
# Relationship handling per path (avoid loader strategy conflicts)
|
||||
used_contains_eager = False
|
||||
for path, relationship_attr, target_alias in join_paths:
|
||||
joined_names: set[str] = set()
|
||||
|
||||
for _path, relationship_attr, target_alias in join_paths:
|
||||
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
||||
name = relationship_attr.key
|
||||
if name in sql_hops:
|
||||
|
|
@ -281,12 +275,20 @@ class CRUDService(Generic[T]):
|
|||
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
|
||||
query = query.options(contains_eager(rel_attr, alias=target_alias))
|
||||
used_contains_eager = True
|
||||
joined_names.add(name)
|
||||
elif name in proj_hops:
|
||||
# Display-only: bulk-load efficiently, no join
|
||||
query = query.options(selectinload(rel_attr))
|
||||
else:
|
||||
# Not needed
|
||||
pass
|
||||
joined_names.add(name)
|
||||
|
||||
# Force-join any SQL-needed relationships that weren't in join_paths
|
||||
missing_sql = sql_hops - joined_names
|
||||
for name in missing_sql:
|
||||
rel_attr = cast(InstrumentedAttribute, getattr(root_alias, name))
|
||||
query = query.join(rel_attr, isouter=True)
|
||||
query = query.options(contains_eager(rel_attr))
|
||||
used_contains_eager = True
|
||||
joined_names.add(name)
|
||||
|
||||
# Apply projection loader options only if they won't conflict with contains_eager
|
||||
if proj_opts and not used_contains_eager:
|
||||
|
|
@ -348,8 +350,43 @@ class CRUDService(Generic[T]):
|
|||
pass
|
||||
|
||||
# Boundary keys for cursor encoding in the API layer
|
||||
first_key = self._pluck_key(items[0], order_spec) if items else None
|
||||
last_key = self._pluck_key(items[-1], order_spec) if items else None
|
||||
# When ORDER BY includes related columns (e.g., owner.first_name),
|
||||
# pluck values from the related object we hydrated with contains_eager/selectinload.
|
||||
def _pluck_key_from_obj(obj: Any) -> list[Any]:
|
||||
vals: list[Any] = []
|
||||
# Build a quick map: selectable -> relationship name
|
||||
alias_to_rel: dict[Any, str] = {}
|
||||
for _p, relationship_attr, target_alias in join_paths:
|
||||
sel = getattr(target_alias, "selectable", None)
|
||||
if sel is not None:
|
||||
alias_to_rel[sel] = relationship_attr.key
|
||||
|
||||
for col in order_spec.cols:
|
||||
key = getattr(col, "key", None) or getattr(col, "name", None)
|
||||
# Try root attribute first
|
||||
if key and hasattr(obj, key):
|
||||
vals.append(getattr(obj, key))
|
||||
continue
|
||||
# Try relationship hop by matching the column's table/selectable
|
||||
table = getattr(col, "table", None)
|
||||
relname = alias_to_rel.get(table)
|
||||
if relname and key:
|
||||
relobj = getattr(obj, relname, None)
|
||||
if relobj is not None and hasattr(relobj, key):
|
||||
vals.append(getattr(relobj, key))
|
||||
continue
|
||||
# Give up: unsupported expression for cursor purposes
|
||||
raise ValueError("unpluckable")
|
||||
return vals
|
||||
|
||||
try:
|
||||
first_key = _pluck_key_from_obj(items[0]) if items else None
|
||||
last_key = _pluck_key_from_obj(items[-1]) if items else None
|
||||
except Exception:
|
||||
# If we can't derive cursor keys (e.g., ORDER BY expression/aggregate),
|
||||
# disable cursors for this response rather than exploding.
|
||||
first_key = None
|
||||
last_key = None
|
||||
|
||||
# Optional total that’s safe under JOINs (COUNT DISTINCT ids)
|
||||
total = None
|
||||
|
|
@ -359,10 +396,15 @@ class CRUDService(Generic[T]):
|
|||
if filters:
|
||||
base = base.filter(*filters)
|
||||
# Mirror join structure for any SQL-needed relationships
|
||||
for path, relationship_attr, target_alias in join_paths:
|
||||
for _path, relationship_attr, target_alias in join_paths:
|
||||
if relationship_attr.key in sql_hops:
|
||||
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
||||
base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
|
||||
# Also mirror any forced joins
|
||||
for name in (sql_hops - {ra.key for _p, ra, _a in join_paths}):
|
||||
rel_attr = cast(InstrumentedAttribute, getattr(root_alias, name))
|
||||
base = base.join(rel_attr, isouter=True)
|
||||
|
||||
total = session.query(func.count()).select_from(
|
||||
base.order_by(None).distinct().subquery()
|
||||
).scalar() or 0
|
||||
|
|
@ -444,8 +486,8 @@ class CRUDService(Generic[T]):
|
|||
|
||||
# Decide which relationship paths are needed for SQL vs display-only
|
||||
# For get(), there is no ORDER BY; only filters might force SQL use.
|
||||
sql_paths = _paths_needed_for_sql(order_by=None, filters=filters, join_paths=join_paths)
|
||||
proj_paths = _paths_from_fields(req_fields)
|
||||
sql_hops = _paths_needed_for_sql(order_by=None, filters=filters, join_paths=join_paths)
|
||||
proj_hops = _paths_from_fields(req_fields)
|
||||
|
||||
# Root column projection
|
||||
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
|
||||
|
|
@ -454,15 +496,15 @@ class CRUDService(Generic[T]):
|
|||
|
||||
# Relationship handling per path: avoid loader strategy conflicts
|
||||
used_contains_eager = False
|
||||
for path, relationship_attr, target_alias in join_paths:
|
||||
for _path, relationship_attr, target_alias in join_paths:
|
||||
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
||||
ptuple = tuple(path)
|
||||
if ptuple in sql_paths:
|
||||
name = relationship_attr.key
|
||||
if name in sql_hops:
|
||||
# Needed in WHERE: join + hydrate from the join
|
||||
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
|
||||
query = query.options(contains_eager(rel_attr, alias=target_alias))
|
||||
used_contains_eager = True
|
||||
elif ptuple in proj_paths:
|
||||
elif name in proj_hops:
|
||||
# Display-only: bulk-load efficiently
|
||||
query = query.options(selectinload(rel_attr))
|
||||
else:
|
||||
|
|
@ -534,8 +576,9 @@ class CRUDService(Generic[T]):
|
|||
query = query.filter(*filters)
|
||||
|
||||
# Determine which relationship paths are needed for SQL vs display-only
|
||||
sql_paths = _paths_needed_for_sql(order_by, filters, join_paths)
|
||||
proj_paths = _paths_from_fields(req_fields)
|
||||
sql_hops = _paths_needed_for_sql(order_by, filters, join_paths)
|
||||
sql_hops |= _hops_from_sort(params) # ensure sort-driven joins exist
|
||||
proj_hops = _paths_from_fields(req_fields)
|
||||
|
||||
# Root column projection
|
||||
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
|
||||
|
|
@ -544,20 +587,30 @@ class CRUDService(Generic[T]):
|
|||
|
||||
# Relationship handling per path
|
||||
used_contains_eager = False
|
||||
for path, relationship_attr, target_alias in join_paths:
|
||||
joined_names: set[str] = set()
|
||||
|
||||
for _path, relationship_attr, target_alias in join_paths:
|
||||
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
||||
ptuple = tuple(path)
|
||||
if ptuple in sql_paths:
|
||||
name = relationship_attr.key
|
||||
if name in sql_hops:
|
||||
# Needed for WHERE/ORDER BY: join + hydrate from the join
|
||||
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
|
||||
query = query.options(contains_eager(rel_attr, alias=target_alias))
|
||||
used_contains_eager = True
|
||||
elif ptuple in proj_paths:
|
||||
joined_names.add(name)
|
||||
elif name in proj_hops:
|
||||
# Display-only: no join, bulk-load efficiently
|
||||
query = query.options(selectinload(rel_attr))
|
||||
else:
|
||||
# Not needed at all; do nothing
|
||||
pass
|
||||
joined_names.add(name)
|
||||
|
||||
# Force-join any SQL-needed relationships that weren't in join_paths
|
||||
missing_sql = sql_hops - joined_names
|
||||
for name in missing_sql:
|
||||
rel_attr = cast(InstrumentedAttribute, getattr(root_alias, name))
|
||||
query = query.join(rel_attr, isouter=True)
|
||||
query = query.options(contains_eager(rel_attr))
|
||||
used_contains_eager = True
|
||||
joined_names.add(name)
|
||||
|
||||
# MSSQL requires ORDER BY when OFFSET is used (SQLA uses OFFSET for limit/offset)
|
||||
paginating = (limit is not None) or (offset is not None and offset != 0)
|
||||
|
|
@ -617,6 +670,7 @@ class CRUDService(Generic[T]):
|
|||
|
||||
return rows
|
||||
|
||||
|
||||
def create(self, data: dict, actor=None) -> T:
|
||||
session = self.session
|
||||
obj = self.model(**data)
|
||||
|
|
@ -627,7 +681,7 @@ class CRUDService(Generic[T]):
|
|||
|
||||
def update(self, id: int, data: dict, actor=None) -> T:
|
||||
session = self.session
|
||||
obj = self.get(id)
|
||||
obj = session.get(self.model, id)
|
||||
if not obj:
|
||||
raise ValueError(f"{self.model.__name__} with ID {id} not found.")
|
||||
valid_fields = {c.name for c in self.model.__table__.columns}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,14 @@ def init_listing_routes(app):
|
|||
if cls is None:
|
||||
abort(404)
|
||||
|
||||
# read query args
|
||||
limit = int(request.args.get("limit", 15))
|
||||
sort = request.args.get("sort") # <- capture sort from URL
|
||||
cursor = request.args.get("cursor")
|
||||
# your decode returns (key, _desc, backward) in this project
|
||||
key, _desc, backward = decode_cursor(cursor)
|
||||
|
||||
# base spec per model
|
||||
spec = {}
|
||||
columns = []
|
||||
row_classes = []
|
||||
|
|
@ -42,7 +50,8 @@ def init_listing_routes(app):
|
|||
{"field": "model"},
|
||||
{"field": "device_type.description", "label": "Device Type"},
|
||||
{"field": "condition"},
|
||||
{"field": "owner.label", "label": "Contact", "link": {"endpoint": "entry.entry", "params": {"id": "{owner.id}", "model": "user"}}},
|
||||
{"field": "owner.label", "label": "Contact",
|
||||
"link": {"endpoint": "entry.entry", "params": {"id": "{owner.id}", "model": "user"}}},
|
||||
{"field": "location.label", "label": "Room"},
|
||||
]
|
||||
elif model.lower() == 'user':
|
||||
|
|
@ -54,12 +63,13 @@ def init_listing_routes(app):
|
|||
"robot.overlord",
|
||||
"staff",
|
||||
"active",
|
||||
], "sort": "first_name,last_name"}
|
||||
], "sort": "first_name,last_name"} # default for users
|
||||
columns = [
|
||||
{"field": "label", "label": "Full Name"},
|
||||
{"field": "last_name"},
|
||||
{"field": "first_name"},
|
||||
{"field": "supervisor.label", "label": "Supervisor", "link": {"endpoint": "entry.entry", "params": {"id": "{supervisor.id}", "model": "user"}}},
|
||||
{"field": "supervisor.label", "label": "Supervisor",
|
||||
"link": {"endpoint": "entry.entry", "params": {"id": "{supervisor.id}", "model": "user"}}},
|
||||
{"field": "staff", "format": "yesno"},
|
||||
{"field": "active", "format": "yesno"},
|
||||
]
|
||||
|
|
@ -79,8 +89,10 @@ def init_listing_routes(app):
|
|||
"complete",
|
||||
]}
|
||||
columns = [
|
||||
{"field": "work_item.label", "label": "Work Item", "link": {"endpoint": "entry.entry", "params": {"id": "{work_item.id}", "model": "inventory"}}},
|
||||
{"field": "contact.label", "label": "Contact", "link": {"endpoint": "entry.entry", "params": {"id": "{contact.id}", "model": "user"}}},
|
||||
{"field": "work_item.label", "label": "Work Item",
|
||||
"link": {"endpoint": "entry.entry", "params": {"id": "{work_item.id}", "model": "inventory"}}},
|
||||
{"field": "contact.label", "label": "Contact",
|
||||
"link": {"endpoint": "entry.entry", "params": {"id": "{contact.id}", "model": "user"}}},
|
||||
{"field": "start_time", "format": "datetime"},
|
||||
{"field": "end_time", "format": "datetime"},
|
||||
{"field": "complete", "format": "yesno"},
|
||||
|
|
@ -89,19 +101,27 @@ def init_listing_routes(app):
|
|||
{"when": {"field": "complete", "is": True}, "class": "table-success"},
|
||||
{"when": {"field": "complete", "is": False}, "class": "table-danger"}
|
||||
]
|
||||
limit = int(request.args.get("limit", 15))
|
||||
cursor = request.args.get("cursor")
|
||||
key, _desc, backward = decode_cursor(cursor)
|
||||
|
||||
# overlay URL-provided sort if present
|
||||
if sort:
|
||||
spec["sort"] = sort
|
||||
|
||||
service = crudkit.crud.get_service(cls)
|
||||
# include limit and go
|
||||
window = service.seek_window(spec | {"limit": limit}, key=key, backward=backward, include_total=True)
|
||||
|
||||
table = render_table(window.items, columns=columns, opts={"object_class": model, "row_classes": row_classes})
|
||||
return render_template("listing.html", model=model, table=table, pagination={
|
||||
table = render_table(window.items, columns=columns,
|
||||
opts={"object_class": model, "row_classes": row_classes})
|
||||
|
||||
# pass sort through so templates can preserve it on pager links, if they care
|
||||
pagination_ctx = {
|
||||
"limit": window.limit,
|
||||
"total": window.total,
|
||||
"next_cursor": encode_cursor(window.last_key, list(window.order.desc), backward=False),
|
||||
"prev_cursor": encode_cursor(window.first_key, list(window.order.desc), backward=True),
|
||||
})
|
||||
"sort": sort or spec.get("sort") # expose current sort to the template
|
||||
}
|
||||
|
||||
return render_template("listing.html", model=model, table=table, pagination=pagination_ctx)
|
||||
|
||||
app.register_blueprint(bp_listing)
|
||||
Loading…
Add table
Add a link
Reference in a new issue