Refactor user route: remove eager loading from list_users function and enhance default_select to support eager loading with skip relations

This commit is contained in:
Yaro Kasear 2025-08-25 10:11:18 -05:00
parent 91e1e5051a
commit 30ec29d497
3 changed files with 124 additions and 58 deletions

View file

@ -12,8 +12,6 @@ from ..models import User, Room, Inventory, WorkLog
@main.route("/users") @main.route("/users")
def list_users(): def list_users():
query = eager_load_user_relationships(db.session.query(User)).order_by(User.last_name, User.first_name)
users = query.all()
return render_template( return render_template(
'table.html', 'table.html',
header = user_headers, header = user_headers,

View file

@ -3,7 +3,7 @@ from collections import defaultdict
from flask import Blueprint, request, render_template, jsonify, abort, make_response from flask import Blueprint, request, render_template, jsonify, abort, make_response
from sqlalchemy.engine import ScalarResult from sqlalchemy.engine import ScalarResult
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import class_mapper, load_only, selectinload from sqlalchemy.orm import class_mapper, load_only, selectinload, joinedload, Load
from sqlalchemy.sql import Select from sqlalchemy.sql import Select
from typing import Any, List, cast, Iterable, Tuple, Set, Dict from typing import Any, List, cast, Iterable, Tuple, Set, Dict
@ -15,6 +15,66 @@ from .. import db
bp = Blueprint("ui", __name__, url_prefix="/ui") bp = Blueprint("ui", __name__, url_prefix="/ui")
from sqlalchemy.orm import Load
def _option_targets_rel(opt: Load, Model, rel_name: str) -> bool:
"""
Return True if this Load option targets Model.rel_name at its root path.
Works for joinedload/selectinload/subqueryload options.
"""
try:
# opt.path is a PathRegistry; .path is a tuple of (mapper, prop, mapper, prop, ...)
path = tuple(getattr(opt, "path", ()).path) # type: ignore[attr-defined]
except Exception:
return False
if not path:
return False
# We only care about the first hop: (Mapper[Model], RelationshipProperty(rel_name))
if len(path) < 2:
return False
first_mapper, first_prop = path[0], path[1]
try:
is_model = first_mapper.class_ is Model # type: ignore[attr-defined]
is_rel = getattr(first_prop, "key", "") == rel_name
return bool(is_model and is_rel)
except Exception:
return False
def _has_loader_for(stmt: Select, Model, rel_name: str) -> bool:
"""
True if stmt already has any loader option configured for Model.rel_name.
"""
opts = getattr(stmt, "_with_options", ()) # SQLAlchemy stores Load options here
for opt in opts:
if isinstance(opt, Load) and _option_targets_rel(opt, Model, rel_name):
return True
return False
def _strategy_for_rel_attr(rel_attr) -> type[Load] | None:
# rel_attr is an InstrumentedAttribute (Model.foo)
prop = getattr(rel_attr, "property", None)
lazy = getattr(prop, "lazy", None)
if lazy in ("joined", "subquery"):
return joinedload
if lazy == "selectin":
return selectinload
# default if mapper left it None or something exotic like 'raise'
return selectinload
def apply_model_default_eager(stmt: Select, Model, skip_rels: Set[str]) -> Select:
# mapper.relationships yields RelationshipProperty objects
mapper = class_mapper(Model)
for prop in mapper.relationships:
if prop.key in skip_rels:
continue
lazy = getattr(prop, "lazy", None)
if lazy in ("joined", "subquery"):
stmt = stmt.options(joinedload(getattr(Model, prop.key)))
elif lazy == "selectin":
stmt = stmt.options(selectinload(getattr(Model, prop.key)))
# else: leave it alone (noload/raise/dynamic/etc.)
return stmt
def split_fields(Model, fields: Iterable[str]) -> Tuple[Set[str], Dict[str, Set[str]]]: def split_fields(Model, fields: Iterable[str]) -> Tuple[Set[str], Dict[str, Set[str]]]:
""" """
Split requested fields into base model columns and relation->attr sets. Split requested fields into base model columns and relation->attr sets.
@ -53,33 +113,50 @@ def _load_only_existing(Model, names: Set[str]):
return cols return cols
def apply_field_loaders(stmt: Select, Model, fields: Iterable[str]) -> Select: def apply_field_loaders(stmt: Select, Model, fields: Iterable[str]) -> Select:
"""
Given a base Select[Model] and requested fields, attach loader options:
- load_only(...) for base scalar columns
- selectinload(Model.rel).options(load_only(...)) for each requested relation
Only single-depth "rel.attr" is supported, which is exactly what youre using.
"""
base_cols, rel_cols = split_fields(Model, fields) base_cols, rel_cols = split_fields(Model, fields)
# Restrict base columns if any were explicitly requested
base_only = _load_only_existing(Model, base_cols) base_only = _load_only_existing(Model, base_cols)
if base_only: if base_only:
stmt = stmt.options(load_only(*base_only)) stmt = stmt.options(load_only(*base_only))
# Relations: selectinload each requested relation and trim its columns
for rel_name, attrs in rel_cols.items(): for rel_name, attrs in rel_cols.items():
if not hasattr(Model, rel_name): if not hasattr(Model, rel_name):
continue continue
# If someone already attached a loader for this relation, don't add another
if _has_loader_for(stmt, Model, rel_name):
# still allow trimming columns on the related entity if we can
rel_attr = getattr(Model, rel_name) rel_attr = getattr(Model, rel_name)
try: try:
target_cls = rel_attr.property.mapper.class_ target_cls = rel_attr.property.mapper.class_
except Exception: except Exception:
continue # not a relationship; ignore continue
rel_only = _load_only_existing(target_cls, attrs)
if rel_only:
# attach a Load that only applies load_only to that path,
# without picking a different strategy
# This relies on SQLA merging load_only onto existing Load for the same path.
stmt = stmt.options(
getattr(Load(Model), rel_name).load_only(*rel_only)
)
continue
opt = selectinload(rel_attr) # Otherwise choose a strategy and add it
rel_attr = getattr(Model, rel_name)
strategy = _strategy_for_rel_attr(rel_attr)
if not strategy:
continue
opt = strategy(rel_attr)
# Trim columns on the related entity if requested
try:
target_cls = rel_attr.property.mapper.class_
except Exception:
continue
rel_only = _load_only_existing(target_cls, attrs) rel_only = _load_only_existing(target_cls, attrs)
if rel_only: if rel_only:
opt = opt.options(load_only(*rel_only)) opt = opt.options(load_only(*rel_only))
stmt = stmt.options(opt) stmt = stmt.options(opt)
return stmt return stmt
@ -159,26 +236,38 @@ def list_items(model_name):
qkwargs: dict[str, Any] = { qkwargs: dict[str, Any] = {
"text": text, "text": text,
# these are irrelevant for stmt-building; keep for ui_query compatibility
"limit": 0 if unlimited else per_page, "limit": 0 if unlimited else per_page,
"offset": 0 if unlimited else (page - 1) * per_page if per_page else 0, "offset": 0 if unlimited else (page - 1) * per_page if per_page else 0,
"sort": sort, "sort": sort,
"direction": direction, "direction": direction,
} }
# 1) Try per-model override first # compute requested relations once
base_cols, rel_cols = split_fields(Model, fields)
skip_rels = set(rel_cols.keys()) if fields else set()
# 1) per-model override first
rows_any: Any = call(Model, "ui_query", db.session, **qkwargs) rows_any: Any = call(Model, "ui_query", db.session, **qkwargs)
stmt: Select | None = None stmt: Select | None = None
total: int total: int
if rows_any is None: if rows_any is None:
# 2) default: build a Select stmt = default_select(Model, text=text, sort=sort, direction=direction, eager=False)
stmt = default_select(Model, text=text, sort=sort, direction=direction)
if not fields:
stmt = apply_model_default_eager(stmt, Model, skip_rels=set())
else:
stmt = apply_field_loaders(stmt, Model, fields)
stmt = ensure_order_by(stmt, Model, sort=sort, direction=direction)
elif isinstance(rows_any, Select): elif isinstance(rows_any, Select):
stmt = rows_any # TRUST ui_query; don't add loaders on top
stmt = ensure_order_by(rows_any, Model, sort=sort, direction=direction)
elif isinstance(rows_any, list): elif isinstance(rows_any, list):
# Someone returned a materialized list. Paginate in Python. # materialized list; paginate in python
total = len(rows_any) total = len(rows_any)
if unlimited: if unlimited:
rows = rows_any rows = rows_any
@ -186,21 +275,15 @@ def list_items(model_name):
start = (page - 1) * per_page start = (page - 1) * per_page
end = start + per_page end = start + per_page
rows = rows_any[start:end] rows = rows_any[start:end]
# serialize and return at the bottom like usual
else: else:
# SQLAlchemy Result-like? # SQLAlchemy Result-like or generic iterable
scalars = getattr(rows_any, "scalars", None) scalars = getattr(rows_any, "scalars", None)
if callable(scalars): if callable(scalars):
# execute now, then paginate in Python
all_rows = list(cast(ScalarResult[Any], scalars())) all_rows = list(cast(ScalarResult[Any], scalars()))
total = len(all_rows) total = len(all_rows)
if unlimited: rows = all_rows if unlimited else all_rows[(page - 1) * per_page : (page * per_page)]
rows = all_rows
else: else:
start = (page - 1) * per_page
end = start + per_page
rows = all_rows[start:end]
else:
# single object or generic iterable
try: try:
all_rows = list(rows_any) all_rows = list(rows_any)
total = len(all_rows) total = len(all_rows)
@ -209,34 +292,15 @@ def list_items(model_name):
total = 1 total = 1
rows = [rows_any] rows = [rows_any]
# If we have a real Select, use db.paginate for proper COUNT and slicing # If we have a real Select, run it once (unlimited) or paginate once.
if stmt is not None: if stmt is not None:
if unlimited: if unlimited:
rows = list(db.session.execute(stmt).scalars()) rows = list(db.session.execute(stmt).scalars())
total = count_for(db.session, stmt) total = count_for(db.session, stmt)
else:
stmt = default_select(Model, text=text, sort=sort, direction=direction)
if fields:
stmt = apply_field_loaders(stmt, Model, fields) # the helper I gave you earlier
stmt = ensure_order_by(stmt, Model, sort=sort, direction=direction)
if unlimited:
rows = list(db.session.execute(stmt).scalars())
total = count_for(db.session, stmt) # uses session, not stmt.bind
else: else:
pagination = db.paginate(stmt, page=page, per_page=per_page, error_out=False) pagination = db.paginate(stmt, page=page, per_page=per_page, error_out=False)
rows = pagination.items rows = pagination.items
total = pagination.total total = pagination.total
stmt = ensure_order_by(stmt, Model, sort=sort, direction=direction)
pagination = db.paginate(
stmt,
page=page,
per_page=per_page,
error_out=False
)
rows = pagination.items
total = pagination.total
# Serialize # Serialize
if fields: if fields:
@ -267,14 +331,6 @@ def list_items(model_name):
if want_list: if want_list:
return render_template("fragments/_list_fragment.html", options=items) return render_template("fragments/_list_fragment.html", options=items)
if want_table: if want_table:
# return render_template("fragments/_table_data_fragment.html",
# rows=items,
# model_name=model_name,
# total=total,
# page=page,
# per_page=per_page,
# pages=(0 if unlimited else ((total + per_page - 1) // per_page)),
# )
resp = make_response(render_template("fragments/_table_data_fragment.html", resp = make_response(render_template("fragments/_table_data_fragment.html",
rows=items, model_name=model_name)) rows=items, model_name=model_name))
resp.headers['X-Total'] = str(total) resp.headers['X-Total'] = str(total)

View file

@ -1,5 +1,6 @@
from sqlalchemy import select, asc as sa_asc, desc as sa_desc, or_, func from sqlalchemy import select, asc as sa_asc, desc as sa_desc, or_, func
from sqlalchemy.inspection import inspect from sqlalchemy.inspection import inspect
from sqlalchemy.orm import class_mapper, joinedload, selectinload
from sqlalchemy.sql import Select from sqlalchemy.sql import Select
from sqlalchemy.sql.sqltypes import String, Unicode, Text from sqlalchemy.sql.sqltypes import String, Unicode, Text
from typing import Any, Optional, cast, Iterable from typing import Any, Optional, cast, Iterable
@ -76,6 +77,8 @@ def default_select(
text: Optional[str] = None, text: Optional[str] = None,
sort: Optional[str] = None, sort: Optional[str] = None,
direction: str = "asc", direction: str = "asc",
eager = False,
skip_rels=frozenset()
) -> Select[Any]: ) -> Select[Any]:
stmt: Select[Any] = select(Model) stmt: Select[Any] = select(Model)
@ -119,6 +122,15 @@ def default_select(
for opt in opts: for opt in opts:
stmt = stmt.options(opt) stmt = stmt.options(opt)
if eager:
for prop in class_mapper(Model).relationships:
if prop.key in skip_rels:
continue
lazy = getattr(prop, "lazy", None)
if lazy in ("joined", "subquery"):
stmt = stmt.options(joinedload(getattr(Model, prop.key)))
elif lazy == "selectin":
stmt = stmt.options(selectinload(getattr(Model, prop.key)))
return stmt return stmt
def default_query( def default_query(