Refactor user route: remove eager loading from list_users function and enhance default_select to support eager loading with skip relations

This commit is contained in:
Yaro Kasear 2025-08-25 10:11:18 -05:00
parent 91e1e5051a
commit 30ec29d497
3 changed files with 124 additions and 58 deletions

View file

@ -12,8 +12,6 @@ from ..models import User, Room, Inventory, WorkLog
@main.route("/users")
def list_users():
query = eager_load_user_relationships(db.session.query(User)).order_by(User.last_name, User.first_name)
users = query.all()
return render_template(
'table.html',
header = user_headers,

View file

@ -3,7 +3,7 @@ from collections import defaultdict
from flask import Blueprint, request, render_template, jsonify, abort, make_response
from sqlalchemy.engine import ScalarResult
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import class_mapper, load_only, selectinload
from sqlalchemy.orm import class_mapper, load_only, selectinload, joinedload, Load
from sqlalchemy.sql import Select
from typing import Any, List, cast, Iterable, Tuple, Set, Dict
@ -15,6 +15,66 @@ from .. import db
bp = Blueprint("ui", __name__, url_prefix="/ui")
from sqlalchemy.orm import Load
def _option_targets_rel(opt: Load, Model, rel_name: str) -> bool:
"""
Return True if this Load option targets Model.rel_name at its root path.
Works for joinedload/selectinload/subqueryload options.
"""
try:
# opt.path is a PathRegistry; .path is a tuple of (mapper, prop, mapper, prop, ...)
path = tuple(getattr(opt, "path", ()).path) # type: ignore[attr-defined]
except Exception:
return False
if not path:
return False
# We only care about the first hop: (Mapper[Model], RelationshipProperty(rel_name))
if len(path) < 2:
return False
first_mapper, first_prop = path[0], path[1]
try:
is_model = first_mapper.class_ is Model # type: ignore[attr-defined]
is_rel = getattr(first_prop, "key", "") == rel_name
return bool(is_model and is_rel)
except Exception:
return False
def _has_loader_for(stmt: Select, Model, rel_name: str) -> bool:
"""
True if stmt already has any loader option configured for Model.rel_name.
"""
opts = getattr(stmt, "_with_options", ()) # SQLAlchemy stores Load options here
for opt in opts:
if isinstance(opt, Load) and _option_targets_rel(opt, Model, rel_name):
return True
return False
def _strategy_for_rel_attr(rel_attr) -> type[Load] | None:
# rel_attr is an InstrumentedAttribute (Model.foo)
prop = getattr(rel_attr, "property", None)
lazy = getattr(prop, "lazy", None)
if lazy in ("joined", "subquery"):
return joinedload
if lazy == "selectin":
return selectinload
# default if mapper left it None or something exotic like 'raise'
return selectinload
def apply_model_default_eager(stmt: Select, Model, skip_rels: Set[str]) -> Select:
# mapper.relationships yields RelationshipProperty objects
mapper = class_mapper(Model)
for prop in mapper.relationships:
if prop.key in skip_rels:
continue
lazy = getattr(prop, "lazy", None)
if lazy in ("joined", "subquery"):
stmt = stmt.options(joinedload(getattr(Model, prop.key)))
elif lazy == "selectin":
stmt = stmt.options(selectinload(getattr(Model, prop.key)))
# else: leave it alone (noload/raise/dynamic/etc.)
return stmt
def split_fields(Model, fields: Iterable[str]) -> Tuple[Set[str], Dict[str, Set[str]]]:
"""
Split requested fields into base model columns and relation->attr sets.
@ -53,33 +113,50 @@ def _load_only_existing(Model, names: Set[str]):
return cols
def apply_field_loaders(stmt: Select, Model, fields: Iterable[str]) -> Select:
"""
Given a base Select[Model] and requested fields, attach loader options:
- load_only(...) for base scalar columns
- selectinload(Model.rel).options(load_only(...)) for each requested relation
Only single-depth "rel.attr" is supported, which is exactly what youre using.
"""
base_cols, rel_cols = split_fields(Model, fields)
# Restrict base columns if any were explicitly requested
base_only = _load_only_existing(Model, base_cols)
if base_only:
stmt = stmt.options(load_only(*base_only))
# Relations: selectinload each requested relation and trim its columns
for rel_name, attrs in rel_cols.items():
if not hasattr(Model, rel_name):
continue
# If someone already attached a loader for this relation, don't add another
if _has_loader_for(stmt, Model, rel_name):
# still allow trimming columns on the related entity if we can
rel_attr = getattr(Model, rel_name)
try:
target_cls = rel_attr.property.mapper.class_
except Exception:
continue # not a relationship; ignore
continue
rel_only = _load_only_existing(target_cls, attrs)
if rel_only:
# attach a Load that only applies load_only to that path,
# without picking a different strategy
# This relies on SQLA merging load_only onto existing Load for the same path.
stmt = stmt.options(
getattr(Load(Model), rel_name).load_only(*rel_only)
)
continue
opt = selectinload(rel_attr)
# Otherwise choose a strategy and add it
rel_attr = getattr(Model, rel_name)
strategy = _strategy_for_rel_attr(rel_attr)
if not strategy:
continue
opt = strategy(rel_attr)
# Trim columns on the related entity if requested
try:
target_cls = rel_attr.property.mapper.class_
except Exception:
continue
rel_only = _load_only_existing(target_cls, attrs)
if rel_only:
opt = opt.options(load_only(*rel_only))
stmt = stmt.options(opt)
return stmt
@ -159,26 +236,38 @@ def list_items(model_name):
qkwargs: dict[str, Any] = {
"text": text,
# these are irrelevant for stmt-building; keep for ui_query compatibility
"limit": 0 if unlimited else per_page,
"offset": 0 if unlimited else (page - 1) * per_page if per_page else 0,
"sort": sort,
"direction": direction,
}
# 1) Try per-model override first
# compute requested relations once
base_cols, rel_cols = split_fields(Model, fields)
skip_rels = set(rel_cols.keys()) if fields else set()
# 1) per-model override first
rows_any: Any = call(Model, "ui_query", db.session, **qkwargs)
stmt: Select | None = None
total: int
if rows_any is None:
# 2) default: build a Select
stmt = default_select(Model, text=text, sort=sort, direction=direction)
stmt = default_select(Model, text=text, sort=sort, direction=direction, eager=False)
if not fields:
stmt = apply_model_default_eager(stmt, Model, skip_rels=set())
else:
stmt = apply_field_loaders(stmt, Model, fields)
stmt = ensure_order_by(stmt, Model, sort=sort, direction=direction)
elif isinstance(rows_any, Select):
stmt = rows_any
# TRUST ui_query; don't add loaders on top
stmt = ensure_order_by(rows_any, Model, sort=sort, direction=direction)
elif isinstance(rows_any, list):
# Someone returned a materialized list. Paginate in Python.
# materialized list; paginate in python
total = len(rows_any)
if unlimited:
rows = rows_any
@ -186,21 +275,15 @@ def list_items(model_name):
start = (page - 1) * per_page
end = start + per_page
rows = rows_any[start:end]
# serialize and return at the bottom like usual
else:
# SQLAlchemy Result-like?
# SQLAlchemy Result-like or generic iterable
scalars = getattr(rows_any, "scalars", None)
if callable(scalars):
# execute now, then paginate in Python
all_rows = list(cast(ScalarResult[Any], scalars()))
total = len(all_rows)
if unlimited:
rows = all_rows
rows = all_rows if unlimited else all_rows[(page - 1) * per_page : (page * per_page)]
else:
start = (page - 1) * per_page
end = start + per_page
rows = all_rows[start:end]
else:
# single object or generic iterable
try:
all_rows = list(rows_any)
total = len(all_rows)
@ -209,34 +292,15 @@ def list_items(model_name):
total = 1
rows = [rows_any]
# If we have a real Select, use db.paginate for proper COUNT and slicing
# If we have a real Select, run it once (unlimited) or paginate once.
if stmt is not None:
if unlimited:
rows = list(db.session.execute(stmt).scalars())
total = count_for(db.session, stmt)
else:
stmt = default_select(Model, text=text, sort=sort, direction=direction)
if fields:
stmt = apply_field_loaders(stmt, Model, fields) # the helper I gave you earlier
stmt = ensure_order_by(stmt, Model, sort=sort, direction=direction)
if unlimited:
rows = list(db.session.execute(stmt).scalars())
total = count_for(db.session, stmt) # uses session, not stmt.bind
else:
pagination = db.paginate(stmt, page=page, per_page=per_page, error_out=False)
rows = pagination.items
total = pagination.total
stmt = ensure_order_by(stmt, Model, sort=sort, direction=direction)
pagination = db.paginate(
stmt,
page=page,
per_page=per_page,
error_out=False
)
rows = pagination.items
total = pagination.total
# Serialize
if fields:
@ -267,14 +331,6 @@ def list_items(model_name):
if want_list:
return render_template("fragments/_list_fragment.html", options=items)
if want_table:
# return render_template("fragments/_table_data_fragment.html",
# rows=items,
# model_name=model_name,
# total=total,
# page=page,
# per_page=per_page,
# pages=(0 if unlimited else ((total + per_page - 1) // per_page)),
# )
resp = make_response(render_template("fragments/_table_data_fragment.html",
rows=items, model_name=model_name))
resp.headers['X-Total'] = str(total)

View file

@ -1,5 +1,6 @@
from sqlalchemy import select, asc as sa_asc, desc as sa_desc, or_, func
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import class_mapper, joinedload, selectinload
from sqlalchemy.sql import Select
from sqlalchemy.sql.sqltypes import String, Unicode, Text
from typing import Any, Optional, cast, Iterable
@ -76,6 +77,8 @@ def default_select(
text: Optional[str] = None,
sort: Optional[str] = None,
direction: str = "asc",
eager = False,
skip_rels=frozenset()
) -> Select[Any]:
stmt: Select[Any] = select(Model)
@ -119,6 +122,15 @@ def default_select(
for opt in opts:
stmt = stmt.options(opt)
if eager:
for prop in class_mapper(Model).relationships:
if prop.key in skip_rels:
continue
lazy = getattr(prop, "lazy", None)
if lazy in ("joined", "subquery"):
stmt = stmt.options(joinedload(getattr(Model, prop.key)))
elif lazy == "selectin":
stmt = stmt.options(selectinload(getattr(Model, prop.key)))
return stmt
def default_query(