diff --git a/crudkit/ui/fragments.py b/crudkit/ui/fragments.py index 5bfcb3b..4621df2 100644 --- a/crudkit/ui/fragments.py +++ b/crudkit/ui/fragments.py @@ -5,9 +5,17 @@ from flask import current_app, url_for from jinja2 import Environment, FileSystemLoader, ChoiceLoader from sqlalchemy import inspect from sqlalchemy.orm import class_mapper, RelationshipProperty, load_only, selectinload -from sqlalchemy.orm.attributes import NO_VALUE +from sqlalchemy.orm.base import NO_VALUE from typing import Any, Dict, List, Optional, Tuple +_ALLOWED_ATTRS = { + "class", "placeholder", "autocomplete", "inputmode", "pattern", + "min", "max", "step", "maxlength", "minlength", + "required", "readonly", "disabled", + "multiple", "size", + "id", "name", "value", +} + def get_env(): app = current_app default_path = os.path.join(os.path.dirname(__file__), 'templates') @@ -17,6 +25,29 @@ def get_env(): loader=ChoiceLoader([app.jinja_loader, fallback_loader]) ) +def _sanitize_attrs(attrs: Any) -> dict[str, Any]: + """ + Whitelist attributes; allow data-* and aria-*; render True as boolean attr. + Drop False/None and anything not whitelisted. + """ + if not isinstance(attrs, dict): + return {} + out: dict[str, Any] = {} + for k, v in attrs.items(): + if not isinstance(k, str): + continue + elif isinstance(v, str): + if len(v) > 512: + v = v[:512] + if k.startswith("data-") or k.startswith("aria-") or k in _ALLOWED_ATTRS: + if isinstance(v, bool): + if v: + out[k] = True + elif v is not None: + out[k] = str(v) + + return out + class _SafeObj: """Attribute access that returns '' for missing/None instead of exploding.""" __slots__ = ("_obj",) @@ -30,6 +61,153 @@ class _SafeObj: return "" return _SafeObj(val) +def _coerce_fk_value(values: dict | None, instance: Any, base: str): + """ + Resolve the current selection for relationship 'base': + 1) values['_id'] + 2) values['']['id'] or values[''] if scalar + 3) instance. (relationship) if it's already loaded -> use its .id + 4) instance._id if it's already loaded (column) and instance is bound + Never trigger a lazy load. Never touch the DB. + """ + # 1) explicit *_id from values + if isinstance(values, dict): + key = f"{base}_id" + if key in values: + return values.get(key) + rel = values.get(base) + if isinstance(rel, dict): + return rel.get("id") or rel.get(key) + if isinstance(rel, (int, str)): + return rel + + # 3) use loaded relationship object (safe for detached instances) + if instance is not None: + try: + state = inspect(instance) + # relationship attr present? + rel_attr = state.attrs.get(base) + if rel_attr is not None and rel_attr.loaded_value is not NO_VALUE: + rel_obj = rel_attr.loaded_value + if rel_obj is not None: + rid = getattr(rel_obj, "id", None) + if rid is not None: + return rid + # 4) use loaded fk column if the value is present and NOT expired + id_attr = state.attrs.get(f"{base}_id") + if id_attr is not None and id_attr.loaded_value is not NO_VALUE: + return id_attr.loaded_value + except Exception: + pass + + return None + +def _is_many_to_one(mapper, name: str) -> Optional[RelationshipProperty]: + try: + prop = mapper.relationships[name] + except Exception: + return None + if isinstance(prop, RelationshipProperty) and prop.direction.name == 'MANYTOONE': + return prop + return None + +def _rel_for_id_name(mapper, name: str) -> tuple[Optional[str], Optional[RelationshipProperty]]: + if name.endswith("_id"): + base = name[":-3"] + prop = _is_many_to_one(mapper, base) + return (base, prop) if prop else (None, None) + else: + prop = _is_many_to_one(mapper, name) + return (name, prop) if prop else (None, None) + +def _fk_options(session, related_model, label_spec): + simple_cols, rel_paths = _extract_label_requirements(label_spec) + q = session.query(related_model) + + col_attrs = [] + if hasattr(related_model, "id"): + col_attrs.append(getattr(related_model, "id")) + for name in simple_cols: + if hasattr(related_model, name): + col_attrs.append(getattr(related_model, name)) + if col_attrs: + q = q.options(load_only(*col_attrs)) + + for rel_name, col_name in rel_paths: + rel_prop = getattr(related_model, rel_name, None) + if rel_prop is None: + continue + try: + target_cls = related_model.__mapper__.relationships[rel_name].mapper.class_ + col_attr = getattr(target_cls, col_name, None) + if col_attr is None: + q = q.options(selectinload(rel_prop)) + else: + q = q.options(selectinload(rel_prop).load_only(col_attr)) + except Exception: + q = q.options(selectinload(rel_prop)) + + if simple_cols: + first = simple_cols[0] + if hasattr(related_model, first): + q = q.order_by(getattr(related_model, first)) + + rows = q.all() + return [ + { + 'value': getattr(opt, 'id'), + 'label': _label_from_obj(opt, label_spec), + } + for opt in rows + ] + +def _normalize_field_spec(spec, mapper, session, label_specs_model_default): + """ + Turn a user field spec into a concrete field dict the template understands. + """ + name = spec['name'] + base_rel_name, rel_prop = _rel_for_id_name(mapper, name) + + field = { + "name": name if not base_rel_name else f"{base_rel_name}_id", + "label": spec.get("label", name), + "type": spec.get("type"), + "options": spec.get("options"), + "attrs": spec.get("attrs"), + "help": spec.get("help"), + } + + if rel_prop: + if field["type"] is None: + field["type"] = "select" + if field["type"] == "select" and field.get("options") is None and session is not None: + related_model = rel_prop.mapper.class_ + label_spec = ( + spec.get("label_spec") + or label_specs_model_default.get(base_rel_name) + or getattr(related_model, "__crud_label__", None) + or "id" + ) + field["options"] = _fk_options(session, related_model, label_spec) + return field + + col = mapper.columns.get(name) + if field["type"] is None: + if col is not None and hasattr(col.type, "python_type"): + py = None + try: + py = col.type.python_type + except Exception: + pass + if py is bool: + field["type"] = "checkbox" + else: + field["type"] = "text" + else: + field["type"] = "text" + + return field + def _extract_label_requirements(spec: Any) -> tuple[list[str], list[tuple[str, str]]]: """ From a label spec, return: @@ -90,12 +268,14 @@ def _attrs_from_label_spec(spec: Any) -> list[str]: def _label_from_obj(obj: Any, spec: Any) -> str: if spec is None: - return str(obj) - if callable(spec): - try: - return str(spec(obj)) - except Exception: - return str(obj) + for attr in ("label", "name", "title", "description"): + if hasattr(obj, attr): + val = getattr(obj, attr) + if not callable(val) and val is not None: + return str(val) + if hasattr(obj, "id"): + return str(getattr(obj, "id")) + return object.__repr__(obj) if isinstance(spec, (list, tuple)): parts = [] @@ -329,7 +509,9 @@ def render_field(field, value): field_label=field.get('label', field['name']), value=value, field_type=field.get('type', 'text'), - options=field.get('options', None) + options=field.get('options', None), + attrs=_sanitize_attrs(field.get('attrs') or {}), + help=field.get('help') ) def render_table(objects: List[Any], columns: Optional[List[Dict[str, Any]]] = None, **opts): @@ -365,92 +547,108 @@ def render_table(objects: List[Any], columns: Optional[List[Dict[str, Any]]] = N return template.render(columns=cols, rows=disp_rows, kwargs=flat_opts) -def render_form(model_cls, values, session=None, *, label_specs: Optional[Dict[str, Any]] = None): +def render_form( + model_cls, + values, + session=None, + *, + fields_spec: Optional[list[dict]] = None, + label_specs: Optional[Dict[str, Any]] = None, + exclude: Optional[set[str]] = None, + overrides: Optional[Dict[str, Dict[str, Any]]] = None, + instance: Any = None, # NEW: pass the ORM object so we can read *_id +): + """ + fields_spec: list of dicts describing fields in order. Each dict supports: + - name: "first_name" | "location" | "location_id" (required) + - label: override_label + - type: "text" | "textarea" | "checkbox" | "select" | "hidden" | ... + - label_spec: for relationship selects, e.g. "{name} - {room_function.description}" + - options: prebuilt list of {"value","label"}; skips querying if provided + - attrs: dict of arbitrary HTML attributes, e.g. {"required": True, "placeholder": "Jane"} + - help: small help text under the field + label_specs: legacy per-relationship label spec fallback ({"location": "..."}). + exclude: set of field names to hide. + overrides: legacy quick overrides keyed by field name (label/type/etc.) + instance: the ORM object backing the form; used to populate *_id values + """ env = get_env() - template = get_crudkit_template(env, 'form.html') - fields = [] - fk_fields = set() + template = get_crudkit_template(env, "form.html") + exclude = exclude or set() + overrides = overrides or {} label_specs = label_specs or {} mapper = class_mapper(model_cls) - for prop in mapper.iterate_properties: - if isinstance(prop, RelationshipProperty) and prop.direction.name == 'MANYTOONE': - if session is None: + fields: list[dict] = [] + values_map = dict(values or {}) # we'll augment this with *_id selections + + if fields_spec: + # Spec-driven path + for spec in fields_spec: + if spec["name"] in exclude: continue - - related_model = prop.mapper.class_ - rel_label_spec = ( - label_specs.get(prop.key) - or getattr(related_model, "__crud_label__", None) - or None + field = _normalize_field_spec( + {**spec, **overrides.get(spec["name"], {})}, + mapper, session, label_specs ) + fields.append(field) - # Figure out what we must load - simple_cols, rel_paths = _extract_label_requirements(rel_label_spec) + # After building fields, inject current values for any M2O selects + for f in fields: + name = f.get("name") + if isinstance(name, str) and name.endswith("_id"): + base = name[:-3] + rel_prop = mapper.relationships.get(base) + if isinstance(rel_prop, RelationshipProperty) and rel_prop.direction.name == "MANYTOONE": + values_map[name] = _coerce_fk_value(values, instance, base) - q = session.query(related_model) + else: + # Auto-generate path (your original behavior) + fk_fields = set() - # id is always needed - col_attrs = [] - if hasattr(related_model, "id"): - col_attrs.append(getattr(related_model, "id")) - for name in simple_cols: - if hasattr(related_model, name): - col_attrs.append(getattr(related_model, name)) - if col_attrs: - q = q.options(load_only(*col_attrs)) - - # Load related bits minimally - for rel_name, col_name in rel_paths: - rel_prop = getattr(related_model, rel_name, None) - if rel_prop is None: + # Relationships first + for prop in mapper.iterate_properties: + if isinstance(prop, RelationshipProperty) and prop.direction.name == 'MANYTOONE': + base = prop.key + if base in exclude or f"{base}_id" in exclude: + continue + if session is None: continue - # grab target class to resolve column attr - try: - target_cls = related_model.__mapper__.relationships[rel_name].mapper.class_ - col_attr = getattr(target_cls, col_name, None) - if col_attr is None: - q = q.options(selectinload(rel_prop)) - else: - q = q.options(selectinload(rel_prop).load_only(col_attr)) - except Exception: - # fallback if mapper lookup is weird - q = q.options(selectinload(rel_prop)) - # Gentle ordering: use first simple col if any, else skip - if simple_cols: - first = simple_cols[0] - if hasattr(related_model, first): - q = q.order_by(getattr(related_model, first)) + related_model = prop.mapper.class_ + rel_label_spec = ( + label_specs.get(base) + or getattr(related_model, "__crud_label__", None) + or "id" + ) + options = _fk_options(session, related_model, rel_label_spec) + base_field = { + "name": f"{base}_id", + "label": base, + "type": "select", + "options": options, + } + field = {**base_field, **overrides.get(f"{base}_id", {})} + fields.append(field) + fk_fields.add(f"{base}_id") - options = q.all() + # NEW: set the current selection for this dropdown + values_map[f"{base}_id"] = _coerce_fk_value(values, instance, base) - fields.append({ - 'name': f"{prop.key}_id", - 'label': prop.key, - 'type': 'select', - 'options': [ - { - 'value': getattr(opt, 'id'), - 'label': _label_from_obj(opt, rel_label_spec), - } - for opt in options - ] - }) - fk_fields.add(f"{prop.key}_id") + # Then plain columns + for col in model_cls.__table__.columns: + if col.name in fk_fields or col.name in exclude: + continue + if col.name in ('id', 'created_at', 'updated_at'): + continue + if col.default or col.server_default or col.onupdate: + continue + base_field = { + "name": col.name, + "label": col.name, + "type": "checkbox" if getattr(col.type, "python_type", None) is bool else "text", + } + field = {**base_field, **overrides.get(col.name, {})} + fields.append(field) - # Base columns - for col in model_cls.__table__.columns: - if col.name in fk_fields: - continue - if col.name in ('id', 'created_at', 'updated_at'): - continue - if col.default or col.server_default or col.onupdate: - continue - fields.append({ - 'name': col.name, - 'label': col.name, - 'type': 'text', - }) - - return template.render(fields=fields, values=values, render_field=render_field) + return template.render(fields=fields, values=values_map, render_field=render_field) diff --git a/crudkit/ui/templates/field.html b/crudkit/ui/templates/field.html index 28fcf7e..c60242a 100644 --- a/crudkit/ui/templates/field.html +++ b/crudkit/ui/templates/field.html @@ -1,16 +1,37 @@ {% if field_type == 'select' %} - {% if options %} {% for opt in options %} - + {% endfor %} {% else %} {% endif %} + +{% elif field_type == 'textarea' %} + + +{% elif field_type == 'checkbox' %} + + {% else %} - + +{% endif %} + +{% if help %} +
{{ help }}
{% endif %}