Optimizations and refactoring.

This commit is contained in:
Yaro Kasear 2025-09-24 09:53:25 -05:00
parent 94837e1b6f
commit a0ee1caeb7
4 changed files with 273 additions and 85 deletions

View file

@ -1,21 +1,135 @@
import base64, json
from typing import Any
# crudkit/api/_cursor.py
def encode_cursor(values: list[Any] | None, desc_flags: list[bool], backward: bool) -> str | None:
from __future__ import annotations
import base64
import dataclasses
import datetime as _dt
import decimal as _dec
import hmac
import json
import typing as _t
from hashlib import sha256
Any = _t.Any
@dataclasses.dataclass(frozen=True)
class Cursor:
values: list[Any]
desc_flags: list[bool]
backward: bool
version: int = 1
def _json_default(o: Any) -> Any:
# Keep it boring and predictable.
if isinstance(o, (_dt.datetime, _dt.date, _dt.time)):
# ISO is good enough; assume UTC-aware datetimes are already normalized upstream.
return o.isoformat()
if isinstance(o, _dec.Decimal):
return str(o)
# UUIDs, Enums, etc.
if hasattr(o, "__str__"):
return str(o)
raise TypeError(f"Unsupported cursor value type: {type(o)!r}")
def _b64url_nopad_encode(b: bytes) -> str:
return base64.urlsafe_b64encode(b).rstrip(b"=").decode("ascii")
def _b64url_nopad_decode(s: str) -> bytes:
# Restore padding for strict decode
pad = "=" * (-len(s) % 4)
return base64.urlsafe_b64decode((s + pad).encode("ascii"))
def encode_cursor(
values: list[Any] | None,
desc_flags: list[bool],
backward: bool,
*,
secret: bytes | None = None,
) -> str | None:
"""
Create an opaque, optionally signed cursor token.
- values: keyset values for the last/first row
- desc_flags: per-key descending flags (same length/order as sort keys)
- backward: whether the window direction is backward
- secret: if set, an HMAC-SHA256 signature is appended (JWT-ish 'payload.sig')
"""
if not values:
return None
payload = {"v": values, "d": desc_flags, "b": backward}
return base64.urlsafe_b64encode(json.dumps(payload).encode()).decode()
def decode_cursor(token: str | None) -> tuple[list[Any] | None, bool] | tuple[None, bool]:
payload = {
"ver": 1,
"v": values,
"d": desc_flags,
"b": bool(backward),
}
body = json.dumps(payload, default=_json_default, separators=(",", ":"), sort_keys=True).encode("utf-8")
token = _b64url_nopad_encode(body)
if secret:
sig = hmac.new(secret, body, sha256).digest()
token = f"{token}.{_b64url_nopad_encode(sig)}"
return token
def decode_cursor(
token: str | None,
*,
secret: bytes | None = None,
) -> tuple[list[Any] | None, list[bool] | None, bool]:
"""
Parse a cursor token. Returns (values, desc_flags, backward).
- Accepts legacy tokens lacking 'd' and returns desc_flags=None in that case.
- If secret is provided, verifies HMAC and rejects tampered tokens.
- On any parse failure, returns (None, None, False).
"""
if not token:
return None, False
return None, None, False
try:
obj = json.loads(base64.urlsafe_b64decode(token.encode()).decode())
# Split payload.sig if signed
if "." in token:
body_b64, sig_b64 = token.split(".", 1)
body = _b64url_nopad_decode(body_b64)
if secret is None:
# Caller didnt ask for verification; still parse but dont trust.
pass
else:
expected = hmac.new(secret, body, sha256).digest()
actual = _b64url_nopad_decode(sig_b64)
if not hmac.compare_digest(expected, actual):
return None, None, False
else:
body = _b64url_nopad_decode(token)
obj = json.loads(body.decode("utf-8"))
# Versioning. If we ever change fields, branch here.
ver = int(obj.get("ver", 0))
if ver not in (0, 1):
return None, None, False
vals = obj.get("v")
backward = bool(obj.get("b", False))
if isinstance(vals, list):
return vals, backward
# desc_flags may be absent in legacy payloads (ver 0)
desc = obj.get("d")
if not isinstance(vals, list):
return None, None, False
if desc is not None and not (isinstance(desc, list) and all(isinstance(x, bool) for x in desc)):
# Ignore weird 'd' types rather than crashing
desc = None
return vals, desc, backward
except Exception:
pass
return None, False
# Be tolerant on decode: treat as no-cursor.
return None, None, False