Refactor user route: remove eager loading from list_users function and enhance default_select to support eager loading with skip relations
This commit is contained in:
parent
91e1e5051a
commit
30ec29d497
3 changed files with 124 additions and 58 deletions
|
|
@ -12,8 +12,6 @@ from ..models import User, Room, Inventory, WorkLog
|
||||||
|
|
||||||
@main.route("/users")
|
@main.route("/users")
|
||||||
def list_users():
|
def list_users():
|
||||||
query = eager_load_user_relationships(db.session.query(User)).order_by(User.last_name, User.first_name)
|
|
||||||
users = query.all()
|
|
||||||
return render_template(
|
return render_template(
|
||||||
'table.html',
|
'table.html',
|
||||||
header = user_headers,
|
header = user_headers,
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from collections import defaultdict
|
||||||
from flask import Blueprint, request, render_template, jsonify, abort, make_response
|
from flask import Blueprint, request, render_template, jsonify, abort, make_response
|
||||||
from sqlalchemy.engine import ScalarResult
|
from sqlalchemy.engine import ScalarResult
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.orm import class_mapper, load_only, selectinload
|
from sqlalchemy.orm import class_mapper, load_only, selectinload, joinedload, Load
|
||||||
from sqlalchemy.sql import Select
|
from sqlalchemy.sql import Select
|
||||||
from typing import Any, List, cast, Iterable, Tuple, Set, Dict
|
from typing import Any, List, cast, Iterable, Tuple, Set, Dict
|
||||||
|
|
||||||
|
|
@ -15,6 +15,66 @@ from .. import db
|
||||||
|
|
||||||
bp = Blueprint("ui", __name__, url_prefix="/ui")
|
bp = Blueprint("ui", __name__, url_prefix="/ui")
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Load
|
||||||
|
|
||||||
|
def _option_targets_rel(opt: Load, Model, rel_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Return True if this Load option targets Model.rel_name at its root path.
|
||||||
|
Works for joinedload/selectinload/subqueryload options.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# opt.path is a PathRegistry; .path is a tuple of (mapper, prop, mapper, prop, ...)
|
||||||
|
path = tuple(getattr(opt, "path", ()).path) # type: ignore[attr-defined]
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
if not path:
|
||||||
|
return False
|
||||||
|
# We only care about the first hop: (Mapper[Model], RelationshipProperty(rel_name))
|
||||||
|
if len(path) < 2:
|
||||||
|
return False
|
||||||
|
first_mapper, first_prop = path[0], path[1]
|
||||||
|
try:
|
||||||
|
is_model = first_mapper.class_ is Model # type: ignore[attr-defined]
|
||||||
|
is_rel = getattr(first_prop, "key", "") == rel_name
|
||||||
|
return bool(is_model and is_rel)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _has_loader_for(stmt: Select, Model, rel_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
True if stmt already has any loader option configured for Model.rel_name.
|
||||||
|
"""
|
||||||
|
opts = getattr(stmt, "_with_options", ()) # SQLAlchemy stores Load options here
|
||||||
|
for opt in opts:
|
||||||
|
if isinstance(opt, Load) and _option_targets_rel(opt, Model, rel_name):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _strategy_for_rel_attr(rel_attr) -> type[Load] | None:
|
||||||
|
# rel_attr is an InstrumentedAttribute (Model.foo)
|
||||||
|
prop = getattr(rel_attr, "property", None)
|
||||||
|
lazy = getattr(prop, "lazy", None)
|
||||||
|
if lazy in ("joined", "subquery"):
|
||||||
|
return joinedload
|
||||||
|
if lazy == "selectin":
|
||||||
|
return selectinload
|
||||||
|
# default if mapper left it None or something exotic like 'raise'
|
||||||
|
return selectinload
|
||||||
|
|
||||||
|
def apply_model_default_eager(stmt: Select, Model, skip_rels: Set[str]) -> Select:
|
||||||
|
# mapper.relationships yields RelationshipProperty objects
|
||||||
|
mapper = class_mapper(Model)
|
||||||
|
for prop in mapper.relationships:
|
||||||
|
if prop.key in skip_rels:
|
||||||
|
continue
|
||||||
|
lazy = getattr(prop, "lazy", None)
|
||||||
|
if lazy in ("joined", "subquery"):
|
||||||
|
stmt = stmt.options(joinedload(getattr(Model, prop.key)))
|
||||||
|
elif lazy == "selectin":
|
||||||
|
stmt = stmt.options(selectinload(getattr(Model, prop.key)))
|
||||||
|
# else: leave it alone (noload/raise/dynamic/etc.)
|
||||||
|
return stmt
|
||||||
|
|
||||||
def split_fields(Model, fields: Iterable[str]) -> Tuple[Set[str], Dict[str, Set[str]]]:
|
def split_fields(Model, fields: Iterable[str]) -> Tuple[Set[str], Dict[str, Set[str]]]:
|
||||||
"""
|
"""
|
||||||
Split requested fields into base model columns and relation->attr sets.
|
Split requested fields into base model columns and relation->attr sets.
|
||||||
|
|
@ -53,33 +113,50 @@ def _load_only_existing(Model, names: Set[str]):
|
||||||
return cols
|
return cols
|
||||||
|
|
||||||
def apply_field_loaders(stmt: Select, Model, fields: Iterable[str]) -> Select:
|
def apply_field_loaders(stmt: Select, Model, fields: Iterable[str]) -> Select:
|
||||||
"""
|
|
||||||
Given a base Select[Model] and requested fields, attach loader options:
|
|
||||||
- load_only(...) for base scalar columns
|
|
||||||
- selectinload(Model.rel).options(load_only(...)) for each requested relation
|
|
||||||
Only single-depth "rel.attr" is supported, which is exactly what you’re using.
|
|
||||||
"""
|
|
||||||
base_cols, rel_cols = split_fields(Model, fields)
|
base_cols, rel_cols = split_fields(Model, fields)
|
||||||
|
|
||||||
# Restrict base columns if any were explicitly requested
|
|
||||||
base_only = _load_only_existing(Model, base_cols)
|
base_only = _load_only_existing(Model, base_cols)
|
||||||
if base_only:
|
if base_only:
|
||||||
stmt = stmt.options(load_only(*base_only))
|
stmt = stmt.options(load_only(*base_only))
|
||||||
|
|
||||||
# Relations: selectinload each requested relation and trim its columns
|
|
||||||
for rel_name, attrs in rel_cols.items():
|
for rel_name, attrs in rel_cols.items():
|
||||||
if not hasattr(Model, rel_name):
|
if not hasattr(Model, rel_name):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# If someone already attached a loader for this relation, don't add another
|
||||||
|
if _has_loader_for(stmt, Model, rel_name):
|
||||||
|
# still allow trimming columns on the related entity if we can
|
||||||
|
rel_attr = getattr(Model, rel_name)
|
||||||
|
try:
|
||||||
|
target_cls = rel_attr.property.mapper.class_
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
rel_only = _load_only_existing(target_cls, attrs)
|
||||||
|
if rel_only:
|
||||||
|
# attach a Load that only applies load_only to that path,
|
||||||
|
# without picking a different strategy
|
||||||
|
# This relies on SQLA merging load_only onto existing Load for the same path.
|
||||||
|
stmt = stmt.options(
|
||||||
|
getattr(Load(Model), rel_name).load_only(*rel_only)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Otherwise choose a strategy and add it
|
||||||
rel_attr = getattr(Model, rel_name)
|
rel_attr = getattr(Model, rel_name)
|
||||||
|
strategy = _strategy_for_rel_attr(rel_attr)
|
||||||
|
if not strategy:
|
||||||
|
continue
|
||||||
|
opt = strategy(rel_attr)
|
||||||
|
|
||||||
|
# Trim columns on the related entity if requested
|
||||||
try:
|
try:
|
||||||
target_cls = rel_attr.property.mapper.class_
|
target_cls = rel_attr.property.mapper.class_
|
||||||
except Exception:
|
except Exception:
|
||||||
continue # not a relationship; ignore
|
continue
|
||||||
|
|
||||||
opt = selectinload(rel_attr)
|
|
||||||
rel_only = _load_only_existing(target_cls, attrs)
|
rel_only = _load_only_existing(target_cls, attrs)
|
||||||
if rel_only:
|
if rel_only:
|
||||||
opt = opt.options(load_only(*rel_only))
|
opt = opt.options(load_only(*rel_only))
|
||||||
|
|
||||||
stmt = stmt.options(opt)
|
stmt = stmt.options(opt)
|
||||||
|
|
||||||
return stmt
|
return stmt
|
||||||
|
|
@ -159,26 +236,38 @@ def list_items(model_name):
|
||||||
|
|
||||||
qkwargs: dict[str, Any] = {
|
qkwargs: dict[str, Any] = {
|
||||||
"text": text,
|
"text": text,
|
||||||
# these are irrelevant for stmt-building; keep for ui_query compatibility
|
|
||||||
"limit": 0 if unlimited else per_page,
|
"limit": 0 if unlimited else per_page,
|
||||||
"offset": 0 if unlimited else (page - 1) * per_page if per_page else 0,
|
"offset": 0 if unlimited else (page - 1) * per_page if per_page else 0,
|
||||||
"sort": sort,
|
"sort": sort,
|
||||||
"direction": direction,
|
"direction": direction,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 1) Try per-model override first
|
# compute requested relations once
|
||||||
|
base_cols, rel_cols = split_fields(Model, fields)
|
||||||
|
skip_rels = set(rel_cols.keys()) if fields else set()
|
||||||
|
|
||||||
|
# 1) per-model override first
|
||||||
rows_any: Any = call(Model, "ui_query", db.session, **qkwargs)
|
rows_any: Any = call(Model, "ui_query", db.session, **qkwargs)
|
||||||
|
|
||||||
stmt: Select | None = None
|
stmt: Select | None = None
|
||||||
total: int
|
total: int
|
||||||
|
|
||||||
if rows_any is None:
|
if rows_any is None:
|
||||||
# 2) default: build a Select
|
stmt = default_select(Model, text=text, sort=sort, direction=direction, eager=False)
|
||||||
stmt = default_select(Model, text=text, sort=sort, direction=direction)
|
|
||||||
|
if not fields:
|
||||||
|
stmt = apply_model_default_eager(stmt, Model, skip_rels=set())
|
||||||
|
else:
|
||||||
|
stmt = apply_field_loaders(stmt, Model, fields)
|
||||||
|
|
||||||
|
stmt = ensure_order_by(stmt, Model, sort=sort, direction=direction)
|
||||||
|
|
||||||
elif isinstance(rows_any, Select):
|
elif isinstance(rows_any, Select):
|
||||||
stmt = rows_any
|
# TRUST ui_query; don't add loaders on top
|
||||||
|
stmt = ensure_order_by(rows_any, Model, sort=sort, direction=direction)
|
||||||
|
|
||||||
elif isinstance(rows_any, list):
|
elif isinstance(rows_any, list):
|
||||||
# Someone returned a materialized list. Paginate in Python.
|
# materialized list; paginate in python
|
||||||
total = len(rows_any)
|
total = len(rows_any)
|
||||||
if unlimited:
|
if unlimited:
|
||||||
rows = rows_any
|
rows = rows_any
|
||||||
|
|
@ -186,21 +275,15 @@ def list_items(model_name):
|
||||||
start = (page - 1) * per_page
|
start = (page - 1) * per_page
|
||||||
end = start + per_page
|
end = start + per_page
|
||||||
rows = rows_any[start:end]
|
rows = rows_any[start:end]
|
||||||
|
# serialize and return at the bottom like usual
|
||||||
else:
|
else:
|
||||||
# SQLAlchemy Result-like?
|
# SQLAlchemy Result-like or generic iterable
|
||||||
scalars = getattr(rows_any, "scalars", None)
|
scalars = getattr(rows_any, "scalars", None)
|
||||||
if callable(scalars):
|
if callable(scalars):
|
||||||
# execute now, then paginate in Python
|
|
||||||
all_rows = list(cast(ScalarResult[Any], scalars()))
|
all_rows = list(cast(ScalarResult[Any], scalars()))
|
||||||
total = len(all_rows)
|
total = len(all_rows)
|
||||||
if unlimited:
|
rows = all_rows if unlimited else all_rows[(page - 1) * per_page : (page * per_page)]
|
||||||
rows = all_rows
|
|
||||||
else:
|
|
||||||
start = (page - 1) * per_page
|
|
||||||
end = start + per_page
|
|
||||||
rows = all_rows[start:end]
|
|
||||||
else:
|
else:
|
||||||
# single object or generic iterable
|
|
||||||
try:
|
try:
|
||||||
all_rows = list(rows_any)
|
all_rows = list(rows_any)
|
||||||
total = len(all_rows)
|
total = len(all_rows)
|
||||||
|
|
@ -209,32 +292,13 @@ def list_items(model_name):
|
||||||
total = 1
|
total = 1
|
||||||
rows = [rows_any]
|
rows = [rows_any]
|
||||||
|
|
||||||
# If we have a real Select, use db.paginate for proper COUNT and slicing
|
# If we have a real Select, run it once (unlimited) or paginate once.
|
||||||
if stmt is not None:
|
if stmt is not None:
|
||||||
if unlimited:
|
if unlimited:
|
||||||
rows = list(db.session.execute(stmt).scalars())
|
rows = list(db.session.execute(stmt).scalars())
|
||||||
total = count_for(db.session, stmt)
|
total = count_for(db.session, stmt)
|
||||||
else:
|
else:
|
||||||
stmt = default_select(Model, text=text, sort=sort, direction=direction)
|
pagination = db.paginate(stmt, page=page, per_page=per_page, error_out=False)
|
||||||
if fields:
|
|
||||||
stmt = apply_field_loaders(stmt, Model, fields) # the helper I gave you earlier
|
|
||||||
|
|
||||||
stmt = ensure_order_by(stmt, Model, sort=sort, direction=direction)
|
|
||||||
|
|
||||||
if unlimited:
|
|
||||||
rows = list(db.session.execute(stmt).scalars())
|
|
||||||
total = count_for(db.session, stmt) # uses session, not stmt.bind
|
|
||||||
else:
|
|
||||||
pagination = db.paginate(stmt, page=page, per_page=per_page, error_out=False)
|
|
||||||
rows = pagination.items
|
|
||||||
total = pagination.total
|
|
||||||
stmt = ensure_order_by(stmt, Model, sort=sort, direction=direction)
|
|
||||||
pagination = db.paginate(
|
|
||||||
stmt,
|
|
||||||
page=page,
|
|
||||||
per_page=per_page,
|
|
||||||
error_out=False
|
|
||||||
)
|
|
||||||
rows = pagination.items
|
rows = pagination.items
|
||||||
total = pagination.total
|
total = pagination.total
|
||||||
|
|
||||||
|
|
@ -267,14 +331,6 @@ def list_items(model_name):
|
||||||
if want_list:
|
if want_list:
|
||||||
return render_template("fragments/_list_fragment.html", options=items)
|
return render_template("fragments/_list_fragment.html", options=items)
|
||||||
if want_table:
|
if want_table:
|
||||||
# return render_template("fragments/_table_data_fragment.html",
|
|
||||||
# rows=items,
|
|
||||||
# model_name=model_name,
|
|
||||||
# total=total,
|
|
||||||
# page=page,
|
|
||||||
# per_page=per_page,
|
|
||||||
# pages=(0 if unlimited else ((total + per_page - 1) // per_page)),
|
|
||||||
# )
|
|
||||||
resp = make_response(render_template("fragments/_table_data_fragment.html",
|
resp = make_response(render_template("fragments/_table_data_fragment.html",
|
||||||
rows=items, model_name=model_name))
|
rows=items, model_name=model_name))
|
||||||
resp.headers['X-Total'] = str(total)
|
resp.headers['X-Total'] = str(total)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from sqlalchemy import select, asc as sa_asc, desc as sa_desc, or_, func
|
from sqlalchemy import select, asc as sa_asc, desc as sa_desc, or_, func
|
||||||
from sqlalchemy.inspection import inspect
|
from sqlalchemy.inspection import inspect
|
||||||
|
from sqlalchemy.orm import class_mapper, joinedload, selectinload
|
||||||
from sqlalchemy.sql import Select
|
from sqlalchemy.sql import Select
|
||||||
from sqlalchemy.sql.sqltypes import String, Unicode, Text
|
from sqlalchemy.sql.sqltypes import String, Unicode, Text
|
||||||
from typing import Any, Optional, cast, Iterable
|
from typing import Any, Optional, cast, Iterable
|
||||||
|
|
@ -76,6 +77,8 @@ def default_select(
|
||||||
text: Optional[str] = None,
|
text: Optional[str] = None,
|
||||||
sort: Optional[str] = None,
|
sort: Optional[str] = None,
|
||||||
direction: str = "asc",
|
direction: str = "asc",
|
||||||
|
eager = False,
|
||||||
|
skip_rels=frozenset()
|
||||||
) -> Select[Any]:
|
) -> Select[Any]:
|
||||||
stmt: Select[Any] = select(Model)
|
stmt: Select[Any] = select(Model)
|
||||||
|
|
||||||
|
|
@ -119,6 +122,15 @@ def default_select(
|
||||||
for opt in opts:
|
for opt in opts:
|
||||||
stmt = stmt.options(opt)
|
stmt = stmt.options(opt)
|
||||||
|
|
||||||
|
if eager:
|
||||||
|
for prop in class_mapper(Model).relationships:
|
||||||
|
if prop.key in skip_rels:
|
||||||
|
continue
|
||||||
|
lazy = getattr(prop, "lazy", None)
|
||||||
|
if lazy in ("joined", "subquery"):
|
||||||
|
stmt = stmt.options(joinedload(getattr(Model, prop.key)))
|
||||||
|
elif lazy == "selectin":
|
||||||
|
stmt = stmt.options(selectinload(getattr(Model, prop.key)))
|
||||||
return stmt
|
return stmt
|
||||||
|
|
||||||
def default_query(
|
def default_query(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue