Compare commits

..

11 commits

Author SHA1 Message Date
Yaro Kasear
87f7108f2d Downstream patch. 2025-10-23 16:00:57 -05:00
Yaro Kasear
ec82ca2394 Downstream changes. 2025-10-22 12:44:00 -05:00
Yaro Kasear
f956e09e2b Lots of downstream fixes. 2025-10-10 09:23:45 -05:00
Yaro Kasear
90dd16baf4 Downstream updated. 2025-10-07 13:39:18 -05:00
Yaro Kasear
10b2843be8 Lots of downstream updates. 2025-10-03 16:27:25 -05:00
Yaro Kasear
f5bc0b5a30 Updates to spec and fragments. 2025-09-29 15:57:41 -05:00
Yaro Kasear
d4e51affd5 Lots of downstream changes. 2025-09-26 15:55:02 -05:00
Yaro Kasear
d34654834b Downstream changes and decoupled engine instance problem. Fwee. 2025-09-24 16:20:28 -05:00
Yaro Kasear
8be6f917c7 Lots and lots and *lots* of downstream updates. 2025-09-23 16:03:32 -05:00
Yaro Kasear
6b56251d33 More field and form rendering logic. 2025-09-22 14:13:31 -05:00
2d837210c1 Merge pull request 'Redesign1' (#1) from Redesign1 into main
Reviewed-on: #1
2025-09-22 14:12:39 -05:00
15 changed files with 2654 additions and 674 deletions

View file

@ -1,21 +1,135 @@
import base64, json # crudkit/api/_cursor.py
from typing import Any
def encode_cursor(values: list[Any] | None, desc_flags: list[bool], backward: bool) -> str | None: from __future__ import annotations
import base64
import dataclasses
import datetime as _dt
import decimal as _dec
import hmac
import json
import typing as _t
from hashlib import sha256
Any = _t.Any
@dataclasses.dataclass(frozen=True)
class Cursor:
values: list[Any]
desc_flags: list[bool]
backward: bool
version: int = 1
def _json_default(o: Any) -> Any:
# Keep it boring and predictable.
if isinstance(o, (_dt.datetime, _dt.date, _dt.time)):
# ISO is good enough; assume UTC-aware datetimes are already normalized upstream.
return o.isoformat()
if isinstance(o, _dec.Decimal):
return str(o)
# UUIDs, Enums, etc.
if hasattr(o, "__str__"):
return str(o)
raise TypeError(f"Unsupported cursor value type: {type(o)!r}")
def _b64url_nopad_encode(b: bytes) -> str:
return base64.urlsafe_b64encode(b).rstrip(b"=").decode("ascii")
def _b64url_nopad_decode(s: str) -> bytes:
# Restore padding for strict decode
pad = "=" * (-len(s) % 4)
return base64.urlsafe_b64decode((s + pad).encode("ascii"))
def encode_cursor(
values: list[Any] | None,
desc_flags: list[bool],
backward: bool,
*,
secret: bytes | None = None,
) -> str | None:
"""
Create an opaque, optionally signed cursor token.
- values: keyset values for the last/first row
- desc_flags: per-key descending flags (same length/order as sort keys)
- backward: whether the window direction is backward
- secret: if set, an HMAC-SHA256 signature is appended (JWT-ish 'payload.sig')
"""
if not values: if not values:
return None return None
payload = {"v": values, "d": desc_flags, "b": backward}
return base64.urlsafe_b64encode(json.dumps(payload).encode()).decode()
def decode_cursor(token: str | None) -> tuple[list[Any] | None, bool] | tuple[None, bool]: payload = {
"ver": 1,
"v": values,
"d": desc_flags,
"b": bool(backward),
}
body = json.dumps(payload, default=_json_default, separators=(",", ":"), sort_keys=True).encode("utf-8")
token = _b64url_nopad_encode(body)
if secret:
sig = hmac.new(secret, body, sha256).digest()
token = f"{token}.{_b64url_nopad_encode(sig)}"
return token
def decode_cursor(
token: str | None,
*,
secret: bytes | None = None,
) -> tuple[list[Any] | None, list[bool] | None, bool]:
"""
Parse a cursor token. Returns (values, desc_flags, backward).
- Accepts legacy tokens lacking 'd' and returns desc_flags=None in that case.
- If secret is provided, verifies HMAC and rejects tampered tokens.
- On any parse failure, returns (None, None, False).
"""
if not token: if not token:
return None, False return None, None, False
try: try:
obj = json.loads(base64.urlsafe_b64decode(token.encode()).decode()) # Split payload.sig if signed
if "." in token:
body_b64, sig_b64 = token.split(".", 1)
body = _b64url_nopad_decode(body_b64)
if secret is None:
# Caller didnt ask for verification; still parse but dont trust.
pass
else:
expected = hmac.new(secret, body, sha256).digest()
actual = _b64url_nopad_decode(sig_b64)
if not hmac.compare_digest(expected, actual):
return None, None, False
else:
body = _b64url_nopad_decode(token)
obj = json.loads(body.decode("utf-8"))
# Versioning. If we ever change fields, branch here.
ver = int(obj.get("ver", 0))
if ver not in (0, 1):
return None, None, False
vals = obj.get("v") vals = obj.get("v")
backward = bool(obj.get("b", False)) backward = bool(obj.get("b", False))
if isinstance(vals, list):
return vals, backward # desc_flags may be absent in legacy payloads (ver 0)
desc = obj.get("d")
if not isinstance(vals, list):
return None, None, False
if desc is not None and not (isinstance(desc, list) and all(isinstance(x, bool) for x in desc)):
# Ignore weird 'd' types rather than crashing
desc = None
return vals, desc, backward
except Exception: except Exception:
pass # Be tolerant on decode: treat as no-cursor.
return None, False return None, None, False

View file

@ -1,90 +1,195 @@
from flask import Blueprint, jsonify, request # crudkit/api/flask_api.py
from __future__ import annotations
from flask import Blueprint, jsonify, request, abort, current_app, url_for
from hashlib import md5
from urllib.parse import urlencode
from werkzeug.exceptions import HTTPException
from crudkit.api._cursor import encode_cursor, decode_cursor
from crudkit.core.service import _is_truthy from crudkit.core.service import _is_truthy
def generate_crud_blueprint(model, service): MAX_JSON = 1_000_000
bp = Blueprint(model.__name__.lower(), __name__)
@bp.get('/') def _etag_for(obj) -> str:
def list_items(): v = getattr(obj, "updated_at", None) or obj.id
args = request.args.to_dict(flat=True) return md5(str(v).encode()).hexdigest()
# legacy detection def _json_payload() -> dict:
legacy_offset = "offset" in args or "page" in args if request.content_length and request.content_length > MAX_JSON:
abort(413)
if not request.is_json:
abort(415)
payload = request.get_json(silent=False)
if not isinstance(payload, dict):
abort(400)
return payload
# sane limit default def _args_flat() -> dict[str, str]:
try: return request.args.to_dict(flat=True) # type: ignore[arg-type]
limit = int(args.get("limit", 50))
except Exception:
limit = 50
args["limit"] = limit
if legacy_offset: def _json_error(e: Exception, status: int = 400):
# Old behavior: honor limit/offset, same CRUDSpec goodies if isinstance(e, HTTPException):
items = service.list(args) status = e.code or status
return jsonify([obj.as_dict() for obj in items]) msg = e.description
else:
msg = str(e)
if current_app.debug:
return jsonify({"status": "error", "error": msg, "type": e.__class__.__name__}), status
return jsonify({"status": "error", "error": msg}), status
# New behavior: keyset seek with cursors def _bool_param(d: dict[str, str], key: str, default: bool) -> bool:
key, backward = decode_cursor(args.get("cursor")) return _is_truthy(d.get(key, "1" if default else "0"))
window = service.seek_window( def generate_crud_blueprint(model, service, *, base_prefix: str | None = None, rest: bool = True, rpc: bool = True):
args, """
key=key, REST:
backward=backward, GET /api/<models>/ -> list (filters via ?q=..., sort=..., limit=..., cursor=...)
include_total=_is_truthy(args.get("include_total", "1")), GET /api/<models>/<id> -> get
) POST /api/<models>/ -> create
PATCH /api/<models>/<id> -> update (partial)
DELETE /api/<models>/<id>[?hard=1] -> delete
desc_flags = list(window.order.desc) RPC (legacy):
body = { GET /api/<model>/get?id=123
"items": [obj.as_dict() for obj in window.items], GET /api/<model>/list
"limit": window.limit, GET /api/<model>/seek_window
"next_cursor": encode_cursor(window.last_key, desc_flags, backward=False), GET /api/<model>/page
"prev_cursor": encode_cursor(window.first_key, desc_flags, backward=True), POST /api/<model>/create
"total": window.total, PATCH /api/<model>/update?id=123
} DELETE /api/<model>/delete?id=123[&hard=1]
"""
model_name = model.__name__.lower()
# bikeshed if you want pluralization; this is the least-annoying default
collection = (base_prefix or model_name).lower()
plural = collection if collection.endswith('s') else f"{collection}s"
resp = jsonify(body) bp = Blueprint(plural, __name__, url_prefix=f"/api/{plural}")
# Optional Link header
links = []
if body["next_cursor"]:
links.append(f'<{request.base_url}?cursor={body["next_cursor"]}&limit={window.limit}>; rel="next"')
if body["prev_cursor"]:
links.append(f'<{request.base_url}?cursor={body["prev_cursor"]}&limit={window.limit}>; rel="prev"')
if links:
resp.headers["Link"] = ", ".join(links)
return resp
@bp.get('/<int:id>') @bp.errorhandler(Exception)
def get_item(id): def _handle_any(e: Exception):
item = service.get(id, request.args) return _json_error(e)
try:
return jsonify(item.as_dict())
except Exception as e:
return jsonify({"status": "error", "error": str(e)})
@bp.post('/') @bp.errorhandler(404)
def create_item(): def _not_found(_e):
obj = service.create(request.json) return jsonify({"status": "error", "error": "not found"}), 404
try:
return jsonify(obj.as_dict())
except Exception as e:
return jsonify({"status": "error", "error": str(e)})
@bp.patch('/<int:id>') # ---------- REST ----------
def update_item(id): if rest:
obj = service.update(id, request.json) @bp.get("/")
try: def rest_list():
return jsonify(obj.as_dict()) args = _args_flat()
except Exception as e: # support cursor pagination transparently; fall back to limit/offset
return jsonify({"status": "error", "error": str(e)}) try:
items = service.list(args)
return jsonify([o.as_dict() for o in items])
except Exception as e:
return _json_error(e)
@bp.delete('/<int:id>') @bp.get("/<int:obj_id>")
def delete_item(id): def rest_get(obj_id: int):
service.delete(id) item = service.get(obj_id, request.args)
try: if item is None:
return jsonify({"status": "success"}), 204 abort(404)
except Exception as e: etag = _etag_for(item)
return jsonify({"status": "error", "error": str(e)}) if request.if_none_match and (etag in request.if_none_match):
return "", 304
resp = jsonify(item.as_dict())
resp.set_etag(etag)
return resp
@bp.post("/")
def rest_create():
payload = _json_payload()
try:
obj = service.create(payload)
resp = jsonify(obj.as_dict())
resp.status_code = 201
resp.headers["Location"] = url_for(f"{plural}.rest_get", obj_id=obj.id, _external=False)
return resp
except Exception as e:
return _json_error(e)
@bp.patch("/<int:obj_id>")
def rest_update(obj_id: int):
payload = _json_payload()
try:
obj = service.update(obj_id, payload)
return jsonify(obj.as_dict())
except Exception as e:
return _json_error(e)
@bp.delete("/<int:obj_id>")
def rest_delete(obj_id: int):
hard = _bool_param(_args_flat(), "hard", False) # type: ignore[arg-type]
try:
obj = service.delete(obj_id, hard=hard)
if obj is None:
abort(404)
return ("", 204)
except Exception as e:
return _json_error(e)
# ---------- RPC (your existing routes) ----------
if rpc:
@bp.get("/get")
def rpc_get():
print("⚠️ WARNING: Deprecated RPC call used: /get")
id_ = int(request.args.get("id", 0))
if not id_:
return jsonify({"status": "error", "error": "missing required param: id"}), 400
try:
item = service.get(id_, request.args)
if item is None:
abort(404)
return jsonify(item.as_dict())
except Exception as e:
return _json_error(e)
@bp.get("/list")
def rpc_list():
print("⚠️ WARNING: Deprecated RPC call used: /list")
args = _args_flat()
try:
items = service.list(args)
return jsonify([obj.as_dict() for obj in items])
except Exception as e:
return _json_error(e)
@bp.post("/create")
def rpc_create():
print("⚠️ WARNING: Deprecated RPC call used: /create")
payload = _json_payload()
try:
obj = service.create(payload)
return jsonify(obj.as_dict()), 201
except Exception as e:
return _json_error(e)
@bp.patch("/update")
def rpc_update():
print("⚠️ WARNING: Deprecated RPC call used: /update")
id_ = int(request.args.get("id", 0))
if not id_:
return jsonify({"status": "error", "error": "missing required param: id"}), 400
payload = _json_payload()
try:
obj = service.update(id_, payload)
return jsonify(obj.as_dict())
except Exception as e:
return _json_error(e)
@bp.delete("/delete")
def rpc_delete():
print("⚠️ WARNING: Deprecated RPC call used: /delete")
id_ = int(request.args.get("id", 0))
if not id_:
return jsonify({"status": "error", "error": "missing required param: id"}), 400
hard = _bool_param(_args_flat(), "hard", False) # type: ignore[arg-type]
try:
obj = service.delete(id_, hard=hard)
return ("", 204) if obj is not None else abort(404)
except Exception as e:
return _json_error(e)
return bp return bp

View file

@ -74,18 +74,32 @@ def apply_pagination(sel: Select, backend: BackendInfo, *, page: int, per_page:
per_page = max(1, int(per_page)) per_page = max(1, int(per_page))
offset = (page - 1) * per_page offset = (page - 1) * per_page
if backend.requires_order_by_for_offset and not sel._order_by_clauses: if backend.requires_order_by_for_offset:
if default_order_by is None: # Avoid private attribute if possible:
sel = sel.order_by(text("1")) has_order = bool(getattr(sel, "_order_by_clauses", ())) # fallback for SA < 2.0.30
else: try:
sel = sel.order_by(default_order_by) has_order = has_order or bool(sel.get_order_by())
except Exception:
pass
if not has_order:
if default_order_by is not None:
sel = sel.order_by(default_order_by)
else:
# Try to find a primary key from the FROMs; fall back to a harmless literal.
try:
first_from = sel.get_final_froms()[0]
pk = next(iter(first_from.primary_key.columns))
sel = sel.order_by(pk)
except Exception:
sel = sel.order_by(text("1"))
return sel.limit(per_page).offset(offset) return sel.limit(per_page).offset(offset)
@contextmanager @contextmanager
def maybe_identify_insert(session: Session, table, backend: BackendInfo): def maybe_identify_insert(session: Session, table, backend: BackendInfo):
""" """
For MSSQL tables with IDENTIFY PK when you need to insert explicit IDs. For MSSQL tables with IDENTITY PK when you need to insert explicit IDs.
No-op elsewhere. No-op elsewhere.
""" """
if not backend.is_mssql: if not backend.is_mssql:
@ -93,7 +107,7 @@ def maybe_identify_insert(session: Session, table, backend: BackendInfo):
return return
full_name = f"{table.schema}.{table.name}" if table.schema else table.name full_name = f"{table.schema}.{table.name}" if table.schema else table.name
session.execute(text(f"SET IDENTIFY_INSERT {full_name} ON")) session.execute(text(f"SET IDENTITY_INSERT {full_name} ON"))
try: try:
yield yield
finally: finally:
@ -101,7 +115,7 @@ def maybe_identify_insert(session: Session, table, backend: BackendInfo):
def chunked_in(column, values: Iterable, backend: BackendInfo, chunk_size: Optional[int] = None) -> ClauseElement: def chunked_in(column, values: Iterable, backend: BackendInfo, chunk_size: Optional[int] = None) -> ClauseElement:
""" """
Build a safe large IN() filter respecting bund param limits. Build a safe large IN() filter respecting bind param limits.
Returns a disjunction of chunked IN clauses if needed. Returns a disjunction of chunked IN clauses if needed.
""" """
vals = list(values) vals = list(values)
@ -120,3 +134,12 @@ def chunked_in(column, values: Iterable, backend: BackendInfo, chunk_size: Optio
for p in parts[1:]: for p in parts[1:]:
expr = expr | p expr = expr | p
return expr return expr
def sql_trim(expr, backend: BackendInfo):
"""
Portable TRIM. SQL Server before compat level 140 lacks TRIM().
Emit LTRIM(RTRIM(...)) there; use TRIM elsewhere
"""
if backend.is_mssql:
return func.ltrim(func.rtrim(expr))
return func.trim(expr)

View file

@ -187,6 +187,8 @@ class Config:
"synchronous": os.getenv("SQLITE_SYNCHRONOUS", "NORMAL"), "synchronous": os.getenv("SQLITE_SYNCHRONOUS", "NORMAL"),
} }
STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "1")))
@classmethod @classmethod
def engine_kwargs(cls) -> Dict[str, Any]: def engine_kwargs(cls) -> Dict[str, Any]:
url = cls.DATABASE_URL url = cls.DATABASE_URL
@ -221,15 +223,18 @@ class Config:
class DevConfig(Config): class DevConfig(Config):
DEBUG = True DEBUG = True
SQLALCHEMY_ECHO = bool(int(os.getenv("DB_ECHO", "1"))) SQLALCHEMY_ECHO = bool(int(os.getenv("DB_ECHO", "1")))
STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "1")))
class TestConfig(Config): class TestConfig(Config):
TESTING = True TESTING = True
DATABASE_URL = build_database_url(backend="sqlite", database=":memory:") DATABASE_URL = build_database_url(backend="sqlite", database=":memory:")
SQLALCHEMY_ECHO = False SQLALCHEMY_ECHO = False
STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "1")))
class ProdConfig(Config): class ProdConfig(Config):
DEBUG = False DEBUG = False
SQLALCHEMY_ECHO = bool(int(os.getenv("DB_ECHO", "0"))) SQLALCHEMY_ECHO = bool(int(os.getenv("DB_ECHO", "0")))
STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "0")))
def get_config(name: str | None) -> Type[Config]: def get_config(name: str | None) -> Type[Config]:
""" """

View file

@ -0,0 +1,9 @@
# crudkit/core/__init__.py
from .utils import (
ISO_DT_FORMATS,
normalize_payload,
deep_diff,
diff_to_patch,
filter_to_columns,
to_jsonable,
)

View file

@ -1,47 +1,358 @@
from sqlalchemy import Column, Integer, DateTime, Boolean, String, JSON, func from functools import lru_cache
from sqlalchemy.orm import declarative_mixin, declarative_base from typing import Any, Dict, Iterable, List, Optional, Tuple, Set, cast
from sqlalchemy import Column, Integer, DateTime, Boolean, String, JSON, func, inspect
from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE, RelationshipProperty, Mapper
from sqlalchemy.orm.state import InstanceState
Base = declarative_base() Base = declarative_base()
@lru_cache(maxsize=512)
def _column_names_for_model(cls: type) -> tuple[str, ...]:
try:
mapper = inspect(cls)
return tuple(prop.key for prop in mapper.column_attrs)
except Exception:
names: list[str] = []
for c in cls.__mro__:
if hasattr(c, "__table__"):
names.extend(col.name for col in c.__table__.columns)
return tuple(dict.fromkeys(names))
def _sa_state(obj: Any) -> Optional[InstanceState[Any]]:
"""Safely get SQLAlchemy InstanceState (or None)."""
try:
st = inspect(obj)
return cast(Optional[InstanceState[Any]], st)
except Exception:
return None
def _sa_mapper(obj: Any) -> Optional[Mapper]:
"""Safely get Mapper for a maooed instance (or None)."""
try:
st = inspect(obj)
mapper = getattr(st, "mapper", None)
return cast(Optional[Mapper], mapper)
except Exception:
return None
def _safe_get_loaded_attr(obj, name):
st = _sa_state(obj)
if st is None:
return None
try:
st_dict = getattr(st, "dict", {})
if name in st_dict:
return st_dict[name]
attrs = getattr(st, "attrs", None)
attr = None
if attrs is not None:
try:
attr = attrs[name]
except Exception:
try:
get = getattr(attrs, "get", None)
if callable(get):
attr = get(name)
except Exception:
attr = None
if attr is not None:
val = attr.loaded_value
return None if val is NO_VALUE else val
return None
except Exception:
return None
def _identity_key(obj) -> Tuple[type, Any]:
try:
st = inspect(obj)
return (type(obj), st.identity_key[1][0] if st.identity_key else id(obj))
except Exception:
return (type(obj), id(obj))
def _is_collection_rel(prop: RelationshipProperty) -> bool:
try:
return prop.uselist is True
except Exception:
return False
def _serialize_simple_obj(obj) -> Dict[str, Any]:
"""Columns only (no relationships)."""
out: Dict[str, Any] = {}
for name in _column_names_for_model(type(obj)):
try:
out[name] = getattr(obj, name)
except Exception:
out[name] = None
return out
def _serialize_loaded_rel(obj, name, *, depth: int, seen: Set[Tuple[type, Any]], embed: Set[str]) -> Any:
"""
Serialize relationship 'name' already loaded on obj.
- If in 'embed' (or depth > 0 for depth-based walk), recurse.
- Else, return None (dont lazy-load).
"""
val = _safe_get_loaded_attr(obj, name)
if val is None:
return None
# Decide whether to recurse into this relationship
should_recurse = (depth > 0) or (name in embed)
if isinstance(val, list):
if not should_recurse:
# Emit a light list of child primary data (id + a couple columns) without recursion.
return [_serialize_simple_obj(child) for child in val]
out = []
for child in val:
ik = _identity_key(child)
if ik in seen: # cycle guard
out.append({"id": getattr(child, "id", None)})
continue
seen.add(ik)
out.append(child.as_tree(embed_depth=max(depth - 1, 0), embed=embed, _seen=seen))
return out
# Scalar relationship
child = val
if not should_recurse:
return _serialize_simple_obj(child)
ik = _identity_key(child)
if ik in seen:
return {"id": getattr(child, "id", None)}
seen.add(ik)
return child.as_tree(embed_depth=max(depth - 1, 0), embed=embed, _seen=seen)
def _split_field_tokens(fields: Iterable[str]) -> Tuple[List[str], Dict[str, List[str]]]:
"""
Split requested fields into:
- scalars: ["label", "name"]
- collections: {"updates": ["id", "timestamp","content"], "owner": ["label"]}
Any dotted token "root.rest.of.path" becomes collections[root].append("rest.of.path").
Bare tokens ("foo") land in scalars.
"""
scalars: List[str] = []
groups: Dict[str, List[str]] = {}
for raw in fields:
f = str(raw).strip()
if not f:
continue
# bare token -> scalar
if "." not in f:
scalars.append(f)
continue
# dotted token -> group under root
root, tail = f.split(".", 1)
if not root or not tail:
continue
groups.setdefault(root, []).append(tail)
return scalars, groups
def _deep_get_loaded(obj: Any, dotted: str) -> Any:
"""
Deep get with no lazy loads:
- For all but the final hop, use _safe_get_loaded_attr (mapped-only, no getattr).
- For the final hop, try _safe_get_loaded_attr first; if None, fall back to getattr()
to allow computed properties/hybrids that rely on already-loaded columns.
"""
parts = dotted.split(".")
if not parts:
return None
cur = obj
# Traverse up to the parent of the last token safely
for part in parts[:-1]:
if cur is None:
return None
cur = _safe_get_loaded_attr(cur, part)
if cur is None:
return None
last = parts[-1]
# Try safe fetch on the last hop first
val = _safe_get_loaded_attr(cur, last)
if val is not None:
return val
# Fall back to getattr for computed/hybrid attributes on an already-loaded object
try:
return getattr(cur, last, None)
except Exception:
return None
def _serialize_leaf(obj: Any) -> Any:
"""
Lead serialization for values we put into as_dict():
- If object has as_dict(), call as_dict() with no args (caller controls field shapes).
- Else return value as-is (Flask/JSON encoder will handle datetimes, etc., via app config).
"""
if obj is None:
return None
ad = getattr(obj, "as_dict", None)
if callable(ad):
try:
return ad(None)
except Exception:
return str(obj)
return obj
def _serialize_collection(items: Iterable[Any], requested_tails: List[str]) -> List[Dict[str, Any]]:
"""
Turn a collection of ORM objects into list[dict] with exactly requested_tails,
where each tail can be dotted again (e.g., "author.label"). We do NOT lazy-load.
"""
out: List[Dict[str, Any]] = []
# Deduplicate while preserving order
uniq_tails = list(dict.fromkeys(requested_tails))
for child in (items or []):
row: Dict[str, Any] = {}
for tail in uniq_tails:
row[tail] = _deep_get_loaded(child, tail)
# ensure id present if exists and not already requested
try:
if "id" not in row and hasattr(child, "id"):
row["id"] = getattr(child, "id")
except Exception:
pass
out.append(row)
return out
@declarative_mixin @declarative_mixin
class CRUDMixin: class CRUDMixin:
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
created_at = Column(DateTime, default=func.now(), nullable=False) created_at = Column(DateTime, default=func.now(), nullable=False)
updated_at = Column(DateTime, default=func.now(), nullable=False, onupdate=func.now()) updated_at = Column(DateTime, default=func.now(), nullable=False, onupdate=func.now())
def as_tree(
self,
*,
embed_depth: int = 0,
embed: Iterable[str] | None = None,
_seen: Set[Tuple[type, Any]] | None = None,
) -> Dict[str, Any]:
"""
Recursive, NON-LAZY serializer.
- Always includes mapped columns.
- For relationships: only serializes those ALREADY LOADED.
- Recurses either up to embed_depth or for specific names in 'embed'.
- Keeps *_id columns alongside embedded objects.
- Cycle-safe via _seen.
"""
seen = _seen or set()
ik = _identity_key(self)
if ik in seen:
return {"id": getattr(self, "id", None)}
seen.add(ik)
data = _serialize_simple_obj(self)
# Determine which relationships to consider
try:
mapper = _sa_mapper(self)
embed_set = set(str(x).split(".", 1)[0] for x in (embed or []))
if mapper is None:
return data
st = _sa_state(self)
if st is None:
return data
for name, prop in mapper.relationships.items():
# Only touch relationships that are already loaded; never lazy-load here.
rel_loaded = getattr(st, "attrs", {}).get(name)
if rel_loaded is None or rel_loaded.loaded_value is NO_VALUE:
continue
data[name] = _serialize_loaded_rel(
self, name, depth=embed_depth, seen=seen, embed=embed_set
)
except Exception:
# If inspection fails, we just return columns.
pass
return data
def as_dict(self, fields: list[str] | None = None): def as_dict(self, fields: list[str] | None = None):
""" """
Serialize the instance. Serialize the instance.
- If 'fields' (possibly dotted) is provided, emit exactly those keys.
- Else, if '__crudkit_projection__' is set on the instance, emit those keys.
- Else, fall back to all mapped columns on this class hierarchy.
Always includes 'id' when present unless explicitly excluded.
"""
if fields is None:
fields = getattr(self, "__crudkit_projection__", None)
if fields: Behavior:
out = {} - If 'fields' (possibly dotted) is provided, emit exactly those keys.
if "id" not in fields and hasattr(self, "id"): * Bare tokens (e.g., "label", "owner") return the current loaded value.
out["id"] = getattr(self, "id") * Dotted tokens for one-to-many (e.g., "updates.id","updates.timestamp")
for f in fields: produce a single "updates" key containing a list of dicts with the requested child keys.
cur = self * Dotted tokens for many-to-one/one-to-one (e.g., "owner.label") emit the scalar under "owner.label".
for part in f.split("."): - Else, if '__crudkit_projection__' is set on the instance, use that.
if cur is None: - Else, fall back to all mapped columns on this class hierarchy.
break
cur = getattr(cur, part, None) Always includes 'id' when present unless explicitly excluded (i.e., fields explicitly provided without id).
out[f] = cur """
req = fields if fields is not None else getattr(self, "__crudkit_projection__", None)
if req:
# Normalize and split into (scalars, groups of dotted by root)
req_list = [p for p in (str(x).strip() for x in req) if p]
scalars, groups = _split_field_tokens(req_list)
out: Dict[str, Any] = {}
# Always include id unless the caller explicitly listed fields containing id
if "id" not in req_list and hasattr(self, "id"):
try:
out["id"] = getattr(self, "id")
except Exception:
pass
# Handle scalar tokens (may be columns, hybrids/properties, or relationships)
for name in scalars:
# Try loaded value first (never lazy-load)
val = _safe_get_loaded_attr(self, name)
# Final-hop getattr for root scalars (hybrids/@property) so they can compute.
if val is None:
try:
val = getattr(self, name)
except Exception:
val = None
# If it's a scalar ORM object (relationship), serialize its columns
mapper = _sa_mapper(val)
if mapper is not None:
out[name] = _serialize_simple_obj(val)
continue
# If it's a collection and no subfields were requested, emit a light list
if isinstance(val, (list, tuple)):
out[name] = [_serialize_leaf(v) for v in val]
else:
out[name] = val
# Handle dotted groups: root -> [tails]
for root, tails in groups.items():
root_val = _safe_get_loaded_attr(self, root)
if isinstance(root_val, (list, tuple)):
# one-to-many collection → list of dicts with the requested tails
out[root] = _serialize_collection(root_val, tails)
else:
# many-to-one or scalar dotted; place each full dotted path as key
for tail in tails:
dotted = f"{root}.{tail}"
out[dotted] = _deep_get_loaded(self, dotted)
# ← This was the placeholder before. We return the dict we just built.
return out return out
result = {} # Fallback: all mapped columns on this class hierarchy
result: Dict[str, Any] = {}
for cls in self.__class__.__mro__: for cls in self.__class__.__mro__:
if hasattr(cls, "__table__"): if hasattr(cls, "__table__"):
for column in cls.__table__.columns: for column in cls.__table__.columns:
name = column.name name = column.name
result[name] = getattr(self, name) try:
result[name] = getattr(self, name)
except Exception:
result[name] = None
return result return result
class Version(Base): class Version(Base):
__tablename__ = "versions" __tablename__ = "versions"

File diff suppressed because it is too large Load diff

View file

@ -1,9 +1,17 @@
from typing import List, Tuple, Set, Dict, Optional from dataclasses import dataclass
from sqlalchemy import asc, desc from typing import Any, List, Tuple, Set, Dict, Optional, Iterable
from sqlalchemy import and_, asc, desc, or_
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import aliased, selectinload from sqlalchemy.orm import aliased, selectinload
from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.attributes import InstrumentedAttribute
@dataclass(frozen=True)
class CollPred:
table: Any
col_key: str
op: str
value: Any
OPERATORS = { OPERATORS = {
'eq': lambda col, val: col == val, 'eq': lambda col, val: col == val,
'lt': lambda col, val: col < val, 'lt': lambda col, val: col < val,
@ -12,6 +20,8 @@ OPERATORS = {
'gte': lambda col, val: col >= val, 'gte': lambda col, val: col >= val,
'ne': lambda col, val: col != val, 'ne': lambda col, val: col != val,
'icontains': lambda col, val: col.ilike(f"%{val}%"), 'icontains': lambda col, val: col.ilike(f"%{val}%"),
'in': lambda col, val: col.in_(val if isinstance(val, (list, tuple, set)) else [val]),
'nin': lambda col, val: ~col.in_(val if isinstance(val, (list, tuple, set)) else [val]),
} }
class CRUDSpec: class CRUDSpec:
@ -20,12 +30,138 @@ class CRUDSpec:
self.params = params self.params = params
self.root_alias = root_alias self.root_alias = root_alias
self.eager_paths: Set[Tuple[str, ...]] = set() self.eager_paths: Set[Tuple[str, ...]] = set()
# (parent_alias. relationship_attr, alias_for_target)
self.join_paths: List[Tuple[object, InstrumentedAttribute, object]] = [] self.join_paths: List[Tuple[object, InstrumentedAttribute, object]] = []
self.alias_map: Dict[Tuple[str, ...], object] = {} self.alias_map: Dict[Tuple[str, ...], object] = {}
self._root_fields: List[InstrumentedAttribute] = [] self._root_fields: List[InstrumentedAttribute] = []
self._rel_field_names: Dict[Tuple[str, ...], object] = {} # dotted non-collection fields (MANYTOONE etc)
self._rel_field_names: Dict[Tuple[str, ...], List[str]] = {}
# dotted collection fields (ONETOMANY)
self._collection_field_names: Dict[str, List[str]] = {}
self.include_paths: Set[Tuple[str, ...]] = set() self.include_paths: Set[Tuple[str, ...]] = set()
def _split_path_and_op(self, key: str) -> tuple[str, str]:
if '__' in key:
path, op = key.rsplit('__', 1)
else:
path, op = key, 'eq'
return path, op
def _resolve_many_columns(self, path: str) -> list[tuple[InstrumentedAttribute, Optional[tuple[str, ...]]]]:
"""
Accepts pipe-delimited paths like 'label|owner.label'
Returns a list of (column, join_path) pairs for every resolvable subpath.
"""
cols: list[tuple[InstrumentedAttribute, Optional[tuple[str, ...]]]] = []
for sub in path.split('|'):
sub = sub.strip()
if not sub:
continue
col, join_path = self._resolve_column(sub)
if col is not None:
cols.append((col, join_path))
return cols
def _build_predicate_for(self, path: str, op: str, value: Any):
"""
Builds a SQLA BooleanClauseList or BinaryExpression for a single key.
If multiple subpaths are provided via pipe, returns an OR of them.
"""
if op not in OPERATORS:
return None
pairs = self._resolve_many_columns(path)
if not pairs:
return None
exprs = []
for col, join_path in pairs:
if join_path:
self.eager_paths.add(join_path)
try:
cur_cls = self.model
names = list(join_path)
last_name = names[-1]
is_collection = False
for nm in names:
rel_attr = getattr(cur_cls, nm)
prop = rel_attr.property
cur_cls = prop.mapper.class_
is_collection = bool(getattr(getattr(self.model, last_name), "property", None)
and getattr(getattr(self.model, last_name).property, "uselist", False))
except Exception:
is_collection = False
if is_collection:
target_cls = cur_cls
key = getattr(col, "key", None) or getattr(col, "name", None)
if key and hasattr(target_cls, key):
target_tbl = getattr(target_cls, "__table__", None)
if target_tbl is not None:
exprs.append(CollPred(table=target_tbl, col_key=key, op=op, value=value))
continue
exprs.append(OPERATORS[op](col, value))
if not exprs:
return None
# If any CollPred is in exprs, do NOT or_ them. Keep it single for now.
if any(isinstance(x, CollPred) for x in exprs):
# If someone used a pipe 'relA.col|relB.col' that produced multiple CollPreds,
# keep the first or raise for now (your choice).
if len(exprs) > 1:
# raise NotImplementedError("OR across collection paths not supported yet")
exprs = [next(x for x in exprs if isinstance(x, CollPred))]
return exprs[0]
# Otherwise, standard SQLA clause(s)
return exprs[0] if len(exprs) == 1 else or_(*exprs)
def _collect_filters(self, params: dict) -> list:
"""
Recursively parse filters from 'param' into a flat list of SQLA expressions.
Supports $or / $and groups. Any other keys are parsed as normal filters.
"""
filters: list = []
for key, value in (params or {}).items():
if key in ('sort', 'limit', 'offset', 'fields', 'include'):
continue
if key == '$or':
# value should be a list of dicts
groups = []
for group in value if isinstance(value, (list, tuple)) else []:
sub = self._collect_filters(group)
if not sub:
continue
groups.append(and_(*sub) if len(sub) > 1 else sub[0])
if groups:
filters.append(or_(*groups))
continue
if key == '$and':
# value should be a list of dicts
parts = []
for group in value if isinstance(value, (list, tuple)) else []:
sub = self._collect_filters(group)
if not sub:
continue
parts.append(and_(*sub) if len(sub) > 1 else sub[0])
if parts:
filters.append(and_(*parts))
continue
# Normal key
path, op = self._split_path_and_op(key)
pred = self._build_predicate_for(path, op, value)
if pred is not None:
filters.append(pred)
return filters
def _resolve_column(self, path: str): def _resolve_column(self, path: str):
current_alias = self.root_alias current_alias = self.root_alias
parts = path.split('.') parts = path.split('.')
@ -68,24 +204,12 @@ class CRUDSpec:
if maybe: if maybe:
self.eager_paths.add(maybe) self.eager_paths.add(maybe)
def parse_filters(self): def parse_filters(self, params: dict | None = None):
filters = [] """
for key, value in self.params.items(): Public entry: parse filters from given params or self.params.
if key in ('sort', 'limit', 'offset'): Returns a list of SQLAlchemy filter expressions
continue """
if '__' in key: return self._collect_filters(params if params is not None else self.params)
path_op = key.rsplit('__', 1)
if len(path_op) != 2:
continue
path, op = path_op
else:
path, op = key, 'eq'
col, join_path = self._resolve_column(path)
if col and op in OPERATORS:
filters.append(OPERATORS[op](col, value))
if join_path:
self.eager_paths.add(join_path)
return filters
def parse_sort(self): def parse_sort(self):
sort_args = self.params.get('sort', '') sort_args = self.params.get('sort', '')
@ -117,11 +241,12 @@ class CRUDSpec:
Parse ?fields=colA,colB,rel1.colC,rel1.rel2.colD Parse ?fields=colA,colB,rel1.colC,rel1.rel2.colD
- Root fields become InstrumentedAttributes bound to root_alias. - Root fields become InstrumentedAttributes bound to root_alias.
- Related fields store attribute NAMES; we'll resolve them on the target class when building loader options. - Related fields store attribute NAMES; we'll resolve them on the target class when building loader options.
Returns (root_fields, rel_field_names). - Collection (uselist=True) relationships record child names by relationship key.
Returns (root_fields, rel_field_names, root_field_names, collection_field_names_by_rel).
""" """
raw = self.params.get('fields') raw = self.params.get('fields')
if not raw: if not raw:
return [], {}, {} return [], {}, {}, {}
if isinstance(raw, list): if isinstance(raw, list):
tokens = [] tokens = []
@ -133,14 +258,36 @@ class CRUDSpec:
root_fields: List[InstrumentedAttribute] = [] root_fields: List[InstrumentedAttribute] = []
root_field_names: list[str] = [] root_field_names: list[str] = []
rel_field_names: Dict[Tuple[str, ...], List[str]] = {} rel_field_names: Dict[Tuple[str, ...], List[str]] = {}
collection_field_names: Dict[str, List[str]] = {}
for token in tokens: for token in tokens:
col, join_path = self._resolve_column(token) col, join_path = self._resolve_column(token)
if not col: if not col:
continue continue
if join_path: if join_path:
rel_field_names.setdefault(join_path, []).append(col.key) # rel_field_names.setdefault(join_path, []).append(col.key)
self.eager_paths.add(join_path) # self.eager_paths.add(join_path)
try:
cur_cls = self.model
names = list(join_path)
last_name = names[-1]
for nm in names:
rel_attr = getattr(cur_cls, nm)
cur_cls = rel_attr.property.mapper.class_
is_collection = bool(getattr(getattr(self.model, last_name), "property", None) and getattr(getattr(self.model, last_name).property, "uselist", False))
except Exception:
# Fallback: inspect the InstrumentedAttribute we recorded on join_paths
is_collection = False
for _pa, rel_attr, _al in self.join_paths:
if rel_attr.key == (join_path[-1] if join_path else ""):
is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False))
break
if is_collection:
collection_field_names.setdefault(join_path[-1], []).append(col.key)
else:
rel_field_names.setdefault(join_path, []).append(col.key)
self.eager_paths.add(join_path)
else: else:
root_fields.append(col) root_fields.append(col)
root_field_names.append(getattr(col, "key", token)) root_field_names.append(getattr(col, "key", token))
@ -153,7 +300,11 @@ class CRUDSpec:
self._root_fields = root_fields self._root_fields = root_fields
self._rel_field_names = rel_field_names self._rel_field_names = rel_field_names
return root_fields, rel_field_names, root_field_names # return root_fields, rel_field_names, root_field_names
for r, names in collection_field_names.items():
seen3 = set()
collection_field_names[r] = [n for n in names if not (n in seen3 or seen3.add(n))]
return root_field_names, rel_field_names, root_field_names, collection_field_names
def get_eager_loads(self, root_alias, *, fields_map=None): def get_eager_loads(self, root_alias, *, fields_map=None):
loads = [] loads = []

176
crudkit/core/utils.py Normal file
View file

@ -0,0 +1,176 @@
from __future__ import annotations
from datetime import datetime, date
from decimal import Decimal
from enum import Enum
from typing import Any, Dict, Optional, Callable
from sqlalchemy import inspect
ISO_DT_FORMATS = ("%Y-%m-%dT%H:%M:%S.%f",
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%d %H:%M",
"%Y-%m-%d")
def to_jsonable(obj: Any):
"""Recursively convert values into JSON-serializable forms."""
if obj is None or isinstance(obj, (str, int, float, bool)):
return obj
if isinstance(obj, (datetime, date)):
return obj.isoformat()
if isinstance(obj, Decimal):
return float(obj)
if isinstance(obj, Enum):
return obj.value
if isinstance(obj, dict):
return {str(k): to_jsonable(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple, set)):
return [to_jsonable(v) for v in obj]
# fallback: strin-ify weird objects (UUID, ORM instances, etc.)
try:
return str(obj)
except Exception:
return None
def filter_to_columns(data: dict, model_cls):
cols = {c.key for c in inspect(model_cls).mapper.columns}
return {k: v for k, v in data.items() if k in cols}
def _parse_dt_maybe(x: Any) -> Any:
if isinstance(x, (datetime, date)):
return x
if isinstance(x, str):
s = x.strip().replace("Z", "+00:00") # tolerate Zulu
for fmt in ISO_DT_FORMATS:
try:
return datetime.strptime(s, fmt)
except ValueError:
pass
try:
return datetime.fromisoformat(s)
except Exception:
return x
return x
def _normalize_for_compare(x: Any) -> Any:
if isinstance(x, (str, datetime, date)):
return _parse_dt_maybe(x)
return x
def deep_diff(
old: Any,
new: Any,
*,
path: str = "",
ignore_keys: Optional[set] = None,
list_mode: str = "index", # "index" or "set"
custom_equal: Optional[Callable[[str, Any, Any], bool]] = None,
) -> Dict[str, Dict[str, Any]]:
if ignore_keys is None:
ignore_keys = set()
out: Dict[str, Dict[str, Any]] = {"added": {}, "removed": {}, "changed": {}}
def mark_changed(p, a, b):
out["changed"][p] = {"from": a, "to": b}
def rec(o, n, pfx):
if custom_equal and custom_equal(pfx.rstrip("."), o, n):
return
if isinstance(o, dict) and isinstance(n, dict):
o_keys = set(o.keys())
n_keys = set(n.keys())
for k in sorted(o_keys - n_keys):
if k not in ignore_keys:
out["removed"][f"{pfx}{k}"] = o[k]
for k in sorted(n_keys - o_keys):
if k not in ignore_keys:
out["added"][f"{pfx}{k}"] = n[k]
for k in sorted(o_keys & n_keys):
if k not in ignore_keys:
rec(o[k], n[k], f"{pfx}{k}.")
return
if isinstance(o, list) and isinstance(n, list):
if list_mode == "set":
if set(o) != set(n):
mark_changed(pfx.rstrip("."), o, n)
else:
max_len = max(len(o), len(n))
for i in range(max_len):
key = f"{pfx}[{i}]"
if i >= len(o):
out["added"][key] = n[i]
elif i >= len(n):
out["removed"][key] = o[i]
else:
rec(o[i], n[i], f"{key}.")
return
a = _normalize_for_compare(o)
b = _normalize_for_compare(n)
if a != b:
mark_changed(pfx.rstrip("."), o, n)
rec(old, new, path)
return out
def diff_to_patch(diff: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
"""
Produce a shallow patch of changed/added top-level fields.
Only includes leaf paths without dots/brackets; useful for simple UPDATEs.
"""
patch: Dict[str, Any] = {}
for k, v in diff["added"].items():
if "." not in k and "[" not in k:
patch[k] = v
for k, v in diff["changed"].items():
if "." not in k and "[" not in k:
patch[k] = v["to"]
return patch
def normalize_payload(payload: dict, model):
"""
Coerce incoming JSON into SQLAlchemy column types for the given model.
- "" or None -> None
- Integer/Boolean/Date/DateTime handled by column type
"""
from sqlalchemy import Integer, Boolean, DateTime, Date
out: Dict[str, Any] = {}
mapper = inspect(model).mapper
cols = {c.key: c.type for c in mapper.columns}
for field, value in payload.items():
if value == "" or value is None:
out[field] = None
continue
coltype = cols.get(field)
if coltype is None:
out[field] = value
continue
tname = coltype.__class__.__name__.lower()
if "integer" in tname:
out[field] = int(value)
elif "boolean" in tname:
out[field] = value if isinstance(value, bool) else str(value).lower() in ("1", "true", "yes", "on")
elif "datetime" in tname:
out[field] = value if isinstance(value, datetime) else _parse_dt_maybe(value)
elif "date" in tname:
v = _parse_dt_maybe(value)
out[field] = v.date() if isinstance(v, datetime) else v
else:
out[field] = value
return out

View file

@ -1,7 +1,8 @@
# engines.py
from __future__ import annotations from __future__ import annotations
from typing import Type, Optional from typing import Type, Optional
from sqlalchemy import create_engine from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker, raiseload, Mapper, RelationshipProperty
from .backend import make_backend_info, BackendInfo from .backend import make_backend_info, BackendInfo
from .config import Config, get_config from .config import Config, get_config
from ._sqlite import apply_sqlite_pragmas from ._sqlite import apply_sqlite_pragmas
@ -12,15 +13,31 @@ def build_engine(config_cls: Type[Config] | None = None):
apply_sqlite_pragmas(engine, config_cls.SQLITE_PRAGMAS) apply_sqlite_pragmas(engine, config_cls.SQLITE_PRAGMAS)
return engine return engine
def _install_nplus1_guards(SessionMaker, *, strict: bool):
if not strict:
return
@event.listens_for(SessionMaker, "do_orm_execute")
def _add_global_raiseload(execute_state):
stmt = execute_state.statement
# Only touch ORM statements (have column_descriptions)
if getattr(stmt, "column_descriptions", None):
execute_state.statement = stmt.options(raiseload("*"))
def build_sessionmaker(config_cls: Type[Config] | None = None, engine=None): def build_sessionmaker(config_cls: Type[Config] | None = None, engine=None):
config_cls = config_cls or get_config(None) config_cls = config_cls or get_config(None)
engine = engine or build_engine(config_cls) engine = engine or build_engine(config_cls)
return sessionmaker(bind=engine, **config_cls.session_kwargs()) SessionMaker = sessionmaker(bind=engine, **config_cls.session_kwargs())
# Toggle with a config flag; default off so you can turn it on when ready
strict = bool(getattr(config_cls, "STRICT_NPLUS1", False))
_install_nplus1_guards(SessionMaker, strict=strict)
return SessionMaker
class CRUDKitRuntime: class CRUDKitRuntime:
""" """
Lightweight container so CRUDKit can be given either: Lightweight container so CRUDKit can be given either:
- prebuild engine/sessionmaker, or - prebuilt engine/sessionmaker, or
- a Config to build them lazily - a Config to build them lazily
""" """
def __init__(self, *, engine=None, session_factory=None, config: Optional[Type[Config]] = None): def __init__(self, *, engine=None, session_factory=None, config: Optional[Type[Config]] = None):

View file

@ -1,20 +1,32 @@
# crudkit/integrations/flask.py
from __future__ import annotations from __future__ import annotations
from flask import Flask from flask import Flask
from sqlalchemy.orm import scoped_session from sqlalchemy.orm import scoped_session, sessionmaker
from ..engines import CRUDKitRuntime from ..engines import CRUDKitRuntime
from ..config import Config from ..config import Config
def init_app(app: Flask, *, runtime: CRUDKitRuntime | None = None, config: type[Config] | None == None): def init_app(app: Flask, *, runtime: CRUDKitRuntime | None = None, config: type[Config] | None = None):
""" """
Initializes CRUDKit for a Flask app. Provies `app.extensions['crudkit']` Initializes CRUDKit for a Flask app. Provides `app.extensions['crudkit']`
with a runtime (engine + session_factory). Caller manages session lifecycle. with a runtime (engine + session_factory). Caller manages session lifecycle.
""" """
runtime = runtime or CRUDKitRuntime(config=config) runtime = runtime or CRUDKitRuntime(config=config)
app.extensions.setdefault("crudkit", {}) app.extensions.setdefault("crudkit", {})
app.extensions["crudkit"]["runtime"] = runtime app.extensions["crudkit"]["runtime"] = runtime
Session = runtime.session_factory # Build ONE sessionmaker bound to the ONE true engine object
if Session is not None: # so engine id == sessionmaker.bind id, always.
app.extensions["crudkit"]["Session"] = scoped_session(Session) engine = runtime.engine
SessionFactory = runtime.session_factory or sessionmaker(bind=engine, **runtime._config.session_kwargs())
app.extensions["crudkit"]["SessionFactory"] = SessionFactory
app.extensions["crudkit"]["Session"] = scoped_session(SessionFactory)
# Attach pool listeners to the *same* engine the SessionFactory is bound to.
# Dont guess. Dont hope. Inspect.
try:
bound_engine = getattr(SessionFactory, "bind", None) or getattr(SessionFactory, "kw", {}).get("bind") or engine
pool = bound_engine.pool
except Exception as e:
print(f"[crudkit.init_app] Failed to attach pool listeners: {e}")
return runtime return runtime

236
crudkit/projection.py Normal file
View file

@ -0,0 +1,236 @@
# crudkit/projection.py
from __future__ import annotations
from typing import Iterable, List, Tuple, Dict, Set
from sqlalchemy.orm import selectinload
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.properties import ColumnProperty, RelationshipProperty
from sqlalchemy import inspect
# ----------------------
# small utilities
# ----------------------
def _is_column_attr(a) -> bool:
try:
return isinstance(a, InstrumentedAttribute) and isinstance(a.property, ColumnProperty)
except Exception:
return False
def _is_relationship_attr(a) -> bool:
try:
return isinstance(a, InstrumentedAttribute) and isinstance(a.property, RelationshipProperty)
except Exception:
return False
def _split_path(field: str) -> List[str]:
return [p for p in str(field).split(".") if p]
def _model_requires_map(model_cls) -> Dict[str, List[str]]:
# apps declare per-model deps, e.g. {"label": ["first_name","last_name","title"]}
return getattr(model_cls, "__crudkit_field_requires__", {}) or {}
def _relationships_of(model_cls) -> Dict[str, RelationshipProperty]:
try:
return dict(model_cls.__mapper__.relationships)
except Exception:
return {}
def _attr_on(model_cls, name: str):
return getattr(model_cls, name, None)
# ----------------------
# EXPAND: add required deps for leaf attributes at the correct class
# ----------------------
def _expand_requires_for_field(model_cls, pieces: List[str]) -> List[str]:
"""
Given a dotted path like ["owner","label"], walk relationships to the leaf *container* class,
pull its __crudkit_field_requires__ for that leaf attr ("label"), and yield prefixed deps:
owner.label -> ["owner.first_name", "owner.last_name", ...] if User requires so.
If leaf is a column (or has no requires), returns [].
"""
if not pieces:
return []
# walk relationships to the leaf container (class that owns the leaf attr)
container_cls = model_cls
prefix_parts: List[str] = []
for part in pieces[:-1]:
a = _attr_on(container_cls, part)
if not _is_relationship_attr(a):
return [] # can't descend; invalid or scalar in the middle
container_cls = a.property.mapper.class_
prefix_parts.append(part)
leaf = pieces[-1]
requires = _model_requires_map(container_cls).get(leaf) or []
if not requires:
return []
prefix = ".".join(prefix_parts)
out: List[str] = []
for dep in requires:
# dep may itself be dotted relative to container (e.g. "room_function.description")
if prefix:
out.append(f"{prefix}.{dep}")
else:
out.append(dep)
return out
def _expand_requires(model_cls, fields: Iterable[str]) -> List[str]:
"""
Dedup + stable expansion of requires for all fields.
"""
seen: Set[str] = set()
out: List[str] = []
def add(f: str):
if f not in seen:
seen.add(f)
out.append(f)
# first pass: add original
queue: List[str] = []
for f in fields:
f = str(f)
if f not in seen:
seen.add(f)
out.append(f)
queue.append(f)
# BFS-ish: when we add deps, they may trigger further deps downstream
while queue:
f = queue.pop(0)
deps = _expand_requires_for_field(model_cls, _split_path(f))
for d in deps:
if d not in seen:
seen.add(d)
out.append(d)
queue.append(d)
return out
# ----------------------
# BUILD loader options tree with selectinload + load_only on real columns
# ----------------------
def _insert_leaf(loader_tree: dict, path: List[str]):
"""
Build nested dict structure keyed by relationship names.
Each node holds:
{
"__cols__": set(column_names_to_load_only),
"<child_rel>": { ... }
}
"""
node = loader_tree
for rel in path[:-1]: # only relationship hops
node = node.setdefault(rel, {"__cols__": set()})
# leaf may be a column or a virtual/hybrid; only columns go to __cols__
node.setdefault("__cols__", set())
def _attach_column(loader_tree: dict, path: List[str], model_cls):
"""
If the leaf is a real column on the target class, record its name into __cols__ at that level.
"""
# descend to target class to test column-ness
container_cls = model_cls
node = loader_tree
for rel in path[:-1]:
a = _attr_on(container_cls, rel)
if not _is_relationship_attr(a):
return # invalid path, ignore
container_cls = a.property.mapper.class_
node = node.setdefault(rel, {"__cols__": set()})
leaf = path[-1]
a_leaf = _attr_on(container_cls, leaf)
node.setdefault("__cols__", set())
if _is_column_attr(a_leaf):
node["__cols__"].add(leaf)
def _build_loader_tree(model_cls, fields: Iterable[str]) -> dict:
"""
For each dotted field:
- walk relationships -> create nodes
- if leaf is a column: record it for load_only
- if leaf is not a column (hybrid/descriptor): no load_only; still ensure rel hops exist
"""
tree: Dict[str, dict] = {"__cols__": set()}
for f in fields:
parts = _split_path(f)
if not parts:
continue
# ensure relationship nodes exist
_insert_leaf(tree, parts)
# attach column if applicable
_attach_column(tree, parts, model_cls)
return tree
def _loader_options_from_tree(model_cls, tree: dict):
"""
Convert the loader tree into SQLAlchemy loader options:
selectinload(<rel>)[.load_only(cols)] recursively
"""
opts = []
rels = _relationships_of(model_cls)
for rel_name, child in tree.items():
if rel_name == "__cols__":
continue
rel_prop = rels.get(rel_name)
if not rel_prop:
continue
rel_attr = getattr(model_cls, rel_name)
opt = selectinload(rel_attr)
# apply load_only on the related class (only real columns recorded at child["__cols__"])
cols = list(child.get("__cols__", []))
if cols:
rel_model = rel_prop.mapper.class_
# map column names to attributes
col_attrs = []
for c in cols:
a = getattr(rel_model, c, None)
if _is_column_attr(a):
col_attrs.append(a)
if col_attrs:
opt = opt.load_only(*col_attrs)
# recurse to grandchildren
sub_opts = _loader_options_from_tree(rel_prop.mapper.class_, child)
for so in sub_opts:
opt = opt.options(so)
opts.append(opt)
# root-level columns (rare in our compile; kept for completeness)
root_cols = list(tree.get("__cols__", []))
if root_cols:
# NOTE: call-site can add a root load_only(...) if desired;
# we purposely return only relationship options here to keep
# the API simple and avoid mixing Load(model_cls) contexts.
pass
return opts
# ----------------------
# PUBLIC API
# ----------------------
def compile_projection(model_cls, fields: Iterable[str]) -> Tuple[List[str], List]:
"""
Returns:
expanded_fields: List[str] # original + declared dependencies
loader_options: List[Load] # apply via query = query.options(*loader_options)
Behavior:
- Expands __crudkit_field_requires__ at the leaf container class for every field.
- Builds a selectinload tree; load_only only includes real columns (no hybrids).
- Safe for nested paths: e.g. "owner.label" pulls owner deps from User.__crudkit_field_requires__.
"""
fields = list(fields or [])
expanded = _expand_requires(model_cls, fields)
tree = _build_loader_tree(model_cls, expanded)
options = _loader_options_from_tree(model_cls, tree)
return expanded, options

View file

@ -6,25 +6,279 @@ from flask import current_app, url_for
from jinja2 import Environment, FileSystemLoader, ChoiceLoader from jinja2 import Environment, FileSystemLoader, ChoiceLoader
from sqlalchemy import inspect from sqlalchemy import inspect
from sqlalchemy.orm import Load, RelationshipProperty, class_mapper, load_only, selectinload from sqlalchemy.orm import Load, RelationshipProperty, class_mapper, load_only, selectinload
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.base import NO_VALUE from sqlalchemy.orm.base import NO_VALUE
from sqlalchemy.orm.properties import ColumnProperty, RelationshipProperty
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import crudkit
_ALLOWED_ATTRS = { _ALLOWED_ATTRS = {
"class", "placeholder", "autocomplete", "inputmode", "pattern", "class", "placeholder", "autocomplete", "inputmode", "pattern",
"min", "max", "step", "maxlength", "minlength", "min", "max", "step", "maxlength", "minlength",
"required", "readonly", "disabled", "required", "readonly", "disabled",
"multiple", "size", "multiple", "size", "rows",
"id", "name", "value", "id", "name", "value",
} }
_SAFE_CSS_PROPS = {
# spacing / sizing
"margin","margin-top","margin-right","margin-bottom","margin-left",
"padding","padding-top","padding-right","padding-bottom","padding-left",
"width","height","min-width","min-height","max-width","max-height", "resize",
# layout
"display","flex","flex-direction","flex-wrap","justify-content","align-items","gap",
# text
"font-size","font-weight","line-height","text-align","white-space",
# colors / background
"color","background-color",
# borders / radius
"border","border-top","border-right","border-bottom","border-left",
"border-width","border-style","border-color","border-radius",
# misc (safe-ish)
"opacity","overflow","overflow-x","overflow-y",
}
_num_unit = r"-?\d+(?:\.\d+)?"
_len_unit = r"(?:px|em|rem|%)"
P_LEN = re.compile(rf"^{_num_unit}(?:{_len_unit})?$") # 12, 12px, 1.2rem, 50%
P_GAP = P_LEN
P_INT = re.compile(r"^\d+$")
P_COLOR = re.compile(
r"^(#[0-9a-fA-F]{3,8}|"
r"rgb\(\s*\d{1,3}\s*,\s*\d{1,3}\s*,\s*\d{1,3}\s*\)|"
r"rgba\(\s*\d{1,3}\s*,\s*\d{1,3}\s*,\s*\d{1,3}\s*,\s*(?:0|1|0?\.\d+)\s*\)|"
r"[a-zA-Z]+)$"
)
_ENUMS = {
"display": {"block","inline","inline-block","flex","grid","none"},
"flex-direction": {"row","row-reverse","column","column-reverse"},
"flex-wrap": {"nowrap","wrap","wrap-reverse"},
"justify-content": {"flex-start","flex-end","center","space-between","space-around","space-evenly"},
"align-items": {"stretch","flex-start","flex-end","center","baseline"},
"text-align": {"left","right","center","justify","start","end"},
"white-space": {"normal","nowrap","pre","pre-wrap","pre-line","break-spaces"},
"border-style": {"none","solid","dashed","dotted","double","groove","ridge","inset","outset"},
"overflow": {"visible","hidden","scroll","auto","clip"},
"overflow-x": {"visible","hidden","scroll","auto","clip"},
"overflow-y": {"visible","hidden","scroll","auto","clip"},
"font-weight": {"normal","bold","bolder","lighter","100","200","300","400","500","600","700","800","900"},
"resize": {"none", "both", "horizontal", "vertical"},
}
def get_env(): def get_env():
"""
Return an overlay Jinja Environment that knows how to load crudkit templates
and has our helper functions available as globals.
"""
app = current_app app = current_app
default_path = os.path.join(os.path.dirname(__file__), 'templates') default_path = os.path.join(os.path.dirname(__file__), 'templates')
fallback_loader = FileSystemLoader(default_path) fallback_loader = FileSystemLoader(default_path)
return app.jinja_env.overlay( env = app.jinja_env.overlay(loader=ChoiceLoader([app.jinja_loader, fallback_loader]))
loader=ChoiceLoader([app.jinja_loader, fallback_loader]) # Ensure helpers are available even when we render via this overlay env.
) # These names are resolved at *call time* (not at def time), so it's safe.
try:
env.globals.setdefault("render_table", render_table)
env.globals.setdefault("render_form", render_form)
env.globals.setdefault("render_field", render_field)
except NameError:
# Functions may not be defined yet at import time; later calls will set them.
pass
return env
def register_template_globals(app=None):
"""
Register crudkit helpers as app-wide Jinja globals so they can be used
directly in any template via {{ render_table(...) }}, {{ render_form(...) }},
and {{ render_field(...) }}.
"""
if app is None:
app = current_app
# Idempotent install using an extension flag
installed = app.extensions.setdefault("crudkit_ui_helpers", set())
to_register = {
"render_table": render_table,
"render_form": render_form,
"render_field": render_field,
}
for name, fn in to_register.items():
if name not in installed:
app.add_template_global(fn, name)
installed.add(name)
def _fields_for_label_params(label_spec, related_model):
"""
Build a 'fields' list suitable for CRUDService.list() so labels render
without triggering lazy loads. Always includes 'id'.
"""
simple_cols, rel_paths = _extract_label_requirements(label_spec, related_model)
fields = set(["id"])
for c in simple_cols:
fields.add(c)
for rel_name, col_name in rel_paths:
if col_name == "__all__":
# just ensure relationship object is present; ask for rel.id
fields.add(f"{rel_name}.id")
else:
fields.add(f"{rel_name}.{col_name}")
return list(fields)
def _fk_options_via_service(related_model, label_spec, *, options_params: dict | None = None):
svc = crudkit.crud.get_service(related_model)
# default to unlimited results for dropdowns
params = {"limit": 0}
if options_params:
params.update(options_params) # caller can override limit if needed
# ensure fields needed to render the label are present (avoid lazy loads)
fields = _fields_for_label_params(label_spec, related_model)
if fields:
existing = params.get("fields")
if isinstance(existing, str):
existing = [s.strip() for s in existing.split(",") if s.strip()]
if isinstance(existing, (list, tuple)):
params["fields"] = list(dict.fromkeys(list(existing) + fields))
else:
params["fields"] = fields
# only set a default sort if caller didnt supply one
if "sort" not in params:
simple_cols, _ = _extract_label_requirements(label_spec, related_model)
params["sort"] = (simple_cols[0] if simple_cols else "id")
rows = svc.list(params)
return [
{"value": str(r.id), "label": _label_from_obj(r, label_spec)}
for r in rows
]
def expand_projection(model_cls, fields):
req = getattr(model_cls, "__crudkit_field_requires__", {}) or {}
out = set(fields)
for f in list(fields):
for dep in req.get(f, ()):
out.add(dep)
return list(out)
def _clean_css_value(prop: str, raw: str) -> str | None:
v = raw.strip()
v = v.replace("!important", "")
low = v.lower()
if any(bad in low for bad in ("url(", "expression(", "javascript:", "var(")):
return None
if prop in {"width","height","min-width","min-height","max-width","max-height",
"margin","margin-top","margin-right","margin-bottom","margin-left",
"padding","padding-top","padding-right","padding-bottom","padding-left",
"border-width","border-top","border-right","border-bottom","border-left","border-radius",
"line-height","font-size"}:
return v if P_LEN.match(v) else None
if prop in {"gap"}:
parts = [p.strip() for p in v.split()]
if 1 <= len(parts) <= 2 and all(P_GAP.match(p) for p in parts):
return " ".join(parts)
return None
if prop in {"color", "background-color", "border-color"}:
return v if P_COLOR.match(v) else None
if prop in _ENUMS:
return v if v.lower() in _ENUMS[prop] else None
if prop == "flex":
toks = v.split()
if len(toks) == 1 and (toks[0].isdigit() or toks[0] in {"auto", "none"}):
return v
if len(toks) == 2 and toks[0].isdigit() and (toks[1].isdigit() or toks[1] == "auto"):
return v
if len(toks) == 3 and toks[0].isdigit() and toks[1].isdigit() and (P_LEN.match(toks[2]) or toks[2] == "auto"):
return " ".join(toks)
return None
if prop == "border":
parts = v.split()
bw = next((p for p in parts if P_LEN.match(p)), None)
bs = next((p for p in parts if p in _ENUMS["border-style"]), None)
bc = next((p for p in parts if P_COLOR.match(p)), None)
chosen = [x for x in (bw, bs, bc) if x]
return " ".join(chosen) if chosen else None
return None
def _sanitize_style(style: str | None) -> str | None:
if not style or not isinstance(style, str):
return None
safe_decls = []
for chunk in style.split(";"):
if not chunk.strip():
continue
if ":" not in chunk:
continue
prop, val = chunk.split(":", 1)
prop = prop.strip().lower()
if prop not in _SAFE_CSS_PROPS:
continue
clean = _clean_css_value(prop, val)
if clean is not None and clean != "":
safe_decls.append(f"{prop}: {clean}")
return "; ".join(safe_decls) if safe_decls else None
def _is_column_attr(attr) -> bool:
try:
return isinstance(attr, InstrumentedAttribute) and isinstance(attr.property, ColumnProperty)
except Exception:
return False
def _is_relationship_attr(attr) -> bool:
try:
return isinstance(attr, InstrumentedAttribute) and isinstance(attr.property, RelationshipProperty)
except Exception:
return False
def _get_attr_deps(model_cls, attr_name: str, extra_deps: Optional[dict] = None) -> list[str]:
"""Merge model-level and per-field declared deps for a computed attr."""
model_deps = getattr(model_cls, "__crudkit_field_requires__", {}) or {}
field_deps = (extra_deps or {})
return list(model_deps.get(attr_name, [])) + list(field_deps.get(attr_name, []))
def _get_loaded_attr(obj: Any, name: str) -> Any:
"""
Return obj.<name> only if it is already loaded.
Never triggers a lazy load. Returns None if missing/unloaded.
Works for both column and relationship attributes.
"""
try:
st = inspect(obj)
# 1) Mapped attribute?
attr = st.attrs.get(name)
if attr is not None:
val = attr.loaded_value
return None if val is NO_VALUE else val
# 2) Already present value (e.g., eager-loaded or set on the dict)?
if name in st.dict:
return st.dict.get(name)
# 3) If object is detached or attr is not mapped, DO NOT eval hybrids
# or descriptors that could lazy-load. That would explode.
if st.session is None:
return None
# 4) As a last resort on attached instances only, try simple getattr,
# but guard against DetachedInstanceError anyway.
try:
return getattr(obj, name, None)
except Exception:
return None
except Exception:
# If we can't even inspect it, be conservative
try:
return getattr(obj, name, None)
except Exception:
return None
def _normalize_rows_layout(layout: Optional[List[dict]]) -> Dict[str, dict]: def _normalize_rows_layout(layout: Optional[List[dict]]) -> Dict[str, dict]:
""" """
@ -164,6 +418,11 @@ def _sanitize_attrs(attrs: Any) -> dict[str, Any]:
elif isinstance(v, str): elif isinstance(v, str):
if len(v) > 512: if len(v) > 512:
v = v[:512] v = v[:512]
if k == "style":
sv = _sanitize_style(v)
if sv:
out["style"] = sv
continue
if k.startswith("data-") or k.startswith("aria-") or k in _ALLOWED_ATTRS: if k.startswith("data-") or k.startswith("aria-") or k in _ALLOWED_ATTRS:
if isinstance(v, bool): if isinstance(v, bool):
if v: if v:
@ -173,6 +432,105 @@ def _sanitize_attrs(attrs: Any) -> dict[str, Any]:
return out return out
def _resolve_rel_obj(values: dict, instance, base: str):
rel = None
if isinstance(values, dict) and base in values:
rel = values[base]
if isinstance(rel, dict):
class _DictObj:
def __init__(self, d): self._d = d
def __getattr__(self, k): return self._d.get(k)
rel = _DictObj(rel)
if rel is None and instance is not None:
try:
st = inspect(instance)
ra = st.attrs.get(base)
if ra is not None and ra.loaded_value is not NO_VALUE:
rel = ra.loaded_value
except Exception:
pass
return rel
def _value_label_for_field(field: dict, mapper, values_map: dict, instance, session):
"""
If field targets a MANYTOONE (foo or foo_id), compute a human-readable label.
No lazy loads. Optional single-row lean fetch if we only have the id.
"""
base, rel_prop = _rel_for_id_name(mapper, field["name"])
if not rel_prop:
return None
rid = _coerce_fk_value(values_map, instance, base, rel_prop)
rel_obj = _resolve_rel_obj(values_map, instance, base)
label_spec = (
field.get("label_spec")
or getattr(rel_prop.mapper.class_, "__crud_label__", None)
or "id"
)
if rel_obj is not None and session is not None and rid is not None:
mdl = rel_prop.mapper.class_
# Work out exactly what the label needs (columns + rel paths),
# expanding model-level and per-field deps (for hybrids etc.)
simple_cols, rel_paths = _extract_label_requirements(
label_spec,
model_cls=mdl,
extra_deps=field.get("label_deps")
)
# If the currently-attached object doesn't have what we need, do one lean requery
if not _has_label_bits_loaded(rel_obj, label_spec):
q = session.query(mdl)
# only real columns in load_only
cols = []
id_attr = getattr(mdl, "id", None)
if _is_column_attr(id_attr):
cols.append(id_attr)
for c in simple_cols:
a = getattr(mdl, c, None)
if _is_column_attr(a):
cols.append(a)
if cols:
q = q.options(load_only(*cols))
# selectinload relationships; "__all__" means just eager the relationship object
for rel_name, col_name in rel_paths:
rel_ia = getattr(mdl, rel_name, None)
if rel_ia is None:
continue
opt = selectinload(rel_ia)
if col_name == "__all__":
q = q.options(opt)
else:
t_cls = mdl.__mapper__.relationships[rel_name].mapper.class_
t_attr = getattr(t_cls, col_name, None)
q = q.options(opt.load_only(t_attr) if _is_column_attr(t_attr) else opt)
rel_obj = q.get(rid)
if rel_obj is not None:
try:
s = _label_from_obj(rel_obj, label_spec)
except Exception:
s = None
# If we couldn't safely render and we have a session+id, do one lean retry.
if (s is None or s == "") and session is not None and rid is not None:
mdl = rel_prop.mapper.class_
try:
rel_obj2 = session.get(mdl, rid) # attached instance
s2 = _label_from_obj(rel_obj2, label_spec)
if s2:
return s2
except Exception:
pass
return s
return str(rid) if rid is not None else None
class _SafeObj: class _SafeObj:
"""Attribute access that returns '' for missing/None instead of exploding.""" """Attribute access that returns '' for missing/None instead of exploding."""
__slots__ = ("_obj",) __slots__ = ("_obj",)
@ -181,12 +539,10 @@ class _SafeObj:
def __getattr__(self, name): def __getattr__(self, name):
if self._obj is None: if self._obj is None:
return "" return ""
val = getattr(self._obj, name, None) val = _get_loaded_attr(self._obj, name)
if val is None: return "" if val is None else _SafeObj(val)
return ""
return _SafeObj(val)
def _coerce_fk_value(values: dict | None, instance: Any, base: str): def _coerce_fk_value(values: dict | None, instance: Any, base: str, rel_prop: Optional[RelationshipProperty] = None):
""" """
Resolve current selection for relationship 'base': Resolve current selection for relationship 'base':
1) values['<base>_id'] 1) values['<base>_id']
@ -233,6 +589,25 @@ def _coerce_fk_value(values: dict | None, instance: Any, base: str):
except Exception: except Exception:
pass pass
# Fallback: if we know the relationship, try its local FK column names
if rel_prop is not None:
try:
st = inspect(instance) if instance is not None else None
except Exception:
st = None
# Try values[...] first
for col in getattr(rel_prop, "local_columns", []) or []:
key = getattr(col, "key", None) or getattr(col, "name", None)
if not key:
continue
if isinstance(values, dict) and key in values and values[key] not in (None, ""):
return values[key]
if set is not None:
attr = st.attrs.get(key) if hasattr(st, "attrs") else None
if attr is not None and attr.loaded_value is not NO_VALUE:
return attr.loaded_value
return None return None
def _is_many_to_one(mapper, name: str) -> Optional[RelationshipProperty]: def _is_many_to_one(mapper, name: str) -> Optional[RelationshipProperty]:
@ -254,43 +629,42 @@ def _rel_for_id_name(mapper, name: str) -> tuple[Optional[str], Optional[Relatio
return (name, prop) if prop else (None, None) return (name, prop) if prop else (None, None)
def _fk_options(session, related_model, label_spec): def _fk_options(session, related_model, label_spec):
simple_cols, rel_paths = _extract_label_requirements(label_spec) simple_cols, rel_paths = _extract_label_requirements(label_spec, related_model)
q = session.query(related_model) q = session.query(related_model)
col_attrs = [] col_attrs = []
if hasattr(related_model, "id"): if hasattr(related_model, "id"):
col_attrs.append(getattr(related_model, "id")) id_attr = getattr(related_model, "id")
if _is_column_attr(id_attr):
col_attrs.append(id_attr)
for name in simple_cols: for name in simple_cols:
if hasattr(related_model, name): attr = getattr(related_model, name, None)
col_attrs.append(getattr(related_model, name)) if _is_column_attr(attr):
col_attrs.append(attr)
if col_attrs: if col_attrs:
q = q.options(load_only(*col_attrs)) q = q.options(load_only(*col_attrs))
for rel_name, col_name in rel_paths: for rel_name, col_name in rel_paths:
rel_prop = getattr(related_model, rel_name, None) rel_attr = getattr(related_model, rel_name, None)
if rel_prop is None: if rel_attr is None:
continue continue
try: opt = selectinload(rel_attr)
if col_name == "__all__":
q = q.options(opt)
else:
target_cls = related_model.__mapper__.relationships[rel_name].mapper.class_ target_cls = related_model.__mapper__.relationships[rel_name].mapper.class_
col_attr = getattr(target_cls, col_name, None) col_attr = getattr(target_cls, col_name, None)
if col_attr is None: q = q.options(opt.load_only(col_attr) if _is_column_attr(col_attr) else opt)
q = q.options(selectinload(rel_prop))
else:
q = q.options(selectinload(rel_prop).load_only(col_attr))
except Exception:
q = q.options(selectinload(rel_prop))
if simple_cols: if simple_cols:
first = simple_cols[0] first = simple_cols[0]
if hasattr(related_model, first): if hasattr(related_model, first):
q = q.order_by(getattr(related_model, first)) q = q.order_by(None).order_by(getattr(related_model, first))
rows = q.all() rows = q.all()
return [ return [
{ {'value': getattr(opt, 'id'), 'label': _label_from_obj(opt, label_spec)}
'value': getattr(opt, 'id'),
'label': _label_from_obj(opt, label_spec),
}
for opt in rows for opt in rows
] ]
@ -314,12 +688,20 @@ def _normalize_field_spec(spec, mapper, session, label_specs_model_default):
"template": spec.get("template"), "template": spec.get("template"),
"template_name": spec.get("template_name"), "template_name": spec.get("template_name"),
"template_ctx": spec.get("template_ctx"), "template_ctx": spec.get("template_ctx"),
"label_spec": spec.get("label_spec"),
} }
if "link" in spec:
field["link"] = spec["link"]
if "label_deps" in spec:
field["label_deps"] = spec["label_deps"]
opts_params = spec.get("options_params") or spec.get("options_filter") or spec.get("options_where")
if rel_prop: if rel_prop:
if field["type"] is None: if field["type"] is None:
field["type"] = "select" field["type"] = "select"
if field["type"] == "select" and field.get("options") is None and session is not None: if field["type"] == "select" and field.get("options") is None:
related_model = rel_prop.mapper.class_ related_model = rel_prop.mapper.class_
label_spec = ( label_spec = (
spec.get("label_spec") spec.get("label_spec")
@ -327,7 +709,11 @@ def _normalize_field_spec(spec, mapper, session, label_specs_model_default):
or getattr(related_model, "__crud_label__", None) or getattr(related_model, "__crud_label__", None)
or "id" or "id"
) )
field["options"] = _fk_options(session, related_model, label_spec) field["options"] = _fk_options_via_service(
related_model,
label_spec,
options_params=opts_params
)
return field return field
col = mapper.columns.get(name) col = mapper.columns.get(name)
@ -347,81 +733,86 @@ def _normalize_field_spec(spec, mapper, session, label_specs_model_default):
return field return field
def _extract_label_requirements(spec: Any) -> tuple[list[str], list[tuple[str, str]]]: def _extract_label_requirements(
spec: Any,
model_cls: Any = None,
extra_deps: Optional[Dict[str, List[str]]] = None
) -> tuple[list[str], list[tuple[str, str]]]:
""" """
From a label spec, return: Returns:
- simple_cols: ["name", "code"] simple_cols: ["name", "code", "label", ...] (non-dotted names; may include non-columns)
- rel_paths: [("room_function", "description"), ("owner", "last_name")] rel_paths: [("room_function", "description"), ("brand", "__all__"), ...]
- ("rel", "__all__") means: just eager the relationship (no specific column)
Also expands dependencies declared by the model or the field (extra_deps).
""" """
simple_cols: list[str] = [] simple_cols: list[str] = []
rel_paths: list[tuple[str, str]] = [] rel_paths: list[tuple[str, str]] = []
seen: set[str] = set()
def ingest(token: str) -> None: def add_dep_token(token: str) -> None:
token = str(token).strip() """Add a concrete dependency token (column or 'rel' or 'rel.col')."""
if not token: if not token or token in seen:
return return
seen.add(token)
if "." in token: if "." in token:
rel, col = token.split(".", 1) rel, col = token.split(".", 1)
if rel and col: if rel and col:
rel_paths.append((rel, col)) rel_paths.append((rel, col))
return
# bare token: could be column, relationship, or computed
simple_cols.append(token)
# If this is not obviously a column, try pulling declared deps.
if model_cls is not None:
attr = getattr(model_cls, token, None)
if _is_column_attr(attr):
return
# If it's a relationship, we want to eager the relationship itself.
if _is_relationship_attr(attr):
rel_paths.append((token, "__all__"))
return
# Not a column/relationship => computed (hybrid/descriptor/etc.)
for dep in _get_attr_deps(model_cls, token, extra_deps):
add_dep_token(dep)
def add_from_spec(piece: Any) -> None:
if piece is None or callable(piece):
return
if isinstance(piece, (list, tuple)):
for a in piece:
add_from_spec(a)
return
s = str(piece)
if "{" in s and "}" in s:
for n in re.findall(r"{\s*([^}:\s]+)", s):
add_dep_token(n)
else: else:
simple_cols.append(token) add_dep_token(s)
if spec is None or callable(spec):
return simple_cols, rel_paths
if isinstance(spec, (list, tuple)):
for a in spec:
ingest(a)
return simple_cols, rel_paths
if isinstance(spec, str):
# format string like "{first} {last}" or "{room_function.description} · {name}"
if "{" in spec and "}" in spec:
names = re.findall(r"{\s*([^}:\s]+)", spec)
for n in names:
ingest(n)
else:
ingest(spec)
return simple_cols, rel_paths
add_from_spec(spec)
return simple_cols, rel_paths return simple_cols, rel_paths
def _attrs_from_label_spec(spec: Any) -> list[str]:
"""
Return a list of attribute names needed from the related model to compute the label.
Only simple attribute names are returned; dotted paths return just the first segment.
"""
if spec is None:
return []
if callable(spec):
return []
if isinstance(spec, (list, tuple)):
return [str(a).split(".", 1)[0] for a in spec]
if isinstance(spec, str):
if "{" in spec and "}" in spec:
names = re.findall(r"{\s*([^}:\s]+)", spec)
return [n.split(".", 1)[0] for n in names]
return [spec.split(".", 1)[0]]
return []
def _label_from_obj(obj: Any, spec: Any) -> str: def _label_from_obj(obj: Any, spec: Any) -> str:
if obj is None:
return ""
if spec is None: if spec is None:
for attr in ("label", "name", "title", "description"): for attr in ("label", "name", "title", "description"):
if hasattr(obj, attr): val = _get_loaded_attr(obj, attr)
val = getattr(obj, attr) if val is not None:
if not callable(val) and val is not None: return str(val)
return str(val) vid = _get_loaded_attr(obj, "id")
if hasattr(obj, "id"): return str(vid) if vid is not None else object.__repr__(obj)
return str(getattr(obj, "id"))
return object.__repr__(obj)
if isinstance(spec, (list, tuple)): if isinstance(spec, (list, tuple)):
parts = [] parts = []
for a in spec: for a in spec:
cur = obj cur = obj
for part in str(a).split("."): for part in str(a).split("."):
cur = getattr(cur, part, None) cur = _get_loaded_attr(cur, part) if cur is not None else None
if cur is None: if cur is None:
break break
parts.append("" if cur is None else str(cur)) parts.append("" if cur is None else str(cur))
@ -433,9 +824,10 @@ def _label_from_obj(obj: Any, spec: Any) -> str:
for f in fields: for f in fields:
root = f.split(".", 1)[0] root = f.split(".", 1)[0]
if root not in data: if root not in data:
val = getattr(obj, root, None) try:
data[root] = _SafeObj(val) data[root] = _SafeObj(_get_loaded_attr(obj, root))
except Exception:
data[root] = _SafeObj(None)
try: try:
return spec.format(**data) return spec.format(**data)
except Exception: except Exception:
@ -443,7 +835,7 @@ def _label_from_obj(obj: Any, spec: Any) -> str:
cur = obj cur = obj
for part in str(spec).split("."): for part in str(spec).split("."):
cur = getattr(cur, part, None) cur = _get_loaded_attr(cur, part) if cur is not None else None
if cur is None: if cur is None:
return "" return ""
return str(cur) return str(cur)
@ -560,6 +952,9 @@ def _format_value(val: Any, fmt: Optional[str]) -> Any:
if fmt is None: if fmt is None:
return val return val
try: try:
if callable(fmt):
return fmt(val)
if fmt == "yesno": if fmt == "yesno":
return "Yes" if bool(val) else "No" return "Yes" if bool(val) else "No"
if fmt == "date": if fmt == "date":
@ -572,12 +967,68 @@ def _format_value(val: Any, fmt: Optional[str]) -> Any:
return val return val
return val return val
def _has_label_bits_loaded(obj, label_spec) -> bool:
try:
st = inspect(obj)
except Exception:
return True
simple_cols, rel_paths = _extract_label_requirements(label_spec, type(obj))
# concrete columns on the object
for name in simple_cols:
a = getattr(type(obj), name, None)
if _is_column_attr(a) and name not in st.dict:
return False
# non-column tokens (hybrids/descriptors) are satisfied by their deps above
# relationships
for rel_name, col_name in rel_paths:
ra = st.attrs.get(rel_name)
if ra is None or ra.loaded_value in (NO_VALUE, None):
return False
if col_name == "__all__":
continue # relationship object present is enough
try:
t_st = inspect(ra.loaded_value)
except Exception:
return False
t_attr = getattr(type(ra.loaded_value), col_name, None)
if _is_column_attr(t_attr) and col_name not in t_st.dict:
return False
return True
def _class_for(val: Any, classes: Optional[Dict[str, str]]) -> Optional[str]: def _class_for(val: Any, classes: Optional[Dict[str, str]]) -> Optional[str]:
if not classes: if not classes:
return None return None
key = "none" if val is None else str(val).lower() key = "none" if val is None else str(val).lower()
return classes.get(key, classes.get("default")) return classes.get(key, classes.get("default"))
def _format_label_from_values(spec: Any, values: dict) -> Optional[str]:
if not spec:
return None
if isinstance(spec, (list, tuple)):
parts = []
for a in spec:
v = _deep_get(values, str(a))
parts.append("" if v is None else str(v))
return " ".join(p for p in parts if p)
s = str(spec)
if "{" in s and "}" in s:
names = re.findall(r"{\s*([^}:\s]+)", s)
data = {n: _deep_get(values, n) for n in names}
# wrap for safe .format()
data = {k: ("" if v is None else v) for k, v in data.items()}
try:
return s.format(**data)
except Exception:
return None
# simple field name
v = _deep_get(values, s)
return "" if v is None else str(v)
def _build_href(spec: Dict[str, Any], row: Dict[str, Any], obj) -> Optional[str]: def _build_href(spec: Dict[str, Any], row: Dict[str, Any], obj) -> Optional[str]:
if not spec: if not spec:
return None return None
@ -594,8 +1045,9 @@ def _build_href(spec: Dict[str, Any], row: Dict[str, Any], obj) -> Optional[str]
if any(v is None for v in params.values()): if any(v is None for v in params.values()):
return None return None
try: try:
return url_for('crudkit.' + spec["endpoint"], **params) return url_for(spec["endpoint"], **params)
except Exception as e: except Exception as e:
print(f"Cannot create endpoint for {spec['endpoint']}: {str(e)}")
return None return None
def _humanize(field: str) -> str: def _humanize(field: str) -> str:
@ -668,6 +1120,8 @@ def render_field(field, value):
attrs=_sanitize_attrs(field.get('attrs') or {}), attrs=_sanitize_attrs(field.get('attrs') or {}),
label_attrs=_sanitize_attrs(field.get('label_attrs') or {}), label_attrs=_sanitize_attrs(field.get('label_attrs') or {}),
help=field.get('help'), help=field.get('help'),
value_label=field.get('value_label'),
link_href=field.get("link_href"),
) )
@ -762,7 +1216,7 @@ def render_form(
base = name[:-3] base = name[:-3]
rel_prop = mapper.relationships.get(base) rel_prop = mapper.relationships.get(base)
if isinstance(rel_prop, RelationshipProperty) and rel_prop.direction.name == "MANYTOONE": if isinstance(rel_prop, RelationshipProperty) and rel_prop.direction.name == "MANYTOONE":
values_map[name] = _coerce_fk_value(values, instance, base) values_map[name] = _coerce_fk_value(values, instance, base, rel_prop) # add rel_prop
else: else:
# Auto-generate path (your original behavior) # Auto-generate path (your original behavior)
@ -795,7 +1249,7 @@ def render_form(
fk_fields.add(f"{base}_id") fk_fields.add(f"{base}_id")
# NEW: set the current selection for this dropdown # NEW: set the current selection for this dropdown
values_map[f"{base}_id"] = _coerce_fk_value(values, instance, base) values_map[f"{base}_id"] = _coerce_fk_value(values, instance, base, prop)
# Then plain columns # Then plain columns
for col in model_cls.__table__.columns: for col in model_cls.__table__.columns:
@ -815,15 +1269,37 @@ def render_form(
field["wrap"] = _sanitize_attrs(field["wrap"]) field["wrap"] = _sanitize_attrs(field["wrap"])
fields.append(field) fields.append(field)
if submit_attrs: if submit_attrs:
submit_attrs = _sanitize_attrs(submit_attrs) submit_attrs = _sanitize_attrs(submit_attrs)
common_ctx = {"values": values_map, "instance": instance, "model_cls": model_cls, "session": session} common_ctx = {"values": values_map, "instance": instance, "model_cls": model_cls, "session": session}
for f in fields: for f in fields:
if f.get("type") == "template": if f.get("type") == "template":
base = dict(common_ctx) base = dict(common_ctx)
base.update(f.get("template_ctx") or {}) base.update(f.get("template_ctx") or {})
f["template_ctx"] = base f["template_ctx"] = base
for f in fields:
# existing FK label resolution
vl = _value_label_for_field(f, mapper, values_map, instance, session)
if vl is not None:
f["value_label"] = vl
# NEW: if not a relationship but a label_spec is provided, format from values
elif f.get("label_spec"):
base, rel_prop = _rel_for_id_name(mapper, f["name"])
if not rel_prop: # scalar field
vl2 = _format_label_from_values(f["label_spec"], values_map)
if vl2 is not None:
f["value_label"] = vl2
link_spec = f.get("link")
if link_spec:
try:
href = _build_href(link_spec, values_map, instance)
except Exception:
href = None
if href:
f["link_href"] = href
# Build rows (supports nested layout with parents) # Build rows (supports nested layout with parents)
rows_map = _normalize_rows_layout(layout) rows_map = _normalize_rows_layout(layout)
@ -835,5 +1311,6 @@ def render_form(
values=values_map, values=values_map,
render_field=render_field, render_field=render_field,
submit_attrs=submit_attrs, submit_attrs=submit_attrs,
submit_label=submit_label submit_label=submit_label,
model_name=model_cls.__name__
) )

View file

@ -4,7 +4,13 @@
{% if label_attrs %}{% for k,v in label_attrs.items() %} {% if label_attrs %}{% for k,v in label_attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %} {{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}> {% endfor %}{% endif %}>
{{ field_label }} {% if link_href %}
<a href="{{ link_href }}">
{% endif %}
{{ field_label }}
{% if link_href %}
</a>
{% endif %}
</label> </label>
{% endif %} {% endif %}
@ -30,7 +36,7 @@
<textarea name="{{ field_name }}" id="{{ field_name }}" <textarea name="{{ field_name }}" id="{{ field_name }}"
{% if attrs %}{% for k,v in attrs.items() %} {% if attrs %}{% for k,v in attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %} {{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}>{{ value }}</textarea> {% endfor %}{% endif %}>{{ value if value else "" }}</textarea>
{% elif field_type == 'checkbox' %} {% elif field_type == 'checkbox' %}
<input type="checkbox" name="{{ field_name }}" id="{{ field_name }}" value="1" <input type="checkbox" name="{{ field_name }}" id="{{ field_name }}" value="1"
@ -40,15 +46,33 @@
{% endfor %}{% endif %}> {% endfor %}{% endif %}>
{% elif field_type == 'hidden' %} {% elif field_type == 'hidden' %}
<input type="hidden" name="{{ field_name }}" id="{{ field_name }}" value="{{ value }}"> <input type="hidden" name="{{ field_name }}" id="{{ field_name }}" value="{{ value if value else "" }}">
{% elif field_type == 'display' %} {% elif field_type == 'display' %}
<div {% if attrs %}{% for k,v in attrs.items() %} <div {% if attrs %}{% for k,v in attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %} {{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}>{{ value }}</div> {% endfor %}{% endif %}>{{ value_label if value_label else (value if value else "") }}</div>
{% elif field_type == "date" %}
<input type="date" name="{{ field_name }}" id="{{ field_name }}" value="{{ value if value else "" }}"
{% if attrs %}{% for k,v in attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}>
{% elif field_type == "time" %}
<input type="time" name="{{ field_name }}" id="{{ field_name }}" value="{{ value if value else "" }}"
{% if attrs %}{% for k,v in attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}>
{% elif field_type == "datetime" %}
<input type="datetime-local" name="{{ field_name }}" id="{{ field_name }}" value="{{ value if value else "" }}"
{% if attrs %}{% for k,v in attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}>
{% else %} {% else %}
<input type="text" name="{{ field_name }}" id="{{ field_name }}" value="{{ value }}" <input type="text" name="{{ field_name }}" id="{{ field_name }}" value="{{ value if value else "" }}"
{% if attrs %}{% for k,v in attrs.items() %} {% if attrs %}{% for k,v in attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %} {{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}> {% endfor %}{% endif %}>

View file

@ -1,6 +1,5 @@
<form method="POST"> <form method="POST" id="{{ model_name|lower }}_form">
{% macro render_row(row) %} {% macro render_row(row) %}
<!-- {{ row.name }} -->
{% if row.fields or row.children or row.legend %} {% if row.fields or row.children or row.legend %}
{% if row.legend %}<legend>{{ row.legend }}</legend>{% endif %} {% if row.legend %}<legend>{{ row.legend }}</legend>{% endif %}
<fieldset <fieldset
@ -36,5 +35,5 @@
{% if submit_attrs %}{% for k,v in submit_attrs.items() %} {% if submit_attrs %}{% for k,v in submit_attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %} {{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %} {% endfor %}{% endif %}
>{{ submit_label if label else 'Save' }}</button> >{{ submit_label if submit_label else 'Save' }}</button>
</form> </form>