Refactor user route: remove eager loading from list_users function and enhance default_select to support eager loading with skip relations
This commit is contained in:
parent
91e1e5051a
commit
30ec29d497
3 changed files with 124 additions and 58 deletions
|
|
@ -12,8 +12,6 @@ from ..models import User, Room, Inventory, WorkLog
|
|||
|
||||
@main.route("/users")
|
||||
def list_users():
|
||||
query = eager_load_user_relationships(db.session.query(User)).order_by(User.last_name, User.first_name)
|
||||
users = query.all()
|
||||
return render_template(
|
||||
'table.html',
|
||||
header = user_headers,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from collections import defaultdict
|
|||
from flask import Blueprint, request, render_template, jsonify, abort, make_response
|
||||
from sqlalchemy.engine import ScalarResult
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import class_mapper, load_only, selectinload
|
||||
from sqlalchemy.orm import class_mapper, load_only, selectinload, joinedload, Load
|
||||
from sqlalchemy.sql import Select
|
||||
from typing import Any, List, cast, Iterable, Tuple, Set, Dict
|
||||
|
||||
|
|
@ -15,6 +15,66 @@ from .. import db
|
|||
|
||||
bp = Blueprint("ui", __name__, url_prefix="/ui")
|
||||
|
||||
from sqlalchemy.orm import Load
|
||||
|
||||
def _option_targets_rel(opt: Load, Model, rel_name: str) -> bool:
|
||||
"""
|
||||
Return True if this Load option targets Model.rel_name at its root path.
|
||||
Works for joinedload/selectinload/subqueryload options.
|
||||
"""
|
||||
try:
|
||||
# opt.path is a PathRegistry; .path is a tuple of (mapper, prop, mapper, prop, ...)
|
||||
path = tuple(getattr(opt, "path", ()).path) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
return False
|
||||
if not path:
|
||||
return False
|
||||
# We only care about the first hop: (Mapper[Model], RelationshipProperty(rel_name))
|
||||
if len(path) < 2:
|
||||
return False
|
||||
first_mapper, first_prop = path[0], path[1]
|
||||
try:
|
||||
is_model = first_mapper.class_ is Model # type: ignore[attr-defined]
|
||||
is_rel = getattr(first_prop, "key", "") == rel_name
|
||||
return bool(is_model and is_rel)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _has_loader_for(stmt: Select, Model, rel_name: str) -> bool:
|
||||
"""
|
||||
True if stmt already has any loader option configured for Model.rel_name.
|
||||
"""
|
||||
opts = getattr(stmt, "_with_options", ()) # SQLAlchemy stores Load options here
|
||||
for opt in opts:
|
||||
if isinstance(opt, Load) and _option_targets_rel(opt, Model, rel_name):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _strategy_for_rel_attr(rel_attr) -> type[Load] | None:
|
||||
# rel_attr is an InstrumentedAttribute (Model.foo)
|
||||
prop = getattr(rel_attr, "property", None)
|
||||
lazy = getattr(prop, "lazy", None)
|
||||
if lazy in ("joined", "subquery"):
|
||||
return joinedload
|
||||
if lazy == "selectin":
|
||||
return selectinload
|
||||
# default if mapper left it None or something exotic like 'raise'
|
||||
return selectinload
|
||||
|
||||
def apply_model_default_eager(stmt: Select, Model, skip_rels: Set[str]) -> Select:
|
||||
# mapper.relationships yields RelationshipProperty objects
|
||||
mapper = class_mapper(Model)
|
||||
for prop in mapper.relationships:
|
||||
if prop.key in skip_rels:
|
||||
continue
|
||||
lazy = getattr(prop, "lazy", None)
|
||||
if lazy in ("joined", "subquery"):
|
||||
stmt = stmt.options(joinedload(getattr(Model, prop.key)))
|
||||
elif lazy == "selectin":
|
||||
stmt = stmt.options(selectinload(getattr(Model, prop.key)))
|
||||
# else: leave it alone (noload/raise/dynamic/etc.)
|
||||
return stmt
|
||||
|
||||
def split_fields(Model, fields: Iterable[str]) -> Tuple[Set[str], Dict[str, Set[str]]]:
|
||||
"""
|
||||
Split requested fields into base model columns and relation->attr sets.
|
||||
|
|
@ -53,33 +113,50 @@ def _load_only_existing(Model, names: Set[str]):
|
|||
return cols
|
||||
|
||||
def apply_field_loaders(stmt: Select, Model, fields: Iterable[str]) -> Select:
|
||||
"""
|
||||
Given a base Select[Model] and requested fields, attach loader options:
|
||||
- load_only(...) for base scalar columns
|
||||
- selectinload(Model.rel).options(load_only(...)) for each requested relation
|
||||
Only single-depth "rel.attr" is supported, which is exactly what you’re using.
|
||||
"""
|
||||
base_cols, rel_cols = split_fields(Model, fields)
|
||||
|
||||
# Restrict base columns if any were explicitly requested
|
||||
base_only = _load_only_existing(Model, base_cols)
|
||||
if base_only:
|
||||
stmt = stmt.options(load_only(*base_only))
|
||||
|
||||
# Relations: selectinload each requested relation and trim its columns
|
||||
for rel_name, attrs in rel_cols.items():
|
||||
if not hasattr(Model, rel_name):
|
||||
continue
|
||||
|
||||
# If someone already attached a loader for this relation, don't add another
|
||||
if _has_loader_for(stmt, Model, rel_name):
|
||||
# still allow trimming columns on the related entity if we can
|
||||
rel_attr = getattr(Model, rel_name)
|
||||
try:
|
||||
target_cls = rel_attr.property.mapper.class_
|
||||
except Exception:
|
||||
continue # not a relationship; ignore
|
||||
continue
|
||||
rel_only = _load_only_existing(target_cls, attrs)
|
||||
if rel_only:
|
||||
# attach a Load that only applies load_only to that path,
|
||||
# without picking a different strategy
|
||||
# This relies on SQLA merging load_only onto existing Load for the same path.
|
||||
stmt = stmt.options(
|
||||
getattr(Load(Model), rel_name).load_only(*rel_only)
|
||||
)
|
||||
continue
|
||||
|
||||
opt = selectinload(rel_attr)
|
||||
# Otherwise choose a strategy and add it
|
||||
rel_attr = getattr(Model, rel_name)
|
||||
strategy = _strategy_for_rel_attr(rel_attr)
|
||||
if not strategy:
|
||||
continue
|
||||
opt = strategy(rel_attr)
|
||||
|
||||
# Trim columns on the related entity if requested
|
||||
try:
|
||||
target_cls = rel_attr.property.mapper.class_
|
||||
except Exception:
|
||||
continue
|
||||
rel_only = _load_only_existing(target_cls, attrs)
|
||||
if rel_only:
|
||||
opt = opt.options(load_only(*rel_only))
|
||||
|
||||
stmt = stmt.options(opt)
|
||||
|
||||
return stmt
|
||||
|
|
@ -159,26 +236,38 @@ def list_items(model_name):
|
|||
|
||||
qkwargs: dict[str, Any] = {
|
||||
"text": text,
|
||||
# these are irrelevant for stmt-building; keep for ui_query compatibility
|
||||
"limit": 0 if unlimited else per_page,
|
||||
"offset": 0 if unlimited else (page - 1) * per_page if per_page else 0,
|
||||
"sort": sort,
|
||||
"direction": direction,
|
||||
}
|
||||
|
||||
# 1) Try per-model override first
|
||||
# compute requested relations once
|
||||
base_cols, rel_cols = split_fields(Model, fields)
|
||||
skip_rels = set(rel_cols.keys()) if fields else set()
|
||||
|
||||
# 1) per-model override first
|
||||
rows_any: Any = call(Model, "ui_query", db.session, **qkwargs)
|
||||
|
||||
stmt: Select | None = None
|
||||
total: int
|
||||
|
||||
if rows_any is None:
|
||||
# 2) default: build a Select
|
||||
stmt = default_select(Model, text=text, sort=sort, direction=direction)
|
||||
stmt = default_select(Model, text=text, sort=sort, direction=direction, eager=False)
|
||||
|
||||
if not fields:
|
||||
stmt = apply_model_default_eager(stmt, Model, skip_rels=set())
|
||||
else:
|
||||
stmt = apply_field_loaders(stmt, Model, fields)
|
||||
|
||||
stmt = ensure_order_by(stmt, Model, sort=sort, direction=direction)
|
||||
|
||||
elif isinstance(rows_any, Select):
|
||||
stmt = rows_any
|
||||
# TRUST ui_query; don't add loaders on top
|
||||
stmt = ensure_order_by(rows_any, Model, sort=sort, direction=direction)
|
||||
|
||||
elif isinstance(rows_any, list):
|
||||
# Someone returned a materialized list. Paginate in Python.
|
||||
# materialized list; paginate in python
|
||||
total = len(rows_any)
|
||||
if unlimited:
|
||||
rows = rows_any
|
||||
|
|
@ -186,21 +275,15 @@ def list_items(model_name):
|
|||
start = (page - 1) * per_page
|
||||
end = start + per_page
|
||||
rows = rows_any[start:end]
|
||||
# serialize and return at the bottom like usual
|
||||
else:
|
||||
# SQLAlchemy Result-like?
|
||||
# SQLAlchemy Result-like or generic iterable
|
||||
scalars = getattr(rows_any, "scalars", None)
|
||||
if callable(scalars):
|
||||
# execute now, then paginate in Python
|
||||
all_rows = list(cast(ScalarResult[Any], scalars()))
|
||||
total = len(all_rows)
|
||||
if unlimited:
|
||||
rows = all_rows
|
||||
rows = all_rows if unlimited else all_rows[(page - 1) * per_page : (page * per_page)]
|
||||
else:
|
||||
start = (page - 1) * per_page
|
||||
end = start + per_page
|
||||
rows = all_rows[start:end]
|
||||
else:
|
||||
# single object or generic iterable
|
||||
try:
|
||||
all_rows = list(rows_any)
|
||||
total = len(all_rows)
|
||||
|
|
@ -209,34 +292,15 @@ def list_items(model_name):
|
|||
total = 1
|
||||
rows = [rows_any]
|
||||
|
||||
# If we have a real Select, use db.paginate for proper COUNT and slicing
|
||||
# If we have a real Select, run it once (unlimited) or paginate once.
|
||||
if stmt is not None:
|
||||
if unlimited:
|
||||
rows = list(db.session.execute(stmt).scalars())
|
||||
total = count_for(db.session, stmt)
|
||||
else:
|
||||
stmt = default_select(Model, text=text, sort=sort, direction=direction)
|
||||
if fields:
|
||||
stmt = apply_field_loaders(stmt, Model, fields) # the helper I gave you earlier
|
||||
|
||||
stmt = ensure_order_by(stmt, Model, sort=sort, direction=direction)
|
||||
|
||||
if unlimited:
|
||||
rows = list(db.session.execute(stmt).scalars())
|
||||
total = count_for(db.session, stmt) # uses session, not stmt.bind
|
||||
else:
|
||||
pagination = db.paginate(stmt, page=page, per_page=per_page, error_out=False)
|
||||
rows = pagination.items
|
||||
total = pagination.total
|
||||
stmt = ensure_order_by(stmt, Model, sort=sort, direction=direction)
|
||||
pagination = db.paginate(
|
||||
stmt,
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
error_out=False
|
||||
)
|
||||
rows = pagination.items
|
||||
total = pagination.total
|
||||
|
||||
# Serialize
|
||||
if fields:
|
||||
|
|
@ -267,14 +331,6 @@ def list_items(model_name):
|
|||
if want_list:
|
||||
return render_template("fragments/_list_fragment.html", options=items)
|
||||
if want_table:
|
||||
# return render_template("fragments/_table_data_fragment.html",
|
||||
# rows=items,
|
||||
# model_name=model_name,
|
||||
# total=total,
|
||||
# page=page,
|
||||
# per_page=per_page,
|
||||
# pages=(0 if unlimited else ((total + per_page - 1) // per_page)),
|
||||
# )
|
||||
resp = make_response(render_template("fragments/_table_data_fragment.html",
|
||||
rows=items, model_name=model_name))
|
||||
resp.headers['X-Total'] = str(total)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from sqlalchemy import select, asc as sa_asc, desc as sa_desc, or_, func
|
||||
from sqlalchemy.inspection import inspect
|
||||
from sqlalchemy.orm import class_mapper, joinedload, selectinload
|
||||
from sqlalchemy.sql import Select
|
||||
from sqlalchemy.sql.sqltypes import String, Unicode, Text
|
||||
from typing import Any, Optional, cast, Iterable
|
||||
|
|
@ -76,6 +77,8 @@ def default_select(
|
|||
text: Optional[str] = None,
|
||||
sort: Optional[str] = None,
|
||||
direction: str = "asc",
|
||||
eager = False,
|
||||
skip_rels=frozenset()
|
||||
) -> Select[Any]:
|
||||
stmt: Select[Any] = select(Model)
|
||||
|
||||
|
|
@ -119,6 +122,15 @@ def default_select(
|
|||
for opt in opts:
|
||||
stmt = stmt.options(opt)
|
||||
|
||||
if eager:
|
||||
for prop in class_mapper(Model).relationships:
|
||||
if prop.key in skip_rels:
|
||||
continue
|
||||
lazy = getattr(prop, "lazy", None)
|
||||
if lazy in ("joined", "subquery"):
|
||||
stmt = stmt.options(joinedload(getattr(Model, prop.key)))
|
||||
elif lazy == "selectin":
|
||||
stmt = stmt.options(selectinload(getattr(Model, prop.key)))
|
||||
return stmt
|
||||
|
||||
def default_query(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue