Compare commits

..

No commits in common. "main" and "Redesign1" have entirely different histories.

15 changed files with 671 additions and 2651 deletions

View file

@ -1,135 +1,21 @@
# crudkit/api/_cursor.py
import base64, json
from typing import Any
from __future__ import annotations
import base64
import dataclasses
import datetime as _dt
import decimal as _dec
import hmac
import json
import typing as _t
from hashlib import sha256
Any = _t.Any
@dataclasses.dataclass(frozen=True)
class Cursor:
values: list[Any]
desc_flags: list[bool]
backward: bool
version: int = 1
def _json_default(o: Any) -> Any:
# Keep it boring and predictable.
if isinstance(o, (_dt.datetime, _dt.date, _dt.time)):
# ISO is good enough; assume UTC-aware datetimes are already normalized upstream.
return o.isoformat()
if isinstance(o, _dec.Decimal):
return str(o)
# UUIDs, Enums, etc.
if hasattr(o, "__str__"):
return str(o)
raise TypeError(f"Unsupported cursor value type: {type(o)!r}")
def _b64url_nopad_encode(b: bytes) -> str:
return base64.urlsafe_b64encode(b).rstrip(b"=").decode("ascii")
def _b64url_nopad_decode(s: str) -> bytes:
# Restore padding for strict decode
pad = "=" * (-len(s) % 4)
return base64.urlsafe_b64decode((s + pad).encode("ascii"))
def encode_cursor(
values: list[Any] | None,
desc_flags: list[bool],
backward: bool,
*,
secret: bytes | None = None,
) -> str | None:
"""
Create an opaque, optionally signed cursor token.
- values: keyset values for the last/first row
- desc_flags: per-key descending flags (same length/order as sort keys)
- backward: whether the window direction is backward
- secret: if set, an HMAC-SHA256 signature is appended (JWT-ish 'payload.sig')
"""
def encode_cursor(values: list[Any] | None, desc_flags: list[bool], backward: bool) -> str | None:
if not values:
return None
payload = {"v": values, "d": desc_flags, "b": backward}
return base64.urlsafe_b64encode(json.dumps(payload).encode()).decode()
payload = {
"ver": 1,
"v": values,
"d": desc_flags,
"b": bool(backward),
}
body = json.dumps(payload, default=_json_default, separators=(",", ":"), sort_keys=True).encode("utf-8")
token = _b64url_nopad_encode(body)
if secret:
sig = hmac.new(secret, body, sha256).digest()
token = f"{token}.{_b64url_nopad_encode(sig)}"
return token
def decode_cursor(
token: str | None,
*,
secret: bytes | None = None,
) -> tuple[list[Any] | None, list[bool] | None, bool]:
"""
Parse a cursor token. Returns (values, desc_flags, backward).
- Accepts legacy tokens lacking 'd' and returns desc_flags=None in that case.
- If secret is provided, verifies HMAC and rejects tampered tokens.
- On any parse failure, returns (None, None, False).
"""
def decode_cursor(token: str | None) -> tuple[list[Any] | None, bool] | tuple[None, bool]:
if not token:
return None, None, False
return None, False
try:
# Split payload.sig if signed
if "." in token:
body_b64, sig_b64 = token.split(".", 1)
body = _b64url_nopad_decode(body_b64)
if secret is None:
# Caller didnt ask for verification; still parse but dont trust.
pass
else:
expected = hmac.new(secret, body, sha256).digest()
actual = _b64url_nopad_decode(sig_b64)
if not hmac.compare_digest(expected, actual):
return None, None, False
else:
body = _b64url_nopad_decode(token)
obj = json.loads(body.decode("utf-8"))
# Versioning. If we ever change fields, branch here.
ver = int(obj.get("ver", 0))
if ver not in (0, 1):
return None, None, False
obj = json.loads(base64.urlsafe_b64decode(token.encode()).decode())
vals = obj.get("v")
backward = bool(obj.get("b", False))
# desc_flags may be absent in legacy payloads (ver 0)
desc = obj.get("d")
if not isinstance(vals, list):
return None, None, False
if desc is not None and not (isinstance(desc, list) and all(isinstance(x, bool) for x in desc)):
# Ignore weird 'd' types rather than crashing
desc = None
return vals, desc, backward
if isinstance(vals, list):
return vals, backward
except Exception:
# Be tolerant on decode: treat as no-cursor.
return None, None, False
pass
return None, False

View file

@ -1,195 +1,90 @@
# crudkit/api/flask_api.py
from __future__ import annotations
from flask import Blueprint, jsonify, request, abort, current_app, url_for
from hashlib import md5
from urllib.parse import urlencode
from werkzeug.exceptions import HTTPException
from flask import Blueprint, jsonify, request
from crudkit.api._cursor import encode_cursor, decode_cursor
from crudkit.core.service import _is_truthy
MAX_JSON = 1_000_000
def generate_crud_blueprint(model, service):
bp = Blueprint(model.__name__.lower(), __name__)
def _etag_for(obj) -> str:
v = getattr(obj, "updated_at", None) or obj.id
return md5(str(v).encode()).hexdigest()
@bp.get('/')
def list_items():
args = request.args.to_dict(flat=True)
def _json_payload() -> dict:
if request.content_length and request.content_length > MAX_JSON:
abort(413)
if not request.is_json:
abort(415)
payload = request.get_json(silent=False)
if not isinstance(payload, dict):
abort(400)
return payload
# legacy detection
legacy_offset = "offset" in args or "page" in args
def _args_flat() -> dict[str, str]:
return request.args.to_dict(flat=True) # type: ignore[arg-type]
# sane limit default
try:
limit = int(args.get("limit", 50))
except Exception:
limit = 50
args["limit"] = limit
def _json_error(e: Exception, status: int = 400):
if isinstance(e, HTTPException):
status = e.code or status
msg = e.description
else:
msg = str(e)
if current_app.debug:
return jsonify({"status": "error", "error": msg, "type": e.__class__.__name__}), status
return jsonify({"status": "error", "error": msg}), status
if legacy_offset:
# Old behavior: honor limit/offset, same CRUDSpec goodies
items = service.list(args)
return jsonify([obj.as_dict() for obj in items])
def _bool_param(d: dict[str, str], key: str, default: bool) -> bool:
return _is_truthy(d.get(key, "1" if default else "0"))
# New behavior: keyset seek with cursors
key, backward = decode_cursor(args.get("cursor"))
def generate_crud_blueprint(model, service, *, base_prefix: str | None = None, rest: bool = True, rpc: bool = True):
"""
REST:
GET /api/<models>/ -> list (filters via ?q=..., sort=..., limit=..., cursor=...)
GET /api/<models>/<id> -> get
POST /api/<models>/ -> create
PATCH /api/<models>/<id> -> update (partial)
DELETE /api/<models>/<id>[?hard=1] -> delete
window = service.seek_window(
args,
key=key,
backward=backward,
include_total=_is_truthy(args.get("include_total", "1")),
)
RPC (legacy):
GET /api/<model>/get?id=123
GET /api/<model>/list
GET /api/<model>/seek_window
GET /api/<model>/page
POST /api/<model>/create
PATCH /api/<model>/update?id=123
DELETE /api/<model>/delete?id=123[&hard=1]
"""
model_name = model.__name__.lower()
# bikeshed if you want pluralization; this is the least-annoying default
collection = (base_prefix or model_name).lower()
plural = collection if collection.endswith('s') else f"{collection}s"
desc_flags = list(window.order.desc)
body = {
"items": [obj.as_dict() for obj in window.items],
"limit": window.limit,
"next_cursor": encode_cursor(window.last_key, desc_flags, backward=False),
"prev_cursor": encode_cursor(window.first_key, desc_flags, backward=True),
"total": window.total,
}
bp = Blueprint(plural, __name__, url_prefix=f"/api/{plural}")
resp = jsonify(body)
# Optional Link header
links = []
if body["next_cursor"]:
links.append(f'<{request.base_url}?cursor={body["next_cursor"]}&limit={window.limit}>; rel="next"')
if body["prev_cursor"]:
links.append(f'<{request.base_url}?cursor={body["prev_cursor"]}&limit={window.limit}>; rel="prev"')
if links:
resp.headers["Link"] = ", ".join(links)
return resp
@bp.errorhandler(Exception)
def _handle_any(e: Exception):
return _json_error(e)
@bp.get('/<int:id>')
def get_item(id):
item = service.get(id, request.args)
try:
return jsonify(item.as_dict())
except Exception as e:
return jsonify({"status": "error", "error": str(e)})
@bp.errorhandler(404)
def _not_found(_e):
return jsonify({"status": "error", "error": "not found"}), 404
@bp.post('/')
def create_item():
obj = service.create(request.json)
try:
return jsonify(obj.as_dict())
except Exception as e:
return jsonify({"status": "error", "error": str(e)})
# ---------- REST ----------
if rest:
@bp.get("/")
def rest_list():
args = _args_flat()
# support cursor pagination transparently; fall back to limit/offset
try:
items = service.list(args)
return jsonify([o.as_dict() for o in items])
except Exception as e:
return _json_error(e)
@bp.patch('/<int:id>')
def update_item(id):
obj = service.update(id, request.json)
try:
return jsonify(obj.as_dict())
except Exception as e:
return jsonify({"status": "error", "error": str(e)})
@bp.get("/<int:obj_id>")
def rest_get(obj_id: int):
item = service.get(obj_id, request.args)
if item is None:
abort(404)
etag = _etag_for(item)
if request.if_none_match and (etag in request.if_none_match):
return "", 304
resp = jsonify(item.as_dict())
resp.set_etag(etag)
return resp
@bp.post("/")
def rest_create():
payload = _json_payload()
try:
obj = service.create(payload)
resp = jsonify(obj.as_dict())
resp.status_code = 201
resp.headers["Location"] = url_for(f"{plural}.rest_get", obj_id=obj.id, _external=False)
return resp
except Exception as e:
return _json_error(e)
@bp.patch("/<int:obj_id>")
def rest_update(obj_id: int):
payload = _json_payload()
try:
obj = service.update(obj_id, payload)
return jsonify(obj.as_dict())
except Exception as e:
return _json_error(e)
@bp.delete("/<int:obj_id>")
def rest_delete(obj_id: int):
hard = _bool_param(_args_flat(), "hard", False) # type: ignore[arg-type]
try:
obj = service.delete(obj_id, hard=hard)
if obj is None:
abort(404)
return ("", 204)
except Exception as e:
return _json_error(e)
# ---------- RPC (your existing routes) ----------
if rpc:
@bp.get("/get")
def rpc_get():
print("⚠️ WARNING: Deprecated RPC call used: /get")
id_ = int(request.args.get("id", 0))
if not id_:
return jsonify({"status": "error", "error": "missing required param: id"}), 400
try:
item = service.get(id_, request.args)
if item is None:
abort(404)
return jsonify(item.as_dict())
except Exception as e:
return _json_error(e)
@bp.get("/list")
def rpc_list():
print("⚠️ WARNING: Deprecated RPC call used: /list")
args = _args_flat()
try:
items = service.list(args)
return jsonify([obj.as_dict() for obj in items])
except Exception as e:
return _json_error(e)
@bp.post("/create")
def rpc_create():
print("⚠️ WARNING: Deprecated RPC call used: /create")
payload = _json_payload()
try:
obj = service.create(payload)
return jsonify(obj.as_dict()), 201
except Exception as e:
return _json_error(e)
@bp.patch("/update")
def rpc_update():
print("⚠️ WARNING: Deprecated RPC call used: /update")
id_ = int(request.args.get("id", 0))
if not id_:
return jsonify({"status": "error", "error": "missing required param: id"}), 400
payload = _json_payload()
try:
obj = service.update(id_, payload)
return jsonify(obj.as_dict())
except Exception as e:
return _json_error(e)
@bp.delete("/delete")
def rpc_delete():
print("⚠️ WARNING: Deprecated RPC call used: /delete")
id_ = int(request.args.get("id", 0))
if not id_:
return jsonify({"status": "error", "error": "missing required param: id"}), 400
hard = _bool_param(_args_flat(), "hard", False) # type: ignore[arg-type]
try:
obj = service.delete(id_, hard=hard)
return ("", 204) if obj is not None else abort(404)
except Exception as e:
return _json_error(e)
@bp.delete('/<int:id>')
def delete_item(id):
service.delete(id)
try:
return jsonify({"status": "success"}), 204
except Exception as e:
return jsonify({"status": "error", "error": str(e)})
return bp

View file

@ -74,32 +74,18 @@ def apply_pagination(sel: Select, backend: BackendInfo, *, page: int, per_page:
per_page = max(1, int(per_page))
offset = (page - 1) * per_page
if backend.requires_order_by_for_offset:
# Avoid private attribute if possible:
has_order = bool(getattr(sel, "_order_by_clauses", ())) # fallback for SA < 2.0.30
try:
has_order = has_order or bool(sel.get_order_by())
except Exception:
pass
if not has_order:
if default_order_by is not None:
sel = sel.order_by(default_order_by)
else:
# Try to find a primary key from the FROMs; fall back to a harmless literal.
try:
first_from = sel.get_final_froms()[0]
pk = next(iter(first_from.primary_key.columns))
sel = sel.order_by(pk)
except Exception:
sel = sel.order_by(text("1"))
if backend.requires_order_by_for_offset and not sel._order_by_clauses:
if default_order_by is None:
sel = sel.order_by(text("1"))
else:
sel = sel.order_by(default_order_by)
return sel.limit(per_page).offset(offset)
@contextmanager
def maybe_identify_insert(session: Session, table, backend: BackendInfo):
"""
For MSSQL tables with IDENTITY PK when you need to insert explicit IDs.
For MSSQL tables with IDENTIFY PK when you need to insert explicit IDs.
No-op elsewhere.
"""
if not backend.is_mssql:
@ -107,7 +93,7 @@ def maybe_identify_insert(session: Session, table, backend: BackendInfo):
return
full_name = f"{table.schema}.{table.name}" if table.schema else table.name
session.execute(text(f"SET IDENTITY_INSERT {full_name} ON"))
session.execute(text(f"SET IDENTIFY_INSERT {full_name} ON"))
try:
yield
finally:
@ -115,7 +101,7 @@ def maybe_identify_insert(session: Session, table, backend: BackendInfo):
def chunked_in(column, values: Iterable, backend: BackendInfo, chunk_size: Optional[int] = None) -> ClauseElement:
"""
Build a safe large IN() filter respecting bind param limits.
Build a safe large IN() filter respecting bund param limits.
Returns a disjunction of chunked IN clauses if needed.
"""
vals = list(values)
@ -134,12 +120,3 @@ def chunked_in(column, values: Iterable, backend: BackendInfo, chunk_size: Optio
for p in parts[1:]:
expr = expr | p
return expr
def sql_trim(expr, backend: BackendInfo):
"""
Portable TRIM. SQL Server before compat level 140 lacks TRIM().
Emit LTRIM(RTRIM(...)) there; use TRIM elsewhere
"""
if backend.is_mssql:
return func.ltrim(func.rtrim(expr))
return func.trim(expr)

View file

@ -187,8 +187,6 @@ class Config:
"synchronous": os.getenv("SQLITE_SYNCHRONOUS", "NORMAL"),
}
STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "1")))
@classmethod
def engine_kwargs(cls) -> Dict[str, Any]:
url = cls.DATABASE_URL
@ -223,18 +221,15 @@ class Config:
class DevConfig(Config):
DEBUG = True
SQLALCHEMY_ECHO = bool(int(os.getenv("DB_ECHO", "1")))
STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "1")))
class TestConfig(Config):
TESTING = True
DATABASE_URL = build_database_url(backend="sqlite", database=":memory:")
SQLALCHEMY_ECHO = False
STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "1")))
class ProdConfig(Config):
DEBUG = False
SQLALCHEMY_ECHO = bool(int(os.getenv("DB_ECHO", "0")))
STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "0")))
def get_config(name: str | None) -> Type[Config]:
"""

View file

@ -1,9 +0,0 @@
# crudkit/core/__init__.py
from .utils import (
ISO_DT_FORMATS,
normalize_payload,
deep_diff,
diff_to_patch,
filter_to_columns,
to_jsonable,
)

View file

@ -1,358 +1,47 @@
from functools import lru_cache
from typing import Any, Dict, Iterable, List, Optional, Tuple, Set, cast
from sqlalchemy import Column, Integer, DateTime, Boolean, String, JSON, func, inspect
from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE, RelationshipProperty, Mapper
from sqlalchemy.orm.state import InstanceState
from sqlalchemy import Column, Integer, DateTime, Boolean, String, JSON, func
from sqlalchemy.orm import declarative_mixin, declarative_base
Base = declarative_base()
@lru_cache(maxsize=512)
def _column_names_for_model(cls: type) -> tuple[str, ...]:
try:
mapper = inspect(cls)
return tuple(prop.key for prop in mapper.column_attrs)
except Exception:
names: list[str] = []
for c in cls.__mro__:
if hasattr(c, "__table__"):
names.extend(col.name for col in c.__table__.columns)
return tuple(dict.fromkeys(names))
def _sa_state(obj: Any) -> Optional[InstanceState[Any]]:
"""Safely get SQLAlchemy InstanceState (or None)."""
try:
st = inspect(obj)
return cast(Optional[InstanceState[Any]], st)
except Exception:
return None
def _sa_mapper(obj: Any) -> Optional[Mapper]:
"""Safely get Mapper for a maooed instance (or None)."""
try:
st = inspect(obj)
mapper = getattr(st, "mapper", None)
return cast(Optional[Mapper], mapper)
except Exception:
return None
def _safe_get_loaded_attr(obj, name):
st = _sa_state(obj)
if st is None:
return None
try:
st_dict = getattr(st, "dict", {})
if name in st_dict:
return st_dict[name]
attrs = getattr(st, "attrs", None)
attr = None
if attrs is not None:
try:
attr = attrs[name]
except Exception:
try:
get = getattr(attrs, "get", None)
if callable(get):
attr = get(name)
except Exception:
attr = None
if attr is not None:
val = attr.loaded_value
return None if val is NO_VALUE else val
return None
except Exception:
return None
def _identity_key(obj) -> Tuple[type, Any]:
try:
st = inspect(obj)
return (type(obj), st.identity_key[1][0] if st.identity_key else id(obj))
except Exception:
return (type(obj), id(obj))
def _is_collection_rel(prop: RelationshipProperty) -> bool:
try:
return prop.uselist is True
except Exception:
return False
def _serialize_simple_obj(obj) -> Dict[str, Any]:
"""Columns only (no relationships)."""
out: Dict[str, Any] = {}
for name in _column_names_for_model(type(obj)):
try:
out[name] = getattr(obj, name)
except Exception:
out[name] = None
return out
def _serialize_loaded_rel(obj, name, *, depth: int, seen: Set[Tuple[type, Any]], embed: Set[str]) -> Any:
"""
Serialize relationship 'name' already loaded on obj.
- If in 'embed' (or depth > 0 for depth-based walk), recurse.
- Else, return None (dont lazy-load).
"""
val = _safe_get_loaded_attr(obj, name)
if val is None:
return None
# Decide whether to recurse into this relationship
should_recurse = (depth > 0) or (name in embed)
if isinstance(val, list):
if not should_recurse:
# Emit a light list of child primary data (id + a couple columns) without recursion.
return [_serialize_simple_obj(child) for child in val]
out = []
for child in val:
ik = _identity_key(child)
if ik in seen: # cycle guard
out.append({"id": getattr(child, "id", None)})
continue
seen.add(ik)
out.append(child.as_tree(embed_depth=max(depth - 1, 0), embed=embed, _seen=seen))
return out
# Scalar relationship
child = val
if not should_recurse:
return _serialize_simple_obj(child)
ik = _identity_key(child)
if ik in seen:
return {"id": getattr(child, "id", None)}
seen.add(ik)
return child.as_tree(embed_depth=max(depth - 1, 0), embed=embed, _seen=seen)
def _split_field_tokens(fields: Iterable[str]) -> Tuple[List[str], Dict[str, List[str]]]:
"""
Split requested fields into:
- scalars: ["label", "name"]
- collections: {"updates": ["id", "timestamp","content"], "owner": ["label"]}
Any dotted token "root.rest.of.path" becomes collections[root].append("rest.of.path").
Bare tokens ("foo") land in scalars.
"""
scalars: List[str] = []
groups: Dict[str, List[str]] = {}
for raw in fields:
f = str(raw).strip()
if not f:
continue
# bare token -> scalar
if "." not in f:
scalars.append(f)
continue
# dotted token -> group under root
root, tail = f.split(".", 1)
if not root or not tail:
continue
groups.setdefault(root, []).append(tail)
return scalars, groups
def _deep_get_loaded(obj: Any, dotted: str) -> Any:
"""
Deep get with no lazy loads:
- For all but the final hop, use _safe_get_loaded_attr (mapped-only, no getattr).
- For the final hop, try _safe_get_loaded_attr first; if None, fall back to getattr()
to allow computed properties/hybrids that rely on already-loaded columns.
"""
parts = dotted.split(".")
if not parts:
return None
cur = obj
# Traverse up to the parent of the last token safely
for part in parts[:-1]:
if cur is None:
return None
cur = _safe_get_loaded_attr(cur, part)
if cur is None:
return None
last = parts[-1]
# Try safe fetch on the last hop first
val = _safe_get_loaded_attr(cur, last)
if val is not None:
return val
# Fall back to getattr for computed/hybrid attributes on an already-loaded object
try:
return getattr(cur, last, None)
except Exception:
return None
def _serialize_leaf(obj: Any) -> Any:
"""
Lead serialization for values we put into as_dict():
- If object has as_dict(), call as_dict() with no args (caller controls field shapes).
- Else return value as-is (Flask/JSON encoder will handle datetimes, etc., via app config).
"""
if obj is None:
return None
ad = getattr(obj, "as_dict", None)
if callable(ad):
try:
return ad(None)
except Exception:
return str(obj)
return obj
def _serialize_collection(items: Iterable[Any], requested_tails: List[str]) -> List[Dict[str, Any]]:
"""
Turn a collection of ORM objects into list[dict] with exactly requested_tails,
where each tail can be dotted again (e.g., "author.label"). We do NOT lazy-load.
"""
out: List[Dict[str, Any]] = []
# Deduplicate while preserving order
uniq_tails = list(dict.fromkeys(requested_tails))
for child in (items or []):
row: Dict[str, Any] = {}
for tail in uniq_tails:
row[tail] = _deep_get_loaded(child, tail)
# ensure id present if exists and not already requested
try:
if "id" not in row and hasattr(child, "id"):
row["id"] = getattr(child, "id")
except Exception:
pass
out.append(row)
return out
@declarative_mixin
class CRUDMixin:
id = Column(Integer, primary_key=True)
created_at = Column(DateTime, default=func.now(), nullable=False)
updated_at = Column(DateTime, default=func.now(), nullable=False, onupdate=func.now())
def as_tree(
self,
*,
embed_depth: int = 0,
embed: Iterable[str] | None = None,
_seen: Set[Tuple[type, Any]] | None = None,
) -> Dict[str, Any]:
"""
Recursive, NON-LAZY serializer.
- Always includes mapped columns.
- For relationships: only serializes those ALREADY LOADED.
- Recurses either up to embed_depth or for specific names in 'embed'.
- Keeps *_id columns alongside embedded objects.
- Cycle-safe via _seen.
"""
seen = _seen or set()
ik = _identity_key(self)
if ik in seen:
return {"id": getattr(self, "id", None)}
seen.add(ik)
data = _serialize_simple_obj(self)
# Determine which relationships to consider
try:
mapper = _sa_mapper(self)
embed_set = set(str(x).split(".", 1)[0] for x in (embed or []))
if mapper is None:
return data
st = _sa_state(self)
if st is None:
return data
for name, prop in mapper.relationships.items():
# Only touch relationships that are already loaded; never lazy-load here.
rel_loaded = getattr(st, "attrs", {}).get(name)
if rel_loaded is None or rel_loaded.loaded_value is NO_VALUE:
continue
data[name] = _serialize_loaded_rel(
self, name, depth=embed_depth, seen=seen, embed=embed_set
)
except Exception:
# If inspection fails, we just return columns.
pass
return data
def as_dict(self, fields: list[str] | None = None):
"""
Serialize the instance.
Behavior:
- If 'fields' (possibly dotted) is provided, emit exactly those keys.
* Bare tokens (e.g., "label", "owner") return the current loaded value.
* Dotted tokens for one-to-many (e.g., "updates.id","updates.timestamp")
produce a single "updates" key containing a list of dicts with the requested child keys.
* Dotted tokens for many-to-one/one-to-one (e.g., "owner.label") emit the scalar under "owner.label".
- Else, if '__crudkit_projection__' is set on the instance, use that.
- Else, fall back to all mapped columns on this class hierarchy.
Always includes 'id' when present unless explicitly excluded (i.e., fields explicitly provided without id).
- If 'fields' (possibly dotted) is provided, emit exactly those keys.
- Else, if '__crudkit_projection__' is set on the instance, emit those keys.
- Else, fall back to all mapped columns on this class hierarchy.
Always includes 'id' when present unless explicitly excluded.
"""
req = fields if fields is not None else getattr(self, "__crudkit_projection__", None)
if fields is None:
fields = getattr(self, "__crudkit_projection__", None)
if req:
# Normalize and split into (scalars, groups of dotted by root)
req_list = [p for p in (str(x).strip() for x in req) if p]
scalars, groups = _split_field_tokens(req_list)
out: Dict[str, Any] = {}
# Always include id unless the caller explicitly listed fields containing id
if "id" not in req_list and hasattr(self, "id"):
try:
out["id"] = getattr(self, "id")
except Exception:
pass
# Handle scalar tokens (may be columns, hybrids/properties, or relationships)
for name in scalars:
# Try loaded value first (never lazy-load)
val = _safe_get_loaded_attr(self, name)
# Final-hop getattr for root scalars (hybrids/@property) so they can compute.
if val is None:
try:
val = getattr(self, name)
except Exception:
val = None
# If it's a scalar ORM object (relationship), serialize its columns
mapper = _sa_mapper(val)
if mapper is not None:
out[name] = _serialize_simple_obj(val)
continue
# If it's a collection and no subfields were requested, emit a light list
if isinstance(val, (list, tuple)):
out[name] = [_serialize_leaf(v) for v in val]
else:
out[name] = val
# Handle dotted groups: root -> [tails]
for root, tails in groups.items():
root_val = _safe_get_loaded_attr(self, root)
if isinstance(root_val, (list, tuple)):
# one-to-many collection → list of dicts with the requested tails
out[root] = _serialize_collection(root_val, tails)
else:
# many-to-one or scalar dotted; place each full dotted path as key
for tail in tails:
dotted = f"{root}.{tail}"
out[dotted] = _deep_get_loaded(self, dotted)
# ← This was the placeholder before. We return the dict we just built.
if fields:
out = {}
if "id" not in fields and hasattr(self, "id"):
out["id"] = getattr(self, "id")
for f in fields:
cur = self
for part in f.split("."):
if cur is None:
break
cur = getattr(cur, part, None)
out[f] = cur
return out
# Fallback: all mapped columns on this class hierarchy
result: Dict[str, Any] = {}
result = {}
for cls in self.__class__.__mro__:
if hasattr(cls, "__table__"):
for column in cls.__table__.columns:
name = column.name
try:
result[name] = getattr(self, name)
except Exception:
result[name] = None
result[name] = getattr(self, name)
return result
class Version(Base):
__tablename__ = "versions"

File diff suppressed because it is too large Load diff

View file

@ -1,17 +1,9 @@
from dataclasses import dataclass
from typing import Any, List, Tuple, Set, Dict, Optional, Iterable
from sqlalchemy import and_, asc, desc, or_
from typing import List, Tuple, Set, Dict, Optional
from sqlalchemy import asc, desc
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import aliased, selectinload
from sqlalchemy.orm.attributes import InstrumentedAttribute
@dataclass(frozen=True)
class CollPred:
table: Any
col_key: str
op: str
value: Any
OPERATORS = {
'eq': lambda col, val: col == val,
'lt': lambda col, val: col < val,
@ -20,8 +12,6 @@ OPERATORS = {
'gte': lambda col, val: col >= val,
'ne': lambda col, val: col != val,
'icontains': lambda col, val: col.ilike(f"%{val}%"),
'in': lambda col, val: col.in_(val if isinstance(val, (list, tuple, set)) else [val]),
'nin': lambda col, val: ~col.in_(val if isinstance(val, (list, tuple, set)) else [val]),
}
class CRUDSpec:
@ -30,138 +20,12 @@ class CRUDSpec:
self.params = params
self.root_alias = root_alias
self.eager_paths: Set[Tuple[str, ...]] = set()
# (parent_alias. relationship_attr, alias_for_target)
self.join_paths: List[Tuple[object, InstrumentedAttribute, object]] = []
self.alias_map: Dict[Tuple[str, ...], object] = {}
self._root_fields: List[InstrumentedAttribute] = []
# dotted non-collection fields (MANYTOONE etc)
self._rel_field_names: Dict[Tuple[str, ...], List[str]] = {}
# dotted collection fields (ONETOMANY)
self._collection_field_names: Dict[str, List[str]] = {}
self._rel_field_names: Dict[Tuple[str, ...], object] = {}
self.include_paths: Set[Tuple[str, ...]] = set()
def _split_path_and_op(self, key: str) -> tuple[str, str]:
if '__' in key:
path, op = key.rsplit('__', 1)
else:
path, op = key, 'eq'
return path, op
def _resolve_many_columns(self, path: str) -> list[tuple[InstrumentedAttribute, Optional[tuple[str, ...]]]]:
"""
Accepts pipe-delimited paths like 'label|owner.label'
Returns a list of (column, join_path) pairs for every resolvable subpath.
"""
cols: list[tuple[InstrumentedAttribute, Optional[tuple[str, ...]]]] = []
for sub in path.split('|'):
sub = sub.strip()
if not sub:
continue
col, join_path = self._resolve_column(sub)
if col is not None:
cols.append((col, join_path))
return cols
def _build_predicate_for(self, path: str, op: str, value: Any):
"""
Builds a SQLA BooleanClauseList or BinaryExpression for a single key.
If multiple subpaths are provided via pipe, returns an OR of them.
"""
if op not in OPERATORS:
return None
pairs = self._resolve_many_columns(path)
if not pairs:
return None
exprs = []
for col, join_path in pairs:
if join_path:
self.eager_paths.add(join_path)
try:
cur_cls = self.model
names = list(join_path)
last_name = names[-1]
is_collection = False
for nm in names:
rel_attr = getattr(cur_cls, nm)
prop = rel_attr.property
cur_cls = prop.mapper.class_
is_collection = bool(getattr(getattr(self.model, last_name), "property", None)
and getattr(getattr(self.model, last_name).property, "uselist", False))
except Exception:
is_collection = False
if is_collection:
target_cls = cur_cls
key = getattr(col, "key", None) or getattr(col, "name", None)
if key and hasattr(target_cls, key):
target_tbl = getattr(target_cls, "__table__", None)
if target_tbl is not None:
exprs.append(CollPred(table=target_tbl, col_key=key, op=op, value=value))
continue
exprs.append(OPERATORS[op](col, value))
if not exprs:
return None
# If any CollPred is in exprs, do NOT or_ them. Keep it single for now.
if any(isinstance(x, CollPred) for x in exprs):
# If someone used a pipe 'relA.col|relB.col' that produced multiple CollPreds,
# keep the first or raise for now (your choice).
if len(exprs) > 1:
# raise NotImplementedError("OR across collection paths not supported yet")
exprs = [next(x for x in exprs if isinstance(x, CollPred))]
return exprs[0]
# Otherwise, standard SQLA clause(s)
return exprs[0] if len(exprs) == 1 else or_(*exprs)
def _collect_filters(self, params: dict) -> list:
"""
Recursively parse filters from 'param' into a flat list of SQLA expressions.
Supports $or / $and groups. Any other keys are parsed as normal filters.
"""
filters: list = []
for key, value in (params or {}).items():
if key in ('sort', 'limit', 'offset', 'fields', 'include'):
continue
if key == '$or':
# value should be a list of dicts
groups = []
for group in value if isinstance(value, (list, tuple)) else []:
sub = self._collect_filters(group)
if not sub:
continue
groups.append(and_(*sub) if len(sub) > 1 else sub[0])
if groups:
filters.append(or_(*groups))
continue
if key == '$and':
# value should be a list of dicts
parts = []
for group in value if isinstance(value, (list, tuple)) else []:
sub = self._collect_filters(group)
if not sub:
continue
parts.append(and_(*sub) if len(sub) > 1 else sub[0])
if parts:
filters.append(and_(*parts))
continue
# Normal key
path, op = self._split_path_and_op(key)
pred = self._build_predicate_for(path, op, value)
if pred is not None:
filters.append(pred)
return filters
def _resolve_column(self, path: str):
current_alias = self.root_alias
parts = path.split('.')
@ -204,12 +68,24 @@ class CRUDSpec:
if maybe:
self.eager_paths.add(maybe)
def parse_filters(self, params: dict | None = None):
"""
Public entry: parse filters from given params or self.params.
Returns a list of SQLAlchemy filter expressions
"""
return self._collect_filters(params if params is not None else self.params)
def parse_filters(self):
filters = []
for key, value in self.params.items():
if key in ('sort', 'limit', 'offset'):
continue
if '__' in key:
path_op = key.rsplit('__', 1)
if len(path_op) != 2:
continue
path, op = path_op
else:
path, op = key, 'eq'
col, join_path = self._resolve_column(path)
if col and op in OPERATORS:
filters.append(OPERATORS[op](col, value))
if join_path:
self.eager_paths.add(join_path)
return filters
def parse_sort(self):
sort_args = self.params.get('sort', '')
@ -241,12 +117,11 @@ class CRUDSpec:
Parse ?fields=colA,colB,rel1.colC,rel1.rel2.colD
- Root fields become InstrumentedAttributes bound to root_alias.
- Related fields store attribute NAMES; we'll resolve them on the target class when building loader options.
- Collection (uselist=True) relationships record child names by relationship key.
Returns (root_fields, rel_field_names, root_field_names, collection_field_names_by_rel).
Returns (root_fields, rel_field_names).
"""
raw = self.params.get('fields')
if not raw:
return [], {}, {}, {}
return [], {}, {}
if isinstance(raw, list):
tokens = []
@ -258,36 +133,14 @@ class CRUDSpec:
root_fields: List[InstrumentedAttribute] = []
root_field_names: list[str] = []
rel_field_names: Dict[Tuple[str, ...], List[str]] = {}
collection_field_names: Dict[str, List[str]] = {}
for token in tokens:
col, join_path = self._resolve_column(token)
if not col:
continue
if join_path:
# rel_field_names.setdefault(join_path, []).append(col.key)
# self.eager_paths.add(join_path)
try:
cur_cls = self.model
names = list(join_path)
last_name = names[-1]
for nm in names:
rel_attr = getattr(cur_cls, nm)
cur_cls = rel_attr.property.mapper.class_
is_collection = bool(getattr(getattr(self.model, last_name), "property", None) and getattr(getattr(self.model, last_name).property, "uselist", False))
except Exception:
# Fallback: inspect the InstrumentedAttribute we recorded on join_paths
is_collection = False
for _pa, rel_attr, _al in self.join_paths:
if rel_attr.key == (join_path[-1] if join_path else ""):
is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False))
break
if is_collection:
collection_field_names.setdefault(join_path[-1], []).append(col.key)
else:
rel_field_names.setdefault(join_path, []).append(col.key)
self.eager_paths.add(join_path)
rel_field_names.setdefault(join_path, []).append(col.key)
self.eager_paths.add(join_path)
else:
root_fields.append(col)
root_field_names.append(getattr(col, "key", token))
@ -300,11 +153,7 @@ class CRUDSpec:
self._root_fields = root_fields
self._rel_field_names = rel_field_names
# return root_fields, rel_field_names, root_field_names
for r, names in collection_field_names.items():
seen3 = set()
collection_field_names[r] = [n for n in names if not (n in seen3 or seen3.add(n))]
return root_field_names, rel_field_names, root_field_names, collection_field_names
return root_fields, rel_field_names, root_field_names
def get_eager_loads(self, root_alias, *, fields_map=None):
loads = []

View file

@ -1,176 +0,0 @@
from __future__ import annotations
from datetime import datetime, date
from decimal import Decimal
from enum import Enum
from typing import Any, Dict, Optional, Callable
from sqlalchemy import inspect
ISO_DT_FORMATS = ("%Y-%m-%dT%H:%M:%S.%f",
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%d %H:%M",
"%Y-%m-%d")
def to_jsonable(obj: Any):
"""Recursively convert values into JSON-serializable forms."""
if obj is None or isinstance(obj, (str, int, float, bool)):
return obj
if isinstance(obj, (datetime, date)):
return obj.isoformat()
if isinstance(obj, Decimal):
return float(obj)
if isinstance(obj, Enum):
return obj.value
if isinstance(obj, dict):
return {str(k): to_jsonable(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple, set)):
return [to_jsonable(v) for v in obj]
# fallback: strin-ify weird objects (UUID, ORM instances, etc.)
try:
return str(obj)
except Exception:
return None
def filter_to_columns(data: dict, model_cls):
cols = {c.key for c in inspect(model_cls).mapper.columns}
return {k: v for k, v in data.items() if k in cols}
def _parse_dt_maybe(x: Any) -> Any:
if isinstance(x, (datetime, date)):
return x
if isinstance(x, str):
s = x.strip().replace("Z", "+00:00") # tolerate Zulu
for fmt in ISO_DT_FORMATS:
try:
return datetime.strptime(s, fmt)
except ValueError:
pass
try:
return datetime.fromisoformat(s)
except Exception:
return x
return x
def _normalize_for_compare(x: Any) -> Any:
if isinstance(x, (str, datetime, date)):
return _parse_dt_maybe(x)
return x
def deep_diff(
old: Any,
new: Any,
*,
path: str = "",
ignore_keys: Optional[set] = None,
list_mode: str = "index", # "index" or "set"
custom_equal: Optional[Callable[[str, Any, Any], bool]] = None,
) -> Dict[str, Dict[str, Any]]:
if ignore_keys is None:
ignore_keys = set()
out: Dict[str, Dict[str, Any]] = {"added": {}, "removed": {}, "changed": {}}
def mark_changed(p, a, b):
out["changed"][p] = {"from": a, "to": b}
def rec(o, n, pfx):
if custom_equal and custom_equal(pfx.rstrip("."), o, n):
return
if isinstance(o, dict) and isinstance(n, dict):
o_keys = set(o.keys())
n_keys = set(n.keys())
for k in sorted(o_keys - n_keys):
if k not in ignore_keys:
out["removed"][f"{pfx}{k}"] = o[k]
for k in sorted(n_keys - o_keys):
if k not in ignore_keys:
out["added"][f"{pfx}{k}"] = n[k]
for k in sorted(o_keys & n_keys):
if k not in ignore_keys:
rec(o[k], n[k], f"{pfx}{k}.")
return
if isinstance(o, list) and isinstance(n, list):
if list_mode == "set":
if set(o) != set(n):
mark_changed(pfx.rstrip("."), o, n)
else:
max_len = max(len(o), len(n))
for i in range(max_len):
key = f"{pfx}[{i}]"
if i >= len(o):
out["added"][key] = n[i]
elif i >= len(n):
out["removed"][key] = o[i]
else:
rec(o[i], n[i], f"{key}.")
return
a = _normalize_for_compare(o)
b = _normalize_for_compare(n)
if a != b:
mark_changed(pfx.rstrip("."), o, n)
rec(old, new, path)
return out
def diff_to_patch(diff: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
"""
Produce a shallow patch of changed/added top-level fields.
Only includes leaf paths without dots/brackets; useful for simple UPDATEs.
"""
patch: Dict[str, Any] = {}
for k, v in diff["added"].items():
if "." not in k and "[" not in k:
patch[k] = v
for k, v in diff["changed"].items():
if "." not in k and "[" not in k:
patch[k] = v["to"]
return patch
def normalize_payload(payload: dict, model):
"""
Coerce incoming JSON into SQLAlchemy column types for the given model.
- "" or None -> None
- Integer/Boolean/Date/DateTime handled by column type
"""
from sqlalchemy import Integer, Boolean, DateTime, Date
out: Dict[str, Any] = {}
mapper = inspect(model).mapper
cols = {c.key: c.type for c in mapper.columns}
for field, value in payload.items():
if value == "" or value is None:
out[field] = None
continue
coltype = cols.get(field)
if coltype is None:
out[field] = value
continue
tname = coltype.__class__.__name__.lower()
if "integer" in tname:
out[field] = int(value)
elif "boolean" in tname:
out[field] = value if isinstance(value, bool) else str(value).lower() in ("1", "true", "yes", "on")
elif "datetime" in tname:
out[field] = value if isinstance(value, datetime) else _parse_dt_maybe(value)
elif "date" in tname:
v = _parse_dt_maybe(value)
out[field] = v.date() if isinstance(v, datetime) else v
else:
out[field] = value
return out

View file

@ -1,8 +1,7 @@
# engines.py
from __future__ import annotations
from typing import Type, Optional
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker, raiseload, Mapper, RelationshipProperty
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from .backend import make_backend_info, BackendInfo
from .config import Config, get_config
from ._sqlite import apply_sqlite_pragmas
@ -13,31 +12,15 @@ def build_engine(config_cls: Type[Config] | None = None):
apply_sqlite_pragmas(engine, config_cls.SQLITE_PRAGMAS)
return engine
def _install_nplus1_guards(SessionMaker, *, strict: bool):
if not strict:
return
@event.listens_for(SessionMaker, "do_orm_execute")
def _add_global_raiseload(execute_state):
stmt = execute_state.statement
# Only touch ORM statements (have column_descriptions)
if getattr(stmt, "column_descriptions", None):
execute_state.statement = stmt.options(raiseload("*"))
def build_sessionmaker(config_cls: Type[Config] | None = None, engine=None):
config_cls = config_cls or get_config(None)
engine = engine or build_engine(config_cls)
SessionMaker = sessionmaker(bind=engine, **config_cls.session_kwargs())
# Toggle with a config flag; default off so you can turn it on when ready
strict = bool(getattr(config_cls, "STRICT_NPLUS1", False))
_install_nplus1_guards(SessionMaker, strict=strict)
return SessionMaker
return sessionmaker(bind=engine, **config_cls.session_kwargs())
class CRUDKitRuntime:
"""
Lightweight container so CRUDKit can be given either:
- prebuilt engine/sessionmaker, or
- prebuild engine/sessionmaker, or
- a Config to build them lazily
"""
def __init__(self, *, engine=None, session_factory=None, config: Optional[Type[Config]] = None):

View file

@ -1,32 +1,20 @@
# crudkit/integrations/flask.py
from __future__ import annotations
from flask import Flask
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.orm import scoped_session
from ..engines import CRUDKitRuntime
from ..config import Config
def init_app(app: Flask, *, runtime: CRUDKitRuntime | None = None, config: type[Config] | None = None):
def init_app(app: Flask, *, runtime: CRUDKitRuntime | None = None, config: type[Config] | None == None):
"""
Initializes CRUDKit for a Flask app. Provides `app.extensions['crudkit']`
Initializes CRUDKit for a Flask app. Provies `app.extensions['crudkit']`
with a runtime (engine + session_factory). Caller manages session lifecycle.
"""
runtime = runtime or CRUDKitRuntime(config=config)
app.extensions.setdefault("crudkit", {})
app.extensions["crudkit"]["runtime"] = runtime
# Build ONE sessionmaker bound to the ONE true engine object
# so engine id == sessionmaker.bind id, always.
engine = runtime.engine
SessionFactory = runtime.session_factory or sessionmaker(bind=engine, **runtime._config.session_kwargs())
app.extensions["crudkit"]["SessionFactory"] = SessionFactory
app.extensions["crudkit"]["Session"] = scoped_session(SessionFactory)
# Attach pool listeners to the *same* engine the SessionFactory is bound to.
# Dont guess. Dont hope. Inspect.
try:
bound_engine = getattr(SessionFactory, "bind", None) or getattr(SessionFactory, "kw", {}).get("bind") or engine
pool = bound_engine.pool
except Exception as e:
print(f"[crudkit.init_app] Failed to attach pool listeners: {e}")
Session = runtime.session_factory
if Session is not None:
app.extensions["crudkit"]["Session"] = scoped_session(Session)
return runtime

View file

@ -1,236 +0,0 @@
# crudkit/projection.py
from __future__ import annotations
from typing import Iterable, List, Tuple, Dict, Set
from sqlalchemy.orm import selectinload
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.properties import ColumnProperty, RelationshipProperty
from sqlalchemy import inspect
# ----------------------
# small utilities
# ----------------------
def _is_column_attr(a) -> bool:
try:
return isinstance(a, InstrumentedAttribute) and isinstance(a.property, ColumnProperty)
except Exception:
return False
def _is_relationship_attr(a) -> bool:
try:
return isinstance(a, InstrumentedAttribute) and isinstance(a.property, RelationshipProperty)
except Exception:
return False
def _split_path(field: str) -> List[str]:
return [p for p in str(field).split(".") if p]
def _model_requires_map(model_cls) -> Dict[str, List[str]]:
# apps declare per-model deps, e.g. {"label": ["first_name","last_name","title"]}
return getattr(model_cls, "__crudkit_field_requires__", {}) or {}
def _relationships_of(model_cls) -> Dict[str, RelationshipProperty]:
try:
return dict(model_cls.__mapper__.relationships)
except Exception:
return {}
def _attr_on(model_cls, name: str):
return getattr(model_cls, name, None)
# ----------------------
# EXPAND: add required deps for leaf attributes at the correct class
# ----------------------
def _expand_requires_for_field(model_cls, pieces: List[str]) -> List[str]:
"""
Given a dotted path like ["owner","label"], walk relationships to the leaf *container* class,
pull its __crudkit_field_requires__ for that leaf attr ("label"), and yield prefixed deps:
owner.label -> ["owner.first_name", "owner.last_name", ...] if User requires so.
If leaf is a column (or has no requires), returns [].
"""
if not pieces:
return []
# walk relationships to the leaf container (class that owns the leaf attr)
container_cls = model_cls
prefix_parts: List[str] = []
for part in pieces[:-1]:
a = _attr_on(container_cls, part)
if not _is_relationship_attr(a):
return [] # can't descend; invalid or scalar in the middle
container_cls = a.property.mapper.class_
prefix_parts.append(part)
leaf = pieces[-1]
requires = _model_requires_map(container_cls).get(leaf) or []
if not requires:
return []
prefix = ".".join(prefix_parts)
out: List[str] = []
for dep in requires:
# dep may itself be dotted relative to container (e.g. "room_function.description")
if prefix:
out.append(f"{prefix}.{dep}")
else:
out.append(dep)
return out
def _expand_requires(model_cls, fields: Iterable[str]) -> List[str]:
"""
Dedup + stable expansion of requires for all fields.
"""
seen: Set[str] = set()
out: List[str] = []
def add(f: str):
if f not in seen:
seen.add(f)
out.append(f)
# first pass: add original
queue: List[str] = []
for f in fields:
f = str(f)
if f not in seen:
seen.add(f)
out.append(f)
queue.append(f)
# BFS-ish: when we add deps, they may trigger further deps downstream
while queue:
f = queue.pop(0)
deps = _expand_requires_for_field(model_cls, _split_path(f))
for d in deps:
if d not in seen:
seen.add(d)
out.append(d)
queue.append(d)
return out
# ----------------------
# BUILD loader options tree with selectinload + load_only on real columns
# ----------------------
def _insert_leaf(loader_tree: dict, path: List[str]):
"""
Build nested dict structure keyed by relationship names.
Each node holds:
{
"__cols__": set(column_names_to_load_only),
"<child_rel>": { ... }
}
"""
node = loader_tree
for rel in path[:-1]: # only relationship hops
node = node.setdefault(rel, {"__cols__": set()})
# leaf may be a column or a virtual/hybrid; only columns go to __cols__
node.setdefault("__cols__", set())
def _attach_column(loader_tree: dict, path: List[str], model_cls):
"""
If the leaf is a real column on the target class, record its name into __cols__ at that level.
"""
# descend to target class to test column-ness
container_cls = model_cls
node = loader_tree
for rel in path[:-1]:
a = _attr_on(container_cls, rel)
if not _is_relationship_attr(a):
return # invalid path, ignore
container_cls = a.property.mapper.class_
node = node.setdefault(rel, {"__cols__": set()})
leaf = path[-1]
a_leaf = _attr_on(container_cls, leaf)
node.setdefault("__cols__", set())
if _is_column_attr(a_leaf):
node["__cols__"].add(leaf)
def _build_loader_tree(model_cls, fields: Iterable[str]) -> dict:
"""
For each dotted field:
- walk relationships -> create nodes
- if leaf is a column: record it for load_only
- if leaf is not a column (hybrid/descriptor): no load_only; still ensure rel hops exist
"""
tree: Dict[str, dict] = {"__cols__": set()}
for f in fields:
parts = _split_path(f)
if not parts:
continue
# ensure relationship nodes exist
_insert_leaf(tree, parts)
# attach column if applicable
_attach_column(tree, parts, model_cls)
return tree
def _loader_options_from_tree(model_cls, tree: dict):
"""
Convert the loader tree into SQLAlchemy loader options:
selectinload(<rel>)[.load_only(cols)] recursively
"""
opts = []
rels = _relationships_of(model_cls)
for rel_name, child in tree.items():
if rel_name == "__cols__":
continue
rel_prop = rels.get(rel_name)
if not rel_prop:
continue
rel_attr = getattr(model_cls, rel_name)
opt = selectinload(rel_attr)
# apply load_only on the related class (only real columns recorded at child["__cols__"])
cols = list(child.get("__cols__", []))
if cols:
rel_model = rel_prop.mapper.class_
# map column names to attributes
col_attrs = []
for c in cols:
a = getattr(rel_model, c, None)
if _is_column_attr(a):
col_attrs.append(a)
if col_attrs:
opt = opt.load_only(*col_attrs)
# recurse to grandchildren
sub_opts = _loader_options_from_tree(rel_prop.mapper.class_, child)
for so in sub_opts:
opt = opt.options(so)
opts.append(opt)
# root-level columns (rare in our compile; kept for completeness)
root_cols = list(tree.get("__cols__", []))
if root_cols:
# NOTE: call-site can add a root load_only(...) if desired;
# we purposely return only relationship options here to keep
# the API simple and avoid mixing Load(model_cls) contexts.
pass
return opts
# ----------------------
# PUBLIC API
# ----------------------
def compile_projection(model_cls, fields: Iterable[str]) -> Tuple[List[str], List]:
"""
Returns:
expanded_fields: List[str] # original + declared dependencies
loader_options: List[Load] # apply via query = query.options(*loader_options)
Behavior:
- Expands __crudkit_field_requires__ at the leaf container class for every field.
- Builds a selectinload tree; load_only only includes real columns (no hybrids).
- Safe for nested paths: e.g. "owner.label" pulls owner deps from User.__crudkit_field_requires__.
"""
fields = list(fields or [])
expanded = _expand_requires(model_cls, fields)
tree = _build_loader_tree(model_cls, expanded)
options = _loader_options_from_tree(model_cls, tree)
return expanded, options

View file

@ -6,279 +6,25 @@ from flask import current_app, url_for
from jinja2 import Environment, FileSystemLoader, ChoiceLoader
from sqlalchemy import inspect
from sqlalchemy.orm import Load, RelationshipProperty, class_mapper, load_only, selectinload
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.base import NO_VALUE
from sqlalchemy.orm.properties import ColumnProperty, RelationshipProperty
from typing import Any, Dict, List, Optional, Tuple
import crudkit
_ALLOWED_ATTRS = {
"class", "placeholder", "autocomplete", "inputmode", "pattern",
"min", "max", "step", "maxlength", "minlength",
"required", "readonly", "disabled",
"multiple", "size", "rows",
"multiple", "size",
"id", "name", "value",
}
_SAFE_CSS_PROPS = {
# spacing / sizing
"margin","margin-top","margin-right","margin-bottom","margin-left",
"padding","padding-top","padding-right","padding-bottom","padding-left",
"width","height","min-width","min-height","max-width","max-height", "resize",
# layout
"display","flex","flex-direction","flex-wrap","justify-content","align-items","gap",
# text
"font-size","font-weight","line-height","text-align","white-space",
# colors / background
"color","background-color",
# borders / radius
"border","border-top","border-right","border-bottom","border-left",
"border-width","border-style","border-color","border-radius",
# misc (safe-ish)
"opacity","overflow","overflow-x","overflow-y",
}
_num_unit = r"-?\d+(?:\.\d+)?"
_len_unit = r"(?:px|em|rem|%)"
P_LEN = re.compile(rf"^{_num_unit}(?:{_len_unit})?$") # 12, 12px, 1.2rem, 50%
P_GAP = P_LEN
P_INT = re.compile(r"^\d+$")
P_COLOR = re.compile(
r"^(#[0-9a-fA-F]{3,8}|"
r"rgb\(\s*\d{1,3}\s*,\s*\d{1,3}\s*,\s*\d{1,3}\s*\)|"
r"rgba\(\s*\d{1,3}\s*,\s*\d{1,3}\s*,\s*\d{1,3}\s*,\s*(?:0|1|0?\.\d+)\s*\)|"
r"[a-zA-Z]+)$"
)
_ENUMS = {
"display": {"block","inline","inline-block","flex","grid","none"},
"flex-direction": {"row","row-reverse","column","column-reverse"},
"flex-wrap": {"nowrap","wrap","wrap-reverse"},
"justify-content": {"flex-start","flex-end","center","space-between","space-around","space-evenly"},
"align-items": {"stretch","flex-start","flex-end","center","baseline"},
"text-align": {"left","right","center","justify","start","end"},
"white-space": {"normal","nowrap","pre","pre-wrap","pre-line","break-spaces"},
"border-style": {"none","solid","dashed","dotted","double","groove","ridge","inset","outset"},
"overflow": {"visible","hidden","scroll","auto","clip"},
"overflow-x": {"visible","hidden","scroll","auto","clip"},
"overflow-y": {"visible","hidden","scroll","auto","clip"},
"font-weight": {"normal","bold","bolder","lighter","100","200","300","400","500","600","700","800","900"},
"resize": {"none", "both", "horizontal", "vertical"},
}
def get_env():
"""
Return an overlay Jinja Environment that knows how to load crudkit templates
and has our helper functions available as globals.
"""
app = current_app
default_path = os.path.join(os.path.dirname(__file__), 'templates')
fallback_loader = FileSystemLoader(default_path)
env = app.jinja_env.overlay(loader=ChoiceLoader([app.jinja_loader, fallback_loader]))
# Ensure helpers are available even when we render via this overlay env.
# These names are resolved at *call time* (not at def time), so it's safe.
try:
env.globals.setdefault("render_table", render_table)
env.globals.setdefault("render_form", render_form)
env.globals.setdefault("render_field", render_field)
except NameError:
# Functions may not be defined yet at import time; later calls will set them.
pass
return env
def register_template_globals(app=None):
"""
Register crudkit helpers as app-wide Jinja globals so they can be used
directly in any template via {{ render_table(...) }}, {{ render_form(...) }},
and {{ render_field(...) }}.
"""
if app is None:
app = current_app
# Idempotent install using an extension flag
installed = app.extensions.setdefault("crudkit_ui_helpers", set())
to_register = {
"render_table": render_table,
"render_form": render_form,
"render_field": render_field,
}
for name, fn in to_register.items():
if name not in installed:
app.add_template_global(fn, name)
installed.add(name)
def _fields_for_label_params(label_spec, related_model):
"""
Build a 'fields' list suitable for CRUDService.list() so labels render
without triggering lazy loads. Always includes 'id'.
"""
simple_cols, rel_paths = _extract_label_requirements(label_spec, related_model)
fields = set(["id"])
for c in simple_cols:
fields.add(c)
for rel_name, col_name in rel_paths:
if col_name == "__all__":
# just ensure relationship object is present; ask for rel.id
fields.add(f"{rel_name}.id")
else:
fields.add(f"{rel_name}.{col_name}")
return list(fields)
def _fk_options_via_service(related_model, label_spec, *, options_params: dict | None = None):
svc = crudkit.crud.get_service(related_model)
# default to unlimited results for dropdowns
params = {"limit": 0}
if options_params:
params.update(options_params) # caller can override limit if needed
# ensure fields needed to render the label are present (avoid lazy loads)
fields = _fields_for_label_params(label_spec, related_model)
if fields:
existing = params.get("fields")
if isinstance(existing, str):
existing = [s.strip() for s in existing.split(",") if s.strip()]
if isinstance(existing, (list, tuple)):
params["fields"] = list(dict.fromkeys(list(existing) + fields))
else:
params["fields"] = fields
# only set a default sort if caller didnt supply one
if "sort" not in params:
simple_cols, _ = _extract_label_requirements(label_spec, related_model)
params["sort"] = (simple_cols[0] if simple_cols else "id")
rows = svc.list(params)
return [
{"value": str(r.id), "label": _label_from_obj(r, label_spec)}
for r in rows
]
def expand_projection(model_cls, fields):
req = getattr(model_cls, "__crudkit_field_requires__", {}) or {}
out = set(fields)
for f in list(fields):
for dep in req.get(f, ()):
out.add(dep)
return list(out)
def _clean_css_value(prop: str, raw: str) -> str | None:
v = raw.strip()
v = v.replace("!important", "")
low = v.lower()
if any(bad in low for bad in ("url(", "expression(", "javascript:", "var(")):
return None
if prop in {"width","height","min-width","min-height","max-width","max-height",
"margin","margin-top","margin-right","margin-bottom","margin-left",
"padding","padding-top","padding-right","padding-bottom","padding-left",
"border-width","border-top","border-right","border-bottom","border-left","border-radius",
"line-height","font-size"}:
return v if P_LEN.match(v) else None
if prop in {"gap"}:
parts = [p.strip() for p in v.split()]
if 1 <= len(parts) <= 2 and all(P_GAP.match(p) for p in parts):
return " ".join(parts)
return None
if prop in {"color", "background-color", "border-color"}:
return v if P_COLOR.match(v) else None
if prop in _ENUMS:
return v if v.lower() in _ENUMS[prop] else None
if prop == "flex":
toks = v.split()
if len(toks) == 1 and (toks[0].isdigit() or toks[0] in {"auto", "none"}):
return v
if len(toks) == 2 and toks[0].isdigit() and (toks[1].isdigit() or toks[1] == "auto"):
return v
if len(toks) == 3 and toks[0].isdigit() and toks[1].isdigit() and (P_LEN.match(toks[2]) or toks[2] == "auto"):
return " ".join(toks)
return None
if prop == "border":
parts = v.split()
bw = next((p for p in parts if P_LEN.match(p)), None)
bs = next((p for p in parts if p in _ENUMS["border-style"]), None)
bc = next((p for p in parts if P_COLOR.match(p)), None)
chosen = [x for x in (bw, bs, bc) if x]
return " ".join(chosen) if chosen else None
return None
def _sanitize_style(style: str | None) -> str | None:
if not style or not isinstance(style, str):
return None
safe_decls = []
for chunk in style.split(";"):
if not chunk.strip():
continue
if ":" not in chunk:
continue
prop, val = chunk.split(":", 1)
prop = prop.strip().lower()
if prop not in _SAFE_CSS_PROPS:
continue
clean = _clean_css_value(prop, val)
if clean is not None and clean != "":
safe_decls.append(f"{prop}: {clean}")
return "; ".join(safe_decls) if safe_decls else None
def _is_column_attr(attr) -> bool:
try:
return isinstance(attr, InstrumentedAttribute) and isinstance(attr.property, ColumnProperty)
except Exception:
return False
def _is_relationship_attr(attr) -> bool:
try:
return isinstance(attr, InstrumentedAttribute) and isinstance(attr.property, RelationshipProperty)
except Exception:
return False
def _get_attr_deps(model_cls, attr_name: str, extra_deps: Optional[dict] = None) -> list[str]:
"""Merge model-level and per-field declared deps for a computed attr."""
model_deps = getattr(model_cls, "__crudkit_field_requires__", {}) or {}
field_deps = (extra_deps or {})
return list(model_deps.get(attr_name, [])) + list(field_deps.get(attr_name, []))
def _get_loaded_attr(obj: Any, name: str) -> Any:
"""
Return obj.<name> only if it is already loaded.
Never triggers a lazy load. Returns None if missing/unloaded.
Works for both column and relationship attributes.
"""
try:
st = inspect(obj)
# 1) Mapped attribute?
attr = st.attrs.get(name)
if attr is not None:
val = attr.loaded_value
return None if val is NO_VALUE else val
# 2) Already present value (e.g., eager-loaded or set on the dict)?
if name in st.dict:
return st.dict.get(name)
# 3) If object is detached or attr is not mapped, DO NOT eval hybrids
# or descriptors that could lazy-load. That would explode.
if st.session is None:
return None
# 4) As a last resort on attached instances only, try simple getattr,
# but guard against DetachedInstanceError anyway.
try:
return getattr(obj, name, None)
except Exception:
return None
except Exception:
# If we can't even inspect it, be conservative
try:
return getattr(obj, name, None)
except Exception:
return None
return app.jinja_env.overlay(
loader=ChoiceLoader([app.jinja_loader, fallback_loader])
)
def _normalize_rows_layout(layout: Optional[List[dict]]) -> Dict[str, dict]:
"""
@ -418,11 +164,6 @@ def _sanitize_attrs(attrs: Any) -> dict[str, Any]:
elif isinstance(v, str):
if len(v) > 512:
v = v[:512]
if k == "style":
sv = _sanitize_style(v)
if sv:
out["style"] = sv
continue
if k.startswith("data-") or k.startswith("aria-") or k in _ALLOWED_ATTRS:
if isinstance(v, bool):
if v:
@ -432,105 +173,6 @@ def _sanitize_attrs(attrs: Any) -> dict[str, Any]:
return out
def _resolve_rel_obj(values: dict, instance, base: str):
rel = None
if isinstance(values, dict) and base in values:
rel = values[base]
if isinstance(rel, dict):
class _DictObj:
def __init__(self, d): self._d = d
def __getattr__(self, k): return self._d.get(k)
rel = _DictObj(rel)
if rel is None and instance is not None:
try:
st = inspect(instance)
ra = st.attrs.get(base)
if ra is not None and ra.loaded_value is not NO_VALUE:
rel = ra.loaded_value
except Exception:
pass
return rel
def _value_label_for_field(field: dict, mapper, values_map: dict, instance, session):
"""
If field targets a MANYTOONE (foo or foo_id), compute a human-readable label.
No lazy loads. Optional single-row lean fetch if we only have the id.
"""
base, rel_prop = _rel_for_id_name(mapper, field["name"])
if not rel_prop:
return None
rid = _coerce_fk_value(values_map, instance, base, rel_prop)
rel_obj = _resolve_rel_obj(values_map, instance, base)
label_spec = (
field.get("label_spec")
or getattr(rel_prop.mapper.class_, "__crud_label__", None)
or "id"
)
if rel_obj is not None and session is not None and rid is not None:
mdl = rel_prop.mapper.class_
# Work out exactly what the label needs (columns + rel paths),
# expanding model-level and per-field deps (for hybrids etc.)
simple_cols, rel_paths = _extract_label_requirements(
label_spec,
model_cls=mdl,
extra_deps=field.get("label_deps")
)
# If the currently-attached object doesn't have what we need, do one lean requery
if not _has_label_bits_loaded(rel_obj, label_spec):
q = session.query(mdl)
# only real columns in load_only
cols = []
id_attr = getattr(mdl, "id", None)
if _is_column_attr(id_attr):
cols.append(id_attr)
for c in simple_cols:
a = getattr(mdl, c, None)
if _is_column_attr(a):
cols.append(a)
if cols:
q = q.options(load_only(*cols))
# selectinload relationships; "__all__" means just eager the relationship object
for rel_name, col_name in rel_paths:
rel_ia = getattr(mdl, rel_name, None)
if rel_ia is None:
continue
opt = selectinload(rel_ia)
if col_name == "__all__":
q = q.options(opt)
else:
t_cls = mdl.__mapper__.relationships[rel_name].mapper.class_
t_attr = getattr(t_cls, col_name, None)
q = q.options(opt.load_only(t_attr) if _is_column_attr(t_attr) else opt)
rel_obj = q.get(rid)
if rel_obj is not None:
try:
s = _label_from_obj(rel_obj, label_spec)
except Exception:
s = None
# If we couldn't safely render and we have a session+id, do one lean retry.
if (s is None or s == "") and session is not None and rid is not None:
mdl = rel_prop.mapper.class_
try:
rel_obj2 = session.get(mdl, rid) # attached instance
s2 = _label_from_obj(rel_obj2, label_spec)
if s2:
return s2
except Exception:
pass
return s
return str(rid) if rid is not None else None
class _SafeObj:
"""Attribute access that returns '' for missing/None instead of exploding."""
__slots__ = ("_obj",)
@ -539,10 +181,12 @@ class _SafeObj:
def __getattr__(self, name):
if self._obj is None:
return ""
val = _get_loaded_attr(self._obj, name)
return "" if val is None else _SafeObj(val)
val = getattr(self._obj, name, None)
if val is None:
return ""
return _SafeObj(val)
def _coerce_fk_value(values: dict | None, instance: Any, base: str, rel_prop: Optional[RelationshipProperty] = None):
def _coerce_fk_value(values: dict | None, instance: Any, base: str):
"""
Resolve current selection for relationship 'base':
1) values['<base>_id']
@ -589,25 +233,6 @@ def _coerce_fk_value(values: dict | None, instance: Any, base: str, rel_prop: Op
except Exception:
pass
# Fallback: if we know the relationship, try its local FK column names
if rel_prop is not None:
try:
st = inspect(instance) if instance is not None else None
except Exception:
st = None
# Try values[...] first
for col in getattr(rel_prop, "local_columns", []) or []:
key = getattr(col, "key", None) or getattr(col, "name", None)
if not key:
continue
if isinstance(values, dict) and key in values and values[key] not in (None, ""):
return values[key]
if set is not None:
attr = st.attrs.get(key) if hasattr(st, "attrs") else None
if attr is not None and attr.loaded_value is not NO_VALUE:
return attr.loaded_value
return None
def _is_many_to_one(mapper, name: str) -> Optional[RelationshipProperty]:
@ -629,42 +254,43 @@ def _rel_for_id_name(mapper, name: str) -> tuple[Optional[str], Optional[Relatio
return (name, prop) if prop else (None, None)
def _fk_options(session, related_model, label_spec):
simple_cols, rel_paths = _extract_label_requirements(label_spec, related_model)
simple_cols, rel_paths = _extract_label_requirements(label_spec)
q = session.query(related_model)
col_attrs = []
if hasattr(related_model, "id"):
id_attr = getattr(related_model, "id")
if _is_column_attr(id_attr):
col_attrs.append(id_attr)
col_attrs.append(getattr(related_model, "id"))
for name in simple_cols:
attr = getattr(related_model, name, None)
if _is_column_attr(attr):
col_attrs.append(attr)
if hasattr(related_model, name):
col_attrs.append(getattr(related_model, name))
if col_attrs:
q = q.options(load_only(*col_attrs))
for rel_name, col_name in rel_paths:
rel_attr = getattr(related_model, rel_name, None)
if rel_attr is None:
rel_prop = getattr(related_model, rel_name, None)
if rel_prop is None:
continue
opt = selectinload(rel_attr)
if col_name == "__all__":
q = q.options(opt)
else:
try:
target_cls = related_model.__mapper__.relationships[rel_name].mapper.class_
col_attr = getattr(target_cls, col_name, None)
q = q.options(opt.load_only(col_attr) if _is_column_attr(col_attr) else opt)
if col_attr is None:
q = q.options(selectinload(rel_prop))
else:
q = q.options(selectinload(rel_prop).load_only(col_attr))
except Exception:
q = q.options(selectinload(rel_prop))
if simple_cols:
first = simple_cols[0]
if hasattr(related_model, first):
q = q.order_by(None).order_by(getattr(related_model, first))
q = q.order_by(getattr(related_model, first))
rows = q.all()
return [
{'value': getattr(opt, 'id'), 'label': _label_from_obj(opt, label_spec)}
{
'value': getattr(opt, 'id'),
'label': _label_from_obj(opt, label_spec),
}
for opt in rows
]
@ -688,20 +314,12 @@ def _normalize_field_spec(spec, mapper, session, label_specs_model_default):
"template": spec.get("template"),
"template_name": spec.get("template_name"),
"template_ctx": spec.get("template_ctx"),
"label_spec": spec.get("label_spec"),
}
if "link" in spec:
field["link"] = spec["link"]
if "label_deps" in spec:
field["label_deps"] = spec["label_deps"]
opts_params = spec.get("options_params") or spec.get("options_filter") or spec.get("options_where")
if rel_prop:
if field["type"] is None:
field["type"] = "select"
if field["type"] == "select" and field.get("options") is None:
if field["type"] == "select" and field.get("options") is None and session is not None:
related_model = rel_prop.mapper.class_
label_spec = (
spec.get("label_spec")
@ -709,11 +327,7 @@ def _normalize_field_spec(spec, mapper, session, label_specs_model_default):
or getattr(related_model, "__crud_label__", None)
or "id"
)
field["options"] = _fk_options_via_service(
related_model,
label_spec,
options_params=opts_params
)
field["options"] = _fk_options(session, related_model, label_spec)
return field
col = mapper.columns.get(name)
@ -733,86 +347,81 @@ def _normalize_field_spec(spec, mapper, session, label_specs_model_default):
return field
def _extract_label_requirements(
spec: Any,
model_cls: Any = None,
extra_deps: Optional[Dict[str, List[str]]] = None
) -> tuple[list[str], list[tuple[str, str]]]:
def _extract_label_requirements(spec: Any) -> tuple[list[str], list[tuple[str, str]]]:
"""
Returns:
simple_cols: ["name", "code", "label", ...] (non-dotted names; may include non-columns)
rel_paths: [("room_function", "description"), ("brand", "__all__"), ...]
- ("rel", "__all__") means: just eager the relationship (no specific column)
Also expands dependencies declared by the model or the field (extra_deps).
From a label spec, return:
- simple_cols: ["name", "code"]
- rel_paths: [("room_function", "description"), ("owner", "last_name")]
"""
simple_cols: list[str] = []
rel_paths: list[tuple[str, str]] = []
seen: set[str] = set()
def add_dep_token(token: str) -> None:
"""Add a concrete dependency token (column or 'rel' or 'rel.col')."""
if not token or token in seen:
def ingest(token: str) -> None:
token = str(token).strip()
if not token:
return
seen.add(token)
if "." in token:
rel, col = token.split(".", 1)
if rel and col:
rel_paths.append((rel, col))
return
# bare token: could be column, relationship, or computed
simple_cols.append(token)
# If this is not obviously a column, try pulling declared deps.
if model_cls is not None:
attr = getattr(model_cls, token, None)
if _is_column_attr(attr):
return
# If it's a relationship, we want to eager the relationship itself.
if _is_relationship_attr(attr):
rel_paths.append((token, "__all__"))
return
# Not a column/relationship => computed (hybrid/descriptor/etc.)
for dep in _get_attr_deps(model_cls, token, extra_deps):
add_dep_token(dep)
def add_from_spec(piece: Any) -> None:
if piece is None or callable(piece):
return
if isinstance(piece, (list, tuple)):
for a in piece:
add_from_spec(a)
return
s = str(piece)
if "{" in s and "}" in s:
for n in re.findall(r"{\s*([^}:\s]+)", s):
add_dep_token(n)
else:
add_dep_token(s)
simple_cols.append(token)
if spec is None or callable(spec):
return simple_cols, rel_paths
if isinstance(spec, (list, tuple)):
for a in spec:
ingest(a)
return simple_cols, rel_paths
if isinstance(spec, str):
# format string like "{first} {last}" or "{room_function.description} · {name}"
if "{" in spec and "}" in spec:
names = re.findall(r"{\s*([^}:\s]+)", spec)
for n in names:
ingest(n)
else:
ingest(spec)
return simple_cols, rel_paths
add_from_spec(spec)
return simple_cols, rel_paths
def _attrs_from_label_spec(spec: Any) -> list[str]:
"""
Return a list of attribute names needed from the related model to compute the label.
Only simple attribute names are returned; dotted paths return just the first segment.
"""
if spec is None:
return []
if callable(spec):
return []
if isinstance(spec, (list, tuple)):
return [str(a).split(".", 1)[0] for a in spec]
if isinstance(spec, str):
if "{" in spec and "}" in spec:
names = re.findall(r"{\s*([^}:\s]+)", spec)
return [n.split(".", 1)[0] for n in names]
return [spec.split(".", 1)[0]]
return []
def _label_from_obj(obj: Any, spec: Any) -> str:
if obj is None:
return ""
if spec is None:
for attr in ("label", "name", "title", "description"):
val = _get_loaded_attr(obj, attr)
if val is not None:
return str(val)
vid = _get_loaded_attr(obj, "id")
return str(vid) if vid is not None else object.__repr__(obj)
if hasattr(obj, attr):
val = getattr(obj, attr)
if not callable(val) and val is not None:
return str(val)
if hasattr(obj, "id"):
return str(getattr(obj, "id"))
return object.__repr__(obj)
if isinstance(spec, (list, tuple)):
parts = []
for a in spec:
cur = obj
for part in str(a).split("."):
cur = _get_loaded_attr(cur, part) if cur is not None else None
cur = getattr(cur, part, None)
if cur is None:
break
parts.append("" if cur is None else str(cur))
@ -824,10 +433,9 @@ def _label_from_obj(obj: Any, spec: Any) -> str:
for f in fields:
root = f.split(".", 1)[0]
if root not in data:
try:
data[root] = _SafeObj(_get_loaded_attr(obj, root))
except Exception:
data[root] = _SafeObj(None)
val = getattr(obj, root, None)
data[root] = _SafeObj(val)
try:
return spec.format(**data)
except Exception:
@ -835,7 +443,7 @@ def _label_from_obj(obj: Any, spec: Any) -> str:
cur = obj
for part in str(spec).split("."):
cur = _get_loaded_attr(cur, part) if cur is not None else None
cur = getattr(cur, part, None)
if cur is None:
return ""
return str(cur)
@ -952,9 +560,6 @@ def _format_value(val: Any, fmt: Optional[str]) -> Any:
if fmt is None:
return val
try:
if callable(fmt):
return fmt(val)
if fmt == "yesno":
return "Yes" if bool(val) else "No"
if fmt == "date":
@ -967,68 +572,12 @@ def _format_value(val: Any, fmt: Optional[str]) -> Any:
return val
return val
def _has_label_bits_loaded(obj, label_spec) -> bool:
try:
st = inspect(obj)
except Exception:
return True
simple_cols, rel_paths = _extract_label_requirements(label_spec, type(obj))
# concrete columns on the object
for name in simple_cols:
a = getattr(type(obj), name, None)
if _is_column_attr(a) and name not in st.dict:
return False
# non-column tokens (hybrids/descriptors) are satisfied by their deps above
# relationships
for rel_name, col_name in rel_paths:
ra = st.attrs.get(rel_name)
if ra is None or ra.loaded_value in (NO_VALUE, None):
return False
if col_name == "__all__":
continue # relationship object present is enough
try:
t_st = inspect(ra.loaded_value)
except Exception:
return False
t_attr = getattr(type(ra.loaded_value), col_name, None)
if _is_column_attr(t_attr) and col_name not in t_st.dict:
return False
return True
def _class_for(val: Any, classes: Optional[Dict[str, str]]) -> Optional[str]:
if not classes:
return None
key = "none" if val is None else str(val).lower()
return classes.get(key, classes.get("default"))
def _format_label_from_values(spec: Any, values: dict) -> Optional[str]:
if not spec:
return None
if isinstance(spec, (list, tuple)):
parts = []
for a in spec:
v = _deep_get(values, str(a))
parts.append("" if v is None else str(v))
return " ".join(p for p in parts if p)
s = str(spec)
if "{" in s and "}" in s:
names = re.findall(r"{\s*([^}:\s]+)", s)
data = {n: _deep_get(values, n) for n in names}
# wrap for safe .format()
data = {k: ("" if v is None else v) for k, v in data.items()}
try:
return s.format(**data)
except Exception:
return None
# simple field name
v = _deep_get(values, s)
return "" if v is None else str(v)
def _build_href(spec: Dict[str, Any], row: Dict[str, Any], obj) -> Optional[str]:
if not spec:
return None
@ -1045,9 +594,8 @@ def _build_href(spec: Dict[str, Any], row: Dict[str, Any], obj) -> Optional[str]
if any(v is None for v in params.values()):
return None
try:
return url_for(spec["endpoint"], **params)
return url_for('crudkit.' + spec["endpoint"], **params)
except Exception as e:
print(f"Cannot create endpoint for {spec['endpoint']}: {str(e)}")
return None
def _humanize(field: str) -> str:
@ -1120,8 +668,6 @@ def render_field(field, value):
attrs=_sanitize_attrs(field.get('attrs') or {}),
label_attrs=_sanitize_attrs(field.get('label_attrs') or {}),
help=field.get('help'),
value_label=field.get('value_label'),
link_href=field.get("link_href"),
)
@ -1216,7 +762,7 @@ def render_form(
base = name[:-3]
rel_prop = mapper.relationships.get(base)
if isinstance(rel_prop, RelationshipProperty) and rel_prop.direction.name == "MANYTOONE":
values_map[name] = _coerce_fk_value(values, instance, base, rel_prop) # add rel_prop
values_map[name] = _coerce_fk_value(values, instance, base)
else:
# Auto-generate path (your original behavior)
@ -1249,7 +795,7 @@ def render_form(
fk_fields.add(f"{base}_id")
# NEW: set the current selection for this dropdown
values_map[f"{base}_id"] = _coerce_fk_value(values, instance, base, prop)
values_map[f"{base}_id"] = _coerce_fk_value(values, instance, base)
# Then plain columns
for col in model_cls.__table__.columns:
@ -1269,37 +815,15 @@ def render_form(
field["wrap"] = _sanitize_attrs(field["wrap"])
fields.append(field)
if submit_attrs:
if submit_attrs:
submit_attrs = _sanitize_attrs(submit_attrs)
common_ctx = {"values": values_map, "instance": instance, "model_cls": model_cls, "session": session}
for f in fields:
if f.get("type") == "template":
base = dict(common_ctx)
base.update(f.get("template_ctx") or {})
f["template_ctx"] = base
for f in fields:
# existing FK label resolution
vl = _value_label_for_field(f, mapper, values_map, instance, session)
if vl is not None:
f["value_label"] = vl
# NEW: if not a relationship but a label_spec is provided, format from values
elif f.get("label_spec"):
base, rel_prop = _rel_for_id_name(mapper, f["name"])
if not rel_prop: # scalar field
vl2 = _format_label_from_values(f["label_spec"], values_map)
if vl2 is not None:
f["value_label"] = vl2
link_spec = f.get("link")
if link_spec:
try:
href = _build_href(link_spec, values_map, instance)
except Exception:
href = None
if href:
f["link_href"] = href
common_ctx = {"values": values_map, "instance": instance, "model_cls": model_cls, "session": session}
for f in fields:
if f.get("type") == "template":
base = dict(common_ctx)
base.update(f.get("template_ctx") or {})
f["template_ctx"] = base
# Build rows (supports nested layout with parents)
rows_map = _normalize_rows_layout(layout)
@ -1311,6 +835,5 @@ def render_form(
values=values_map,
render_field=render_field,
submit_attrs=submit_attrs,
submit_label=submit_label,
model_name=model_cls.__name__
submit_label=submit_label
)

View file

@ -4,13 +4,7 @@
{% if label_attrs %}{% for k,v in label_attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}>
{% if link_href %}
<a href="{{ link_href }}">
{% endif %}
{{ field_label }}
{% if link_href %}
</a>
{% endif %}
{{ field_label }}
</label>
{% endif %}
@ -36,7 +30,7 @@
<textarea name="{{ field_name }}" id="{{ field_name }}"
{% if attrs %}{% for k,v in attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}>{{ value if value else "" }}</textarea>
{% endfor %}{% endif %}>{{ value }}</textarea>
{% elif field_type == 'checkbox' %}
<input type="checkbox" name="{{ field_name }}" id="{{ field_name }}" value="1"
@ -46,33 +40,15 @@
{% endfor %}{% endif %}>
{% elif field_type == 'hidden' %}
<input type="hidden" name="{{ field_name }}" id="{{ field_name }}" value="{{ value if value else "" }}">
<input type="hidden" name="{{ field_name }}" id="{{ field_name }}" value="{{ value }}">
{% elif field_type == 'display' %}
<div {% if attrs %}{% for k,v in attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}>{{ value_label if value_label else (value if value else "") }}</div>
{% elif field_type == "date" %}
<input type="date" name="{{ field_name }}" id="{{ field_name }}" value="{{ value if value else "" }}"
{% if attrs %}{% for k,v in attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}>
{% elif field_type == "time" %}
<input type="time" name="{{ field_name }}" id="{{ field_name }}" value="{{ value if value else "" }}"
{% if attrs %}{% for k,v in attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}>
{% elif field_type == "datetime" %}
<input type="datetime-local" name="{{ field_name }}" id="{{ field_name }}" value="{{ value if value else "" }}"
{% if attrs %}{% for k,v in attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}>
{% endfor %}{% endif %}>{{ value }}</div>
{% else %}
<input type="text" name="{{ field_name }}" id="{{ field_name }}" value="{{ value if value else "" }}"
<input type="text" name="{{ field_name }}" id="{{ field_name }}" value="{{ value }}"
{% if attrs %}{% for k,v in attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}>

View file

@ -1,5 +1,6 @@
<form method="POST" id="{{ model_name|lower }}_form">
<form method="POST">
{% macro render_row(row) %}
<!-- {{ row.name }} -->
{% if row.fields or row.children or row.legend %}
{% if row.legend %}<legend>{{ row.legend }}</legend>{% endif %}
<fieldset
@ -35,5 +36,5 @@
{% if submit_attrs %}{% for k,v in submit_attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}
>{{ submit_label if submit_label else 'Save' }}</button>
>{{ submit_label if label else 'Save' }}</button>
</form>