Getting more patch logic in.

This commit is contained in:
Yaro Kasear 2025-10-01 09:38:40 -05:00
parent cab35b72ec
commit c040ff74c9
5 changed files with 196 additions and 175 deletions

2
.gitignore vendored
View file

@ -3,7 +3,7 @@ inventory/static/uploads/*
!inventory/static/uploads/.gitkeep
.venv/
.env
*.db
*.db*
*.db-journal
*.sqlite
*.sqlite3

View file

@ -0,0 +1,8 @@
# crudkit/core/__init__.py
from .utils import (
ISO_DT_FORMATS,
normalize_payload,
deep_diff,
diff_to_patch,
filter_to_columns,
)

View file

@ -10,6 +10,7 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import UnaryExpression, ColumnElement
from crudkit.core import deep_diff, diff_to_patch
from crudkit.core.base import Version
from crudkit.core.spec import CRUDSpec
from crudkit.core.types import OrderSpec, SeekWindow
@ -660,16 +661,41 @@ class CRUDService(Generic[T]):
session = self.session
obj = session.get(self.model, id)
if not obj:
raise ValueError(f"{self.model.__name__} with ID {id} not found.")
valid_fields = {c.name for c in self.model.__table__.columns}
unknown = set(data) - valid_fields
raise ValueError("f{self.model.__name__} id ID {id} not found.")
# Only touch real columns
valid = {c.name for c in self.model.__table__.columns}
unknown = set(data) - valid
if unknown:
raise ValueError(f"Unknown fields: {', '.join(sorted(unknown))}")
# BEFORE snapshot (non-lazy, columns-only is fine)
before = obj.as_dict()
# Apply patch
for k, v in data.items():
if k in valid_fields:
setattr(obj, k, v)
setattr(obj, k, v)
# If nothing changed at ORM level, bail early
if not session.is_modified(obj, include_collections=False):
return obj
session.commit()
self._log_version("update", obj, actor)
after = obj.as_dict()
diff = deep_diff(
before,
after,
ignore_keys={"updated_at", "created_at", "id"},
list_mode="index",
)
# If somehow no diff, do not spam versions
if not(diff["added"] or diff["removed"] or diff["changed"]):
return obj
self._log_version("update", obj, actor, metadata={"diff": diff})
return obj
def delete(self, id: int, hard: bool = False, actor = None):
@ -688,17 +714,19 @@ class CRUDService(Generic[T]):
def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict | None = None):
session = self.session
snapshot = {}
try:
data = obj.as_dict()
snapshot = obj.as_dict()
except Exception:
data = {"error": "Failed to serialize object."}
snapshot = {"error": "serialize failed"}
version = Version(
model_name=self.model.__name__,
object_id=obj.id,
change_type=change_type,
data=data,
data=snapshot,
actor=str(actor) if actor else None,
meta=metadata
meta=metadata or None,
)
session.add(version)
session.commit()

148
crudkit/core/utils.py Normal file
View file

@ -0,0 +1,148 @@
from __future__ import annotations
from datetime import datetime, date
from typing import Any, Dict, Optional, Callable
from sqlalchemy import inspect
ISO_DT_FORMATS = ("%Y-%m-%dT%H:%M:%S.%f",
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%d %H:%M",
"%Y-%m-%d")
def filter_to_columns(data: dict, model_cls):
cols = {c.key for c in inspect(model_cls).mapper.columns}
return {k: v for k, v in data.items() if k in cols}
def _parse_dt_maybe(x: Any) -> Any:
if isinstance(x, (datetime, date)):
return x
if isinstance(x, str):
s = x.strip().replace("Z", "+00:00") # tolerate Zulu
for fmt in ISO_DT_FORMATS:
try:
return datetime.strptime(s, fmt)
except ValueError:
pass
try:
return datetime.fromisoformat(s)
except Exception:
return x
return x
def _normalize_for_compare(x: Any) -> Any:
if isinstance(x, (str, datetime, date)):
return _parse_dt_maybe(x)
return x
def deep_diff(
old: Any,
new: Any,
*,
path: str = "",
ignore_keys: Optional[set] = None,
list_mode: str = "index", # "index" or "set"
custom_equal: Optional[Callable[[str, Any, Any], bool]] = None,
) -> Dict[str, Dict[str, Any]]:
if ignore_keys is None:
ignore_keys = set()
out: Dict[str, Dict[str, Any]] = {"added": {}, "removed": {}, "changed": {}}
def mark_changed(p, a, b):
out["changed"][p] = {"from": a, "to": b}
def rec(o, n, pfx):
if custom_equal and custom_equal(pfx.rstrip("."), o, n):
return
if isinstance(o, dict) and isinstance(n, dict):
o_keys = set(o.keys())
n_keys = set(n.keys())
for k in sorted(o_keys - n_keys):
if k not in ignore_keys:
out["removed"][f"{pfx}{k}"] = o[k]
for k in sorted(n_keys - o_keys):
if k not in ignore_keys:
out["added"][f"{pfx}{k}"] = n[k]
for k in sorted(o_keys & n_keys):
if k not in ignore_keys:
rec(o[k], n[k], f"{pfx}{k}.")
return
if isinstance(o, list) and isinstance(n, list):
if list_mode == "set":
if set(o) != set(n):
mark_changed(pfx.rstrip("."), o, n)
else:
max_len = max(len(o), len(n))
for i in range(max_len):
key = f"{pfx}[{i}]"
if i >= len(o):
out["added"][key] = n[i]
elif i >= len(n):
out["removed"][key] = o[i]
else:
rec(o[i], n[i], f"{key}.")
return
a = _normalize_for_compare(o)
b = _normalize_for_compare(n)
if a != b:
mark_changed(pfx.rstrip("."), o, n)
rec(old, new, path)
return out
def diff_to_patch(diff: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
"""
Produce a shallow patch of changed/added top-level fields.
Only includes leaf paths without dots/brackets; useful for simple UPDATEs.
"""
patch: Dict[str, Any] = {}
for k, v in diff["added"].items():
if "." not in k and "[" not in k:
patch[k] = v
for k, v in diff["changed"].items():
if "." not in k and "[" not in k:
patch[k] = v["to"]
return patch
def normalize_payload(payload: dict, model):
"""
Coerce incoming JSON into SQLAlchemy column types for the given model.
- "" or None -> None
- Integer/Boolean/Date/DateTime handled by column type
"""
from sqlalchemy import Integer, Boolean, DateTime, Date
out: Dict[str, Any] = {}
mapper = inspect(model).mapper
cols = {c.key: c.type for c in mapper.columns}
for field, value in payload.items():
if value == "" or value is None:
out[field] = None
continue
coltype = cols.get(field)
if coltype is None:
out[field] = value
continue
tname = coltype.__class__.__name__.lower()
if "integer" in tname:
out[field] = int(value)
elif "boolean" in tname:
out[field] = value if isinstance(value, bool) else str(value).lower() in ("1", "true", "yes", "on")
elif "datetime" in tname:
out[field] = value if isinstance(value, datetime) else _parse_dt_maybe(value)
elif "date" in tname:
v = _parse_dt_maybe(value)
out[field] = v.date() if isinstance(v, datetime) else v
else:
out[field] = value
return out

View file

@ -5,169 +5,10 @@ from typing import Any, Dict, List, Tuple, Callable, Optional
import crudkit
from crudkit.ui.fragments import render_form
ISO_DT_FORMATS = ("%Y-%m-%dT%H:%M:%S.%f", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M", "%Y-%m-%d")
from crudkit.core import normalize_payload
bp_entry = Blueprint("entry", __name__)
def filter_to_columns(data: dict, model_cls):
cols = {c.key for c in inspect(model_cls).mapper.columns}
return {k: v for k, v in data.items() if k in cols}
def _parse_dt_maybe(x: Any) -> Any:
if isinstance(x, datetime):
return x
if isinstance(x, str):
s = x.strip()
for fmt in ISO_DT_FORMATS:
try:
return datetime.strptime(s, fmt)
except ValueError:
pass
try:
return datetime.fromisoformat(s)
except Exception:
return x
return x
def _normalize_for_compare(x: Any) -> Any:
if isinstance(x, (str, datetime)):
return _parse_dt_maybe(x)
return x
def deep_diff(
old: Any,
new: Any,
*,
path: str = "",
ignore_keys: Optional[set] = None,
list_mode: str = "index", # "index" or "set"
custom_equal: Optional[Callable[[str, Any, Any], bool]] = None,
) -> Dict[str, Dict[str, Any]]:
"""
Returns {'added': {...}, 'removed': {...}, 'changed': {...}}
Paths use dot notation for dicts and [i] for lists.
"""
if ignore_keys is None:
ignore_keys = set()
out: Dict[str, Dict[str, Any]] = {"added": {}, "removed": {}, "changed": {}}
def mark_changed(p, a, b):
out["changed"][p] = {"from": a, "to": b}
def rec(o, n, pfx):
# custom equality short-circuit
if custom_equal and custom_equal(pfx.rstrip("."), o, n):
return
# Dict vs Dict
if isinstance(o, dict) and isinstance(n, dict):
o_keys = set(o.keys())
n_keys = set(n.keys())
# removed
for k in sorted(o_keys - n_keys):
if k in ignore_keys:
continue
out["removed"][f"{pfx}{k}"] = o[k]
# added
for k in sorted(n_keys - o_keys):
if k in ignore_keys:
continue
out["added"][f"{pfx}{k}"] = n[k]
# present in both -> recurse
for k in sorted(o_keys & n_keys):
if k in ignore_keys:
continue
rec(o[k], n[k], f"{pfx}{k}.")
return
# List vs List
if isinstance(o, list) and isinstance(n, list):
if list_mode == "set":
if set(o) != set(n):
mark_changed(pfx.rstrip("."), o, n)
else:
max_len = max(len(o), len(n))
for i in range(max_len):
key = f"{pfx}[{i}]"
if i >= len(o):
out["added"][key] = n[i]
elif i >= len(n):
out["removed"][key] = o[i]
else:
rec(o[i], n[i], f"{key}.")
return
# Scalars or type mismatch
a = _normalize_for_compare(o)
b = _normalize_for_compare(n)
if a != b:
mark_changed(pfx.rstrip("."), o, n)
rec(old, new, path)
return out
def diff_to_patch(diff: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
"""
Produce a shallow patch of changed/added top-level fields.
Onky includes leaf paths without dots/brackets; useful for simple UPDATEs.
"""
patch = {}
for k, v in diff["added"].items():
if "." not in k and "[" not in k:
patch[k] = v
for k, v in diff["changed"].items():
if "." not in k and "[" not in k:
patch[k] = v["to"]
return patch
def normalize_payload(payload: dict, model):
"""
Take raw JSON dict from frontend and coerce valies
into types expected by the SQLAlchemy model.
"""
out = {}
for field, value in payload.items():
if value == "" or value is None:
out[field] = None
continue
# Look up the SQLAlchemy column type if available
col = getattr(model, field, None)
coltype = getattr(col, "type", None)
if coltype is not None:
tname = coltype.__class__.__name__.lower()
if "integer" in tname:
out[field] = int(value)
elif "boolean" in tname:
# frontend may send true/false already,
# or string "true"/"false"
if isinstance(value, bool):
out[field] = value
else:
out[field] = str(value).lower() in ("1", "true", "yes", "on")
elif "datetime" in tname:
out[field] = (
value if isinstance(value, datetime)
else datetime.fromisoformat(value)
)
else:
out[field] = value
else:
out[field] = value
return out
def init_entry_routes(app):
@bp_entry.get("/entry/<model>/<int:id>")
@ -357,10 +198,6 @@ def init_entry_routes(app):
service = crudkit.crud.get_service(cls)
item = service.get(id, params)
d = deep_diff(item.as_dict(), payload, ignore_keys={"id", "created_at", "updated_at"})
patch = diff_to_patch(d)
clean_patch = filter_to_columns(patch, cls)
print(f"OLD = {item.as_dict()}\n\nNEW = {payload}\n\nDIFF = {d}\n\nPATCH = {patch}\n\nCLEAN PATCH = {clean_patch}")
return {"status": "success", "payload": payload}
except Exception as e: