326 lines
10 KiB
Python
326 lines
10 KiB
Python
from sqlalchemy import select, asc as sa_asc, desc as sa_desc, or_, func
|
|
from sqlalchemy.inspection import inspect
|
|
from sqlalchemy.sql import Select
|
|
from sqlalchemy.sql.sqltypes import String, Unicode, Text
|
|
from typing import Any, Optional, cast, Iterable
|
|
|
|
PREFERRED_LABELS = ("identifier", "name", "first_name", "last_name", "description")
|
|
|
|
def _columns_for_text_search(Model):
|
|
mapper = inspect(Model)
|
|
cols = []
|
|
for c in mapper.columns:
|
|
if isinstance(c.type, (String, Unicode, Text)):
|
|
cols.append(getattr(Model, c.key))
|
|
|
|
return cols
|
|
|
|
def _mapped_column(Model, attr):
|
|
"""Return the mapped column attr on the class (InstrumentedAttribute) or None"""
|
|
mapper = inspect(Model)
|
|
if attr in mapper.columns.keys():
|
|
return getattr(Model, attr)
|
|
for prop in mapper.column_attrs:
|
|
if prop.key == attr:
|
|
return getattr(Model, prop.key)
|
|
return None
|
|
|
|
def infer_label_attr(Model):
|
|
explicit = getattr(Model, 'ui_label_attr', None)
|
|
if explicit:
|
|
if _mapped_column(Model, explicit) is not None:
|
|
return explicit
|
|
raise RuntimeError(f"ui_label_attr '{explicit}' on {Model.__name__} is not a mapped column")
|
|
|
|
for a in PREFERRED_LABELS:
|
|
if _mapped_column(Model, a) is not None:
|
|
return a
|
|
raise RuntimeError(f"No label-like mapped column on {Model.__name__} (tried {PREFERRED_LABELS})")
|
|
|
|
def count_for(stmt: Select) -> int:
|
|
subq = stmt.order_by(None).subquery()
|
|
return stmt.bind.execute(select(func.count()).select_from(subq)).scalar_one()
|
|
|
|
def ensure_order_by(stmt, Model, sort=None, direction="asc"):
|
|
try:
|
|
has_order = bool(getattr(stmt, '_order_by_clauses', None))
|
|
except Exception:
|
|
has_order = False
|
|
if has_order:
|
|
return stmt
|
|
|
|
cols = []
|
|
|
|
if sort and hasattr(Model, sort):
|
|
col = getattr(Model, sort)
|
|
cols.append(col.desc() if direction == "desc" else col.asc())
|
|
|
|
if not cols:
|
|
ui_order_cols = getattr(Model, 'ui_order_cols', ())
|
|
for name in ui_order_cols or ():
|
|
c = getattr(Model, name, None)
|
|
if c is not None:
|
|
cols.append(c.asc())
|
|
|
|
if not cols:
|
|
for pk_col in inspect(Model).primary_key:
|
|
cols.append(pk_col.asc())
|
|
|
|
return stmt.order_by(*cols)
|
|
|
|
def default_select(
|
|
Model,
|
|
*,
|
|
text: Optional[str] = None,
|
|
sort: Optional[str] = None,
|
|
direction: str = "asc"
|
|
) -> Select[Any]:
|
|
stmt: Select[Any] = select(Model)
|
|
|
|
ui_search = getattr(Model, "ui_search", None)
|
|
if callable(ui_search) and text:
|
|
stmt = cast(Select[Any], ui_search(stmt, text))
|
|
|
|
if sort:
|
|
ui_sort = getattr(Model, "ui_sort", None)
|
|
if callable(ui_sort):
|
|
stmt = cast(Select[Any], ui_sort(stmt, sort, direction))
|
|
else:
|
|
col = getattr(Model, sort, None)
|
|
if col is not None:
|
|
stmt = stmt.order_by(sa_desc(col) if direction == "desc" else sa_asc(col))
|
|
|
|
else:
|
|
ui_order_cols = getattr(Model, "ui_order_cols", ())
|
|
if ui_order_cols:
|
|
order_cols = []
|
|
for name in ui_order_cols:
|
|
col = getattr(Model, name, None)
|
|
if col is not None:
|
|
order_cols.append(sa_asc(col))
|
|
if order_cols:
|
|
stmt = stmt.order_by(*order_cols)
|
|
|
|
return stmt
|
|
|
|
def default_query(
|
|
session,
|
|
Model,
|
|
*,
|
|
text: Optional[str] = None,
|
|
limit: int = 0,
|
|
offset: int = 0,
|
|
sort: Optional[str] = None,
|
|
direction: str = "asc",
|
|
) -> list[Any]:
|
|
"""
|
|
SA 2.x ONLY. Returns list[Model].
|
|
|
|
Hooks:
|
|
- ui_search(stmt: Select, text: str) -> Select
|
|
- ui_sort(stmt: Select, sort: str, direction: str) -> Select
|
|
- ui_order_cols: tuple[str, ...] # default ordering columns
|
|
"""
|
|
stmt: Select[Any] = select(Model)
|
|
|
|
ui_search = getattr(Model, "ui_search", None)
|
|
if callable(ui_search) and text:
|
|
stmt = cast(Select[Any], ui_search(stmt, text))
|
|
elif text:
|
|
t = f"%{text}%"
|
|
text_cols = _columns_for_text_search(Model)
|
|
if text_cols:
|
|
stmt = stmt.where(or_(*(col.ilike(t) for col in text_cols)))
|
|
|
|
if sort:
|
|
ui_sort = getattr(Model, "ui_sort", None)
|
|
if callable(ui_sort):
|
|
stmt = cast(Select[Any], ui_sort(stmt, sort, direction))
|
|
else:
|
|
col = getattr(Model, sort, None)
|
|
if col is not None:
|
|
stmt = stmt.order_by(sa_desc(col) if direction == "desc" else sa_asc(col))
|
|
else:
|
|
order_cols = getattr(Model, "ui_order_cols", ())
|
|
if order_cols:
|
|
for colname in order_cols:
|
|
col = getattr(Model, colname, None)
|
|
if col is not None:
|
|
stmt = stmt.order_by(sa_asc(col))
|
|
|
|
if offset:
|
|
stmt = stmt.offset(offset)
|
|
if limit > 0:
|
|
stmt = stmt.limit(limit)
|
|
|
|
opts_attr = getattr(Model, "ui_eagerload", ())
|
|
|
|
opts: Iterable[Any]
|
|
if callable(opts_attr):
|
|
opts = cast(Iterable[Any], opts_attr()) # if you want, pass Model to it: opts_attr(Model)
|
|
else:
|
|
opts = cast(Iterable[Any], opts_attr)
|
|
|
|
for opt in opts:
|
|
stmt = stmt.options(opt)
|
|
|
|
return list(session.execute(stmt).scalars().all())
|
|
|
|
def _resolve_column(Model, path: str):
|
|
"""Return (selectable, joins:list[tuple[parent, attr]]) for 'col' or 'rel.col'"""
|
|
if '.' not in path:
|
|
col = _mapped_column(Model, path)
|
|
if col is None:
|
|
raise ValueError(f"Column '{path}' is not a mapped column on {Model.__name__}")
|
|
return col, []
|
|
rel_name, rel_field = path.split('.', 1)
|
|
rel_attr = getattr(Model, rel_name, None)
|
|
if getattr(rel_attr, 'property', None) is None:
|
|
raise ValueError(f"Column '{path}' is not a valid relationship on {Model.__name__}")
|
|
Rel = rel_attr.property.mapper.class_
|
|
col = _mapped_column(Rel, rel_field)
|
|
if col is None:
|
|
raise ValueError(f"Column '{path}' is not a mapped column on {Rel.__name__}")
|
|
return col, [(Model, rel_name)]
|
|
|
|
def default_values(session, Model, *, id_: int, fields: Iterable[str]) -> dict[str, Any]:
|
|
fields = [f.strip() for f in fields if f.strip()]
|
|
if not fields:
|
|
raise ValueError("No fields provided for default_values")
|
|
|
|
mapper = inspect(Model)
|
|
pk = mapper.primary_key[0]
|
|
|
|
selects = []
|
|
joins = []
|
|
for f in fields:
|
|
col, j = _resolve_column(Model, f)
|
|
selects.append(col.label(f.replace('.', '_')))
|
|
joins.extend(j)
|
|
|
|
seen = set()
|
|
stmt = select(*selects).where(pk == id_)
|
|
current = Model
|
|
for parent, attr_name in joins:
|
|
key = (parent, attr_name)
|
|
if key in seen:
|
|
continue
|
|
seen.add(key)
|
|
stmt = stmt.join(getattr(parent, attr_name))
|
|
|
|
row = session.execute(stmt).one_or_none()
|
|
if row is None:
|
|
return {}
|
|
|
|
allow = getattr(Model, "ui_value_allow", None)
|
|
if allow:
|
|
for f in fields:
|
|
if f not in allow:
|
|
raise ValueError(f"Field '{f}' not allowed")
|
|
|
|
data = {}
|
|
for f in fields:
|
|
key = f.replace('.', '_')
|
|
data[f] = getattr(row, key, None)
|
|
return data
|
|
|
|
def default_value(session, Model, *, id_: int, field: str) -> Any:
|
|
if '.' not in field:
|
|
col = _mapped_column(Model, field)
|
|
if col is None:
|
|
raise ValueError(f"Field '{field}' is not a mapped column on {Model.__name__}")
|
|
pk = inspect(Model).primary_key[0]
|
|
return session.scalar(select(col).where(pk == id_))
|
|
|
|
rel_name, rel_field = field.split('.', 1)
|
|
rel_attr = getattr(Model, rel_name, None)
|
|
if rel_attr is None or not hasattr(rel_attr, 'property'):
|
|
raise ValueError(f"Field '{field}' is not a valid relationship on {Model.__name__}")
|
|
|
|
Rel = rel_attr.property.mapper.class_
|
|
rel_col = _mapped_column(Rel, rel_field)
|
|
if rel_col is None:
|
|
raise ValueError(f"Field '{field}' is not a mapped column on {Rel.__name__}")
|
|
|
|
pk = inspect(Model).primary_key[0]
|
|
stmt = select(rel_col).join(getattr(Model, rel_name)).where(pk == id_).limit(1)
|
|
return session.scalar(stmt)
|
|
|
|
def default_create(session, Model, payload):
|
|
label = infer_label_attr(Model)
|
|
obj = Model(**{label: payload.get(label) or payload.get("name")})
|
|
session.add(obj)
|
|
session.commit()
|
|
return obj
|
|
|
|
def default_update(session, Model, id_, payload):
|
|
obj = session.get(Model, id_)
|
|
if not obj:
|
|
return None
|
|
|
|
editable = getattr(Model, 'ui_editable_cols', None)
|
|
|
|
changed = False
|
|
for k, v in payload.items():
|
|
if k == 'id':
|
|
continue
|
|
|
|
col = _mapped_column(Model, k)
|
|
if col is None:
|
|
continue
|
|
|
|
if editable and k not in editable:
|
|
continue
|
|
|
|
if v == '' or v is None:
|
|
nv = None
|
|
else:
|
|
try:
|
|
nv = int(v) if col.type.python_type is int else v
|
|
except Exception:
|
|
nv = v
|
|
|
|
setattr(obj, k, nv)
|
|
changed = True
|
|
|
|
if changed:
|
|
session.commit()
|
|
return obj
|
|
|
|
def default_delete(session, Model, ids):
|
|
count = 0
|
|
for i in ids:
|
|
obj = session.get(Model, i)
|
|
if obj:
|
|
session.delete(obj); count += 1
|
|
session.commit()
|
|
return count
|
|
|
|
def default_serialize(Model, obj, *, view=None):
|
|
# 1. Explicit config wins
|
|
label_attr = getattr(Model, 'ui_label_attr', None)
|
|
|
|
# 2. Otherwise, pick the first PREFERRED_LABELS that exists (can be @property or real column)
|
|
if not label_attr:
|
|
for candidate in PREFERRED_LABELS:
|
|
if hasattr(obj, candidate):
|
|
label_attr = candidate
|
|
break
|
|
|
|
# 3. Fallback to str(obj) if literally nothing found
|
|
if not label_attr:
|
|
name_val = str(obj)
|
|
else:
|
|
try:
|
|
name_val = getattr(obj, label_attr)
|
|
except Exception:
|
|
name_val = str(obj)
|
|
|
|
data = {'id': obj.id, 'name': name_val}
|
|
|
|
# Include extra attrs if defined
|
|
for attr in getattr(Model, 'ui_extra_attrs', ()):
|
|
if hasattr(obj, attr):
|
|
data[attr] = getattr(obj, attr)
|
|
|
|
return data
|