inventory/inventory/ui/defaults.py

326 lines
10 KiB
Python

from sqlalchemy import select, asc as sa_asc, desc as sa_desc, or_, func
from sqlalchemy.inspection import inspect
from sqlalchemy.sql import Select
from sqlalchemy.sql.sqltypes import String, Unicode, Text
from typing import Any, Optional, cast, Iterable
PREFERRED_LABELS = ("identifier", "name", "first_name", "last_name", "description")
def _columns_for_text_search(Model):
mapper = inspect(Model)
cols = []
for c in mapper.columns:
if isinstance(c.type, (String, Unicode, Text)):
cols.append(getattr(Model, c.key))
return cols
def _mapped_column(Model, attr):
"""Return the mapped column attr on the class (InstrumentedAttribute) or None"""
mapper = inspect(Model)
if attr in mapper.columns.keys():
return getattr(Model, attr)
for prop in mapper.column_attrs:
if prop.key == attr:
return getattr(Model, prop.key)
return None
def infer_label_attr(Model):
explicit = getattr(Model, 'ui_label_attr', None)
if explicit:
if _mapped_column(Model, explicit) is not None:
return explicit
raise RuntimeError(f"ui_label_attr '{explicit}' on {Model.__name__} is not a mapped column")
for a in PREFERRED_LABELS:
if _mapped_column(Model, a) is not None:
return a
raise RuntimeError(f"No label-like mapped column on {Model.__name__} (tried {PREFERRED_LABELS})")
def count_for(stmt: Select) -> int:
subq = stmt.order_by(None).subquery()
return stmt.bind.execute(select(func.count()).select_from(subq)).scalar_one()
def ensure_order_by(stmt, Model, sort=None, direction="asc"):
try:
has_order = bool(getattr(stmt, '_order_by_clauses', None))
except Exception:
has_order = False
if has_order:
return stmt
cols = []
if sort and hasattr(Model, sort):
col = getattr(Model, sort)
cols.append(col.desc() if direction == "desc" else col.asc())
if not cols:
ui_order_cols = getattr(Model, 'ui_order_cols', ())
for name in ui_order_cols or ():
c = getattr(Model, name, None)
if c is not None:
cols.append(c.asc())
if not cols:
for pk_col in inspect(Model).primary_key:
cols.append(pk_col.asc())
return stmt.order_by(*cols)
def default_select(
Model,
*,
text: Optional[str] = None,
sort: Optional[str] = None,
direction: str = "asc"
) -> Select[Any]:
stmt: Select[Any] = select(Model)
ui_search = getattr(Model, "ui_search", None)
if callable(ui_search) and text:
stmt = cast(Select[Any], ui_search(stmt, text))
if sort:
ui_sort = getattr(Model, "ui_sort", None)
if callable(ui_sort):
stmt = cast(Select[Any], ui_sort(stmt, sort, direction))
else:
col = getattr(Model, sort, None)
if col is not None:
stmt = stmt.order_by(sa_desc(col) if direction == "desc" else sa_asc(col))
else:
ui_order_cols = getattr(Model, "ui_order_cols", ())
if ui_order_cols:
order_cols = []
for name in ui_order_cols:
col = getattr(Model, name, None)
if col is not None:
order_cols.append(sa_asc(col))
if order_cols:
stmt = stmt.order_by(*order_cols)
return stmt
def default_query(
session,
Model,
*,
text: Optional[str] = None,
limit: int = 0,
offset: int = 0,
sort: Optional[str] = None,
direction: str = "asc",
) -> list[Any]:
"""
SA 2.x ONLY. Returns list[Model].
Hooks:
- ui_search(stmt: Select, text: str) -> Select
- ui_sort(stmt: Select, sort: str, direction: str) -> Select
- ui_order_cols: tuple[str, ...] # default ordering columns
"""
stmt: Select[Any] = select(Model)
ui_search = getattr(Model, "ui_search", None)
if callable(ui_search) and text:
stmt = cast(Select[Any], ui_search(stmt, text))
elif text:
t = f"%{text}%"
text_cols = _columns_for_text_search(Model)
if text_cols:
stmt = stmt.where(or_(*(col.ilike(t) for col in text_cols)))
if sort:
ui_sort = getattr(Model, "ui_sort", None)
if callable(ui_sort):
stmt = cast(Select[Any], ui_sort(stmt, sort, direction))
else:
col = getattr(Model, sort, None)
if col is not None:
stmt = stmt.order_by(sa_desc(col) if direction == "desc" else sa_asc(col))
else:
order_cols = getattr(Model, "ui_order_cols", ())
if order_cols:
for colname in order_cols:
col = getattr(Model, colname, None)
if col is not None:
stmt = stmt.order_by(sa_asc(col))
if offset:
stmt = stmt.offset(offset)
if limit > 0:
stmt = stmt.limit(limit)
opts_attr = getattr(Model, "ui_eagerload", ())
opts: Iterable[Any]
if callable(opts_attr):
opts = cast(Iterable[Any], opts_attr()) # if you want, pass Model to it: opts_attr(Model)
else:
opts = cast(Iterable[Any], opts_attr)
for opt in opts:
stmt = stmt.options(opt)
return list(session.execute(stmt).scalars().all())
def _resolve_column(Model, path: str):
"""Return (selectable, joins:list[tuple[parent, attr]]) for 'col' or 'rel.col'"""
if '.' not in path:
col = _mapped_column(Model, path)
if col is None:
raise ValueError(f"Column '{path}' is not a mapped column on {Model.__name__}")
return col, []
rel_name, rel_field = path.split('.', 1)
rel_attr = getattr(Model, rel_name, None)
if getattr(rel_attr, 'property', None) is None:
raise ValueError(f"Column '{path}' is not a valid relationship on {Model.__name__}")
Rel = rel_attr.property.mapper.class_
col = _mapped_column(Rel, rel_field)
if col is None:
raise ValueError(f"Column '{path}' is not a mapped column on {Rel.__name__}")
return col, [(Model, rel_name)]
def default_values(session, Model, *, id_: int, fields: Iterable[str]) -> dict[str, Any]:
fields = [f.strip() for f in fields if f.strip()]
if not fields:
raise ValueError("No fields provided for default_values")
mapper = inspect(Model)
pk = mapper.primary_key[0]
selects = []
joins = []
for f in fields:
col, j = _resolve_column(Model, f)
selects.append(col.label(f.replace('.', '_')))
joins.extend(j)
seen = set()
stmt = select(*selects).where(pk == id_)
current = Model
for parent, attr_name in joins:
key = (parent, attr_name)
if key in seen:
continue
seen.add(key)
stmt = stmt.join(getattr(parent, attr_name))
row = session.execute(stmt).one_or_none()
if row is None:
return {}
allow = getattr(Model, "ui_value_allow", None)
if allow:
for f in fields:
if f not in allow:
raise ValueError(f"Field '{f}' not allowed")
data = {}
for f in fields:
key = f.replace('.', '_')
data[f] = getattr(row, key, None)
return data
def default_value(session, Model, *, id_: int, field: str) -> Any:
if '.' not in field:
col = _mapped_column(Model, field)
if col is None:
raise ValueError(f"Field '{field}' is not a mapped column on {Model.__name__}")
pk = inspect(Model).primary_key[0]
return session.scalar(select(col).where(pk == id_))
rel_name, rel_field = field.split('.', 1)
rel_attr = getattr(Model, rel_name, None)
if rel_attr is None or not hasattr(rel_attr, 'property'):
raise ValueError(f"Field '{field}' is not a valid relationship on {Model.__name__}")
Rel = rel_attr.property.mapper.class_
rel_col = _mapped_column(Rel, rel_field)
if rel_col is None:
raise ValueError(f"Field '{field}' is not a mapped column on {Rel.__name__}")
pk = inspect(Model).primary_key[0]
stmt = select(rel_col).join(getattr(Model, rel_name)).where(pk == id_).limit(1)
return session.scalar(stmt)
def default_create(session, Model, payload):
label = infer_label_attr(Model)
obj = Model(**{label: payload.get(label) or payload.get("name")})
session.add(obj)
session.commit()
return obj
def default_update(session, Model, id_, payload):
obj = session.get(Model, id_)
if not obj:
return None
editable = getattr(Model, 'ui_editable_cols', None)
changed = False
for k, v in payload.items():
if k == 'id':
continue
col = _mapped_column(Model, k)
if col is None:
continue
if editable and k not in editable:
continue
if v == '' or v is None:
nv = None
else:
try:
nv = int(v) if col.type.python_type is int else v
except Exception:
nv = v
setattr(obj, k, nv)
changed = True
if changed:
session.commit()
return obj
def default_delete(session, Model, ids):
count = 0
for i in ids:
obj = session.get(Model, i)
if obj:
session.delete(obj); count += 1
session.commit()
return count
def default_serialize(Model, obj, *, view=None):
# 1. Explicit config wins
label_attr = getattr(Model, 'ui_label_attr', None)
# 2. Otherwise, pick the first PREFERRED_LABELS that exists (can be @property or real column)
if not label_attr:
for candidate in PREFERRED_LABELS:
if hasattr(obj, candidate):
label_attr = candidate
break
# 3. Fallback to str(obj) if literally nothing found
if not label_attr:
name_val = str(obj)
else:
try:
name_val = getattr(obj, label_attr)
except Exception:
name_val = str(obj)
data = {'id': obj.id, 'name': name_val}
# Include extra attrs if defined
for attr in getattr(Model, 'ui_extra_attrs', ()):
if hasattr(obj, attr):
data[attr] = getattr(obj, attr)
return data