from sqlalchemy import select, asc as sa_asc, desc as sa_desc, or_, func from sqlalchemy.inspection import inspect from sqlalchemy.sql import Select from sqlalchemy.sql.sqltypes import String, Unicode, Text from typing import Any, Optional, cast, Iterable PREFERRED_LABELS = ("identifier", "name", "first_name", "last_name", "description") def _columns_for_text_search(Model): mapper = inspect(Model) cols = [] for c in mapper.columns: if isinstance(c.type, (String, Unicode, Text)): cols.append(getattr(Model, c.key)) return cols def _mapped_column(Model, attr): """Return the mapped column attr on the class (InstrumentedAttribute) or None""" mapper = inspect(Model) if attr in mapper.columns.keys(): return getattr(Model, attr) for prop in mapper.column_attrs: if prop.key == attr: return getattr(Model, prop.key) return None def infer_label_attr(Model): explicit = getattr(Model, 'ui_label_attr', None) if explicit: if _mapped_column(Model, explicit) is not None: return explicit raise RuntimeError(f"ui_label_attr '{explicit}' on {Model.__name__} is not a mapped column") for a in PREFERRED_LABELS: if _mapped_column(Model, a) is not None: return a raise RuntimeError(f"No label-like mapped column on {Model.__name__} (tried {PREFERRED_LABELS})") def count_for(stmt: Select) -> int: subq = stmt.order_by(None).subquery() return stmt.bind.execute(select(func.count()).select_from(subq)).scalar_one() def ensure_order_by(stmt, Model, sort=None, direction="asc"): try: has_order = bool(getattr(stmt, '_order_by_clauses', None)) except Exception: has_order = False if has_order: return stmt cols = [] if sort and hasattr(Model, sort): col = getattr(Model, sort) cols.append(col.desc() if direction == "desc" else col.asc()) if not cols: ui_order_cols = getattr(Model, 'ui_order_cols', ()) for name in ui_order_cols or (): c = getattr(Model, name, None) if c is not None: cols.append(c.asc()) if not cols: for pk_col in inspect(Model).primary_key: cols.append(pk_col.asc()) return stmt.order_by(*cols) def default_select( Model, *, text: Optional[str] = None, sort: Optional[str] = None, direction: str = "asc" ) -> Select[Any]: stmt: Select[Any] = select(Model) ui_search = getattr(Model, "ui_search", None) if callable(ui_search) and text: stmt = cast(Select[Any], ui_search(stmt, text)) if sort: ui_sort = getattr(Model, "ui_sort", None) if callable(ui_sort): stmt = cast(Select[Any], ui_sort(stmt, sort, direction)) else: col = getattr(Model, sort, None) if col is not None: stmt = stmt.order_by(sa_desc(col) if direction == "desc" else sa_asc(col)) else: ui_order_cols = getattr(Model, "ui_order_cols", ()) if ui_order_cols: order_cols = [] for name in ui_order_cols: col = getattr(Model, name, None) if col is not None: order_cols.append(sa_asc(col)) if order_cols: stmt = stmt.order_by(*order_cols) return stmt def default_query( session, Model, *, text: Optional[str] = None, limit: int = 0, offset: int = 0, sort: Optional[str] = None, direction: str = "asc", ) -> list[Any]: """ SA 2.x ONLY. Returns list[Model]. Hooks: - ui_search(stmt: Select, text: str) -> Select - ui_sort(stmt: Select, sort: str, direction: str) -> Select - ui_order_cols: tuple[str, ...] # default ordering columns """ stmt: Select[Any] = select(Model) ui_search = getattr(Model, "ui_search", None) if callable(ui_search) and text: stmt = cast(Select[Any], ui_search(stmt, text)) elif text: t = f"%{text}%" text_cols = _columns_for_text_search(Model) if text_cols: stmt = stmt.where(or_(*(col.ilike(t) for col in text_cols))) if sort: ui_sort = getattr(Model, "ui_sort", None) if callable(ui_sort): stmt = cast(Select[Any], ui_sort(stmt, sort, direction)) else: col = getattr(Model, sort, None) if col is not None: stmt = stmt.order_by(sa_desc(col) if direction == "desc" else sa_asc(col)) else: order_cols = getattr(Model, "ui_order_cols", ()) if order_cols: for colname in order_cols: col = getattr(Model, colname, None) if col is not None: stmt = stmt.order_by(sa_asc(col)) if offset: stmt = stmt.offset(offset) if limit > 0: stmt = stmt.limit(limit) opts_attr = getattr(Model, "ui_eagerload", ()) opts: Iterable[Any] if callable(opts_attr): opts = cast(Iterable[Any], opts_attr()) # if you want, pass Model to it: opts_attr(Model) else: opts = cast(Iterable[Any], opts_attr) for opt in opts: stmt = stmt.options(opt) return list(session.execute(stmt).scalars().all()) def _resolve_column(Model, path: str): """Return (selectable, joins:list[tuple[parent, attr]]) for 'col' or 'rel.col'""" if '.' not in path: col = _mapped_column(Model, path) if col is None: raise ValueError(f"Column '{path}' is not a mapped column on {Model.__name__}") return col, [] rel_name, rel_field = path.split('.', 1) rel_attr = getattr(Model, rel_name, None) if getattr(rel_attr, 'property', None) is None: raise ValueError(f"Column '{path}' is not a valid relationship on {Model.__name__}") Rel = rel_attr.property.mapper.class_ col = _mapped_column(Rel, rel_field) if col is None: raise ValueError(f"Column '{path}' is not a mapped column on {Rel.__name__}") return col, [(Model, rel_name)] def default_values(session, Model, *, id_: int, fields: Iterable[str]) -> dict[str, Any]: fields = [f.strip() for f in fields if f.strip()] if not fields: raise ValueError("No fields provided for default_values") mapper = inspect(Model) pk = mapper.primary_key[0] selects = [] joins = [] for f in fields: col, j = _resolve_column(Model, f) selects.append(col.label(f.replace('.', '_'))) joins.extend(j) seen = set() stmt = select(*selects).where(pk == id_) current = Model for parent, attr_name in joins: key = (parent, attr_name) if key in seen: continue seen.add(key) stmt = stmt.join(getattr(parent, attr_name)) row = session.execute(stmt).one_or_none() if row is None: return {} allow = getattr(Model, "ui_value_allow", None) if allow: for f in fields: if f not in allow: raise ValueError(f"Field '{f}' not allowed") data = {} for f in fields: key = f.replace('.', '_') data[f] = getattr(row, key, None) return data def default_value(session, Model, *, id_: int, field: str) -> Any: if '.' not in field: col = _mapped_column(Model, field) if col is None: raise ValueError(f"Field '{field}' is not a mapped column on {Model.__name__}") pk = inspect(Model).primary_key[0] return session.scalar(select(col).where(pk == id_)) rel_name, rel_field = field.split('.', 1) rel_attr = getattr(Model, rel_name, None) if rel_attr is None or not hasattr(rel_attr, 'property'): raise ValueError(f"Field '{field}' is not a valid relationship on {Model.__name__}") Rel = rel_attr.property.mapper.class_ rel_col = _mapped_column(Rel, rel_field) if rel_col is None: raise ValueError(f"Field '{field}' is not a mapped column on {Rel.__name__}") pk = inspect(Model).primary_key[0] stmt = select(rel_col).join(getattr(Model, rel_name)).where(pk == id_).limit(1) return session.scalar(stmt) def default_create(session, Model, payload): label = infer_label_attr(Model) obj = Model(**{label: payload.get(label) or payload.get("name")}) session.add(obj) session.commit() return obj def default_update(session, Model, id_, payload): obj = session.get(Model, id_) if not obj: return None editable = getattr(Model, 'ui_editable_cols', None) changed = False for k, v in payload.items(): if k == 'id': continue col = _mapped_column(Model, k) if col is None: continue if editable and k not in editable: continue if v == '' or v is None: nv = None else: try: nv = int(v) if col.type.python_type is int else v except Exception: nv = v setattr(obj, k, nv) changed = True if changed: session.commit() return obj def default_delete(session, Model, ids): count = 0 for i in ids: obj = session.get(Model, i) if obj: session.delete(obj); count += 1 session.commit() return count def default_serialize(Model, obj, *, view=None): # 1. Explicit config wins label_attr = getattr(Model, 'ui_label_attr', None) # 2. Otherwise, pick the first PREFERRED_LABELS that exists (can be @property or real column) if not label_attr: for candidate in PREFERRED_LABELS: if hasattr(obj, candidate): label_attr = candidate break # 3. Fallback to str(obj) if literally nothing found if not label_attr: name_val = str(obj) else: try: name_val = getattr(obj, label_attr) except Exception: name_val = str(obj) data = {'id': obj.id, 'name': name_val} # Include extra attrs if defined for attr in getattr(Model, 'ui_extra_attrs', ()): if hasattr(obj, attr): data[attr] = getattr(obj, attr) return data