diff --git a/inventory/models/users.py b/inventory/models/users.py index d1dbe9b..36a23e2 100644 --- a/inventory/models/users.py +++ b/inventory/models/users.py @@ -32,6 +32,7 @@ class User(db.Model, ImageAttachable): image: Mapped[Optional['Image']] = relationship('Image', back_populates='user', passive_deletes=True) ui_eagerload = tuple() + ui_order_cols = ('first_name', 'last_name',) @property def identifier(self) -> str: diff --git a/inventory/ui/blueprint.py b/inventory/ui/blueprint.py index c82baee..0ed41a8 100644 --- a/inventory/ui/blueprint.py +++ b/inventory/ui/blueprint.py @@ -1,5 +1,7 @@ from flask import Blueprint, request, render_template, jsonify, abort +from sqlalchemy.engine import ScalarResult from sqlalchemy.exc import IntegrityError +from sqlalchemy.sql import Select from typing import Any, Optional, List, cast, Type, Iterable from .defaults import ( @@ -45,21 +47,49 @@ def call(Model: type, name: str, *args: Any, **kwargs: Any) -> Any: def list_items(model_name): Model = get_model_class(model_name) text = (request.args.get("q") or "").strip() or None + limit_param = request.args.get("limit") - limit: int | None = None if limit_param in (None, "", "0", "-1") else min(int(limit_param), 500) + # 0 / -1 / blank => unlimited (pass 0) + if limit_param in (None, "", "0", "-1"): + effective_limit = 0 + else: + effective_limit = min(int(limit_param), 500) + offset = int(request.args.get("offset", 0)) view = (request.args.get("view") or "json").strip() - # Build kwargs so we only include 'limit' when it's an int - qkwargs: dict[str, Any] = {"text": text, "offset": offset} - if limit is not None: - qkwargs["limit"] = limit + sort = (request.args.get("sort") or "").strip() or None + direction = (request.args.get("dir") or request.args.get("direction") or "asc").lower() + if direction not in ("asc", "desc"): + direction = "asc" - rows_iter: Iterable[Any] = ( - call(Model, "ui_query", db.session, **qkwargs) - or default_query(db.session, Model, **qkwargs) - ) - rows = list(rows_iter) + qkwargs: dict[str, Any] = { + "text": text, + "limit": effective_limit, + "offset": offset, + "sort": sort, + "direction": direction, + } + + # Prefer per-model override. Contract: return list[Model] OR a Select (SA 2.x). + rows_any: Any = call(Model, "ui_query", db.session, **qkwargs) + if rows_any is None: + rows = default_query(db.session, Model, **qkwargs) + elif isinstance(rows_any, list): + rows = rows_any + elif isinstance(rows_any, Select): + rows = list(cast(ScalarResult[Any], db.session.execute(rows_any).scalars())) + else: + # If someone returns a Result or other iterable of models + try: + # Try SQLAlchemy Result-like + scalars = getattr(rows_any, "scalars", None) + if callable(scalars): + rows = list(cast(ScalarResult[Any], scalars())) + else: + rows = list(rows_any) + except TypeError: + rows = [rows_any] items = [ (call(Model, "ui_serialize", r, view=view) or default_serialize(Model, r, view=view)) diff --git a/inventory/ui/defaults.py b/inventory/ui/defaults.py index 6661fd4..d01e1c9 100644 --- a/inventory/ui/defaults.py +++ b/inventory/ui/defaults.py @@ -1,5 +1,8 @@ -from sqlalchemy import select, or_ +from sqlalchemy import select, asc as sa_asc, desc as sa_desc from sqlalchemy.inspection import inspect +from sqlalchemy.sql import Select +from sqlalchemy.orm import Query +from typing import Any, Optional, cast PREFERRED_LABELS = ("identifier", "name", "first_name", "last_name", "description") @@ -25,46 +28,57 @@ def infer_label_attr(Model): return a raise RuntimeError(f"No label-like mapped column on {Model.__name__} (tried {PREFERRED_LABELS})") -def default_query(session, Model, *, text=None, limit=100, offset=0, filters=None, order=None): - label_name = infer_label_attr(Model) - label_col = _mapped_column(Model, label_name) # guaranteed not None now +def default_query( + session, + Model, + *, + text: Optional[str] = None, + limit: int = 0, + offset: int = 0, + sort: Optional[str] = None, + direction: str = "asc", +) -> list[Any]: + """ + SA 2.x ONLY. Returns list[Model]. - stmt = select(Model) + Hooks: + - ui_search(stmt: Select, text: str) -> Select + - ui_sort(stmt: Select, sort: str, direction: str) -> Select + - ui_order_cols: tuple[str, ...] # default ordering columns + """ + stmt: Select[Any] = select(Model) - # Eager loads if class defines them (expects loader options like selectinload(...)) - for opt in getattr(Model, "ui_eagerload", ()) or (): - stmt = stmt.options(opt) + # Optional per-model search hook + ui_search = getattr(Model, "ui_search", None) + if callable(ui_search) and text: + stmt = cast(Select[Any], ui_search(stmt, text)) - # Text search across mapped columns only - if text: - cols = getattr(Model, "ui_search_cols", None) or (label_name,) - mapped = [ _mapped_column(Model, c) for c in cols ] - mapped = [ c for c in mapped if c is not None ] - if mapped: - stmt = stmt.where(or_(*[ c.ilike(f"%{text}%") for c in mapped ])) - - # Filters (exact-match) across mapped columns only - if filters: - for k, v in filters.items(): - if v is None: - continue - col = _mapped_column(Model, k) + # Sorting priority: + # 1. explicit sort param + # 2. per-model ui_sort hook + # 3. per-model ui_order_cols default ordering + if sort: + ui_sort = getattr(Model, "ui_sort", None) + if callable(ui_sort): + stmt = cast(Select[Any], ui_sort(stmt, sort, direction)) + else: + col = getattr(Model, sort, None) if col is not None: - stmt = stmt.where(col == v) + stmt = stmt.order_by(sa_desc(col) if direction == "desc" else sa_asc(col)) + else: + order_cols = getattr(Model, "ui_order_cols", ()) + if order_cols: + for colname in order_cols: + col = getattr(Model, colname, None) + if col is not None: + stmt = stmt.order_by(sa_asc(col)) - # Order by mapped columns (fallback to label) - order_cols = order or getattr(Model, "ui_order_cols", None) or (label_name,) - for c in order_cols: - col = _mapped_column(Model, c) - if col is not None: - stmt = stmt.order_by(col) - - # stmt = stmt.limit(limit).offset(offset) - if limit is not None: - stmt = stmt.limit(limit) if offset: stmt = stmt.offset(offset) - return session.execute(stmt).scalars().all() + if limit > 0: + stmt = stmt.limit(limit) + + return list(session.execute(stmt).scalars().all()) def default_create(session, Model, payload): label = infer_label_attr(Model) @@ -73,16 +87,6 @@ def default_create(session, Model, payload): session.commit() return obj -# def default_update(session, Model, id_, payload): -# obj = session.get(Model, id_) -# if not obj: -# return None -# label = infer_label_attr(Model) -# if (nv := payload.get(label) or payload.get("name")): -# setattr(obj, label, nv) -# session.commit() -# return obj - def default_update(session, Model, id_, payload): obj = session.get(Model, id_) if not obj: