diff --git a/inventory/routes/user.py b/inventory/routes/user.py index 5dc6726..0c4cb3f 100644 --- a/inventory/routes/user.py +++ b/inventory/routes/user.py @@ -12,8 +12,6 @@ from ..models import User, Room, Inventory, WorkLog @main.route("/users") def list_users(): - query = eager_load_user_relationships(db.session.query(User)).order_by(User.last_name, User.first_name) - users = query.all() return render_template( 'table.html', header = user_headers, diff --git a/inventory/ui/blueprint.py b/inventory/ui/blueprint.py index b8db54b..683ff47 100644 --- a/inventory/ui/blueprint.py +++ b/inventory/ui/blueprint.py @@ -3,7 +3,7 @@ from collections import defaultdict from flask import Blueprint, request, render_template, jsonify, abort, make_response from sqlalchemy.engine import ScalarResult from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import class_mapper, load_only, selectinload +from sqlalchemy.orm import class_mapper, load_only, selectinload, joinedload, Load from sqlalchemy.sql import Select from typing import Any, List, cast, Iterable, Tuple, Set, Dict @@ -15,6 +15,66 @@ from .. import db bp = Blueprint("ui", __name__, url_prefix="/ui") +from sqlalchemy.orm import Load + +def _option_targets_rel(opt: Load, Model, rel_name: str) -> bool: + """ + Return True if this Load option targets Model.rel_name at its root path. + Works for joinedload/selectinload/subqueryload options. + """ + try: + # opt.path is a PathRegistry; .path is a tuple of (mapper, prop, mapper, prop, ...) + path = tuple(getattr(opt, "path", ()).path) # type: ignore[attr-defined] + except Exception: + return False + if not path: + return False + # We only care about the first hop: (Mapper[Model], RelationshipProperty(rel_name)) + if len(path) < 2: + return False + first_mapper, first_prop = path[0], path[1] + try: + is_model = first_mapper.class_ is Model # type: ignore[attr-defined] + is_rel = getattr(first_prop, "key", "") == rel_name + return bool(is_model and is_rel) + except Exception: + return False + +def _has_loader_for(stmt: Select, Model, rel_name: str) -> bool: + """ + True if stmt already has any loader option configured for Model.rel_name. + """ + opts = getattr(stmt, "_with_options", ()) # SQLAlchemy stores Load options here + for opt in opts: + if isinstance(opt, Load) and _option_targets_rel(opt, Model, rel_name): + return True + return False + +def _strategy_for_rel_attr(rel_attr) -> type[Load] | None: + # rel_attr is an InstrumentedAttribute (Model.foo) + prop = getattr(rel_attr, "property", None) + lazy = getattr(prop, "lazy", None) + if lazy in ("joined", "subquery"): + return joinedload + if lazy == "selectin": + return selectinload + # default if mapper left it None or something exotic like 'raise' + return selectinload + +def apply_model_default_eager(stmt: Select, Model, skip_rels: Set[str]) -> Select: + # mapper.relationships yields RelationshipProperty objects + mapper = class_mapper(Model) + for prop in mapper.relationships: + if prop.key in skip_rels: + continue + lazy = getattr(prop, "lazy", None) + if lazy in ("joined", "subquery"): + stmt = stmt.options(joinedload(getattr(Model, prop.key))) + elif lazy == "selectin": + stmt = stmt.options(selectinload(getattr(Model, prop.key))) + # else: leave it alone (noload/raise/dynamic/etc.) + return stmt + def split_fields(Model, fields: Iterable[str]) -> Tuple[Set[str], Dict[str, Set[str]]]: """ Split requested fields into base model columns and relation->attr sets. @@ -53,33 +113,50 @@ def _load_only_existing(Model, names: Set[str]): return cols def apply_field_loaders(stmt: Select, Model, fields: Iterable[str]) -> Select: - """ - Given a base Select[Model] and requested fields, attach loader options: - - load_only(...) for base scalar columns - - selectinload(Model.rel).options(load_only(...)) for each requested relation - Only single-depth "rel.attr" is supported, which is exactly what you’re using. - """ base_cols, rel_cols = split_fields(Model, fields) - # Restrict base columns if any were explicitly requested base_only = _load_only_existing(Model, base_cols) if base_only: stmt = stmt.options(load_only(*base_only)) - # Relations: selectinload each requested relation and trim its columns for rel_name, attrs in rel_cols.items(): if not hasattr(Model, rel_name): continue + + # If someone already attached a loader for this relation, don't add another + if _has_loader_for(stmt, Model, rel_name): + # still allow trimming columns on the related entity if we can + rel_attr = getattr(Model, rel_name) + try: + target_cls = rel_attr.property.mapper.class_ + except Exception: + continue + rel_only = _load_only_existing(target_cls, attrs) + if rel_only: + # attach a Load that only applies load_only to that path, + # without picking a different strategy + # This relies on SQLA merging load_only onto existing Load for the same path. + stmt = stmt.options( + getattr(Load(Model), rel_name).load_only(*rel_only) + ) + continue + + # Otherwise choose a strategy and add it rel_attr = getattr(Model, rel_name) + strategy = _strategy_for_rel_attr(rel_attr) + if not strategy: + continue + opt = strategy(rel_attr) + + # Trim columns on the related entity if requested try: target_cls = rel_attr.property.mapper.class_ except Exception: - continue # not a relationship; ignore - - opt = selectinload(rel_attr) + continue rel_only = _load_only_existing(target_cls, attrs) if rel_only: opt = opt.options(load_only(*rel_only)) + stmt = stmt.options(opt) return stmt @@ -159,26 +236,38 @@ def list_items(model_name): qkwargs: dict[str, Any] = { "text": text, - # these are irrelevant for stmt-building; keep for ui_query compatibility "limit": 0 if unlimited else per_page, "offset": 0 if unlimited else (page - 1) * per_page if per_page else 0, "sort": sort, "direction": direction, } - # 1) Try per-model override first + # compute requested relations once + base_cols, rel_cols = split_fields(Model, fields) + skip_rels = set(rel_cols.keys()) if fields else set() + + # 1) per-model override first rows_any: Any = call(Model, "ui_query", db.session, **qkwargs) stmt: Select | None = None total: int if rows_any is None: - # 2) default: build a Select - stmt = default_select(Model, text=text, sort=sort, direction=direction) + stmt = default_select(Model, text=text, sort=sort, direction=direction, eager=False) + + if not fields: + stmt = apply_model_default_eager(stmt, Model, skip_rels=set()) + else: + stmt = apply_field_loaders(stmt, Model, fields) + + stmt = ensure_order_by(stmt, Model, sort=sort, direction=direction) + elif isinstance(rows_any, Select): - stmt = rows_any + # TRUST ui_query; don't add loaders on top + stmt = ensure_order_by(rows_any, Model, sort=sort, direction=direction) + elif isinstance(rows_any, list): - # Someone returned a materialized list. Paginate in Python. + # materialized list; paginate in python total = len(rows_any) if unlimited: rows = rows_any @@ -186,21 +275,15 @@ def list_items(model_name): start = (page - 1) * per_page end = start + per_page rows = rows_any[start:end] + # serialize and return at the bottom like usual else: - # SQLAlchemy Result-like? + # SQLAlchemy Result-like or generic iterable scalars = getattr(rows_any, "scalars", None) if callable(scalars): - # execute now, then paginate in Python all_rows = list(cast(ScalarResult[Any], scalars())) total = len(all_rows) - if unlimited: - rows = all_rows - else: - start = (page - 1) * per_page - end = start + per_page - rows = all_rows[start:end] + rows = all_rows if unlimited else all_rows[(page - 1) * per_page : (page * per_page)] else: - # single object or generic iterable try: all_rows = list(rows_any) total = len(all_rows) @@ -209,32 +292,13 @@ def list_items(model_name): total = 1 rows = [rows_any] - # If we have a real Select, use db.paginate for proper COUNT and slicing + # If we have a real Select, run it once (unlimited) or paginate once. if stmt is not None: if unlimited: rows = list(db.session.execute(stmt).scalars()) total = count_for(db.session, stmt) else: - stmt = default_select(Model, text=text, sort=sort, direction=direction) - if fields: - stmt = apply_field_loaders(stmt, Model, fields) # the helper I gave you earlier - - stmt = ensure_order_by(stmt, Model, sort=sort, direction=direction) - - if unlimited: - rows = list(db.session.execute(stmt).scalars()) - total = count_for(db.session, stmt) # uses session, not stmt.bind - else: - pagination = db.paginate(stmt, page=page, per_page=per_page, error_out=False) - rows = pagination.items - total = pagination.total - stmt = ensure_order_by(stmt, Model, sort=sort, direction=direction) - pagination = db.paginate( - stmt, - page=page, - per_page=per_page, - error_out=False - ) + pagination = db.paginate(stmt, page=page, per_page=per_page, error_out=False) rows = pagination.items total = pagination.total @@ -267,14 +331,6 @@ def list_items(model_name): if want_list: return render_template("fragments/_list_fragment.html", options=items) if want_table: - # return render_template("fragments/_table_data_fragment.html", - # rows=items, - # model_name=model_name, - # total=total, - # page=page, - # per_page=per_page, - # pages=(0 if unlimited else ((total + per_page - 1) // per_page)), - # ) resp = make_response(render_template("fragments/_table_data_fragment.html", rows=items, model_name=model_name)) resp.headers['X-Total'] = str(total) diff --git a/inventory/ui/defaults.py b/inventory/ui/defaults.py index fa250df..ab6be3e 100644 --- a/inventory/ui/defaults.py +++ b/inventory/ui/defaults.py @@ -1,5 +1,6 @@ from sqlalchemy import select, asc as sa_asc, desc as sa_desc, or_, func from sqlalchemy.inspection import inspect +from sqlalchemy.orm import class_mapper, joinedload, selectinload from sqlalchemy.sql import Select from sqlalchemy.sql.sqltypes import String, Unicode, Text from typing import Any, Optional, cast, Iterable @@ -76,6 +77,8 @@ def default_select( text: Optional[str] = None, sort: Optional[str] = None, direction: str = "asc", + eager = False, + skip_rels=frozenset() ) -> Select[Any]: stmt: Select[Any] = select(Model) @@ -119,6 +122,15 @@ def default_select( for opt in opts: stmt = stmt.options(opt) + if eager: + for prop in class_mapper(Model).relationships: + if prop.key in skip_rels: + continue + lazy = getattr(prop, "lazy", None) + if lazy in ("joined", "subquery"): + stmt = stmt.options(joinedload(getattr(Model, prop.key))) + elif lazy == "selectin": + stmt = stmt.options(selectinload(getattr(Model, prop.key))) return stmt def default_query(