Refactor default_query function to improve readability and add eager loading options

This commit is contained in:
Yaro Kasear 2025-08-19 08:11:05 -05:00
parent f3b781a360
commit 7c4958e0cb

View file

@ -1,8 +1,7 @@
from sqlalchemy import select, asc as sa_asc, desc as sa_desc from sqlalchemy import select, asc as sa_asc, desc as sa_desc
from sqlalchemy.inspection import inspect from sqlalchemy.inspection import inspect
from sqlalchemy.sql import Select from sqlalchemy.sql import Select
from sqlalchemy.orm import Query from typing import Any, Optional, cast, Iterable
from typing import Any, Optional, cast
PREFERRED_LABELS = ("identifier", "name", "first_name", "last_name", "description") PREFERRED_LABELS = ("identifier", "name", "first_name", "last_name", "description")
@ -48,15 +47,10 @@ def default_query(
""" """
stmt: Select[Any] = select(Model) stmt: Select[Any] = select(Model)
# Optional per-model search hook
ui_search = getattr(Model, "ui_search", None) ui_search = getattr(Model, "ui_search", None)
if callable(ui_search) and text: if callable(ui_search) and text:
stmt = cast(Select[Any], ui_search(stmt, text)) stmt = cast(Select[Any], ui_search(stmt, text))
# Sorting priority:
# 1. explicit sort param
# 2. per-model ui_sort hook
# 3. per-model ui_order_cols default ordering
if sort: if sort:
ui_sort = getattr(Model, "ui_sort", None) ui_sort = getattr(Model, "ui_sort", None)
if callable(ui_sort): if callable(ui_sort):
@ -78,6 +72,17 @@ def default_query(
if limit > 0: if limit > 0:
stmt = stmt.limit(limit) stmt = stmt.limit(limit)
opts_attr = getattr(Model, "ui_eagerload", ())
opts: Iterable[Any]
if callable(opts_attr):
opts = cast(Iterable[Any], opts_attr()) # if you want, pass Model to it: opts_attr(Model)
else:
opts = cast(Iterable[Any], opts_attr)
for opt in opts:
stmt = stmt.options(opt)
return list(session.execute(stmt).scalars().all()) return list(session.execute(stmt).scalars().all())
def default_create(session, Model, payload): def default_create(session, Model, payload):