135 lines
3.9 KiB
Python
135 lines
3.9 KiB
Python
# crudkit/api/_cursor.py
|
||
|
||
from __future__ import annotations
|
||
|
||
import base64
|
||
import dataclasses
|
||
import datetime as _dt
|
||
import decimal as _dec
|
||
import hmac
|
||
import json
|
||
import typing as _t
|
||
from hashlib import sha256
|
||
|
||
Any = _t.Any
|
||
|
||
@dataclasses.dataclass(frozen=True)
|
||
class Cursor:
|
||
values: list[Any]
|
||
desc_flags: list[bool]
|
||
backward: bool
|
||
version: int = 1
|
||
|
||
|
||
def _json_default(o: Any) -> Any:
|
||
# Keep it boring and predictable.
|
||
if isinstance(o, (_dt.datetime, _dt.date, _dt.time)):
|
||
# ISO is good enough; assume UTC-aware datetimes are already normalized upstream.
|
||
return o.isoformat()
|
||
if isinstance(o, _dec.Decimal):
|
||
return str(o)
|
||
# UUIDs, Enums, etc.
|
||
if hasattr(o, "__str__"):
|
||
return str(o)
|
||
raise TypeError(f"Unsupported cursor value type: {type(o)!r}")
|
||
|
||
|
||
def _b64url_nopad_encode(b: bytes) -> str:
|
||
return base64.urlsafe_b64encode(b).rstrip(b"=").decode("ascii")
|
||
|
||
|
||
def _b64url_nopad_decode(s: str) -> bytes:
|
||
# Restore padding for strict decode
|
||
pad = "=" * (-len(s) % 4)
|
||
return base64.urlsafe_b64decode((s + pad).encode("ascii"))
|
||
|
||
|
||
def encode_cursor(
|
||
values: list[Any] | None,
|
||
desc_flags: list[bool],
|
||
backward: bool,
|
||
*,
|
||
secret: bytes | None = None,
|
||
) -> str | None:
|
||
"""
|
||
Create an opaque, optionally signed cursor token.
|
||
|
||
- values: keyset values for the last/first row
|
||
- desc_flags: per-key descending flags (same length/order as sort keys)
|
||
- backward: whether the window direction is backward
|
||
- secret: if set, an HMAC-SHA256 signature is appended (JWT-ish 'payload.sig')
|
||
"""
|
||
if not values:
|
||
return None
|
||
|
||
payload = {
|
||
"ver": 1,
|
||
"v": values,
|
||
"d": desc_flags,
|
||
"b": bool(backward),
|
||
}
|
||
|
||
body = json.dumps(payload, default=_json_default, separators=(",", ":"), sort_keys=True).encode("utf-8")
|
||
token = _b64url_nopad_encode(body)
|
||
|
||
if secret:
|
||
sig = hmac.new(secret, body, sha256).digest()
|
||
token = f"{token}.{_b64url_nopad_encode(sig)}"
|
||
|
||
return token
|
||
|
||
|
||
def decode_cursor(
|
||
token: str | None,
|
||
*,
|
||
secret: bytes | None = None,
|
||
) -> tuple[list[Any] | None, list[bool] | None, bool]:
|
||
"""
|
||
Parse a cursor token. Returns (values, desc_flags, backward).
|
||
|
||
- Accepts legacy tokens lacking 'd' and returns desc_flags=None in that case.
|
||
- If secret is provided, verifies HMAC and rejects tampered tokens.
|
||
- On any parse failure, returns (None, None, False).
|
||
"""
|
||
if not token:
|
||
return None, None, False
|
||
|
||
try:
|
||
# Split payload.sig if signed
|
||
if "." in token:
|
||
body_b64, sig_b64 = token.split(".", 1)
|
||
body = _b64url_nopad_decode(body_b64)
|
||
if secret is None:
|
||
# Caller didn’t ask for verification; still parse but don’t trust.
|
||
pass
|
||
else:
|
||
expected = hmac.new(secret, body, sha256).digest()
|
||
actual = _b64url_nopad_decode(sig_b64)
|
||
if not hmac.compare_digest(expected, actual):
|
||
return None, None, False
|
||
else:
|
||
body = _b64url_nopad_decode(token)
|
||
|
||
obj = json.loads(body.decode("utf-8"))
|
||
|
||
# Versioning. If we ever change fields, branch here.
|
||
ver = int(obj.get("ver", 0))
|
||
if ver not in (0, 1):
|
||
return None, None, False
|
||
|
||
vals = obj.get("v")
|
||
backward = bool(obj.get("b", False))
|
||
|
||
# desc_flags may be absent in legacy payloads (ver 0)
|
||
desc = obj.get("d")
|
||
if not isinstance(vals, list):
|
||
return None, None, False
|
||
if desc is not None and not (isinstance(desc, list) and all(isinstance(x, bool) for x in desc)):
|
||
# Ignore weird 'd' types rather than crashing
|
||
desc = None
|
||
|
||
return vals, desc, backward
|
||
|
||
except Exception:
|
||
# Be tolerant on decode: treat as no-cursor.
|
||
return None, None, False
|