diff --git a/crudkit/api/_cursor.py b/crudkit/api/_cursor.py index 3b63cfd..a2a80c2 100644 --- a/crudkit/api/_cursor.py +++ b/crudkit/api/_cursor.py @@ -1,21 +1,135 @@ -import base64, json -from typing import Any +# crudkit/api/_cursor.py -def encode_cursor(values: list[Any] | None, desc_flags: list[bool], backward: bool) -> str | None: +from __future__ import annotations + +import base64 +import dataclasses +import datetime as _dt +import decimal as _dec +import hmac +import json +import typing as _t +from hashlib import sha256 + +Any = _t.Any + +@dataclasses.dataclass(frozen=True) +class Cursor: + values: list[Any] + desc_flags: list[bool] + backward: bool + version: int = 1 + + +def _json_default(o: Any) -> Any: + # Keep it boring and predictable. + if isinstance(o, (_dt.datetime, _dt.date, _dt.time)): + # ISO is good enough; assume UTC-aware datetimes are already normalized upstream. + return o.isoformat() + if isinstance(o, _dec.Decimal): + return str(o) + # UUIDs, Enums, etc. + if hasattr(o, "__str__"): + return str(o) + raise TypeError(f"Unsupported cursor value type: {type(o)!r}") + + +def _b64url_nopad_encode(b: bytes) -> str: + return base64.urlsafe_b64encode(b).rstrip(b"=").decode("ascii") + + +def _b64url_nopad_decode(s: str) -> bytes: + # Restore padding for strict decode + pad = "=" * (-len(s) % 4) + return base64.urlsafe_b64decode((s + pad).encode("ascii")) + + +def encode_cursor( + values: list[Any] | None, + desc_flags: list[bool], + backward: bool, + *, + secret: bytes | None = None, +) -> str | None: + """ + Create an opaque, optionally signed cursor token. + + - values: keyset values for the last/first row + - desc_flags: per-key descending flags (same length/order as sort keys) + - backward: whether the window direction is backward + - secret: if set, an HMAC-SHA256 signature is appended (JWT-ish 'payload.sig') + """ if not values: return None - payload = {"v": values, "d": desc_flags, "b": backward} - return base64.urlsafe_b64encode(json.dumps(payload).encode()).decode() -def decode_cursor(token: str | None) -> tuple[list[Any] | None, bool] | tuple[None, bool]: + payload = { + "ver": 1, + "v": values, + "d": desc_flags, + "b": bool(backward), + } + + body = json.dumps(payload, default=_json_default, separators=(",", ":"), sort_keys=True).encode("utf-8") + token = _b64url_nopad_encode(body) + + if secret: + sig = hmac.new(secret, body, sha256).digest() + token = f"{token}.{_b64url_nopad_encode(sig)}" + + return token + + +def decode_cursor( + token: str | None, + *, + secret: bytes | None = None, +) -> tuple[list[Any] | None, list[bool] | None, bool]: + """ + Parse a cursor token. Returns (values, desc_flags, backward). + + - Accepts legacy tokens lacking 'd' and returns desc_flags=None in that case. + - If secret is provided, verifies HMAC and rejects tampered tokens. + - On any parse failure, returns (None, None, False). + """ if not token: - return None, False + return None, None, False + try: - obj = json.loads(base64.urlsafe_b64decode(token.encode()).decode()) + # Split payload.sig if signed + if "." in token: + body_b64, sig_b64 = token.split(".", 1) + body = _b64url_nopad_decode(body_b64) + if secret is None: + # Caller didn’t ask for verification; still parse but don’t trust. + pass + else: + expected = hmac.new(secret, body, sha256).digest() + actual = _b64url_nopad_decode(sig_b64) + if not hmac.compare_digest(expected, actual): + return None, None, False + else: + body = _b64url_nopad_decode(token) + + obj = json.loads(body.decode("utf-8")) + + # Versioning. If we ever change fields, branch here. + ver = int(obj.get("ver", 0)) + if ver not in (0, 1): + return None, None, False + vals = obj.get("v") backward = bool(obj.get("b", False)) - if isinstance(vals, list): - return vals, backward + + # desc_flags may be absent in legacy payloads (ver 0) + desc = obj.get("d") + if not isinstance(vals, list): + return None, None, False + if desc is not None and not (isinstance(desc, list) and all(isinstance(x, bool) for x in desc)): + # Ignore weird 'd' types rather than crashing + desc = None + + return vals, desc, backward + except Exception: - pass - return None, False + # Be tolerant on decode: treat as no-cursor. + return None, None, False diff --git a/crudkit/api/flask_api.py b/crudkit/api/flask_api.py index bf505b2..3b061e2 100644 --- a/crudkit/api/flask_api.py +++ b/crudkit/api/flask_api.py @@ -1,23 +1,40 @@ -from flask import Blueprint, jsonify, request +from __future__ import annotations + +from flask import Blueprint, jsonify, request, abort +from urllib.parse import urlencode from crudkit.api._cursor import encode_cursor, decode_cursor from crudkit.core.service import _is_truthy + +def _bool_param(d: dict[str, str], key: str, default: bool) -> bool: + return _is_truthy(d.get(key, "1" if default else "0")) + + +def _safe_int(value: str | None, default: int) -> int: + try: + return int(value) if value is not None else default + except Exception: + return default + + +def _link_with_params(base_url: str, **params) -> str: + # Filter out None, encode safely + q = {k: v for k, v in params.items() if v is not None} + return f"{base_url}?{urlencode(q)}" + + def generate_crud_blueprint(model, service): bp = Blueprint(model.__name__.lower(), __name__) - @bp.get('/') + @bp.get("/") def list_items(): + # Work from a copy so we don't mutate request.args args = request.args.to_dict(flat=True) - # legacy detection legacy_offset = "offset" in args or "page" in args - # sane limit default - try: - limit = int(args.get("limit", 50)) - except Exception: - limit = 50 + limit = _safe_int(args.get("limit"), 50) args["limit"] = limit if legacy_offset: @@ -25,17 +42,23 @@ def generate_crud_blueprint(model, service): items = service.list(args) return jsonify([obj.as_dict() for obj in items]) - # New behavior: keyset seek with cursors - key, backward = decode_cursor(args.get("cursor")) + # New behavior: keyset pagination with cursors + cursor_token = args.get("cursor") + key, desc_from_cursor, backward = decode_cursor(cursor_token) window = service.seek_window( args, key=key, backward=backward, - include_total=_is_truthy(args.get("include_total", "1")), + include_total=_bool_param(args, "include_total", True), ) - desc_flags = list(window.order.desc) + # Prefer the order actually used by the window; fall back to desc_from_cursor if needed. + try: + desc_flags = list(window.order.desc) + except Exception: + desc_flags = desc_from_cursor or [] + body = { "items": [obj.as_dict() for obj in window.items], "limit": window.limit, @@ -45,46 +68,60 @@ def generate_crud_blueprint(model, service): } resp = jsonify(body) - # Optional Link header - links = [] + + # Preserve user’s other query params like include_total, filters, sorts, etc. + base_url = request.base_url + base_params = {k: v for k, v in args.items() if k not in {"cursor"}} + link_parts = [] if body["next_cursor"]: - links.append(f'<{request.base_url}?cursor={body["next_cursor"]}&limit={window.limit}>; rel="next"') + link_parts.append( + f'<{_link_with_params(base_url, **base_params, cursor=body["next_cursor"])}>; rel="next"' + ) if body["prev_cursor"]: - links.append(f'<{request.base_url}?cursor={body["prev_cursor"]}&limit={window.limit}>; rel="prev"') - if links: - resp.headers["Link"] = ", ".join(links) + link_parts.append( + f'<{_link_with_params(base_url, **base_params, cursor=body["prev_cursor"])}>; rel="prev"' + ) + if link_parts: + resp.headers["Link"] = ", ".join(link_parts) return resp - @bp.get('/') + @bp.get("/") def get_item(id): - item = service.get(id, request.args) try: + item = service.get(id, request.args) + if item is None: + abort(404) return jsonify(item.as_dict()) except Exception as e: - return jsonify({"status": "error", "error": str(e)}) + # Could be validation, auth, or just you forgetting an index again + return jsonify({"status": "error", "error": str(e)}), 400 - @bp.post('/') + @bp.post("/") def create_item(): - obj = service.create(request.json) + payload = request.get_json(silent=True) or {} try: - return jsonify(obj.as_dict()) + obj = service.create(payload) + return jsonify(obj.as_dict()), 201 except Exception as e: - return jsonify({"status": "error", "error": str(e)}) + return jsonify({"status": "error", "error": str(e)}), 400 - @bp.patch('/') + @bp.patch("/") def update_item(id): - obj = service.update(id, request.json) + payload = request.get_json(silent=True) or {} try: + obj = service.update(id, payload) return jsonify(obj.as_dict()) except Exception as e: - return jsonify({"status": "error", "error": str(e)}) + # 404 if not found, 400 if validation. Your service can throw specific exceptions if you ever feel like being professional. + return jsonify({"status": "error", "error": str(e)}), 400 - @bp.delete('/') + @bp.delete("/") def delete_item(id): - service.delete(id) try: - return jsonify({"status": "success"}), 204 + service.delete(id) + # 204 means "no content" so don't send any. + return ("", 204) except Exception as e: - return jsonify({"status": "error", "error": str(e)}) + return jsonify({"status": "error", "error": str(e)}), 400 return bp diff --git a/crudkit/core/service.py b/crudkit/core/service.py index b84ebc6..9eb6d8b 100644 --- a/crudkit/core/service.py +++ b/crudkit/core/service.py @@ -1,92 +1,22 @@ -from typing import Any, Callable, Dict, Iterable, List, Tuple, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast -from sqlalchemy import and_, func, inspect, or_, text +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast +from sqlalchemy import and_, func, inspect, or_, text, UnaryExpression from sqlalchemy.engine import Engine, Connection -from sqlalchemy.orm import Load, Session, raiseload, selectinload, with_polymorphic, Mapper, RelationshipProperty, class_mapper, ColumnProperty +from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager, selectinload from sqlalchemy.orm.attributes import InstrumentedAttribute -from sqlalchemy.orm.base import NO_VALUE -from sqlalchemy.orm.util import AliasedClass from sqlalchemy.sql import operators -from sqlalchemy.sql.elements import UnaryExpression +from sqlalchemy.sql.elements import ColumnElement from crudkit.core.base import Version from crudkit.core.spec import CRUDSpec from crudkit.core.types import OrderSpec, SeekWindow from crudkit.backend import BackendInfo, make_backend_info +from crudkit.projection import compile_projection -def _expand_requires(model_cls, fields): - out, seen = [], set() - def add(f): - if f not in seen: - seen.add(f); out.append(f) - - for f in fields: - add(f) - parts = f.split(".") - cur_cls = model_cls - prefix = [] - - for p in parts[:-1]: - rel = getattr(cur_cls.__mapper__.relationships, 'get', lambda _: None)(p) - if not rel: - cur_cls = None - break - cur_cls = rel.mapper.class_ - prefix.append(p) - - if cur_cls is None: - continue - - leaf = parts[-1] - deps = (getattr(cur_cls, "__crudkit_field_requires__", {}) or {}).get(leaf) - if not deps: - continue - - pre = ".".join(prefix) - for dep in deps: - add(f"{pre + '.' if pre else ''}{dep}") - return out - -def _is_rel(model_cls, name: str) -> bool: - try: - prop = model_cls.__mapper__.relationships.get(name) - return isinstance(prop, RelationshipProperty) - except Exception: - return False - -def _is_instrumented_column(attr) -> bool: - try: - return hasattr(attr, "property") and isinstance(attr.property, ColumnProperty) - except Exception: - return False - -def _loader_options_for_fields(root_alias, model_cls, fields: list[str]) -> list[Load]: - """ - For bare MANYTOONE names in fields (e.g. "location"), selectinload the relationship - and only fetch the related PK. This is enough for preselecting