Fixing render_form's sins.

This commit is contained in:
Yaro Kasear 2025-09-16 15:59:52 -05:00
parent 27431a7150
commit 7f6cbf66fb
7 changed files with 211 additions and 11 deletions

View file

@ -1,9 +1,10 @@
import os
import re
from flask import current_app, url_for
from jinja2 import Environment, FileSystemLoader, ChoiceLoader
from sqlalchemy import inspect
from sqlalchemy.orm import class_mapper, RelationshipProperty
from sqlalchemy.orm import class_mapper, RelationshipProperty, load_only, selectinload
from sqlalchemy.orm.attributes import NO_VALUE
from typing import Any, Dict, List, Optional, Tuple
@ -16,6 +17,118 @@ def get_env():
loader=ChoiceLoader([app.jinja_loader, fallback_loader])
)
class _SafeObj:
"""Attribute access that returns '' for missing/None instead of exploding."""
__slots__ = ("_obj",)
def __init__(self, obj): self._obj = obj
def __str__(self): return "" if self._obj is None else str(self._obj)
def __getattr__(self, name):
if self._obj is None:
return ""
val = getattr(self._obj, name, None)
if val is None:
return ""
return _SafeObj(val)
def _extract_label_requirements(spec: Any) -> tuple[list[str], list[tuple[str, str]]]:
"""
From a label spec, return:
- simple_cols: ["name", "code"]
- rel_paths: [("room_function", "description"), ("owner", "last_name")]
"""
simple_cols: list[str] = []
rel_paths: list[tuple[str, str]] = []
def ingest(token: str) -> None:
token = str(token).strip()
if not token:
return
if "." in token:
rel, col = token.split(".", 1)
if rel and col:
rel_paths.append((rel, col))
else:
simple_cols.append(token)
if spec is None or callable(spec):
return simple_cols, rel_paths
if isinstance(spec, (list, tuple)):
for a in spec:
ingest(a)
return simple_cols, rel_paths
if isinstance(spec, str):
# format string like "{first} {last}" or "{room_function.description} · {name}"
if "{" in spec and "}" in spec:
names = re.findall(r"{\s*([^}:\s]+)", spec)
for n in names:
ingest(n)
else:
ingest(spec)
return simple_cols, rel_paths
return simple_cols, rel_paths
def _attrs_from_label_spec(spec: Any) -> list[str]:
"""
Return a list of attribute names needed from the related model to compute the label.
Only simple attribute names are returned; dotted paths return just the first segment.
"""
if spec is None:
return []
if callable(spec):
return []
if isinstance(spec, (list, tuple)):
return [str(a).split(".", 1)[0] for a in spec]
if isinstance(spec, str):
if "{" in spec and "}" in spec:
names = re.findall(r"{\s*([^}:\s]+)", spec)
return [n.split(".", 1)[0] for n in names]
return [spec.split(".", 1)[0]]
return []
def _label_from_obj(obj: Any, spec: Any) -> str:
if spec is None:
return str(obj)
if callable(spec):
try:
return str(spec(obj))
except Exception:
return str(obj)
if isinstance(spec, (list, tuple)):
parts = []
for a in spec:
cur = obj
for part in str(a).split("."):
cur = getattr(cur, part, None)
if cur is None:
break
parts.append("" if cur is None else str(cur))
return " ".join(p for p in parts if p)
if isinstance(spec, str) and "{" in spec and "}" in spec:
fields = re.findall(r"{\s*([^}:\s]+)", spec)
data: dict[str, Any] = {}
for f in fields:
root = f.split(".", 1)[0]
if root not in data:
val = getattr(obj, root, None)
data[root] = _SafeObj(val)
try:
return spec.format(**data)
except Exception:
return str(obj)
cur = obj
for part in str(spec).split("."):
cur = getattr(cur, part, None)
if cur is None:
return ""
return str(cur)
def _val_from_row_or_obj(row: Dict[str, Any], obj: Any, dotted: str) -> Any:
"""Best-effort deep get: try the projected row first, then the ORM object."""
val = _deep_get(row, dotted)
@ -252,33 +365,81 @@ def render_table(objects: List[Any], columns: Optional[List[Dict[str, Any]]] = N
return template.render(columns=cols, rows=disp_rows, kwargs=flat_opts)
def render_form(model_cls, values, session=None):
def render_form(model_cls, values, session=None, *, label_specs: Optional[Dict[str, Any]] = None):
env = get_env()
template = get_crudkit_template(env, 'form.html')
fields = []
fk_fields = set()
label_specs = label_specs or {}
mapper = class_mapper(model_cls)
for prop in mapper.iterate_properties:
# FK Relationship fields (many-to-one)
if isinstance(prop, RelationshipProperty) and prop.direction.name == 'MANYTOONE':
if session is None:
continue
related_model = prop.mapper.class_
options = session.query(related_model).all()
rel_label_spec = (
label_specs.get(prop.key)
or getattr(related_model, "__crud_label__", None)
or None
)
# Figure out what we must load
simple_cols, rel_paths = _extract_label_requirements(rel_label_spec)
q = session.query(related_model)
# id is always needed
col_attrs = []
if hasattr(related_model, "id"):
col_attrs.append(getattr(related_model, "id"))
for name in simple_cols:
if hasattr(related_model, name):
col_attrs.append(getattr(related_model, name))
if col_attrs:
q = q.options(load_only(*col_attrs))
# Load related bits minimally
for rel_name, col_name in rel_paths:
rel_prop = getattr(related_model, rel_name, None)
if rel_prop is None:
continue
# grab target class to resolve column attr
try:
target_cls = related_model.__mapper__.relationships[rel_name].mapper.class_
col_attr = getattr(target_cls, col_name, None)
if col_attr is None:
q = q.options(selectinload(rel_prop))
else:
q = q.options(selectinload(rel_prop).load_only(col_attr))
except Exception:
# fallback if mapper lookup is weird
q = q.options(selectinload(rel_prop))
# Gentle ordering: use first simple col if any, else skip
if simple_cols:
first = simple_cols[0]
if hasattr(related_model, first):
q = q.order_by(getattr(related_model, first))
options = q.all()
fields.append({
'name': f"{prop.key}_id",
'label': prop.key,
'type': 'select',
'options': [
{'value': getattr(obj, 'id'), 'label': str(obj)}
for obj in options
{
'value': getattr(opt, 'id'),
'label': _label_from_obj(opt, rel_label_spec),
}
for opt in options
]
})
fk_fields.add(f"{prop.key}_id")
# Now add basic columns — excluding FKs already covered
# Base columns
for col in model_cls.__table__.columns:
if col.name in fk_fields:
continue
@ -293,4 +454,3 @@ def render_form(model_cls, values, session=None):
})
return template.render(fields=fields, values=values, render_field=render_field)

View file

@ -1,6 +1,6 @@
<form method="POST">
{% for field in fields %}
{{ render_field(field, values.get(field.name, '')) }}
{{ render_field(field, values.get(field.name, '')) | safe }}
{% endfor %}
<button type="submit">Create</button>
</form>

View file

@ -9,7 +9,7 @@
<tbody>
{% if rows %}
{% for row in rows %}
<tr>
<tr class="{{ row.class or '' }}">
{% for cell in row.cells %}
{% if cell.href %}
<td class="{{ cell.class or '' }}"><a href="{{ cell.href }}">{{ cell.text if cell.text is not none else '-' }}</a></td>