diff --git a/crudkit/api/_cursor.py b/crudkit/api/_cursor.py index 3b63cfd..a2a80c2 100644 --- a/crudkit/api/_cursor.py +++ b/crudkit/api/_cursor.py @@ -1,21 +1,135 @@ -import base64, json -from typing import Any +# crudkit/api/_cursor.py -def encode_cursor(values: list[Any] | None, desc_flags: list[bool], backward: bool) -> str | None: +from __future__ import annotations + +import base64 +import dataclasses +import datetime as _dt +import decimal as _dec +import hmac +import json +import typing as _t +from hashlib import sha256 + +Any = _t.Any + +@dataclasses.dataclass(frozen=True) +class Cursor: + values: list[Any] + desc_flags: list[bool] + backward: bool + version: int = 1 + + +def _json_default(o: Any) -> Any: + # Keep it boring and predictable. + if isinstance(o, (_dt.datetime, _dt.date, _dt.time)): + # ISO is good enough; assume UTC-aware datetimes are already normalized upstream. + return o.isoformat() + if isinstance(o, _dec.Decimal): + return str(o) + # UUIDs, Enums, etc. + if hasattr(o, "__str__"): + return str(o) + raise TypeError(f"Unsupported cursor value type: {type(o)!r}") + + +def _b64url_nopad_encode(b: bytes) -> str: + return base64.urlsafe_b64encode(b).rstrip(b"=").decode("ascii") + + +def _b64url_nopad_decode(s: str) -> bytes: + # Restore padding for strict decode + pad = "=" * (-len(s) % 4) + return base64.urlsafe_b64decode((s + pad).encode("ascii")) + + +def encode_cursor( + values: list[Any] | None, + desc_flags: list[bool], + backward: bool, + *, + secret: bytes | None = None, +) -> str | None: + """ + Create an opaque, optionally signed cursor token. + + - values: keyset values for the last/first row + - desc_flags: per-key descending flags (same length/order as sort keys) + - backward: whether the window direction is backward + - secret: if set, an HMAC-SHA256 signature is appended (JWT-ish 'payload.sig') + """ if not values: return None - payload = {"v": values, "d": desc_flags, "b": backward} - return base64.urlsafe_b64encode(json.dumps(payload).encode()).decode() -def decode_cursor(token: str | None) -> tuple[list[Any] | None, bool] | tuple[None, bool]: + payload = { + "ver": 1, + "v": values, + "d": desc_flags, + "b": bool(backward), + } + + body = json.dumps(payload, default=_json_default, separators=(",", ":"), sort_keys=True).encode("utf-8") + token = _b64url_nopad_encode(body) + + if secret: + sig = hmac.new(secret, body, sha256).digest() + token = f"{token}.{_b64url_nopad_encode(sig)}" + + return token + + +def decode_cursor( + token: str | None, + *, + secret: bytes | None = None, +) -> tuple[list[Any] | None, list[bool] | None, bool]: + """ + Parse a cursor token. Returns (values, desc_flags, backward). + + - Accepts legacy tokens lacking 'd' and returns desc_flags=None in that case. + - If secret is provided, verifies HMAC and rejects tampered tokens. + - On any parse failure, returns (None, None, False). + """ if not token: - return None, False + return None, None, False + try: - obj = json.loads(base64.urlsafe_b64decode(token.encode()).decode()) + # Split payload.sig if signed + if "." in token: + body_b64, sig_b64 = token.split(".", 1) + body = _b64url_nopad_decode(body_b64) + if secret is None: + # Caller didn’t ask for verification; still parse but don’t trust. + pass + else: + expected = hmac.new(secret, body, sha256).digest() + actual = _b64url_nopad_decode(sig_b64) + if not hmac.compare_digest(expected, actual): + return None, None, False + else: + body = _b64url_nopad_decode(token) + + obj = json.loads(body.decode("utf-8")) + + # Versioning. If we ever change fields, branch here. + ver = int(obj.get("ver", 0)) + if ver not in (0, 1): + return None, None, False + vals = obj.get("v") backward = bool(obj.get("b", False)) - if isinstance(vals, list): - return vals, backward + + # desc_flags may be absent in legacy payloads (ver 0) + desc = obj.get("d") + if not isinstance(vals, list): + return None, None, False + if desc is not None and not (isinstance(desc, list) and all(isinstance(x, bool) for x in desc)): + # Ignore weird 'd' types rather than crashing + desc = None + + return vals, desc, backward + except Exception: - pass - return None, False + # Be tolerant on decode: treat as no-cursor. + return None, None, False diff --git a/crudkit/api/flask_api.py b/crudkit/api/flask_api.py index bf505b2..39e7c49 100644 --- a/crudkit/api/flask_api.py +++ b/crudkit/api/flask_api.py @@ -1,90 +1,195 @@ -from flask import Blueprint, jsonify, request +# crudkit/api/flask_api.py + +from __future__ import annotations + +from flask import Blueprint, jsonify, request, abort, current_app, url_for +from hashlib import md5 +from urllib.parse import urlencode +from werkzeug.exceptions import HTTPException -from crudkit.api._cursor import encode_cursor, decode_cursor from crudkit.core.service import _is_truthy -def generate_crud_blueprint(model, service): - bp = Blueprint(model.__name__.lower(), __name__) +MAX_JSON = 1_000_000 - @bp.get('/') - def list_items(): - args = request.args.to_dict(flat=True) +def _etag_for(obj) -> str: + v = getattr(obj, "updated_at", None) or obj.id + return md5(str(v).encode()).hexdigest() - # legacy detection - legacy_offset = "offset" in args or "page" in args +def _json_payload() -> dict: + if request.content_length and request.content_length > MAX_JSON: + abort(413) + if not request.is_json: + abort(415) + payload = request.get_json(silent=False) + if not isinstance(payload, dict): + abort(400) + return payload - # sane limit default - try: - limit = int(args.get("limit", 50)) - except Exception: - limit = 50 - args["limit"] = limit +def _args_flat() -> dict[str, str]: + return request.args.to_dict(flat=True) # type: ignore[arg-type] - if legacy_offset: - # Old behavior: honor limit/offset, same CRUDSpec goodies - items = service.list(args) - return jsonify([obj.as_dict() for obj in items]) +def _json_error(e: Exception, status: int = 400): + if isinstance(e, HTTPException): + status = e.code or status + msg = e.description + else: + msg = str(e) + if current_app.debug: + return jsonify({"status": "error", "error": msg, "type": e.__class__.__name__}), status + return jsonify({"status": "error", "error": msg}), status - # New behavior: keyset seek with cursors - key, backward = decode_cursor(args.get("cursor")) +def _bool_param(d: dict[str, str], key: str, default: bool) -> bool: + return _is_truthy(d.get(key, "1" if default else "0")) - window = service.seek_window( - args, - key=key, - backward=backward, - include_total=_is_truthy(args.get("include_total", "1")), - ) +def generate_crud_blueprint(model, service, *, base_prefix: str | None = None, rest: bool = True, rpc: bool = True): + """ + REST: + GET /api// -> list (filters via ?q=..., sort=..., limit=..., cursor=...) + GET /api// -> get + POST /api// -> create + PATCH /api// -> update (partial) + DELETE /api//[?hard=1] -> delete - desc_flags = list(window.order.desc) - body = { - "items": [obj.as_dict() for obj in window.items], - "limit": window.limit, - "next_cursor": encode_cursor(window.last_key, desc_flags, backward=False), - "prev_cursor": encode_cursor(window.first_key, desc_flags, backward=True), - "total": window.total, - } + RPC (legacy): + GET /api//get?id=123 + GET /api//list + GET /api//seek_window + GET /api//page + POST /api//create + PATCH /api//update?id=123 + DELETE /api//delete?id=123[&hard=1] + """ + model_name = model.__name__.lower() + # bikeshed if you want pluralization; this is the least-annoying default + collection = (base_prefix or model_name).lower() + plural = collection if collection.endswith('s') else f"{collection}s" - resp = jsonify(body) - # Optional Link header - links = [] - if body["next_cursor"]: - links.append(f'<{request.base_url}?cursor={body["next_cursor"]}&limit={window.limit}>; rel="next"') - if body["prev_cursor"]: - links.append(f'<{request.base_url}?cursor={body["prev_cursor"]}&limit={window.limit}>; rel="prev"') - if links: - resp.headers["Link"] = ", ".join(links) - return resp + bp = Blueprint(plural, __name__, url_prefix=f"/api/{plural}") - @bp.get('/') - def get_item(id): - item = service.get(id, request.args) - try: - return jsonify(item.as_dict()) - except Exception as e: - return jsonify({"status": "error", "error": str(e)}) + @bp.errorhandler(Exception) + def _handle_any(e: Exception): + return _json_error(e) - @bp.post('/') - def create_item(): - obj = service.create(request.json) - try: - return jsonify(obj.as_dict()) - except Exception as e: - return jsonify({"status": "error", "error": str(e)}) + @bp.errorhandler(404) + def _not_found(_e): + return jsonify({"status": "error", "error": "not found"}), 404 - @bp.patch('/') - def update_item(id): - obj = service.update(id, request.json) - try: - return jsonify(obj.as_dict()) - except Exception as e: - return jsonify({"status": "error", "error": str(e)}) + # ---------- REST ---------- + if rest: + @bp.get("/") + def rest_list(): + args = _args_flat() + # support cursor pagination transparently; fall back to limit/offset + try: + items = service.list(args) + return jsonify([o.as_dict() for o in items]) + except Exception as e: + return _json_error(e) - @bp.delete('/') - def delete_item(id): - service.delete(id) - try: - return jsonify({"status": "success"}), 204 - except Exception as e: - return jsonify({"status": "error", "error": str(e)}) + @bp.get("/") + def rest_get(obj_id: int): + item = service.get(obj_id, request.args) + if item is None: + abort(404) + etag = _etag_for(item) + if request.if_none_match and (etag in request.if_none_match): + return "", 304 + resp = jsonify(item.as_dict()) + resp.set_etag(etag) + return resp + + @bp.post("/") + def rest_create(): + payload = _json_payload() + try: + obj = service.create(payload) + resp = jsonify(obj.as_dict()) + resp.status_code = 201 + resp.headers["Location"] = url_for(f"{plural}.rest_get", obj_id=obj.id, _external=False) + return resp + except Exception as e: + return _json_error(e) + + @bp.patch("/") + def rest_update(obj_id: int): + payload = _json_payload() + try: + obj = service.update(obj_id, payload) + return jsonify(obj.as_dict()) + except Exception as e: + return _json_error(e) + + @bp.delete("/") + def rest_delete(obj_id: int): + hard = _bool_param(_args_flat(), "hard", False) # type: ignore[arg-type] + try: + obj = service.delete(obj_id, hard=hard) + if obj is None: + abort(404) + return ("", 204) + except Exception as e: + return _json_error(e) + + # ---------- RPC (your existing routes) ---------- + if rpc: + @bp.get("/get") + def rpc_get(): + print("⚠️ WARNING: Deprecated RPC call used: /get") + id_ = int(request.args.get("id", 0)) + if not id_: + return jsonify({"status": "error", "error": "missing required param: id"}), 400 + try: + item = service.get(id_, request.args) + if item is None: + abort(404) + return jsonify(item.as_dict()) + except Exception as e: + return _json_error(e) + + @bp.get("/list") + def rpc_list(): + print("⚠️ WARNING: Deprecated RPC call used: /list") + args = _args_flat() + try: + items = service.list(args) + return jsonify([obj.as_dict() for obj in items]) + except Exception as e: + return _json_error(e) + + @bp.post("/create") + def rpc_create(): + print("⚠️ WARNING: Deprecated RPC call used: /create") + payload = _json_payload() + try: + obj = service.create(payload) + return jsonify(obj.as_dict()), 201 + except Exception as e: + return _json_error(e) + + @bp.patch("/update") + def rpc_update(): + print("⚠️ WARNING: Deprecated RPC call used: /update") + id_ = int(request.args.get("id", 0)) + if not id_: + return jsonify({"status": "error", "error": "missing required param: id"}), 400 + payload = _json_payload() + try: + obj = service.update(id_, payload) + return jsonify(obj.as_dict()) + except Exception as e: + return _json_error(e) + + @bp.delete("/delete") + def rpc_delete(): + print("⚠️ WARNING: Deprecated RPC call used: /delete") + id_ = int(request.args.get("id", 0)) + if not id_: + return jsonify({"status": "error", "error": "missing required param: id"}), 400 + hard = _bool_param(_args_flat(), "hard", False) # type: ignore[arg-type] + try: + obj = service.delete(id_, hard=hard) + return ("", 204) if obj is not None else abort(404) + except Exception as e: + return _json_error(e) return bp diff --git a/crudkit/backend.py b/crudkit/backend.py index 3232b71..2da68f8 100644 --- a/crudkit/backend.py +++ b/crudkit/backend.py @@ -74,18 +74,32 @@ def apply_pagination(sel: Select, backend: BackendInfo, *, page: int, per_page: per_page = max(1, int(per_page)) offset = (page - 1) * per_page - if backend.requires_order_by_for_offset and not sel._order_by_clauses: - if default_order_by is None: - sel = sel.order_by(text("1")) - else: - sel = sel.order_by(default_order_by) + if backend.requires_order_by_for_offset: + # Avoid private attribute if possible: + has_order = bool(getattr(sel, "_order_by_clauses", ())) # fallback for SA < 2.0.30 + try: + has_order = has_order or bool(sel.get_order_by()) + except Exception: + pass + + if not has_order: + if default_order_by is not None: + sel = sel.order_by(default_order_by) + else: + # Try to find a primary key from the FROMs; fall back to a harmless literal. + try: + first_from = sel.get_final_froms()[0] + pk = next(iter(first_from.primary_key.columns)) + sel = sel.order_by(pk) + except Exception: + sel = sel.order_by(text("1")) return sel.limit(per_page).offset(offset) @contextmanager def maybe_identify_insert(session: Session, table, backend: BackendInfo): """ - For MSSQL tables with IDENTIFY PK when you need to insert explicit IDs. + For MSSQL tables with IDENTITY PK when you need to insert explicit IDs. No-op elsewhere. """ if not backend.is_mssql: @@ -93,7 +107,7 @@ def maybe_identify_insert(session: Session, table, backend: BackendInfo): return full_name = f"{table.schema}.{table.name}" if table.schema else table.name - session.execute(text(f"SET IDENTIFY_INSERT {full_name} ON")) + session.execute(text(f"SET IDENTITY_INSERT {full_name} ON")) try: yield finally: @@ -101,7 +115,7 @@ def maybe_identify_insert(session: Session, table, backend: BackendInfo): def chunked_in(column, values: Iterable, backend: BackendInfo, chunk_size: Optional[int] = None) -> ClauseElement: """ - Build a safe large IN() filter respecting bund param limits. + Build a safe large IN() filter respecting bind param limits. Returns a disjunction of chunked IN clauses if needed. """ vals = list(values) @@ -120,3 +134,12 @@ def chunked_in(column, values: Iterable, backend: BackendInfo, chunk_size: Optio for p in parts[1:]: expr = expr | p return expr + +def sql_trim(expr, backend: BackendInfo): + """ + Portable TRIM. SQL Server before compat level 140 lacks TRIM(). + Emit LTRIM(RTRIM(...)) there; use TRIM elsewhere + """ + if backend.is_mssql: + return func.ltrim(func.rtrim(expr)) + return func.trim(expr) diff --git a/crudkit/config.py b/crudkit/config.py index 0439a3e..fb87b51 100644 --- a/crudkit/config.py +++ b/crudkit/config.py @@ -187,6 +187,8 @@ class Config: "synchronous": os.getenv("SQLITE_SYNCHRONOUS", "NORMAL"), } + STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "1"))) + @classmethod def engine_kwargs(cls) -> Dict[str, Any]: url = cls.DATABASE_URL @@ -221,15 +223,18 @@ class Config: class DevConfig(Config): DEBUG = True SQLALCHEMY_ECHO = bool(int(os.getenv("DB_ECHO", "1"))) + STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "1"))) class TestConfig(Config): TESTING = True DATABASE_URL = build_database_url(backend="sqlite", database=":memory:") SQLALCHEMY_ECHO = False + STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "1"))) class ProdConfig(Config): DEBUG = False SQLALCHEMY_ECHO = bool(int(os.getenv("DB_ECHO", "0"))) + STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "0"))) def get_config(name: str | None) -> Type[Config]: """ diff --git a/crudkit/core/__init__.py b/crudkit/core/__init__.py index e69de29..86d90b7 100644 --- a/crudkit/core/__init__.py +++ b/crudkit/core/__init__.py @@ -0,0 +1,9 @@ +# crudkit/core/__init__.py +from .utils import ( + ISO_DT_FORMATS, + normalize_payload, + deep_diff, + diff_to_patch, + filter_to_columns, + to_jsonable, +) diff --git a/crudkit/core/base.py b/crudkit/core/base.py index 46874fe..d73a04f 100644 --- a/crudkit/core/base.py +++ b/crudkit/core/base.py @@ -1,47 +1,358 @@ -from sqlalchemy import Column, Integer, DateTime, Boolean, String, JSON, func -from sqlalchemy.orm import declarative_mixin, declarative_base +from functools import lru_cache +from typing import Any, Dict, Iterable, List, Optional, Tuple, Set, cast +from sqlalchemy import Column, Integer, DateTime, Boolean, String, JSON, func, inspect +from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE, RelationshipProperty, Mapper +from sqlalchemy.orm.state import InstanceState Base = declarative_base() +@lru_cache(maxsize=512) +def _column_names_for_model(cls: type) -> tuple[str, ...]: + try: + mapper = inspect(cls) + return tuple(prop.key for prop in mapper.column_attrs) + except Exception: + names: list[str] = [] + for c in cls.__mro__: + if hasattr(c, "__table__"): + names.extend(col.name for col in c.__table__.columns) + return tuple(dict.fromkeys(names)) + +def _sa_state(obj: Any) -> Optional[InstanceState[Any]]: + """Safely get SQLAlchemy InstanceState (or None).""" + try: + st = inspect(obj) + return cast(Optional[InstanceState[Any]], st) + except Exception: + return None + +def _sa_mapper(obj: Any) -> Optional[Mapper]: + """Safely get Mapper for a maooed instance (or None).""" + try: + st = inspect(obj) + mapper = getattr(st, "mapper", None) + return cast(Optional[Mapper], mapper) + except Exception: + return None + +def _safe_get_loaded_attr(obj, name): + st = _sa_state(obj) + if st is None: + return None + try: + st_dict = getattr(st, "dict", {}) + if name in st_dict: + return st_dict[name] + + attrs = getattr(st, "attrs", None) + attr = None + if attrs is not None: + try: + attr = attrs[name] + except Exception: + try: + get = getattr(attrs, "get", None) + if callable(get): + attr = get(name) + except Exception: + attr = None + + if attr is not None: + val = attr.loaded_value + return None if val is NO_VALUE else val + + return None + except Exception: + return None + +def _identity_key(obj) -> Tuple[type, Any]: + try: + st = inspect(obj) + return (type(obj), st.identity_key[1][0] if st.identity_key else id(obj)) + except Exception: + return (type(obj), id(obj)) + +def _is_collection_rel(prop: RelationshipProperty) -> bool: + try: + return prop.uselist is True + except Exception: + return False + +def _serialize_simple_obj(obj) -> Dict[str, Any]: + """Columns only (no relationships).""" + out: Dict[str, Any] = {} + for name in _column_names_for_model(type(obj)): + try: + out[name] = getattr(obj, name) + except Exception: + out[name] = None + return out + +def _serialize_loaded_rel(obj, name, *, depth: int, seen: Set[Tuple[type, Any]], embed: Set[str]) -> Any: + """ + Serialize relationship 'name' already loaded on obj. + - If in 'embed' (or depth > 0 for depth-based walk), recurse. + - Else, return None (don’t lazy-load). + """ + val = _safe_get_loaded_attr(obj, name) + if val is None: + return None + + # Decide whether to recurse into this relationship + should_recurse = (depth > 0) or (name in embed) + + if isinstance(val, list): + if not should_recurse: + # Emit a light list of child primary data (id + a couple columns) without recursion. + return [_serialize_simple_obj(child) for child in val] + out = [] + for child in val: + ik = _identity_key(child) + if ik in seen: # cycle guard + out.append({"id": getattr(child, "id", None)}) + continue + seen.add(ik) + out.append(child.as_tree(embed_depth=max(depth - 1, 0), embed=embed, _seen=seen)) + return out + + # Scalar relationship + child = val + if not should_recurse: + return _serialize_simple_obj(child) + ik = _identity_key(child) + if ik in seen: + return {"id": getattr(child, "id", None)} + seen.add(ik) + return child.as_tree(embed_depth=max(depth - 1, 0), embed=embed, _seen=seen) + +def _split_field_tokens(fields: Iterable[str]) -> Tuple[List[str], Dict[str, List[str]]]: + """ + Split requested fields into: + - scalars: ["label", "name"] + - collections: {"updates": ["id", "timestamp","content"], "owner": ["label"]} + Any dotted token "root.rest.of.path" becomes collections[root].append("rest.of.path"). + Bare tokens ("foo") land in scalars. + """ + scalars: List[str] = [] + groups: Dict[str, List[str]] = {} + for raw in fields: + f = str(raw).strip() + if not f: + continue + # bare token -> scalar + if "." not in f: + scalars.append(f) + continue + # dotted token -> group under root + root, tail = f.split(".", 1) + if not root or not tail: + continue + groups.setdefault(root, []).append(tail) + return scalars, groups + +def _deep_get_loaded(obj: Any, dotted: str) -> Any: + """ + Deep get with no lazy loads: + - For all but the final hop, use _safe_get_loaded_attr (mapped-only, no getattr). + - For the final hop, try _safe_get_loaded_attr first; if None, fall back to getattr() + to allow computed properties/hybrids that rely on already-loaded columns. + """ + parts = dotted.split(".") + if not parts: + return None + + cur = obj + # Traverse up to the parent of the last token safely + for part in parts[:-1]: + if cur is None: + return None + cur = _safe_get_loaded_attr(cur, part) + if cur is None: + return None + + last = parts[-1] + # Try safe fetch on the last hop first + val = _safe_get_loaded_attr(cur, last) + if val is not None: + return val + # Fall back to getattr for computed/hybrid attributes on an already-loaded object + try: + return getattr(cur, last, None) + except Exception: + return None + +def _serialize_leaf(obj: Any) -> Any: + """ + Lead serialization for values we put into as_dict(): + - If object has as_dict(), call as_dict() with no args (caller controls field shapes). + - Else return value as-is (Flask/JSON encoder will handle datetimes, etc., via app config). + """ + if obj is None: + return None + ad = getattr(obj, "as_dict", None) + if callable(ad): + try: + return ad(None) + except Exception: + return str(obj) + return obj + +def _serialize_collection(items: Iterable[Any], requested_tails: List[str]) -> List[Dict[str, Any]]: + """ + Turn a collection of ORM objects into list[dict] with exactly requested_tails, + where each tail can be dotted again (e.g., "author.label"). We do NOT lazy-load. + """ + out: List[Dict[str, Any]] = [] + # Deduplicate while preserving order + uniq_tails = list(dict.fromkeys(requested_tails)) + for child in (items or []): + row: Dict[str, Any] = {} + for tail in uniq_tails: + row[tail] = _deep_get_loaded(child, tail) + # ensure id present if exists and not already requested + try: + if "id" not in row and hasattr(child, "id"): + row["id"] = getattr(child, "id") + except Exception: + pass + out.append(row) + return out + @declarative_mixin class CRUDMixin: id = Column(Integer, primary_key=True) created_at = Column(DateTime, default=func.now(), nullable=False) updated_at = Column(DateTime, default=func.now(), nullable=False, onupdate=func.now()) + def as_tree( + self, + *, + embed_depth: int = 0, + embed: Iterable[str] | None = None, + _seen: Set[Tuple[type, Any]] | None = None, + ) -> Dict[str, Any]: + """ + Recursive, NON-LAZY serializer. + - Always includes mapped columns. + - For relationships: only serializes those ALREADY LOADED. + - Recurses either up to embed_depth or for specific names in 'embed'. + - Keeps *_id columns alongside embedded objects. + - Cycle-safe via _seen. + """ + seen = _seen or set() + ik = _identity_key(self) + if ik in seen: + return {"id": getattr(self, "id", None)} + seen.add(ik) + + data = _serialize_simple_obj(self) + + # Determine which relationships to consider + try: + mapper = _sa_mapper(self) + embed_set = set(str(x).split(".", 1)[0] for x in (embed or [])) + if mapper is None: + return data + st = _sa_state(self) + if st is None: + return data + for name, prop in mapper.relationships.items(): + # Only touch relationships that are already loaded; never lazy-load here. + rel_loaded = getattr(st, "attrs", {}).get(name) + if rel_loaded is None or rel_loaded.loaded_value is NO_VALUE: + continue + + data[name] = _serialize_loaded_rel( + self, name, depth=embed_depth, seen=seen, embed=embed_set + ) + except Exception: + # If inspection fails, we just return columns. + pass + + return data + def as_dict(self, fields: list[str] | None = None): """ Serialize the instance. - - If 'fields' (possibly dotted) is provided, emit exactly those keys. - - Else, if '__crudkit_projection__' is set on the instance, emit those keys. - - Else, fall back to all mapped columns on this class hierarchy. - Always includes 'id' when present unless explicitly excluded. - """ - if fields is None: - fields = getattr(self, "__crudkit_projection__", None) - if fields: - out = {} - if "id" not in fields and hasattr(self, "id"): - out["id"] = getattr(self, "id") - for f in fields: - cur = self - for part in f.split("."): - if cur is None: - break - cur = getattr(cur, part, None) - out[f] = cur + Behavior: + - If 'fields' (possibly dotted) is provided, emit exactly those keys. + * Bare tokens (e.g., "label", "owner") return the current loaded value. + * Dotted tokens for one-to-many (e.g., "updates.id","updates.timestamp") + produce a single "updates" key containing a list of dicts with the requested child keys. + * Dotted tokens for many-to-one/one-to-one (e.g., "owner.label") emit the scalar under "owner.label". + - Else, if '__crudkit_projection__' is set on the instance, use that. + - Else, fall back to all mapped columns on this class hierarchy. + + Always includes 'id' when present unless explicitly excluded (i.e., fields explicitly provided without id). + """ + req = fields if fields is not None else getattr(self, "__crudkit_projection__", None) + + if req: + # Normalize and split into (scalars, groups of dotted by root) + req_list = [p for p in (str(x).strip() for x in req) if p] + scalars, groups = _split_field_tokens(req_list) + + out: Dict[str, Any] = {} + + # Always include id unless the caller explicitly listed fields containing id + if "id" not in req_list and hasattr(self, "id"): + try: + out["id"] = getattr(self, "id") + except Exception: + pass + + # Handle scalar tokens (may be columns, hybrids/properties, or relationships) + for name in scalars: + # Try loaded value first (never lazy-load) + val = _safe_get_loaded_attr(self, name) + + # Final-hop getattr for root scalars (hybrids/@property) so they can compute. + if val is None: + try: + val = getattr(self, name) + except Exception: + val = None + + # If it's a scalar ORM object (relationship), serialize its columns + mapper = _sa_mapper(val) + if mapper is not None: + out[name] = _serialize_simple_obj(val) + continue + + # If it's a collection and no subfields were requested, emit a light list + if isinstance(val, (list, tuple)): + out[name] = [_serialize_leaf(v) for v in val] + else: + out[name] = val + + # Handle dotted groups: root -> [tails] + for root, tails in groups.items(): + root_val = _safe_get_loaded_attr(self, root) + if isinstance(root_val, (list, tuple)): + # one-to-many collection → list of dicts with the requested tails + out[root] = _serialize_collection(root_val, tails) + else: + # many-to-one or scalar dotted; place each full dotted path as key + for tail in tails: + dotted = f"{root}.{tail}" + out[dotted] = _deep_get_loaded(self, dotted) + + # ← This was the placeholder before. We return the dict we just built. return out - result = {} + # Fallback: all mapped columns on this class hierarchy + result: Dict[str, Any] = {} for cls in self.__class__.__mro__: if hasattr(cls, "__table__"): for column in cls.__table__.columns: name = column.name - result[name] = getattr(self, name) + try: + result[name] = getattr(self, name) + except Exception: + result[name] = None return result - class Version(Base): __tablename__ = "versions" diff --git a/crudkit/core/service.py b/crudkit/core/service.py index 1220e04..d7fabc4 100644 --- a/crudkit/core/service.py +++ b/crudkit/core/service.py @@ -1,45 +1,37 @@ -from typing import Any, Callable, Dict, Iterable, List, Tuple, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast -from sqlalchemy import and_, func, inspect, or_, text -from sqlalchemy.engine import Engine, Connection -from sqlalchemy.orm import Load, Session, raiseload, selectinload, with_polymorphic, Mapper, RelationshipProperty, class_mapper -from sqlalchemy.orm.attributes import InstrumentedAttribute -from sqlalchemy.orm.util import AliasedClass -from sqlalchemy.sql import operators -from sqlalchemy.sql.elements import UnaryExpression +from __future__ import annotations +from collections.abc import Iterable +from dataclasses import dataclass +from flask import current_app +from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast +from sqlalchemy import and_, func, inspect, or_, text, select, literal +from sqlalchemy.engine import Engine, Connection +from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, selectinload, with_loader_criteria, aliased, with_parent +from sqlalchemy.orm.attributes import InstrumentedAttribute +from sqlalchemy.sql import operators, visitors +from sqlalchemy.sql.elements import UnaryExpression, ColumnElement + +from crudkit.core import to_jsonable, deep_diff, diff_to_patch, filter_to_columns, normalize_payload from crudkit.core.base import Version -from crudkit.core.spec import CRUDSpec +from crudkit.core.spec import CRUDSpec, CollPred from crudkit.core.types import OrderSpec, SeekWindow from crudkit.backend import BackendInfo, make_backend_info +from crudkit.projection import compile_projection -def _loader_options_for_fields(root_alias, model_cls, fields: list[str]) -> list[Load]: - """ - For bare MANYTOONE names in fields (e.g. "location"), selectinload the relationship - and only fetch the related PK. This is enough for preselecting + {% endfor %}{% endif %}>{{ value if value else "" }} {% elif field_type == 'checkbox' %} {% elif field_type == 'hidden' %} - + {% elif field_type == 'display' %}
{{ value }}
+ {% endfor %}{% endif %}>{{ value_label if value_label else (value if value else "") }} + +{% elif field_type == "date" %} + + +{% elif field_type == "time" %} + + +{% elif field_type == "datetime" %} + {% else %} - diff --git a/crudkit/ui/templates/form.html b/crudkit/ui/templates/form.html index b073fc3..f57074a 100644 --- a/crudkit/ui/templates/form.html +++ b/crudkit/ui/templates/form.html @@ -1,6 +1,5 @@ -
+ {% macro render_row(row) %} - {% if row.fields or row.children or row.legend %} {% if row.legend %}{{ row.legend }}{% endif %}
{{ submit_label if label else 'Save' }} + >{{ submit_label if submit_label else 'Save' }}