Lots of downstream updates.

This commit is contained in:
Yaro Kasear 2025-10-03 16:27:25 -05:00
parent f5bc0b5a30
commit 10b2843be8
6 changed files with 373 additions and 99 deletions

View file

@ -0,0 +1,9 @@
# crudkit/core/__init__.py
from .utils import (
ISO_DT_FORMATS,
normalize_payload,
deep_diff,
diff_to_patch,
filter_to_columns,
to_jsonable,
)

View file

@ -10,6 +10,7 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql import operators from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import UnaryExpression, ColumnElement from sqlalchemy.sql.elements import UnaryExpression, ColumnElement
from crudkit.core import to_jsonable, deep_diff, diff_to_patch, filter_to_columns, normalize_payload
from crudkit.core.base import Version from crudkit.core.base import Version
from crudkit.core.spec import CRUDSpec from crudkit.core.spec import CRUDSpec
from crudkit.core.types import OrderSpec, SeekWindow from crudkit.core.types import OrderSpec, SeekWindow
@ -246,32 +247,30 @@ class CRUDService(Generic[T]):
# Detect first hops that have deeper, nested tails requested (e.g. "contact.supervisor") # Detect first hops that have deeper, nested tails requested (e.g. "contact.supervisor")
nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 }
used_contains_eager = False # IMPORTANT:
# - Only attach loader options for first-hop relations from the root.
# - Always use selectinload here (avoid contains_eager joins).
# - Let compile_projections() supply deep chained options.
for base_alias, rel_attr, target_alias in join_paths: for base_alias, rel_attr, target_alias in join_paths:
is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False)) is_firsthop_from_root = (base_alias is root_alias)
if not is_firsthop_from_root:
# Deeper hops are handled by proj_opts below
continue
prop = getattr(rel_attr, "property", None)
is_collection = bool(getattr(prop, "uselist", False))
is_nested_firsthop = rel_attr.key in nested_first_hops is_nested_firsthop = rel_attr.key in nested_first_hops
if is_collection or is_nested_firsthop: opt = selectinload(rel_attr)
# Use selectinload so deeper hops can chain cleanly (and to avoid # Optional narrowng for collections
# contains_eager/loader conflicts on nested paths). if is_collection:
opt = selectinload(rel_attr) child_names = (collection_field_names or {}).get(rel_attr.key, [])
if child_names:
# Narrow columns for collections if we know child scalar names target_cls = prop.mapper.class_
if is_collection: cols = [getattr(target_cls, n, None) for n in child_names]
child_names = (collection_field_names or {}).get(rel_attr.key, []) cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
if child_names: if cols:
target_cls = rel_attr.property.mapper.class_ opt = opt.load_only(*cols)
cols = [getattr(target_cls, n, None) for n in child_names] query = query.options(opt)
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
if cols:
opt = opt.load_only(*cols)
query = query.options(opt)
else:
# Simple first-hop scalar rel with no deeper tails: safe to join + contains_eager
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True
# Filters AFTER joins → no cartesian products # Filters AFTER joins → no cartesian products
if filters: if filters:
@ -364,6 +363,10 @@ class CRUDService(Generic[T]):
last_key = None last_key = None
# Count DISTINCT ids with mirrored joins # Count DISTINCT ids with mirrored joins
# Apply deep projection loader options (safe: we avoided contains_eager)
if proj_opts:
query = query.options(*proj_opts)
total = None total = None
if include_total: if include_total:
base = session.query(getattr(root_alias, "id")) base = session.query(getattr(root_alias, "id"))
@ -465,26 +468,25 @@ class CRUDService(Generic[T]):
nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 }
used_contains_eager = False # First-hop only; use selectinload (no contains_eager)
for base_alias, rel_attr, target_alias in join_paths: for base_alias, rel_attr, target_alias in join_paths:
is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False)) is_firsthop_from_root = (base_alias is root_alias)
is_nested_firsthop = rel_attr.key in nested_first_hops if not is_firsthop_from_root:
continue
prop = getattr(rel_attr, "property", None)
is_collection = bool(getattr(prop, "uselist", False))
_is_nested_firsthop = rel_attr.key in nested_first_hops
if is_collection or is_nested_firsthop: opt = selectinload(rel_attr)
opt = selectinload(rel_attr) if is_collection:
if is_collection: child_names = (collection_field_names or {}).get(rel_attr.key, [])
child_names = (collection_field_names or {}).get(rel_attr.key, []) if child_names:
if child_names: target_cls = prop.mapper.class_
target_cls = rel_attr.property.mapper.class_ cols = [getattr(target_cls, n, None) for n in child_names]
cols = [getattr(target_cls, n, None) for n in child_names] cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] if cols:
if cols: opt = opt.load_only(*cols)
opt = opt.load_only(*cols) query = query.options(opt)
query = query.options(opt)
else:
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True
# Apply filters (joins are in place → no cartesian products) # Apply filters (joins are in place → no cartesian products)
if filters: if filters:
@ -496,7 +498,7 @@ class CRUDService(Generic[T]):
# Projection loader options compiled from requested fields. # Projection loader options compiled from requested fields.
# Skip if we used contains_eager to avoid loader-strategy conflicts. # Skip if we used contains_eager to avoid loader-strategy conflicts.
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
if proj_opts and not used_contains_eager: if proj_opts:
query = query.options(*proj_opts) query = query.options(*proj_opts)
obj = query.first() obj = query.first()
@ -564,26 +566,25 @@ class CRUDService(Generic[T]):
nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 }
used_contains_eager = False # First-hop only; use selectinload
for _base_alias, rel_attr, target_alias in join_paths: for base_alias, rel_attr, target_alias in join_paths:
is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False)) is_firsthop_from_root = (base_alias is root_alias)
is_nested_firsthop = rel_attr.key in nested_first_hops if not is_firsthop_from_root:
continue
prop = getattr(rel_attr, "property", None)
is_collection = bool(getattr(prop, "uselist", False))
_is_nested_firsthop = rel_attr.key in nested_first_hops
if is_collection or is_nested_firsthop: opt = selectinload(rel_attr)
opt = selectinload(rel_attr) if is_collection:
if is_collection: child_names = (collection_field_names or {}).get(rel_attr.key, [])
child_names = (collection_field_names or {}).get(rel_attr.key, []) if child_names:
if child_names: target_cls = prop.mapper.class_
target_cls = rel_attr.property.mapper.class_ cols = [getattr(target_cls, n, None) for n in child_names]
cols = [getattr(target_cls, n, None) for n in child_names] cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] if cols:
if cols: opt = opt.load_only(*cols)
opt = opt.load_only(*cols) query = query.options(opt)
query = query.options(opt)
else:
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True
# Filters AFTER joins → no cartesian products # Filters AFTER joins → no cartesian products
if filters: if filters:
@ -607,7 +608,7 @@ class CRUDService(Generic[T]):
# Projection loaders only if we didnt use contains_eager # Projection loaders only if we didnt use contains_eager
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
if proj_opts and not used_contains_eager: if proj_opts:
query = query.options(*proj_opts) query = query.options(*proj_opts)
else: else:
@ -648,31 +649,79 @@ class CRUDService(Generic[T]):
return rows return rows
def create(self, data: dict, actor=None) -> T: def create(self, data: dict, actor=None, *, commit: bool = True) -> T:
session = self.session session = self.session
obj = self.model(**data) obj = self.model(**data)
session.add(obj) session.add(obj)
session.commit()
self._log_version("create", obj, actor) session.flush()
self._log_version("create", obj, actor, commit=commit)
if commit:
session.commit()
return obj return obj
def update(self, id: int, data: dict, actor=None) -> T: def update(self, id: int, data: dict, actor=None, *, commit: bool = True) -> T:
session = self.session session = self.session
obj = session.get(self.model, id) obj = session.get(self.model, id)
if not obj: if not obj:
raise ValueError(f"{self.model.__name__} with ID {id} not found.") raise ValueError(f"{self.model.__name__} with ID {id} not found.")
valid_fields = {c.name for c in self.model.__table__.columns}
unknown = set(data) - valid_fields before = obj.as_dict()
if unknown:
raise ValueError(f"Unknown fields: {', '.join(sorted(unknown))}") # Normalize and restrict payload to real columns
for k, v in data.items(): norm = normalize_payload(data, self.model)
if k in valid_fields: incoming = filter_to_columns(norm, self.model)
setattr(obj, k, v)
session.commit() # Build a synthetic "desired" state for top-level columns
self._log_version("update", obj, actor) desired = {**before, **incoming}
# Compute intended change set (before vs intended)
proposed = deep_diff(
before, desired,
ignore_keys={"id", "created_at", "updated_at"},
list_mode="index",
)
patch = diff_to_patch(proposed)
# Nothing to do
if not patch:
return obj
# Apply only what actually changes
for k, v in patch.items():
setattr(obj, k, v)
# Optional: skip commit if ORM says no real change (paranoid check)
# Note: is_modified can lie if attrs are expired; use history for certainty.
dirty = any(inspect(obj).attrs[k].history.has_changes() for k in patch.keys())
if not dirty:
return obj
# Commit atomically
if commit:
session.commit()
# AFTER snapshot for audit
after = obj.as_dict()
# Actual diff (captures triggers/defaults, still ignoring noisy keys)
actual = deep_diff(
before, after,
ignore_keys={"id", "created_at", "updated_at"},
list_mode="index",
)
# If truly nothing changed post-commit (rare), skip version spam
if not (actual["added"] or actual["removed"] or actual["changed"]):
return obj
# Log both what we *intended* and what *actually* happened
self._log_version("update", obj, actor, metadata={"diff": actual, "patch": patch}, commit=commit)
return obj return obj
def delete(self, id: int, hard: bool = False, actor = None): def delete(self, id: int, hard: bool = False, actor = None, *, commit: bool = True):
session = self.session session = self.session
obj = session.get(self.model, id) obj = session.get(self.model, id)
if not obj: if not obj:
@ -682,23 +731,31 @@ class CRUDService(Generic[T]):
else: else:
soft = cast(_SoftDeletable, obj) soft = cast(_SoftDeletable, obj)
soft.is_deleted = True soft.is_deleted = True
session.commit() if commit:
self._log_version("delete", obj, actor) session.commit()
self._log_version("delete", obj, actor, commit=commit)
return obj return obj
def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict | None = None): def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict | None = None, *, commit: bool = True):
session = self.session session = self.session
try: try:
data = obj.as_dict() snapshot = {}
except Exception: try:
data = {"error": "Failed to serialize object."} snapshot = obj.as_dict()
version = Version( except Exception:
model_name=self.model.__name__, snapshot = {"error": "serialize failed"}
object_id=obj.id,
change_type=change_type, version = Version(
data=data, model_name=self.model.__name__,
actor=str(actor) if actor else None, object_id=obj.id,
meta=metadata change_type=change_type,
) data=to_jsonable(snapshot),
session.add(version) actor=str(actor) if actor else None,
session.commit() meta=to_jsonable(metadata) if metadata else None,
)
session.add(version)
if commit:
session.commit()
except Exception as e:
log.warning(f"Version logging failed for {self.model.__name__} id={getattr(obj, 'id', '?')}: {str(e)}")
session.rollback()

176
crudkit/core/utils.py Normal file
View file

@ -0,0 +1,176 @@
from __future__ import annotations
from datetime import datetime, date
from decimal import Decimal
from enum import Enum
from typing import Any, Dict, Optional, Callable
from sqlalchemy import inspect
ISO_DT_FORMATS = ("%Y-%m-%dT%H:%M:%S.%f",
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%d %H:%M",
"%Y-%m-%d")
def to_jsonable(obj: Any):
"""Recursively convert values into JSON-serializable forms."""
if obj is None or isinstance(obj, (str, int, float, bool)):
return obj
if isinstance(obj, (datetime, date)):
return obj.isoformat()
if isinstance(obj, Decimal):
return float(obj)
if isinstance(obj, Enum):
return obj.value
if isinstance(obj, dict):
return {str(k): to_jsonable(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple, set)):
return [to_jsonable(v) for v in obj]
# fallback: strin-ify weird objects (UUID, ORM instances, etc.)
try:
return str(obj)
except Exception:
return None
def filter_to_columns(data: dict, model_cls):
cols = {c.key for c in inspect(model_cls).mapper.columns}
return {k: v for k, v in data.items() if k in cols}
def _parse_dt_maybe(x: Any) -> Any:
if isinstance(x, (datetime, date)):
return x
if isinstance(x, str):
s = x.strip().replace("Z", "+00:00") # tolerate Zulu
for fmt in ISO_DT_FORMATS:
try:
return datetime.strptime(s, fmt)
except ValueError:
pass
try:
return datetime.fromisoformat(s)
except Exception:
return x
return x
def _normalize_for_compare(x: Any) -> Any:
if isinstance(x, (str, datetime, date)):
return _parse_dt_maybe(x)
return x
def deep_diff(
old: Any,
new: Any,
*,
path: str = "",
ignore_keys: Optional[set] = None,
list_mode: str = "index", # "index" or "set"
custom_equal: Optional[Callable[[str, Any, Any], bool]] = None,
) -> Dict[str, Dict[str, Any]]:
if ignore_keys is None:
ignore_keys = set()
out: Dict[str, Dict[str, Any]] = {"added": {}, "removed": {}, "changed": {}}
def mark_changed(p, a, b):
out["changed"][p] = {"from": a, "to": b}
def rec(o, n, pfx):
if custom_equal and custom_equal(pfx.rstrip("."), o, n):
return
if isinstance(o, dict) and isinstance(n, dict):
o_keys = set(o.keys())
n_keys = set(n.keys())
for k in sorted(o_keys - n_keys):
if k not in ignore_keys:
out["removed"][f"{pfx}{k}"] = o[k]
for k in sorted(n_keys - o_keys):
if k not in ignore_keys:
out["added"][f"{pfx}{k}"] = n[k]
for k in sorted(o_keys & n_keys):
if k not in ignore_keys:
rec(o[k], n[k], f"{pfx}{k}.")
return
if isinstance(o, list) and isinstance(n, list):
if list_mode == "set":
if set(o) != set(n):
mark_changed(pfx.rstrip("."), o, n)
else:
max_len = max(len(o), len(n))
for i in range(max_len):
key = f"{pfx}[{i}]"
if i >= len(o):
out["added"][key] = n[i]
elif i >= len(n):
out["removed"][key] = o[i]
else:
rec(o[i], n[i], f"{key}.")
return
a = _normalize_for_compare(o)
b = _normalize_for_compare(n)
if a != b:
mark_changed(pfx.rstrip("."), o, n)
rec(old, new, path)
return out
def diff_to_patch(diff: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
"""
Produce a shallow patch of changed/added top-level fields.
Only includes leaf paths without dots/brackets; useful for simple UPDATEs.
"""
patch: Dict[str, Any] = {}
for k, v in diff["added"].items():
if "." not in k and "[" not in k:
patch[k] = v
for k, v in diff["changed"].items():
if "." not in k and "[" not in k:
patch[k] = v["to"]
return patch
def normalize_payload(payload: dict, model):
"""
Coerce incoming JSON into SQLAlchemy column types for the given model.
- "" or None -> None
- Integer/Boolean/Date/DateTime handled by column type
"""
from sqlalchemy import Integer, Boolean, DateTime, Date
out: Dict[str, Any] = {}
mapper = inspect(model).mapper
cols = {c.key: c.type for c in mapper.columns}
for field, value in payload.items():
if value == "" or value is None:
out[field] = None
continue
coltype = cols.get(field)
if coltype is None:
out[field] = value
continue
tname = coltype.__class__.__name__.lower()
if "integer" in tname:
out[field] = int(value)
elif "boolean" in tname:
out[field] = value if isinstance(value, bool) else str(value).lower() in ("1", "true", "yes", "on")
elif "datetime" in tname:
out[field] = value if isinstance(value, datetime) else _parse_dt_maybe(value)
elif "date" in tname:
v = _parse_dt_maybe(value)
out[field] = v.date() if isinstance(v, datetime) else v
else:
out[field] = value
return out

View file

@ -66,13 +66,46 @@ _ENUMS = {
} }
def get_env(): def get_env():
"""
Return an overlay Jinja Environment that knows how to load crudkit templates
and has our helper functions available as globals.
"""
app = current_app app = current_app
default_path = os.path.join(os.path.dirname(__file__), 'templates') default_path = os.path.join(os.path.dirname(__file__), 'templates')
fallback_loader = FileSystemLoader(default_path) fallback_loader = FileSystemLoader(default_path)
return app.jinja_env.overlay( env = app.jinja_env.overlay(loader=ChoiceLoader([app.jinja_loader, fallback_loader]))
loader=ChoiceLoader([app.jinja_loader, fallback_loader]) # Ensure helpers are available even when we render via this overlay env.
) # These names are resolved at *call time* (not at def time), so it's safe.
try:
env.globals.setdefault("render_table", render_table)
env.globals.setdefault("render_form", render_form)
env.globals.setdefault("render_field", render_field)
except NameError:
# Functions may not be defined yet at import time; later calls will set them.
pass
return env
def register_template_globals(app=None):
"""
Register crudkit helpers as app-wide Jinja globals so they can be used
directly in any template via {{ render_table(...) }}, {{ render_form(...) }},
and {{ render_field(...) }}.
"""
if app is None:
app = current_app
# Idempotent install using an extension flag
installed = app.extensions.setdefault("crudkit_ui_helpers", set())
to_register = {
"render_table": render_table,
"render_form": render_form,
"render_field": render_field,
}
for name, fn in to_register.items():
if name not in installed:
app.add_template_global(fn, name)
installed.add(name)
def expand_projection(model_cls, fields): def expand_projection(model_cls, fields):
req = getattr(model_cls, "__crudkit_field_requires__", {}) or {} req = getattr(model_cls, "__crudkit_field_requires__", {}) or {}
@ -1189,5 +1222,6 @@ def render_form(
values=values_map, values=values_map,
render_field=render_field, render_field=render_field,
submit_attrs=submit_attrs, submit_attrs=submit_attrs,
submit_label=submit_label submit_label=submit_label,
model_name=model_cls.__name__
) )

View file

@ -1,5 +1,4 @@
{# show label unless hidden/custom #} {# show label unless hidden/custom #}
<!-- {{ field_name }} (field) -->
{% if field_type != 'hidden' and field_label %} {% if field_type != 'hidden' and field_label %}
<label for="{{ field_name }}" <label for="{{ field_name }}"
{% if label_attrs %}{% for k,v in label_attrs.items() %} {% if label_attrs %}{% for k,v in label_attrs.items() %}

View file

@ -1,6 +1,5 @@
<form method="POST"> <form method="POST" id="{{ model_name|lower }}_form">
{% macro render_row(row) %} {% macro render_row(row) %}
<!-- {{ row.name }} (row) -->
{% if row.fields or row.children or row.legend %} {% if row.fields or row.children or row.legend %}
{% if row.legend %}<legend>{{ row.legend }}</legend>{% endif %} {% if row.legend %}<legend>{{ row.legend }}</legend>{% endif %}
<fieldset <fieldset