diff --git a/crudkit/api/_cursor.py b/crudkit/api/_cursor.py index a2a80c2..3b63cfd 100644 --- a/crudkit/api/_cursor.py +++ b/crudkit/api/_cursor.py @@ -1,135 +1,21 @@ -# crudkit/api/_cursor.py +import base64, json +from typing import Any -from __future__ import annotations - -import base64 -import dataclasses -import datetime as _dt -import decimal as _dec -import hmac -import json -import typing as _t -from hashlib import sha256 - -Any = _t.Any - -@dataclasses.dataclass(frozen=True) -class Cursor: - values: list[Any] - desc_flags: list[bool] - backward: bool - version: int = 1 - - -def _json_default(o: Any) -> Any: - # Keep it boring and predictable. - if isinstance(o, (_dt.datetime, _dt.date, _dt.time)): - # ISO is good enough; assume UTC-aware datetimes are already normalized upstream. - return o.isoformat() - if isinstance(o, _dec.Decimal): - return str(o) - # UUIDs, Enums, etc. - if hasattr(o, "__str__"): - return str(o) - raise TypeError(f"Unsupported cursor value type: {type(o)!r}") - - -def _b64url_nopad_encode(b: bytes) -> str: - return base64.urlsafe_b64encode(b).rstrip(b"=").decode("ascii") - - -def _b64url_nopad_decode(s: str) -> bytes: - # Restore padding for strict decode - pad = "=" * (-len(s) % 4) - return base64.urlsafe_b64decode((s + pad).encode("ascii")) - - -def encode_cursor( - values: list[Any] | None, - desc_flags: list[bool], - backward: bool, - *, - secret: bytes | None = None, -) -> str | None: - """ - Create an opaque, optionally signed cursor token. - - - values: keyset values for the last/first row - - desc_flags: per-key descending flags (same length/order as sort keys) - - backward: whether the window direction is backward - - secret: if set, an HMAC-SHA256 signature is appended (JWT-ish 'payload.sig') - """ +def encode_cursor(values: list[Any] | None, desc_flags: list[bool], backward: bool) -> str | None: if not values: return None + payload = {"v": values, "d": desc_flags, "b": backward} + return base64.urlsafe_b64encode(json.dumps(payload).encode()).decode() - payload = { - "ver": 1, - "v": values, - "d": desc_flags, - "b": bool(backward), - } - - body = json.dumps(payload, default=_json_default, separators=(",", ":"), sort_keys=True).encode("utf-8") - token = _b64url_nopad_encode(body) - - if secret: - sig = hmac.new(secret, body, sha256).digest() - token = f"{token}.{_b64url_nopad_encode(sig)}" - - return token - - -def decode_cursor( - token: str | None, - *, - secret: bytes | None = None, -) -> tuple[list[Any] | None, list[bool] | None, bool]: - """ - Parse a cursor token. Returns (values, desc_flags, backward). - - - Accepts legacy tokens lacking 'd' and returns desc_flags=None in that case. - - If secret is provided, verifies HMAC and rejects tampered tokens. - - On any parse failure, returns (None, None, False). - """ +def decode_cursor(token: str | None) -> tuple[list[Any] | None, bool] | tuple[None, bool]: if not token: - return None, None, False - + return None, False try: - # Split payload.sig if signed - if "." in token: - body_b64, sig_b64 = token.split(".", 1) - body = _b64url_nopad_decode(body_b64) - if secret is None: - # Caller didn’t ask for verification; still parse but don’t trust. - pass - else: - expected = hmac.new(secret, body, sha256).digest() - actual = _b64url_nopad_decode(sig_b64) - if not hmac.compare_digest(expected, actual): - return None, None, False - else: - body = _b64url_nopad_decode(token) - - obj = json.loads(body.decode("utf-8")) - - # Versioning. If we ever change fields, branch here. - ver = int(obj.get("ver", 0)) - if ver not in (0, 1): - return None, None, False - + obj = json.loads(base64.urlsafe_b64decode(token.encode()).decode()) vals = obj.get("v") backward = bool(obj.get("b", False)) - - # desc_flags may be absent in legacy payloads (ver 0) - desc = obj.get("d") - if not isinstance(vals, list): - return None, None, False - if desc is not None and not (isinstance(desc, list) and all(isinstance(x, bool) for x in desc)): - # Ignore weird 'd' types rather than crashing - desc = None - - return vals, desc, backward - + if isinstance(vals, list): + return vals, backward except Exception: - # Be tolerant on decode: treat as no-cursor. - return None, None, False + pass + return None, False diff --git a/crudkit/api/flask_api.py b/crudkit/api/flask_api.py index 39e7c49..bf505b2 100644 --- a/crudkit/api/flask_api.py +++ b/crudkit/api/flask_api.py @@ -1,195 +1,90 @@ -# crudkit/api/flask_api.py - -from __future__ import annotations - -from flask import Blueprint, jsonify, request, abort, current_app, url_for -from hashlib import md5 -from urllib.parse import urlencode -from werkzeug.exceptions import HTTPException +from flask import Blueprint, jsonify, request +from crudkit.api._cursor import encode_cursor, decode_cursor from crudkit.core.service import _is_truthy -MAX_JSON = 1_000_000 +def generate_crud_blueprint(model, service): + bp = Blueprint(model.__name__.lower(), __name__) -def _etag_for(obj) -> str: - v = getattr(obj, "updated_at", None) or obj.id - return md5(str(v).encode()).hexdigest() + @bp.get('/') + def list_items(): + args = request.args.to_dict(flat=True) -def _json_payload() -> dict: - if request.content_length and request.content_length > MAX_JSON: - abort(413) - if not request.is_json: - abort(415) - payload = request.get_json(silent=False) - if not isinstance(payload, dict): - abort(400) - return payload + # legacy detection + legacy_offset = "offset" in args or "page" in args -def _args_flat() -> dict[str, str]: - return request.args.to_dict(flat=True) # type: ignore[arg-type] + # sane limit default + try: + limit = int(args.get("limit", 50)) + except Exception: + limit = 50 + args["limit"] = limit -def _json_error(e: Exception, status: int = 400): - if isinstance(e, HTTPException): - status = e.code or status - msg = e.description - else: - msg = str(e) - if current_app.debug: - return jsonify({"status": "error", "error": msg, "type": e.__class__.__name__}), status - return jsonify({"status": "error", "error": msg}), status + if legacy_offset: + # Old behavior: honor limit/offset, same CRUDSpec goodies + items = service.list(args) + return jsonify([obj.as_dict() for obj in items]) -def _bool_param(d: dict[str, str], key: str, default: bool) -> bool: - return _is_truthy(d.get(key, "1" if default else "0")) + # New behavior: keyset seek with cursors + key, backward = decode_cursor(args.get("cursor")) -def generate_crud_blueprint(model, service, *, base_prefix: str | None = None, rest: bool = True, rpc: bool = True): - """ - REST: - GET /api// -> list (filters via ?q=..., sort=..., limit=..., cursor=...) - GET /api// -> get - POST /api// -> create - PATCH /api// -> update (partial) - DELETE /api//[?hard=1] -> delete + window = service.seek_window( + args, + key=key, + backward=backward, + include_total=_is_truthy(args.get("include_total", "1")), + ) - RPC (legacy): - GET /api//get?id=123 - GET /api//list - GET /api//seek_window - GET /api//page - POST /api//create - PATCH /api//update?id=123 - DELETE /api//delete?id=123[&hard=1] - """ - model_name = model.__name__.lower() - # bikeshed if you want pluralization; this is the least-annoying default - collection = (base_prefix or model_name).lower() - plural = collection if collection.endswith('s') else f"{collection}s" + desc_flags = list(window.order.desc) + body = { + "items": [obj.as_dict() for obj in window.items], + "limit": window.limit, + "next_cursor": encode_cursor(window.last_key, desc_flags, backward=False), + "prev_cursor": encode_cursor(window.first_key, desc_flags, backward=True), + "total": window.total, + } - bp = Blueprint(plural, __name__, url_prefix=f"/api/{plural}") + resp = jsonify(body) + # Optional Link header + links = [] + if body["next_cursor"]: + links.append(f'<{request.base_url}?cursor={body["next_cursor"]}&limit={window.limit}>; rel="next"') + if body["prev_cursor"]: + links.append(f'<{request.base_url}?cursor={body["prev_cursor"]}&limit={window.limit}>; rel="prev"') + if links: + resp.headers["Link"] = ", ".join(links) + return resp - @bp.errorhandler(Exception) - def _handle_any(e: Exception): - return _json_error(e) + @bp.get('/') + def get_item(id): + item = service.get(id, request.args) + try: + return jsonify(item.as_dict()) + except Exception as e: + return jsonify({"status": "error", "error": str(e)}) - @bp.errorhandler(404) - def _not_found(_e): - return jsonify({"status": "error", "error": "not found"}), 404 + @bp.post('/') + def create_item(): + obj = service.create(request.json) + try: + return jsonify(obj.as_dict()) + except Exception as e: + return jsonify({"status": "error", "error": str(e)}) - # ---------- REST ---------- - if rest: - @bp.get("/") - def rest_list(): - args = _args_flat() - # support cursor pagination transparently; fall back to limit/offset - try: - items = service.list(args) - return jsonify([o.as_dict() for o in items]) - except Exception as e: - return _json_error(e) + @bp.patch('/') + def update_item(id): + obj = service.update(id, request.json) + try: + return jsonify(obj.as_dict()) + except Exception as e: + return jsonify({"status": "error", "error": str(e)}) - @bp.get("/") - def rest_get(obj_id: int): - item = service.get(obj_id, request.args) - if item is None: - abort(404) - etag = _etag_for(item) - if request.if_none_match and (etag in request.if_none_match): - return "", 304 - resp = jsonify(item.as_dict()) - resp.set_etag(etag) - return resp - - @bp.post("/") - def rest_create(): - payload = _json_payload() - try: - obj = service.create(payload) - resp = jsonify(obj.as_dict()) - resp.status_code = 201 - resp.headers["Location"] = url_for(f"{plural}.rest_get", obj_id=obj.id, _external=False) - return resp - except Exception as e: - return _json_error(e) - - @bp.patch("/") - def rest_update(obj_id: int): - payload = _json_payload() - try: - obj = service.update(obj_id, payload) - return jsonify(obj.as_dict()) - except Exception as e: - return _json_error(e) - - @bp.delete("/") - def rest_delete(obj_id: int): - hard = _bool_param(_args_flat(), "hard", False) # type: ignore[arg-type] - try: - obj = service.delete(obj_id, hard=hard) - if obj is None: - abort(404) - return ("", 204) - except Exception as e: - return _json_error(e) - - # ---------- RPC (your existing routes) ---------- - if rpc: - @bp.get("/get") - def rpc_get(): - print("⚠️ WARNING: Deprecated RPC call used: /get") - id_ = int(request.args.get("id", 0)) - if not id_: - return jsonify({"status": "error", "error": "missing required param: id"}), 400 - try: - item = service.get(id_, request.args) - if item is None: - abort(404) - return jsonify(item.as_dict()) - except Exception as e: - return _json_error(e) - - @bp.get("/list") - def rpc_list(): - print("⚠️ WARNING: Deprecated RPC call used: /list") - args = _args_flat() - try: - items = service.list(args) - return jsonify([obj.as_dict() for obj in items]) - except Exception as e: - return _json_error(e) - - @bp.post("/create") - def rpc_create(): - print("⚠️ WARNING: Deprecated RPC call used: /create") - payload = _json_payload() - try: - obj = service.create(payload) - return jsonify(obj.as_dict()), 201 - except Exception as e: - return _json_error(e) - - @bp.patch("/update") - def rpc_update(): - print("⚠️ WARNING: Deprecated RPC call used: /update") - id_ = int(request.args.get("id", 0)) - if not id_: - return jsonify({"status": "error", "error": "missing required param: id"}), 400 - payload = _json_payload() - try: - obj = service.update(id_, payload) - return jsonify(obj.as_dict()) - except Exception as e: - return _json_error(e) - - @bp.delete("/delete") - def rpc_delete(): - print("⚠️ WARNING: Deprecated RPC call used: /delete") - id_ = int(request.args.get("id", 0)) - if not id_: - return jsonify({"status": "error", "error": "missing required param: id"}), 400 - hard = _bool_param(_args_flat(), "hard", False) # type: ignore[arg-type] - try: - obj = service.delete(id_, hard=hard) - return ("", 204) if obj is not None else abort(404) - except Exception as e: - return _json_error(e) + @bp.delete('/') + def delete_item(id): + service.delete(id) + try: + return jsonify({"status": "success"}), 204 + except Exception as e: + return jsonify({"status": "error", "error": str(e)}) return bp diff --git a/crudkit/backend.py b/crudkit/backend.py index 2da68f8..3232b71 100644 --- a/crudkit/backend.py +++ b/crudkit/backend.py @@ -74,32 +74,18 @@ def apply_pagination(sel: Select, backend: BackendInfo, *, page: int, per_page: per_page = max(1, int(per_page)) offset = (page - 1) * per_page - if backend.requires_order_by_for_offset: - # Avoid private attribute if possible: - has_order = bool(getattr(sel, "_order_by_clauses", ())) # fallback for SA < 2.0.30 - try: - has_order = has_order or bool(sel.get_order_by()) - except Exception: - pass - - if not has_order: - if default_order_by is not None: - sel = sel.order_by(default_order_by) - else: - # Try to find a primary key from the FROMs; fall back to a harmless literal. - try: - first_from = sel.get_final_froms()[0] - pk = next(iter(first_from.primary_key.columns)) - sel = sel.order_by(pk) - except Exception: - sel = sel.order_by(text("1")) + if backend.requires_order_by_for_offset and not sel._order_by_clauses: + if default_order_by is None: + sel = sel.order_by(text("1")) + else: + sel = sel.order_by(default_order_by) return sel.limit(per_page).offset(offset) @contextmanager def maybe_identify_insert(session: Session, table, backend: BackendInfo): """ - For MSSQL tables with IDENTITY PK when you need to insert explicit IDs. + For MSSQL tables with IDENTIFY PK when you need to insert explicit IDs. No-op elsewhere. """ if not backend.is_mssql: @@ -107,7 +93,7 @@ def maybe_identify_insert(session: Session, table, backend: BackendInfo): return full_name = f"{table.schema}.{table.name}" if table.schema else table.name - session.execute(text(f"SET IDENTITY_INSERT {full_name} ON")) + session.execute(text(f"SET IDENTIFY_INSERT {full_name} ON")) try: yield finally: @@ -115,7 +101,7 @@ def maybe_identify_insert(session: Session, table, backend: BackendInfo): def chunked_in(column, values: Iterable, backend: BackendInfo, chunk_size: Optional[int] = None) -> ClauseElement: """ - Build a safe large IN() filter respecting bind param limits. + Build a safe large IN() filter respecting bund param limits. Returns a disjunction of chunked IN clauses if needed. """ vals = list(values) @@ -134,12 +120,3 @@ def chunked_in(column, values: Iterable, backend: BackendInfo, chunk_size: Optio for p in parts[1:]: expr = expr | p return expr - -def sql_trim(expr, backend: BackendInfo): - """ - Portable TRIM. SQL Server before compat level 140 lacks TRIM(). - Emit LTRIM(RTRIM(...)) there; use TRIM elsewhere - """ - if backend.is_mssql: - return func.ltrim(func.rtrim(expr)) - return func.trim(expr) diff --git a/crudkit/config.py b/crudkit/config.py index fb87b51..0439a3e 100644 --- a/crudkit/config.py +++ b/crudkit/config.py @@ -187,8 +187,6 @@ class Config: "synchronous": os.getenv("SQLITE_SYNCHRONOUS", "NORMAL"), } - STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "1"))) - @classmethod def engine_kwargs(cls) -> Dict[str, Any]: url = cls.DATABASE_URL @@ -223,18 +221,15 @@ class Config: class DevConfig(Config): DEBUG = True SQLALCHEMY_ECHO = bool(int(os.getenv("DB_ECHO", "1"))) - STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "1"))) class TestConfig(Config): TESTING = True DATABASE_URL = build_database_url(backend="sqlite", database=":memory:") SQLALCHEMY_ECHO = False - STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "1"))) class ProdConfig(Config): DEBUG = False SQLALCHEMY_ECHO = bool(int(os.getenv("DB_ECHO", "0"))) - STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "0"))) def get_config(name: str | None) -> Type[Config]: """ diff --git a/crudkit/core/__init__.py b/crudkit/core/__init__.py index 86d90b7..e69de29 100644 --- a/crudkit/core/__init__.py +++ b/crudkit/core/__init__.py @@ -1,9 +0,0 @@ -# crudkit/core/__init__.py -from .utils import ( - ISO_DT_FORMATS, - normalize_payload, - deep_diff, - diff_to_patch, - filter_to_columns, - to_jsonable, -) diff --git a/crudkit/core/base.py b/crudkit/core/base.py index d73a04f..46874fe 100644 --- a/crudkit/core/base.py +++ b/crudkit/core/base.py @@ -1,358 +1,47 @@ -from functools import lru_cache -from typing import Any, Dict, Iterable, List, Optional, Tuple, Set, cast -from sqlalchemy import Column, Integer, DateTime, Boolean, String, JSON, func, inspect -from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE, RelationshipProperty, Mapper -from sqlalchemy.orm.state import InstanceState +from sqlalchemy import Column, Integer, DateTime, Boolean, String, JSON, func +from sqlalchemy.orm import declarative_mixin, declarative_base Base = declarative_base() -@lru_cache(maxsize=512) -def _column_names_for_model(cls: type) -> tuple[str, ...]: - try: - mapper = inspect(cls) - return tuple(prop.key for prop in mapper.column_attrs) - except Exception: - names: list[str] = [] - for c in cls.__mro__: - if hasattr(c, "__table__"): - names.extend(col.name for col in c.__table__.columns) - return tuple(dict.fromkeys(names)) - -def _sa_state(obj: Any) -> Optional[InstanceState[Any]]: - """Safely get SQLAlchemy InstanceState (or None).""" - try: - st = inspect(obj) - return cast(Optional[InstanceState[Any]], st) - except Exception: - return None - -def _sa_mapper(obj: Any) -> Optional[Mapper]: - """Safely get Mapper for a maooed instance (or None).""" - try: - st = inspect(obj) - mapper = getattr(st, "mapper", None) - return cast(Optional[Mapper], mapper) - except Exception: - return None - -def _safe_get_loaded_attr(obj, name): - st = _sa_state(obj) - if st is None: - return None - try: - st_dict = getattr(st, "dict", {}) - if name in st_dict: - return st_dict[name] - - attrs = getattr(st, "attrs", None) - attr = None - if attrs is not None: - try: - attr = attrs[name] - except Exception: - try: - get = getattr(attrs, "get", None) - if callable(get): - attr = get(name) - except Exception: - attr = None - - if attr is not None: - val = attr.loaded_value - return None if val is NO_VALUE else val - - return None - except Exception: - return None - -def _identity_key(obj) -> Tuple[type, Any]: - try: - st = inspect(obj) - return (type(obj), st.identity_key[1][0] if st.identity_key else id(obj)) - except Exception: - return (type(obj), id(obj)) - -def _is_collection_rel(prop: RelationshipProperty) -> bool: - try: - return prop.uselist is True - except Exception: - return False - -def _serialize_simple_obj(obj) -> Dict[str, Any]: - """Columns only (no relationships).""" - out: Dict[str, Any] = {} - for name in _column_names_for_model(type(obj)): - try: - out[name] = getattr(obj, name) - except Exception: - out[name] = None - return out - -def _serialize_loaded_rel(obj, name, *, depth: int, seen: Set[Tuple[type, Any]], embed: Set[str]) -> Any: - """ - Serialize relationship 'name' already loaded on obj. - - If in 'embed' (or depth > 0 for depth-based walk), recurse. - - Else, return None (don’t lazy-load). - """ - val = _safe_get_loaded_attr(obj, name) - if val is None: - return None - - # Decide whether to recurse into this relationship - should_recurse = (depth > 0) or (name in embed) - - if isinstance(val, list): - if not should_recurse: - # Emit a light list of child primary data (id + a couple columns) without recursion. - return [_serialize_simple_obj(child) for child in val] - out = [] - for child in val: - ik = _identity_key(child) - if ik in seen: # cycle guard - out.append({"id": getattr(child, "id", None)}) - continue - seen.add(ik) - out.append(child.as_tree(embed_depth=max(depth - 1, 0), embed=embed, _seen=seen)) - return out - - # Scalar relationship - child = val - if not should_recurse: - return _serialize_simple_obj(child) - ik = _identity_key(child) - if ik in seen: - return {"id": getattr(child, "id", None)} - seen.add(ik) - return child.as_tree(embed_depth=max(depth - 1, 0), embed=embed, _seen=seen) - -def _split_field_tokens(fields: Iterable[str]) -> Tuple[List[str], Dict[str, List[str]]]: - """ - Split requested fields into: - - scalars: ["label", "name"] - - collections: {"updates": ["id", "timestamp","content"], "owner": ["label"]} - Any dotted token "root.rest.of.path" becomes collections[root].append("rest.of.path"). - Bare tokens ("foo") land in scalars. - """ - scalars: List[str] = [] - groups: Dict[str, List[str]] = {} - for raw in fields: - f = str(raw).strip() - if not f: - continue - # bare token -> scalar - if "." not in f: - scalars.append(f) - continue - # dotted token -> group under root - root, tail = f.split(".", 1) - if not root or not tail: - continue - groups.setdefault(root, []).append(tail) - return scalars, groups - -def _deep_get_loaded(obj: Any, dotted: str) -> Any: - """ - Deep get with no lazy loads: - - For all but the final hop, use _safe_get_loaded_attr (mapped-only, no getattr). - - For the final hop, try _safe_get_loaded_attr first; if None, fall back to getattr() - to allow computed properties/hybrids that rely on already-loaded columns. - """ - parts = dotted.split(".") - if not parts: - return None - - cur = obj - # Traverse up to the parent of the last token safely - for part in parts[:-1]: - if cur is None: - return None - cur = _safe_get_loaded_attr(cur, part) - if cur is None: - return None - - last = parts[-1] - # Try safe fetch on the last hop first - val = _safe_get_loaded_attr(cur, last) - if val is not None: - return val - # Fall back to getattr for computed/hybrid attributes on an already-loaded object - try: - return getattr(cur, last, None) - except Exception: - return None - -def _serialize_leaf(obj: Any) -> Any: - """ - Lead serialization for values we put into as_dict(): - - If object has as_dict(), call as_dict() with no args (caller controls field shapes). - - Else return value as-is (Flask/JSON encoder will handle datetimes, etc., via app config). - """ - if obj is None: - return None - ad = getattr(obj, "as_dict", None) - if callable(ad): - try: - return ad(None) - except Exception: - return str(obj) - return obj - -def _serialize_collection(items: Iterable[Any], requested_tails: List[str]) -> List[Dict[str, Any]]: - """ - Turn a collection of ORM objects into list[dict] with exactly requested_tails, - where each tail can be dotted again (e.g., "author.label"). We do NOT lazy-load. - """ - out: List[Dict[str, Any]] = [] - # Deduplicate while preserving order - uniq_tails = list(dict.fromkeys(requested_tails)) - for child in (items or []): - row: Dict[str, Any] = {} - for tail in uniq_tails: - row[tail] = _deep_get_loaded(child, tail) - # ensure id present if exists and not already requested - try: - if "id" not in row and hasattr(child, "id"): - row["id"] = getattr(child, "id") - except Exception: - pass - out.append(row) - return out - @declarative_mixin class CRUDMixin: id = Column(Integer, primary_key=True) created_at = Column(DateTime, default=func.now(), nullable=False) updated_at = Column(DateTime, default=func.now(), nullable=False, onupdate=func.now()) - def as_tree( - self, - *, - embed_depth: int = 0, - embed: Iterable[str] | None = None, - _seen: Set[Tuple[type, Any]] | None = None, - ) -> Dict[str, Any]: - """ - Recursive, NON-LAZY serializer. - - Always includes mapped columns. - - For relationships: only serializes those ALREADY LOADED. - - Recurses either up to embed_depth or for specific names in 'embed'. - - Keeps *_id columns alongside embedded objects. - - Cycle-safe via _seen. - """ - seen = _seen or set() - ik = _identity_key(self) - if ik in seen: - return {"id": getattr(self, "id", None)} - seen.add(ik) - - data = _serialize_simple_obj(self) - - # Determine which relationships to consider - try: - mapper = _sa_mapper(self) - embed_set = set(str(x).split(".", 1)[0] for x in (embed or [])) - if mapper is None: - return data - st = _sa_state(self) - if st is None: - return data - for name, prop in mapper.relationships.items(): - # Only touch relationships that are already loaded; never lazy-load here. - rel_loaded = getattr(st, "attrs", {}).get(name) - if rel_loaded is None or rel_loaded.loaded_value is NO_VALUE: - continue - - data[name] = _serialize_loaded_rel( - self, name, depth=embed_depth, seen=seen, embed=embed_set - ) - except Exception: - # If inspection fails, we just return columns. - pass - - return data - def as_dict(self, fields: list[str] | None = None): """ Serialize the instance. - - Behavior: - - If 'fields' (possibly dotted) is provided, emit exactly those keys. - * Bare tokens (e.g., "label", "owner") return the current loaded value. - * Dotted tokens for one-to-many (e.g., "updates.id","updates.timestamp") - produce a single "updates" key containing a list of dicts with the requested child keys. - * Dotted tokens for many-to-one/one-to-one (e.g., "owner.label") emit the scalar under "owner.label". - - Else, if '__crudkit_projection__' is set on the instance, use that. - - Else, fall back to all mapped columns on this class hierarchy. - - Always includes 'id' when present unless explicitly excluded (i.e., fields explicitly provided without id). + - If 'fields' (possibly dotted) is provided, emit exactly those keys. + - Else, if '__crudkit_projection__' is set on the instance, emit those keys. + - Else, fall back to all mapped columns on this class hierarchy. + Always includes 'id' when present unless explicitly excluded. """ - req = fields if fields is not None else getattr(self, "__crudkit_projection__", None) + if fields is None: + fields = getattr(self, "__crudkit_projection__", None) - if req: - # Normalize and split into (scalars, groups of dotted by root) - req_list = [p for p in (str(x).strip() for x in req) if p] - scalars, groups = _split_field_tokens(req_list) - - out: Dict[str, Any] = {} - - # Always include id unless the caller explicitly listed fields containing id - if "id" not in req_list and hasattr(self, "id"): - try: - out["id"] = getattr(self, "id") - except Exception: - pass - - # Handle scalar tokens (may be columns, hybrids/properties, or relationships) - for name in scalars: - # Try loaded value first (never lazy-load) - val = _safe_get_loaded_attr(self, name) - - # Final-hop getattr for root scalars (hybrids/@property) so they can compute. - if val is None: - try: - val = getattr(self, name) - except Exception: - val = None - - # If it's a scalar ORM object (relationship), serialize its columns - mapper = _sa_mapper(val) - if mapper is not None: - out[name] = _serialize_simple_obj(val) - continue - - # If it's a collection and no subfields were requested, emit a light list - if isinstance(val, (list, tuple)): - out[name] = [_serialize_leaf(v) for v in val] - else: - out[name] = val - - # Handle dotted groups: root -> [tails] - for root, tails in groups.items(): - root_val = _safe_get_loaded_attr(self, root) - if isinstance(root_val, (list, tuple)): - # one-to-many collection → list of dicts with the requested tails - out[root] = _serialize_collection(root_val, tails) - else: - # many-to-one or scalar dotted; place each full dotted path as key - for tail in tails: - dotted = f"{root}.{tail}" - out[dotted] = _deep_get_loaded(self, dotted) - - # ← This was the placeholder before. We return the dict we just built. + if fields: + out = {} + if "id" not in fields and hasattr(self, "id"): + out["id"] = getattr(self, "id") + for f in fields: + cur = self + for part in f.split("."): + if cur is None: + break + cur = getattr(cur, part, None) + out[f] = cur return out - # Fallback: all mapped columns on this class hierarchy - result: Dict[str, Any] = {} + result = {} for cls in self.__class__.__mro__: if hasattr(cls, "__table__"): for column in cls.__table__.columns: name = column.name - try: - result[name] = getattr(self, name) - except Exception: - result[name] = None + result[name] = getattr(self, name) return result + class Version(Base): __tablename__ = "versions" diff --git a/crudkit/core/service.py b/crudkit/core/service.py index d7fabc4..1220e04 100644 --- a/crudkit/core/service.py +++ b/crudkit/core/service.py @@ -1,37 +1,45 @@ -from __future__ import annotations - -from collections.abc import Iterable -from dataclasses import dataclass -from flask import current_app -from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast -from sqlalchemy import and_, func, inspect, or_, text, select, literal +from typing import Any, Callable, Dict, Iterable, List, Tuple, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast +from sqlalchemy import and_, func, inspect, or_, text from sqlalchemy.engine import Engine, Connection -from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, selectinload, with_loader_criteria, aliased, with_parent +from sqlalchemy.orm import Load, Session, raiseload, selectinload, with_polymorphic, Mapper, RelationshipProperty, class_mapper from sqlalchemy.orm.attributes import InstrumentedAttribute -from sqlalchemy.sql import operators, visitors -from sqlalchemy.sql.elements import UnaryExpression, ColumnElement +from sqlalchemy.orm.util import AliasedClass +from sqlalchemy.sql import operators +from sqlalchemy.sql.elements import UnaryExpression -from crudkit.core import to_jsonable, deep_diff, diff_to_patch, filter_to_columns, normalize_payload from crudkit.core.base import Version -from crudkit.core.spec import CRUDSpec, CollPred +from crudkit.core.spec import CRUDSpec from crudkit.core.types import OrderSpec, SeekWindow from crudkit.backend import BackendInfo, make_backend_info -from crudkit.projection import compile_projection -import logging -log = logging.getLogger("crudkit.service") -# logging.getLogger("crudkit.service").setLevel(logging.DEBUG) -# Ensure our debug actually prints even if the app/root logger is WARNING+ -# if not log.handlers: -# _h = logging.StreamHandler() -# _h.setLevel(logging.DEBUG) -# _h.setFormatter(logging.Formatter( -# "%(asctime)s %(levelname)s %(name)s: %(message)s" -# )) -# log.addHandler(_h) -# -# log.setLevel(logging.DEBUG) -# log.propagate = False +def _loader_options_for_fields(root_alias, model_cls, fields: list[str]) -> list[Load]: + """ + For bare MANYTOONE names in fields (e.g. "location"), selectinload the relationship + and only fetch the related PK. This is enough for preselecting + {% endfor %}{% endif %}>{{ value }} {% elif field_type == 'checkbox' %} {% elif field_type == 'hidden' %} - + {% elif field_type == 'display' %}
{{ value_label if value_label else (value if value else "") }}
- -{% elif field_type == "date" %} - - -{% elif field_type == "time" %} - - -{% elif field_type == "datetime" %} - + {% endfor %}{% endif %}>{{ value }} {% else %} - diff --git a/crudkit/ui/templates/form.html b/crudkit/ui/templates/form.html index f57074a..b073fc3 100644 --- a/crudkit/ui/templates/form.html +++ b/crudkit/ui/templates/form.html @@ -1,5 +1,6 @@ -
+ {% macro render_row(row) %} + {% if row.fields or row.children or row.legend %} {% if row.legend %}{{ row.legend }}{% endif %}
{{ submit_label if submit_label else 'Save' }} + >{{ submit_label if label else 'Save' }}