Getting more patch logic in.
This commit is contained in:
parent
cab35b72ec
commit
c040ff74c9
5 changed files with 196 additions and 175 deletions
|
|
@ -0,0 +1,8 @@
|
|||
# crudkit/core/__init__.py
|
||||
from .utils import (
|
||||
ISO_DT_FORMATS,
|
||||
normalize_payload,
|
||||
deep_diff,
|
||||
diff_to_patch,
|
||||
filter_to_columns,
|
||||
)
|
||||
|
|
@ -10,6 +10,7 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute
|
|||
from sqlalchemy.sql import operators
|
||||
from sqlalchemy.sql.elements import UnaryExpression, ColumnElement
|
||||
|
||||
from crudkit.core import deep_diff, diff_to_patch
|
||||
from crudkit.core.base import Version
|
||||
from crudkit.core.spec import CRUDSpec
|
||||
from crudkit.core.types import OrderSpec, SeekWindow
|
||||
|
|
@ -660,16 +661,41 @@ class CRUDService(Generic[T]):
|
|||
session = self.session
|
||||
obj = session.get(self.model, id)
|
||||
if not obj:
|
||||
raise ValueError(f"{self.model.__name__} with ID {id} not found.")
|
||||
valid_fields = {c.name for c in self.model.__table__.columns}
|
||||
unknown = set(data) - valid_fields
|
||||
raise ValueError("f{self.model.__name__} id ID {id} not found.")
|
||||
|
||||
# Only touch real columns
|
||||
valid = {c.name for c in self.model.__table__.columns}
|
||||
unknown = set(data) - valid
|
||||
if unknown:
|
||||
raise ValueError(f"Unknown fields: {', '.join(sorted(unknown))}")
|
||||
|
||||
# BEFORE snapshot (non-lazy, columns-only is fine)
|
||||
before = obj.as_dict()
|
||||
|
||||
# Apply patch
|
||||
for k, v in data.items():
|
||||
if k in valid_fields:
|
||||
setattr(obj, k, v)
|
||||
setattr(obj, k, v)
|
||||
|
||||
# If nothing changed at ORM level, bail early
|
||||
if not session.is_modified(obj, include_collections=False):
|
||||
return obj
|
||||
|
||||
session.commit()
|
||||
self._log_version("update", obj, actor)
|
||||
|
||||
after = obj.as_dict()
|
||||
|
||||
diff = deep_diff(
|
||||
before,
|
||||
after,
|
||||
ignore_keys={"updated_at", "created_at", "id"},
|
||||
list_mode="index",
|
||||
)
|
||||
|
||||
# If somehow no diff, do not spam versions
|
||||
if not(diff["added"] or diff["removed"] or diff["changed"]):
|
||||
return obj
|
||||
|
||||
self._log_version("update", obj, actor, metadata={"diff": diff})
|
||||
return obj
|
||||
|
||||
def delete(self, id: int, hard: bool = False, actor = None):
|
||||
|
|
@ -688,17 +714,19 @@ class CRUDService(Generic[T]):
|
|||
|
||||
def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict | None = None):
|
||||
session = self.session
|
||||
snapshot = {}
|
||||
try:
|
||||
data = obj.as_dict()
|
||||
snapshot = obj.as_dict()
|
||||
except Exception:
|
||||
data = {"error": "Failed to serialize object."}
|
||||
snapshot = {"error": "serialize failed"}
|
||||
|
||||
version = Version(
|
||||
model_name=self.model.__name__,
|
||||
object_id=obj.id,
|
||||
change_type=change_type,
|
||||
data=data,
|
||||
data=snapshot,
|
||||
actor=str(actor) if actor else None,
|
||||
meta=metadata
|
||||
meta=metadata or None,
|
||||
)
|
||||
session.add(version)
|
||||
session.commit()
|
||||
|
|
|
|||
148
crudkit/core/utils.py
Normal file
148
crudkit/core/utils.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
from __future__ import annotations
|
||||
from datetime import datetime, date
|
||||
from typing import Any, Dict, Optional, Callable
|
||||
from sqlalchemy import inspect
|
||||
|
||||
ISO_DT_FORMATS = ("%Y-%m-%dT%H:%M:%S.%f",
|
||||
"%Y-%m-%dT%H:%M:%S",
|
||||
"%Y-%m-%d %H:%M",
|
||||
"%Y-%m-%d")
|
||||
|
||||
def filter_to_columns(data: dict, model_cls):
|
||||
cols = {c.key for c in inspect(model_cls).mapper.columns}
|
||||
return {k: v for k, v in data.items() if k in cols}
|
||||
|
||||
def _parse_dt_maybe(x: Any) -> Any:
|
||||
if isinstance(x, (datetime, date)):
|
||||
return x
|
||||
if isinstance(x, str):
|
||||
s = x.strip().replace("Z", "+00:00") # tolerate Zulu
|
||||
for fmt in ISO_DT_FORMATS:
|
||||
try:
|
||||
return datetime.strptime(s, fmt)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
return datetime.fromisoformat(s)
|
||||
except Exception:
|
||||
return x
|
||||
return x
|
||||
|
||||
def _normalize_for_compare(x: Any) -> Any:
|
||||
if isinstance(x, (str, datetime, date)):
|
||||
return _parse_dt_maybe(x)
|
||||
return x
|
||||
|
||||
def deep_diff(
|
||||
old: Any,
|
||||
new: Any,
|
||||
*,
|
||||
path: str = "",
|
||||
ignore_keys: Optional[set] = None,
|
||||
list_mode: str = "index", # "index" or "set"
|
||||
custom_equal: Optional[Callable[[str, Any, Any], bool]] = None,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
if ignore_keys is None:
|
||||
ignore_keys = set()
|
||||
|
||||
out: Dict[str, Dict[str, Any]] = {"added": {}, "removed": {}, "changed": {}}
|
||||
|
||||
def mark_changed(p, a, b):
|
||||
out["changed"][p] = {"from": a, "to": b}
|
||||
|
||||
def rec(o, n, pfx):
|
||||
if custom_equal and custom_equal(pfx.rstrip("."), o, n):
|
||||
return
|
||||
|
||||
if isinstance(o, dict) and isinstance(n, dict):
|
||||
o_keys = set(o.keys())
|
||||
n_keys = set(n.keys())
|
||||
|
||||
for k in sorted(o_keys - n_keys):
|
||||
if k not in ignore_keys:
|
||||
out["removed"][f"{pfx}{k}"] = o[k]
|
||||
|
||||
for k in sorted(n_keys - o_keys):
|
||||
if k not in ignore_keys:
|
||||
out["added"][f"{pfx}{k}"] = n[k]
|
||||
|
||||
for k in sorted(o_keys & n_keys):
|
||||
if k not in ignore_keys:
|
||||
rec(o[k], n[k], f"{pfx}{k}.")
|
||||
return
|
||||
|
||||
if isinstance(o, list) and isinstance(n, list):
|
||||
if list_mode == "set":
|
||||
if set(o) != set(n):
|
||||
mark_changed(pfx.rstrip("."), o, n)
|
||||
else:
|
||||
max_len = max(len(o), len(n))
|
||||
for i in range(max_len):
|
||||
key = f"{pfx}[{i}]"
|
||||
if i >= len(o):
|
||||
out["added"][key] = n[i]
|
||||
elif i >= len(n):
|
||||
out["removed"][key] = o[i]
|
||||
else:
|
||||
rec(o[i], n[i], f"{key}.")
|
||||
return
|
||||
|
||||
a = _normalize_for_compare(o)
|
||||
b = _normalize_for_compare(n)
|
||||
if a != b:
|
||||
mark_changed(pfx.rstrip("."), o, n)
|
||||
|
||||
rec(old, new, path)
|
||||
return out
|
||||
|
||||
def diff_to_patch(diff: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Produce a shallow patch of changed/added top-level fields.
|
||||
Only includes leaf paths without dots/brackets; useful for simple UPDATEs.
|
||||
"""
|
||||
patch: Dict[str, Any] = {}
|
||||
for k, v in diff["added"].items():
|
||||
if "." not in k and "[" not in k:
|
||||
patch[k] = v
|
||||
for k, v in diff["changed"].items():
|
||||
if "." not in k and "[" not in k:
|
||||
patch[k] = v["to"]
|
||||
return patch
|
||||
|
||||
def normalize_payload(payload: dict, model):
|
||||
"""
|
||||
Coerce incoming JSON into SQLAlchemy column types for the given model.
|
||||
- "" or None -> None
|
||||
- Integer/Boolean/Date/DateTime handled by column type
|
||||
"""
|
||||
from sqlalchemy import Integer, Boolean, DateTime, Date
|
||||
out: Dict[str, Any] = {}
|
||||
|
||||
mapper = inspect(model).mapper
|
||||
cols = {c.key: c.type for c in mapper.columns}
|
||||
|
||||
for field, value in payload.items():
|
||||
if value == "" or value is None:
|
||||
out[field] = None
|
||||
continue
|
||||
|
||||
coltype = cols.get(field)
|
||||
if coltype is None:
|
||||
out[field] = value
|
||||
continue
|
||||
|
||||
tname = coltype.__class__.__name__.lower()
|
||||
|
||||
if "integer" in tname:
|
||||
out[field] = int(value)
|
||||
elif "boolean" in tname:
|
||||
out[field] = value if isinstance(value, bool) else str(value).lower() in ("1", "true", "yes", "on")
|
||||
elif "datetime" in tname:
|
||||
out[field] = value if isinstance(value, datetime) else _parse_dt_maybe(value)
|
||||
elif "date" in tname:
|
||||
v = _parse_dt_maybe(value)
|
||||
out[field] = v.date() if isinstance(v, datetime) else v
|
||||
else:
|
||||
out[field] = value
|
||||
|
||||
return out
|
||||
Loading…
Add table
Add a link
Reference in a new issue