diff --git a/crudkit/core/service.py b/crudkit/core/service.py index 9eb6d8b..5ec4757 100644 --- a/crudkit/core/service.py +++ b/crudkit/core/service.py @@ -2,12 +2,12 @@ from __future__ import annotations from collections.abc import Iterable from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast -from sqlalchemy import and_, func, inspect, or_, text, UnaryExpression +from sqlalchemy import and_, func, inspect, or_, text from sqlalchemy.engine import Engine, Connection from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager, selectinload from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.sql import operators -from sqlalchemy.sql.elements import ColumnElement +from sqlalchemy.sql.elements import UnaryExpression, ColumnElement from crudkit.core.base import Version from crudkit.core.spec import CRUDSpec @@ -40,6 +40,25 @@ class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol): T = TypeVar("T", bound=_CRUDModelProto) +def _hops_from_sort(params: dict | None) -> set[str]: + """Extract first-hop relationship names from a sort spec like 'owner.first_name,-brand.name'.""" + if not params: + return set() + raw = params.get("sort") + tokens: list[str] = [] + if isinstance(raw, str): + tokens = [t.strip() for t in raw.split(",") if t.strip()] + elif isinstance(raw, (list, tuple)): + for item in raw: + if isinstance(item, str): + tokens.extend([t.strip() for t in item.split(",") if t.strip()]) + hops: set[str] = set() + for tok in tokens: + tok = tok.lstrip("+-") + if "." in tok: + hops.add(tok.split(".", 1)[0]) + return hops + def _belongs_to_alias(col: Any, alias: Any) -> bool: # Try to detect if a column/expression ultimately comes from this alias. # Works for most ORM columns; complex expressions may need more. @@ -47,14 +66,15 @@ def _belongs_to_alias(col: Any, alias: Any) -> bool: selectable = getattr(alias, "selectable", None) return t is not None and selectable is not None and t is selectable -def _paths_needed_for_sql(order_by: Iterable[Any], filters: Iterable[Any], join_paths: tuple) -> set[tuple[str, ...]]: +def _paths_needed_for_sql(order_by: Iterable[Any], filters: Iterable[Any], join_paths: tuple) -> set[str]: + hops: set[str] = set() paths: set[tuple[str, ...]] = set() # Sort columns for ob in order_by or []: col = getattr(ob, "element", ob) # unwrap UnaryExpression - for path, _rel_attr, target_alias in join_paths: + for _path, rel_attr, target_alias in join_paths: if _belongs_to_alias(col, target_alias): - paths.add(tuple(path)) + hops.add(rel_attr.key) # Filter columns (best-effort) # Walk simple binary expressions def _extract_cols(expr: Any) -> Iterable[Any]: @@ -68,18 +88,18 @@ def _paths_needed_for_sql(order_by: Iterable[Any], filters: Iterable[Any], join_ for flt in filters or []: for col in _extract_cols(flt): - for path, _rel_attr, target_alias in join_paths: + for _path, rel_attr, target_alias in join_paths: if _belongs_to_alias(col, target_alias): - paths.add(tuple[path]) - return paths + hops.add(rel_attr.key) + return hops -def _paths_from_fields(req_fields: list[str]) -> set[tuple[str, ...]]: - out: set[tuple[str, ...]] = set() +def _paths_from_fields(req_fields: list[str]) -> set[str]: + out: set[str] = set() for f in req_fields: if "." in f: - parts = tuple(f.split(".")[:-1]) - if parts: - out.add(parts) + parent = f.split(".", 1)[0] + if parent: + out.add(parent) return out def _is_truthy(val): @@ -230,50 +250,24 @@ class CRUDService(Generic[T]): spec.parse_includes() join_paths = tuple(spec.get_join_paths()) # iterable of (path, relationship_attr, target_alias) - # Decide which relationship *names* are required for SQL (filters/sort) vs display-only - def _belongs_to_alias(col: Any, alias: Any) -> bool: - t = getattr(col, "table", None) - selectable = getattr(alias, "selectable", None) - return t is not None and selectable is not None and t is selectable + # Relationship names required by ORDER BY / WHERE + sql_hops: set[str] = _paths_needed_for_sql(order_by, filters, join_paths) + # Also include relationships mentioned directly in the sort spec + sql_hops |= _hops_from_sort(params) - # 1) which relationship aliases are referenced by sort/filter - sql_hops: set[str] = set() - for path, relationship_attr, target_alias in join_paths: - # If any ORDER BY column comes from this alias, mark it - for ob in (order_by or []): - col = getattr(ob, "element", ob) # unwrap UnaryExpression - if _belongs_to_alias(col, target_alias): - sql_hops.add(relationship_attr.key) - break - # If any filter expr touches this alias, mark it (best effort) - if relationship_attr.key not in sql_hops: - def _walk_cols(expr: Any): - # Primitive walker for ColumnElement trees - from sqlalchemy.sql.elements import ColumnElement - if isinstance(expr, ColumnElement): - yield expr - for ch in getattr(expr, "get_children", lambda: [])(): - yield from _walk_cols(ch) - elif hasattr(expr, "clauses"): - for ch in expr.clauses: - yield from _walk_cols(ch) - for flt in (filters or []): - if any(_belongs_to_alias(c, target_alias) for c in _walk_cols(flt)): - sql_hops.add(relationship_attr.key) - break - - # 2) first-hop relationship names implied by dotted projection fields - proj_hops: set[str] = {f.split(".", 1)[0] for f in fields if "." in f} + # First-hop relationship names implied by dotted projection fields + proj_hops: set[str] = _paths_from_fields(fields) # Root column projection - from sqlalchemy.orm import Load # local import to match your style only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) # Relationship handling per path (avoid loader strategy conflicts) used_contains_eager = False - for path, relationship_attr, target_alias in join_paths: + joined_names: set[str] = set() + + for _path, relationship_attr, target_alias in join_paths: rel_attr = cast(InstrumentedAttribute, relationship_attr) name = relationship_attr.key if name in sql_hops: @@ -281,12 +275,20 @@ class CRUDService(Generic[T]): query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.options(contains_eager(rel_attr, alias=target_alias)) used_contains_eager = True + joined_names.add(name) elif name in proj_hops: # Display-only: bulk-load efficiently, no join query = query.options(selectinload(rel_attr)) - else: - # Not needed - pass + joined_names.add(name) + + # Force-join any SQL-needed relationships that weren't in join_paths + missing_sql = sql_hops - joined_names + for name in missing_sql: + rel_attr = cast(InstrumentedAttribute, getattr(root_alias, name)) + query = query.join(rel_attr, isouter=True) + query = query.options(contains_eager(rel_attr)) + used_contains_eager = True + joined_names.add(name) # Apply projection loader options only if they won't conflict with contains_eager if proj_opts and not used_contains_eager: @@ -348,8 +350,43 @@ class CRUDService(Generic[T]): pass # Boundary keys for cursor encoding in the API layer - first_key = self._pluck_key(items[0], order_spec) if items else None - last_key = self._pluck_key(items[-1], order_spec) if items else None + # When ORDER BY includes related columns (e.g., owner.first_name), + # pluck values from the related object we hydrated with contains_eager/selectinload. + def _pluck_key_from_obj(obj: Any) -> list[Any]: + vals: list[Any] = [] + # Build a quick map: selectable -> relationship name + alias_to_rel: dict[Any, str] = {} + for _p, relationship_attr, target_alias in join_paths: + sel = getattr(target_alias, "selectable", None) + if sel is not None: + alias_to_rel[sel] = relationship_attr.key + + for col in order_spec.cols: + key = getattr(col, "key", None) or getattr(col, "name", None) + # Try root attribute first + if key and hasattr(obj, key): + vals.append(getattr(obj, key)) + continue + # Try relationship hop by matching the column's table/selectable + table = getattr(col, "table", None) + relname = alias_to_rel.get(table) + if relname and key: + relobj = getattr(obj, relname, None) + if relobj is not None and hasattr(relobj, key): + vals.append(getattr(relobj, key)) + continue + # Give up: unsupported expression for cursor purposes + raise ValueError("unpluckable") + return vals + + try: + first_key = _pluck_key_from_obj(items[0]) if items else None + last_key = _pluck_key_from_obj(items[-1]) if items else None + except Exception: + # If we can't derive cursor keys (e.g., ORDER BY expression/aggregate), + # disable cursors for this response rather than exploding. + first_key = None + last_key = None # Optional total that’s safe under JOINs (COUNT DISTINCT ids) total = None @@ -359,10 +396,15 @@ class CRUDService(Generic[T]): if filters: base = base.filter(*filters) # Mirror join structure for any SQL-needed relationships - for path, relationship_attr, target_alias in join_paths: + for _path, relationship_attr, target_alias in join_paths: if relationship_attr.key in sql_hops: rel_attr = cast(InstrumentedAttribute, relationship_attr) base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True) + # Also mirror any forced joins + for name in (sql_hops - {ra.key for _p, ra, _a in join_paths}): + rel_attr = cast(InstrumentedAttribute, getattr(root_alias, name)) + base = base.join(rel_attr, isouter=True) + total = session.query(func.count()).select_from( base.order_by(None).distinct().subquery() ).scalar() or 0 @@ -444,8 +486,8 @@ class CRUDService(Generic[T]): # Decide which relationship paths are needed for SQL vs display-only # For get(), there is no ORDER BY; only filters might force SQL use. - sql_paths = _paths_needed_for_sql(order_by=None, filters=filters, join_paths=join_paths) - proj_paths = _paths_from_fields(req_fields) + sql_hops = _paths_needed_for_sql(order_by=None, filters=filters, join_paths=join_paths) + proj_hops = _paths_from_fields(req_fields) # Root column projection only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] @@ -454,15 +496,15 @@ class CRUDService(Generic[T]): # Relationship handling per path: avoid loader strategy conflicts used_contains_eager = False - for path, relationship_attr, target_alias in join_paths: + for _path, relationship_attr, target_alias in join_paths: rel_attr = cast(InstrumentedAttribute, relationship_attr) - ptuple = tuple(path) - if ptuple in sql_paths: + name = relationship_attr.key + if name in sql_hops: # Needed in WHERE: join + hydrate from the join query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.options(contains_eager(rel_attr, alias=target_alias)) used_contains_eager = True - elif ptuple in proj_paths: + elif name in proj_hops: # Display-only: bulk-load efficiently query = query.options(selectinload(rel_attr)) else: @@ -534,8 +576,9 @@ class CRUDService(Generic[T]): query = query.filter(*filters) # Determine which relationship paths are needed for SQL vs display-only - sql_paths = _paths_needed_for_sql(order_by, filters, join_paths) - proj_paths = _paths_from_fields(req_fields) + sql_hops = _paths_needed_for_sql(order_by, filters, join_paths) + sql_hops |= _hops_from_sort(params) # ensure sort-driven joins exist + proj_hops = _paths_from_fields(req_fields) # Root column projection only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] @@ -544,20 +587,30 @@ class CRUDService(Generic[T]): # Relationship handling per path used_contains_eager = False - for path, relationship_attr, target_alias in join_paths: + joined_names: set[str] = set() + + for _path, relationship_attr, target_alias in join_paths: rel_attr = cast(InstrumentedAttribute, relationship_attr) - ptuple = tuple(path) - if ptuple in sql_paths: + name = relationship_attr.key + if name in sql_hops: # Needed for WHERE/ORDER BY: join + hydrate from the join query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.options(contains_eager(rel_attr, alias=target_alias)) used_contains_eager = True - elif ptuple in proj_paths: + joined_names.add(name) + elif name in proj_hops: # Display-only: no join, bulk-load efficiently query = query.options(selectinload(rel_attr)) - else: - # Not needed at all; do nothing - pass + joined_names.add(name) + + # Force-join any SQL-needed relationships that weren't in join_paths + missing_sql = sql_hops - joined_names + for name in missing_sql: + rel_attr = cast(InstrumentedAttribute, getattr(root_alias, name)) + query = query.join(rel_attr, isouter=True) + query = query.options(contains_eager(rel_attr)) + used_contains_eager = True + joined_names.add(name) # MSSQL requires ORDER BY when OFFSET is used (SQLA uses OFFSET for limit/offset) paginating = (limit is not None) or (offset is not None and offset != 0) @@ -617,6 +670,7 @@ class CRUDService(Generic[T]): return rows + def create(self, data: dict, actor=None) -> T: session = self.session obj = self.model(**data) @@ -627,7 +681,7 @@ class CRUDService(Generic[T]): def update(self, id: int, data: dict, actor=None) -> T: session = self.session - obj = self.get(id) + obj = session.get(self.model, id) if not obj: raise ValueError(f"{self.model.__name__} with ID {id} not found.") valid_fields = {c.name for c in self.model.__table__.columns} diff --git a/inventory/routes/listing.py b/inventory/routes/listing.py index f5b474e..6f529ca 100644 --- a/inventory/routes/listing.py +++ b/inventory/routes/listing.py @@ -17,6 +17,14 @@ def init_listing_routes(app): if cls is None: abort(404) + # read query args + limit = int(request.args.get("limit", 15)) + sort = request.args.get("sort") # <- capture sort from URL + cursor = request.args.get("cursor") + # your decode returns (key, _desc, backward) in this project + key, _desc, backward = decode_cursor(cursor) + + # base spec per model spec = {} columns = [] row_classes = [] @@ -42,7 +50,8 @@ def init_listing_routes(app): {"field": "model"}, {"field": "device_type.description", "label": "Device Type"}, {"field": "condition"}, - {"field": "owner.label", "label": "Contact", "link": {"endpoint": "entry.entry", "params": {"id": "{owner.id}", "model": "user"}}}, + {"field": "owner.label", "label": "Contact", + "link": {"endpoint": "entry.entry", "params": {"id": "{owner.id}", "model": "user"}}}, {"field": "location.label", "label": "Room"}, ] elif model.lower() == 'user': @@ -54,12 +63,13 @@ def init_listing_routes(app): "robot.overlord", "staff", "active", - ], "sort": "first_name,last_name"} + ], "sort": "first_name,last_name"} # default for users columns = [ {"field": "label", "label": "Full Name"}, {"field": "last_name"}, {"field": "first_name"}, - {"field": "supervisor.label", "label": "Supervisor", "link": {"endpoint": "entry.entry", "params": {"id": "{supervisor.id}", "model": "user"}}}, + {"field": "supervisor.label", "label": "Supervisor", + "link": {"endpoint": "entry.entry", "params": {"id": "{supervisor.id}", "model": "user"}}}, {"field": "staff", "format": "yesno"}, {"field": "active", "format": "yesno"}, ] @@ -79,8 +89,10 @@ def init_listing_routes(app): "complete", ]} columns = [ - {"field": "work_item.label", "label": "Work Item", "link": {"endpoint": "entry.entry", "params": {"id": "{work_item.id}", "model": "inventory"}}}, - {"field": "contact.label", "label": "Contact", "link": {"endpoint": "entry.entry", "params": {"id": "{contact.id}", "model": "user"}}}, + {"field": "work_item.label", "label": "Work Item", + "link": {"endpoint": "entry.entry", "params": {"id": "{work_item.id}", "model": "inventory"}}}, + {"field": "contact.label", "label": "Contact", + "link": {"endpoint": "entry.entry", "params": {"id": "{contact.id}", "model": "user"}}}, {"field": "start_time", "format": "datetime"}, {"field": "end_time", "format": "datetime"}, {"field": "complete", "format": "yesno"}, @@ -89,19 +101,27 @@ def init_listing_routes(app): {"when": {"field": "complete", "is": True}, "class": "table-success"}, {"when": {"field": "complete", "is": False}, "class": "table-danger"} ] - limit = int(request.args.get("limit", 15)) - cursor = request.args.get("cursor") - key, _desc, backward = decode_cursor(cursor) + + # overlay URL-provided sort if present + if sort: + spec["sort"] = sort service = crudkit.crud.get_service(cls) + # include limit and go window = service.seek_window(spec | {"limit": limit}, key=key, backward=backward, include_total=True) - table = render_table(window.items, columns=columns, opts={"object_class": model, "row_classes": row_classes}) - return render_template("listing.html", model=model, table=table, pagination={ + table = render_table(window.items, columns=columns, + opts={"object_class": model, "row_classes": row_classes}) + + # pass sort through so templates can preserve it on pager links, if they care + pagination_ctx = { "limit": window.limit, "total": window.total, "next_cursor": encode_cursor(window.last_key, list(window.order.desc), backward=False), "prev_cursor": encode_cursor(window.first_key, list(window.order.desc), backward=True), - }) + "sort": sort or spec.get("sort") # expose current sort to the template + } + + return render_template("listing.html", model=model, table=table, pagination=pagination_ctx) app.register_blueprint(bp_listing) \ No newline at end of file