Fixed up Pylance issues WITHOUT destroying half the functionality we just did.

This commit is contained in:
Yaro Kasear 2025-08-15 15:00:07 -05:00
parent 4e15972275
commit 247b167377
3 changed files with 89 additions and 54 deletions

View file

@ -32,6 +32,7 @@ class User(db.Model, ImageAttachable):
image: Mapped[Optional['Image']] = relationship('Image', back_populates='user', passive_deletes=True) image: Mapped[Optional['Image']] = relationship('Image', back_populates='user', passive_deletes=True)
ui_eagerload = tuple() ui_eagerload = tuple()
ui_order_cols = ('first_name', 'last_name',)
@property @property
def identifier(self) -> str: def identifier(self) -> str:

View file

@ -1,5 +1,7 @@
from flask import Blueprint, request, render_template, jsonify, abort from flask import Blueprint, request, render_template, jsonify, abort
from sqlalchemy.engine import ScalarResult
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.sql import Select
from typing import Any, Optional, List, cast, Type, Iterable from typing import Any, Optional, List, cast, Type, Iterable
from .defaults import ( from .defaults import (
@ -45,21 +47,49 @@ def call(Model: type, name: str, *args: Any, **kwargs: Any) -> Any:
def list_items(model_name): def list_items(model_name):
Model = get_model_class(model_name) Model = get_model_class(model_name)
text = (request.args.get("q") or "").strip() or None text = (request.args.get("q") or "").strip() or None
limit_param = request.args.get("limit") limit_param = request.args.get("limit")
limit: int | None = None if limit_param in (None, "", "0", "-1") else min(int(limit_param), 500) # 0 / -1 / blank => unlimited (pass 0)
if limit_param in (None, "", "0", "-1"):
effective_limit = 0
else:
effective_limit = min(int(limit_param), 500)
offset = int(request.args.get("offset", 0)) offset = int(request.args.get("offset", 0))
view = (request.args.get("view") or "json").strip() view = (request.args.get("view") or "json").strip()
# Build kwargs so we only include 'limit' when it's an int sort = (request.args.get("sort") or "").strip() or None
qkwargs: dict[str, Any] = {"text": text, "offset": offset} direction = (request.args.get("dir") or request.args.get("direction") or "asc").lower()
if limit is not None: if direction not in ("asc", "desc"):
qkwargs["limit"] = limit direction = "asc"
rows_iter: Iterable[Any] = ( qkwargs: dict[str, Any] = {
call(Model, "ui_query", db.session, **qkwargs) "text": text,
or default_query(db.session, Model, **qkwargs) "limit": effective_limit,
) "offset": offset,
rows = list(rows_iter) "sort": sort,
"direction": direction,
}
# Prefer per-model override. Contract: return list[Model] OR a Select (SA 2.x).
rows_any: Any = call(Model, "ui_query", db.session, **qkwargs)
if rows_any is None:
rows = default_query(db.session, Model, **qkwargs)
elif isinstance(rows_any, list):
rows = rows_any
elif isinstance(rows_any, Select):
rows = list(cast(ScalarResult[Any], db.session.execute(rows_any).scalars()))
else:
# If someone returns a Result or other iterable of models
try:
# Try SQLAlchemy Result-like
scalars = getattr(rows_any, "scalars", None)
if callable(scalars):
rows = list(cast(ScalarResult[Any], scalars()))
else:
rows = list(rows_any)
except TypeError:
rows = [rows_any]
items = [ items = [
(call(Model, "ui_serialize", r, view=view) or default_serialize(Model, r, view=view)) (call(Model, "ui_serialize", r, view=view) or default_serialize(Model, r, view=view))

View file

@ -1,5 +1,8 @@
from sqlalchemy import select, or_ from sqlalchemy import select, asc as sa_asc, desc as sa_desc
from sqlalchemy.inspection import inspect from sqlalchemy.inspection import inspect
from sqlalchemy.sql import Select
from sqlalchemy.orm import Query
from typing import Any, Optional, cast
PREFERRED_LABELS = ("identifier", "name", "first_name", "last_name", "description") PREFERRED_LABELS = ("identifier", "name", "first_name", "last_name", "description")
@ -25,46 +28,57 @@ def infer_label_attr(Model):
return a return a
raise RuntimeError(f"No label-like mapped column on {Model.__name__} (tried {PREFERRED_LABELS})") raise RuntimeError(f"No label-like mapped column on {Model.__name__} (tried {PREFERRED_LABELS})")
def default_query(session, Model, *, text=None, limit=100, offset=0, filters=None, order=None): def default_query(
label_name = infer_label_attr(Model) session,
label_col = _mapped_column(Model, label_name) # guaranteed not None now Model,
*,
text: Optional[str] = None,
limit: int = 0,
offset: int = 0,
sort: Optional[str] = None,
direction: str = "asc",
) -> list[Any]:
"""
SA 2.x ONLY. Returns list[Model].
stmt = select(Model) Hooks:
- ui_search(stmt: Select, text: str) -> Select
- ui_sort(stmt: Select, sort: str, direction: str) -> Select
- ui_order_cols: tuple[str, ...] # default ordering columns
"""
stmt: Select[Any] = select(Model)
# Eager loads if class defines them (expects loader options like selectinload(...)) # Optional per-model search hook
for opt in getattr(Model, "ui_eagerload", ()) or (): ui_search = getattr(Model, "ui_search", None)
stmt = stmt.options(opt) if callable(ui_search) and text:
stmt = cast(Select[Any], ui_search(stmt, text))
# Text search across mapped columns only # Sorting priority:
if text: # 1. explicit sort param
cols = getattr(Model, "ui_search_cols", None) or (label_name,) # 2. per-model ui_sort hook
mapped = [ _mapped_column(Model, c) for c in cols ] # 3. per-model ui_order_cols default ordering
mapped = [ c for c in mapped if c is not None ] if sort:
if mapped: ui_sort = getattr(Model, "ui_sort", None)
stmt = stmt.where(or_(*[ c.ilike(f"%{text}%") for c in mapped ])) if callable(ui_sort):
stmt = cast(Select[Any], ui_sort(stmt, sort, direction))
# Filters (exact-match) across mapped columns only else:
if filters: col = getattr(Model, sort, None)
for k, v in filters.items():
if v is None:
continue
col = _mapped_column(Model, k)
if col is not None: if col is not None:
stmt = stmt.where(col == v) stmt = stmt.order_by(sa_desc(col) if direction == "desc" else sa_asc(col))
else:
# Order by mapped columns (fallback to label) order_cols = getattr(Model, "ui_order_cols", ())
order_cols = order or getattr(Model, "ui_order_cols", None) or (label_name,) if order_cols:
for c in order_cols: for colname in order_cols:
col = _mapped_column(Model, c) col = getattr(Model, colname, None)
if col is not None: if col is not None:
stmt = stmt.order_by(col) stmt = stmt.order_by(sa_asc(col))
# stmt = stmt.limit(limit).offset(offset)
if limit is not None:
stmt = stmt.limit(limit)
if offset: if offset:
stmt = stmt.offset(offset) stmt = stmt.offset(offset)
return session.execute(stmt).scalars().all() if limit > 0:
stmt = stmt.limit(limit)
return list(session.execute(stmt).scalars().all())
def default_create(session, Model, payload): def default_create(session, Model, payload):
label = infer_label_attr(Model) label = infer_label_attr(Model)
@ -73,16 +87,6 @@ def default_create(session, Model, payload):
session.commit() session.commit()
return obj return obj
# def default_update(session, Model, id_, payload):
# obj = session.get(Model, id_)
# if not obj:
# return None
# label = infer_label_attr(Model)
# if (nv := payload.get(label) or payload.get("name")):
# setattr(obj, label, nv)
# session.commit()
# return obj
def default_update(session, Model, id_, payload): def default_update(session, Model, id_, payload):
obj = session.get(Model, id_) obj = session.get(Model, id_)
if not obj: if not obj: