Optimizations and refactoring.
This commit is contained in:
parent
94837e1b6f
commit
a0ee1caeb7
4 changed files with 273 additions and 85 deletions
|
|
@ -1,21 +1,135 @@
|
|||
import base64, json
|
||||
from typing import Any
|
||||
# crudkit/api/_cursor.py
|
||||
|
||||
def encode_cursor(values: list[Any] | None, desc_flags: list[bool], backward: bool) -> str | None:
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import dataclasses
|
||||
import datetime as _dt
|
||||
import decimal as _dec
|
||||
import hmac
|
||||
import json
|
||||
import typing as _t
|
||||
from hashlib import sha256
|
||||
|
||||
Any = _t.Any
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Cursor:
|
||||
values: list[Any]
|
||||
desc_flags: list[bool]
|
||||
backward: bool
|
||||
version: int = 1
|
||||
|
||||
|
||||
def _json_default(o: Any) -> Any:
|
||||
# Keep it boring and predictable.
|
||||
if isinstance(o, (_dt.datetime, _dt.date, _dt.time)):
|
||||
# ISO is good enough; assume UTC-aware datetimes are already normalized upstream.
|
||||
return o.isoformat()
|
||||
if isinstance(o, _dec.Decimal):
|
||||
return str(o)
|
||||
# UUIDs, Enums, etc.
|
||||
if hasattr(o, "__str__"):
|
||||
return str(o)
|
||||
raise TypeError(f"Unsupported cursor value type: {type(o)!r}")
|
||||
|
||||
|
||||
def _b64url_nopad_encode(b: bytes) -> str:
|
||||
return base64.urlsafe_b64encode(b).rstrip(b"=").decode("ascii")
|
||||
|
||||
|
||||
def _b64url_nopad_decode(s: str) -> bytes:
|
||||
# Restore padding for strict decode
|
||||
pad = "=" * (-len(s) % 4)
|
||||
return base64.urlsafe_b64decode((s + pad).encode("ascii"))
|
||||
|
||||
|
||||
def encode_cursor(
|
||||
values: list[Any] | None,
|
||||
desc_flags: list[bool],
|
||||
backward: bool,
|
||||
*,
|
||||
secret: bytes | None = None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Create an opaque, optionally signed cursor token.
|
||||
|
||||
- values: keyset values for the last/first row
|
||||
- desc_flags: per-key descending flags (same length/order as sort keys)
|
||||
- backward: whether the window direction is backward
|
||||
- secret: if set, an HMAC-SHA256 signature is appended (JWT-ish 'payload.sig')
|
||||
"""
|
||||
if not values:
|
||||
return None
|
||||
payload = {"v": values, "d": desc_flags, "b": backward}
|
||||
return base64.urlsafe_b64encode(json.dumps(payload).encode()).decode()
|
||||
|
||||
def decode_cursor(token: str | None) -> tuple[list[Any] | None, bool] | tuple[None, bool]:
|
||||
payload = {
|
||||
"ver": 1,
|
||||
"v": values,
|
||||
"d": desc_flags,
|
||||
"b": bool(backward),
|
||||
}
|
||||
|
||||
body = json.dumps(payload, default=_json_default, separators=(",", ":"), sort_keys=True).encode("utf-8")
|
||||
token = _b64url_nopad_encode(body)
|
||||
|
||||
if secret:
|
||||
sig = hmac.new(secret, body, sha256).digest()
|
||||
token = f"{token}.{_b64url_nopad_encode(sig)}"
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def decode_cursor(
|
||||
token: str | None,
|
||||
*,
|
||||
secret: bytes | None = None,
|
||||
) -> tuple[list[Any] | None, list[bool] | None, bool]:
|
||||
"""
|
||||
Parse a cursor token. Returns (values, desc_flags, backward).
|
||||
|
||||
- Accepts legacy tokens lacking 'd' and returns desc_flags=None in that case.
|
||||
- If secret is provided, verifies HMAC and rejects tampered tokens.
|
||||
- On any parse failure, returns (None, None, False).
|
||||
"""
|
||||
if not token:
|
||||
return None, False
|
||||
return None, None, False
|
||||
|
||||
try:
|
||||
obj = json.loads(base64.urlsafe_b64decode(token.encode()).decode())
|
||||
# Split payload.sig if signed
|
||||
if "." in token:
|
||||
body_b64, sig_b64 = token.split(".", 1)
|
||||
body = _b64url_nopad_decode(body_b64)
|
||||
if secret is None:
|
||||
# Caller didn’t ask for verification; still parse but don’t trust.
|
||||
pass
|
||||
else:
|
||||
expected = hmac.new(secret, body, sha256).digest()
|
||||
actual = _b64url_nopad_decode(sig_b64)
|
||||
if not hmac.compare_digest(expected, actual):
|
||||
return None, None, False
|
||||
else:
|
||||
body = _b64url_nopad_decode(token)
|
||||
|
||||
obj = json.loads(body.decode("utf-8"))
|
||||
|
||||
# Versioning. If we ever change fields, branch here.
|
||||
ver = int(obj.get("ver", 0))
|
||||
if ver not in (0, 1):
|
||||
return None, None, False
|
||||
|
||||
vals = obj.get("v")
|
||||
backward = bool(obj.get("b", False))
|
||||
if isinstance(vals, list):
|
||||
return vals, backward
|
||||
|
||||
# desc_flags may be absent in legacy payloads (ver 0)
|
||||
desc = obj.get("d")
|
||||
if not isinstance(vals, list):
|
||||
return None, None, False
|
||||
if desc is not None and not (isinstance(desc, list) and all(isinstance(x, bool) for x in desc)):
|
||||
# Ignore weird 'd' types rather than crashing
|
||||
desc = None
|
||||
|
||||
return vals, desc, backward
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
return None, False
|
||||
# Be tolerant on decode: treat as no-cursor.
|
||||
return None, None, False
|
||||
|
|
|
|||
|
|
@ -1,23 +1,40 @@
|
|||
from flask import Blueprint, jsonify, request
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Blueprint, jsonify, request, abort
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from crudkit.api._cursor import encode_cursor, decode_cursor
|
||||
from crudkit.core.service import _is_truthy
|
||||
|
||||
|
||||
def _bool_param(d: dict[str, str], key: str, default: bool) -> bool:
|
||||
return _is_truthy(d.get(key, "1" if default else "0"))
|
||||
|
||||
|
||||
def _safe_int(value: str | None, default: int) -> int:
|
||||
try:
|
||||
return int(value) if value is not None else default
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
|
||||
def _link_with_params(base_url: str, **params) -> str:
|
||||
# Filter out None, encode safely
|
||||
q = {k: v for k, v in params.items() if v is not None}
|
||||
return f"{base_url}?{urlencode(q)}"
|
||||
|
||||
|
||||
def generate_crud_blueprint(model, service):
|
||||
bp = Blueprint(model.__name__.lower(), __name__)
|
||||
|
||||
@bp.get('/')
|
||||
@bp.get("/")
|
||||
def list_items():
|
||||
# Work from a copy so we don't mutate request.args
|
||||
args = request.args.to_dict(flat=True)
|
||||
|
||||
# legacy detection
|
||||
legacy_offset = "offset" in args or "page" in args
|
||||
|
||||
# sane limit default
|
||||
try:
|
||||
limit = int(args.get("limit", 50))
|
||||
except Exception:
|
||||
limit = 50
|
||||
limit = _safe_int(args.get("limit"), 50)
|
||||
args["limit"] = limit
|
||||
|
||||
if legacy_offset:
|
||||
|
|
@ -25,17 +42,23 @@ def generate_crud_blueprint(model, service):
|
|||
items = service.list(args)
|
||||
return jsonify([obj.as_dict() for obj in items])
|
||||
|
||||
# New behavior: keyset seek with cursors
|
||||
key, backward = decode_cursor(args.get("cursor"))
|
||||
# New behavior: keyset pagination with cursors
|
||||
cursor_token = args.get("cursor")
|
||||
key, desc_from_cursor, backward = decode_cursor(cursor_token)
|
||||
|
||||
window = service.seek_window(
|
||||
args,
|
||||
key=key,
|
||||
backward=backward,
|
||||
include_total=_is_truthy(args.get("include_total", "1")),
|
||||
include_total=_bool_param(args, "include_total", True),
|
||||
)
|
||||
|
||||
desc_flags = list(window.order.desc)
|
||||
# Prefer the order actually used by the window; fall back to desc_from_cursor if needed.
|
||||
try:
|
||||
desc_flags = list(window.order.desc)
|
||||
except Exception:
|
||||
desc_flags = desc_from_cursor or []
|
||||
|
||||
body = {
|
||||
"items": [obj.as_dict() for obj in window.items],
|
||||
"limit": window.limit,
|
||||
|
|
@ -45,46 +68,60 @@ def generate_crud_blueprint(model, service):
|
|||
}
|
||||
|
||||
resp = jsonify(body)
|
||||
# Optional Link header
|
||||
links = []
|
||||
|
||||
# Preserve user’s other query params like include_total, filters, sorts, etc.
|
||||
base_url = request.base_url
|
||||
base_params = {k: v for k, v in args.items() if k not in {"cursor"}}
|
||||
link_parts = []
|
||||
if body["next_cursor"]:
|
||||
links.append(f'<{request.base_url}?cursor={body["next_cursor"]}&limit={window.limit}>; rel="next"')
|
||||
link_parts.append(
|
||||
f'<{_link_with_params(base_url, **base_params, cursor=body["next_cursor"])}>; rel="next"'
|
||||
)
|
||||
if body["prev_cursor"]:
|
||||
links.append(f'<{request.base_url}?cursor={body["prev_cursor"]}&limit={window.limit}>; rel="prev"')
|
||||
if links:
|
||||
resp.headers["Link"] = ", ".join(links)
|
||||
link_parts.append(
|
||||
f'<{_link_with_params(base_url, **base_params, cursor=body["prev_cursor"])}>; rel="prev"'
|
||||
)
|
||||
if link_parts:
|
||||
resp.headers["Link"] = ", ".join(link_parts)
|
||||
return resp
|
||||
|
||||
@bp.get('/<int:id>')
|
||||
@bp.get("/<int:id>")
|
||||
def get_item(id):
|
||||
item = service.get(id, request.args)
|
||||
try:
|
||||
item = service.get(id, request.args)
|
||||
if item is None:
|
||||
abort(404)
|
||||
return jsonify(item.as_dict())
|
||||
except Exception as e:
|
||||
return jsonify({"status": "error", "error": str(e)})
|
||||
# Could be validation, auth, or just you forgetting an index again
|
||||
return jsonify({"status": "error", "error": str(e)}), 400
|
||||
|
||||
@bp.post('/')
|
||||
@bp.post("/")
|
||||
def create_item():
|
||||
obj = service.create(request.json)
|
||||
payload = request.get_json(silent=True) or {}
|
||||
try:
|
||||
return jsonify(obj.as_dict())
|
||||
obj = service.create(payload)
|
||||
return jsonify(obj.as_dict()), 201
|
||||
except Exception as e:
|
||||
return jsonify({"status": "error", "error": str(e)})
|
||||
return jsonify({"status": "error", "error": str(e)}), 400
|
||||
|
||||
@bp.patch('/<int:id>')
|
||||
@bp.patch("/<int:id>")
|
||||
def update_item(id):
|
||||
obj = service.update(id, request.json)
|
||||
payload = request.get_json(silent=True) or {}
|
||||
try:
|
||||
obj = service.update(id, payload)
|
||||
return jsonify(obj.as_dict())
|
||||
except Exception as e:
|
||||
return jsonify({"status": "error", "error": str(e)})
|
||||
# 404 if not found, 400 if validation. Your service can throw specific exceptions if you ever feel like being professional.
|
||||
return jsonify({"status": "error", "error": str(e)}), 400
|
||||
|
||||
@bp.delete('/<int:id>')
|
||||
@bp.delete("/<int:id>")
|
||||
def delete_item(id):
|
||||
service.delete(id)
|
||||
try:
|
||||
return jsonify({"status": "success"}), 204
|
||||
service.delete(id)
|
||||
# 204 means "no content" so don't send any.
|
||||
return ("", 204)
|
||||
except Exception as e:
|
||||
return jsonify({"status": "error", "error": str(e)})
|
||||
return jsonify({"status": "error", "error": str(e)}), 400
|
||||
|
||||
return bp
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue