Update refresh URL handling in combobox and dropdown fragments to include 'per_page' parameter; modify count_for function to accept session for improved efficiency.

This commit is contained in:
Yaro Kasear 2025-08-22 15:00:45 -05:00
parent f3f4493698
commit 91e1e5051a
4 changed files with 109 additions and 7 deletions

View file

@ -40,7 +40,7 @@ create_url = none, edit_url = none, delete_url = none, refresh_url = none
</select>
{% if refresh_url %}
{% set url = refresh_url ~ ('&' if '?' in refresh_url else '?') ~ 'view=option&limit=0' %}
{% set url = refresh_url ~ ('&' if '?' in refresh_url else '?') ~ 'view=option&limit=0&per_page=0' %}
<div id="{{ id }}-htmx-refresh" class="d-none" hx-get="{{ url }}"
hx-trigger="revealed, combobox:refresh from:#{{ id }}-container" hx-target="#{{ id }}-list" hx-swap="innerHTML">
</div>

View file

@ -49,7 +49,7 @@
</ul>
</div>
{% if refresh_url %}
{% set url = refresh_url ~ ('&' if '?' in refresh_url else '?') ~ 'view=list&limit=0' %}
{% set url = refresh_url ~ ('&' if '?' in refresh_url else '?') ~ 'view=list&limit=0&per_page=0' %}
<div id="{{ id }}-htmx-refresh" class="d-none" hx-get="{{ url }}"
hx-trigger="revealed, combobox:refresh from:#{{ id }}-dropdown" hx-target="#{{ id }}DropdownContent" hx-swap="innerHTML"></div>
{% endif %}

View file

@ -1,8 +1,11 @@
from collections import defaultdict
from flask import Blueprint, request, render_template, jsonify, abort, make_response
from sqlalchemy.engine import ScalarResult
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import class_mapper, load_only, selectinload
from sqlalchemy.sql import Select
from typing import Any, List, cast
from typing import Any, List, cast, Iterable, Tuple, Set, Dict
from .defaults import (
default_query, default_create, default_update, default_delete, default_serialize, default_values, default_value, default_select, ensure_order_by, count_for
@ -12,6 +15,75 @@ from .. import db
bp = Blueprint("ui", __name__, url_prefix="/ui")
def split_fields(Model, fields: Iterable[str]) -> Tuple[Set[str], Dict[str, Set[str]]]:
"""
Split requested fields into base model columns and relation->attr sets.
Example: ["name", "brand.name", "owner.identifier"] =>
base_cols = {"name"}
rel_cols = {"brand": {"name"}, "owner": {"identifier"}}
"""
base_cols: Set[str] = set()
rel_cols: Dict[str, Set[str]] = defaultdict(set)
for f in fields:
f = f.strip()
if not f:
continue
if "." in f:
rel, attr = f.split(".", 1)
rel_cols[rel].add(attr)
else:
base_cols.add(f)
return base_cols, rel_cols
def _load_only_existing(Model, names: Set[str]):
"""
Return a list of mapped column attributes present on Model for load_only(...).
Skips relationships and unmapped/hybrid attributes so SQLA doesnt scream.
"""
cols = []
mapper = class_mapper(Model)
mapped_attr_names = set(mapper.attrs.keys())
for n in names:
if n in mapped_attr_names:
attr = getattr(Model, n)
prop = getattr(attr, "property", None)
if prop is not None and hasattr(prop, "columns"):
cols.append(attr)
return cols
def apply_field_loaders(stmt: Select, Model, fields: Iterable[str]) -> Select:
"""
Given a base Select[Model] and requested fields, attach loader options:
- load_only(...) for base scalar columns
- selectinload(Model.rel).options(load_only(...)) for each requested relation
Only single-depth "rel.attr" is supported, which is exactly what youre using.
"""
base_cols, rel_cols = split_fields(Model, fields)
# Restrict base columns if any were explicitly requested
base_only = _load_only_existing(Model, base_cols)
if base_only:
stmt = stmt.options(load_only(*base_only))
# Relations: selectinload each requested relation and trim its columns
for rel_name, attrs in rel_cols.items():
if not hasattr(Model, rel_name):
continue
rel_attr = getattr(Model, rel_name)
try:
target_cls = rel_attr.property.mapper.class_
except Exception:
continue # not a relationship; ignore
opt = selectinload(rel_attr)
rel_only = _load_only_existing(target_cls, attrs)
if rel_only:
opt = opt.options(load_only(*rel_only))
stmt = stmt.options(opt)
return stmt
def _normalize(s: str) -> str:
return s.replace("_", "").replace("-", "").lower()
@ -144,6 +216,18 @@ def list_items(model_name):
total = count_for(db.session, stmt)
else:
stmt = default_select(Model, text=text, sort=sort, direction=direction)
if fields:
stmt = apply_field_loaders(stmt, Model, fields) # the helper I gave you earlier
stmt = ensure_order_by(stmt, Model, sort=sort, direction=direction)
if unlimited:
rows = list(db.session.execute(stmt).scalars())
total = count_for(db.session, stmt) # uses session, not stmt.bind
else:
pagination = db.paginate(stmt, page=page, per_page=per_page, error_out=False)
rows = pagination.items
total = pagination.total
stmt = ensure_order_by(stmt, Model, sort=sort, direction=direction)
pagination = db.paginate(
stmt,

View file

@ -37,9 +37,11 @@ def infer_label_attr(Model):
return a
raise RuntimeError(f"No label-like mapped column on {Model.__name__} (tried {PREFERRED_LABELS})")
def count_for(stmt: Select) -> int:
def count_for(session, stmt: Select) -> int:
# strip ORDER BY for efficiency
subq = stmt.order_by(None).subquery()
return stmt.bind.execute(select(func.count()).select_from(subq)).scalar_one()
count_stmt = select(func.count()).select_from(subq)
return session.execute(count_stmt).scalar_one()
def ensure_order_by(stmt, Model, sort=None, direction="asc"):
try:
@ -73,14 +75,22 @@ def default_select(
*,
text: Optional[str] = None,
sort: Optional[str] = None,
direction: str = "asc"
direction: str = "asc",
) -> Select[Any]:
stmt: Select[Any] = select(Model)
# search
ui_search = getattr(Model, "ui_search", None)
if callable(ui_search) and text:
stmt = cast(Select[Any], ui_search(stmt, text))
elif text:
# optional generic search fallback if you used this in default_query
t = f"%{text}%"
text_cols = _columns_for_text_search(Model) # your existing helper
if text_cols:
stmt = stmt.where(or_(*(col.ilike(t) for col in text_cols)))
# sorting
if sort:
ui_sort = getattr(Model, "ui_sort", None)
if callable(ui_sort):
@ -89,7 +99,6 @@ def default_select(
col = getattr(Model, sort, None)
if col is not None:
stmt = stmt.order_by(sa_desc(col) if direction == "desc" else sa_asc(col))
else:
ui_order_cols = getattr(Model, "ui_order_cols", ())
if ui_order_cols:
@ -101,6 +110,15 @@ def default_select(
if order_cols:
stmt = stmt.order_by(*order_cols)
# eagerload defaults
opts_attr = getattr(Model, "ui_eagerload", ())
if callable(opts_attr):
opts = cast(Iterable[Any], opts_attr()) # if you prefer, pass Model in
else:
opts = cast(Iterable[Any], opts_attr)
for opt in opts:
stmt = stmt.options(opt)
return stmt
def default_query(