Optimizations and refactoring.

This commit is contained in:
Yaro Kasear 2025-09-24 09:53:25 -05:00
parent 94837e1b6f
commit a0ee1caeb7
4 changed files with 273 additions and 85 deletions

View file

@ -1,21 +1,135 @@
import base64, json # crudkit/api/_cursor.py
from typing import Any
def encode_cursor(values: list[Any] | None, desc_flags: list[bool], backward: bool) -> str | None: from __future__ import annotations
import base64
import dataclasses
import datetime as _dt
import decimal as _dec
import hmac
import json
import typing as _t
from hashlib import sha256
Any = _t.Any
@dataclasses.dataclass(frozen=True)
class Cursor:
values: list[Any]
desc_flags: list[bool]
backward: bool
version: int = 1
def _json_default(o: Any) -> Any:
# Keep it boring and predictable.
if isinstance(o, (_dt.datetime, _dt.date, _dt.time)):
# ISO is good enough; assume UTC-aware datetimes are already normalized upstream.
return o.isoformat()
if isinstance(o, _dec.Decimal):
return str(o)
# UUIDs, Enums, etc.
if hasattr(o, "__str__"):
return str(o)
raise TypeError(f"Unsupported cursor value type: {type(o)!r}")
def _b64url_nopad_encode(b: bytes) -> str:
return base64.urlsafe_b64encode(b).rstrip(b"=").decode("ascii")
def _b64url_nopad_decode(s: str) -> bytes:
# Restore padding for strict decode
pad = "=" * (-len(s) % 4)
return base64.urlsafe_b64decode((s + pad).encode("ascii"))
def encode_cursor(
values: list[Any] | None,
desc_flags: list[bool],
backward: bool,
*,
secret: bytes | None = None,
) -> str | None:
"""
Create an opaque, optionally signed cursor token.
- values: keyset values for the last/first row
- desc_flags: per-key descending flags (same length/order as sort keys)
- backward: whether the window direction is backward
- secret: if set, an HMAC-SHA256 signature is appended (JWT-ish 'payload.sig')
"""
if not values: if not values:
return None return None
payload = {"v": values, "d": desc_flags, "b": backward}
return base64.urlsafe_b64encode(json.dumps(payload).encode()).decode()
def decode_cursor(token: str | None) -> tuple[list[Any] | None, bool] | tuple[None, bool]: payload = {
"ver": 1,
"v": values,
"d": desc_flags,
"b": bool(backward),
}
body = json.dumps(payload, default=_json_default, separators=(",", ":"), sort_keys=True).encode("utf-8")
token = _b64url_nopad_encode(body)
if secret:
sig = hmac.new(secret, body, sha256).digest()
token = f"{token}.{_b64url_nopad_encode(sig)}"
return token
def decode_cursor(
token: str | None,
*,
secret: bytes | None = None,
) -> tuple[list[Any] | None, list[bool] | None, bool]:
"""
Parse a cursor token. Returns (values, desc_flags, backward).
- Accepts legacy tokens lacking 'd' and returns desc_flags=None in that case.
- If secret is provided, verifies HMAC and rejects tampered tokens.
- On any parse failure, returns (None, None, False).
"""
if not token: if not token:
return None, False return None, None, False
try: try:
obj = json.loads(base64.urlsafe_b64decode(token.encode()).decode()) # Split payload.sig if signed
if "." in token:
body_b64, sig_b64 = token.split(".", 1)
body = _b64url_nopad_decode(body_b64)
if secret is None:
# Caller didnt ask for verification; still parse but dont trust.
pass
else:
expected = hmac.new(secret, body, sha256).digest()
actual = _b64url_nopad_decode(sig_b64)
if not hmac.compare_digest(expected, actual):
return None, None, False
else:
body = _b64url_nopad_decode(token)
obj = json.loads(body.decode("utf-8"))
# Versioning. If we ever change fields, branch here.
ver = int(obj.get("ver", 0))
if ver not in (0, 1):
return None, None, False
vals = obj.get("v") vals = obj.get("v")
backward = bool(obj.get("b", False)) backward = bool(obj.get("b", False))
if isinstance(vals, list):
return vals, backward # desc_flags may be absent in legacy payloads (ver 0)
desc = obj.get("d")
if not isinstance(vals, list):
return None, None, False
if desc is not None and not (isinstance(desc, list) and all(isinstance(x, bool) for x in desc)):
# Ignore weird 'd' types rather than crashing
desc = None
return vals, desc, backward
except Exception: except Exception:
pass # Be tolerant on decode: treat as no-cursor.
return None, False return None, None, False

View file

@ -1,23 +1,40 @@
from flask import Blueprint, jsonify, request from __future__ import annotations
from flask import Blueprint, jsonify, request, abort
from urllib.parse import urlencode
from crudkit.api._cursor import encode_cursor, decode_cursor from crudkit.api._cursor import encode_cursor, decode_cursor
from crudkit.core.service import _is_truthy from crudkit.core.service import _is_truthy
def _bool_param(d: dict[str, str], key: str, default: bool) -> bool:
return _is_truthy(d.get(key, "1" if default else "0"))
def _safe_int(value: str | None, default: int) -> int:
try:
return int(value) if value is not None else default
except Exception:
return default
def _link_with_params(base_url: str, **params) -> str:
# Filter out None, encode safely
q = {k: v for k, v in params.items() if v is not None}
return f"{base_url}?{urlencode(q)}"
def generate_crud_blueprint(model, service): def generate_crud_blueprint(model, service):
bp = Blueprint(model.__name__.lower(), __name__) bp = Blueprint(model.__name__.lower(), __name__)
@bp.get('/') @bp.get("/")
def list_items(): def list_items():
# Work from a copy so we don't mutate request.args
args = request.args.to_dict(flat=True) args = request.args.to_dict(flat=True)
# legacy detection
legacy_offset = "offset" in args or "page" in args legacy_offset = "offset" in args or "page" in args
# sane limit default limit = _safe_int(args.get("limit"), 50)
try:
limit = int(args.get("limit", 50))
except Exception:
limit = 50
args["limit"] = limit args["limit"] = limit
if legacy_offset: if legacy_offset:
@ -25,17 +42,23 @@ def generate_crud_blueprint(model, service):
items = service.list(args) items = service.list(args)
return jsonify([obj.as_dict() for obj in items]) return jsonify([obj.as_dict() for obj in items])
# New behavior: keyset seek with cursors # New behavior: keyset pagination with cursors
key, backward = decode_cursor(args.get("cursor")) cursor_token = args.get("cursor")
key, desc_from_cursor, backward = decode_cursor(cursor_token)
window = service.seek_window( window = service.seek_window(
args, args,
key=key, key=key,
backward=backward, backward=backward,
include_total=_is_truthy(args.get("include_total", "1")), include_total=_bool_param(args, "include_total", True),
) )
# Prefer the order actually used by the window; fall back to desc_from_cursor if needed.
try:
desc_flags = list(window.order.desc) desc_flags = list(window.order.desc)
except Exception:
desc_flags = desc_from_cursor or []
body = { body = {
"items": [obj.as_dict() for obj in window.items], "items": [obj.as_dict() for obj in window.items],
"limit": window.limit, "limit": window.limit,
@ -45,46 +68,60 @@ def generate_crud_blueprint(model, service):
} }
resp = jsonify(body) resp = jsonify(body)
# Optional Link header
links = [] # Preserve users other query params like include_total, filters, sorts, etc.
base_url = request.base_url
base_params = {k: v for k, v in args.items() if k not in {"cursor"}}
link_parts = []
if body["next_cursor"]: if body["next_cursor"]:
links.append(f'<{request.base_url}?cursor={body["next_cursor"]}&limit={window.limit}>; rel="next"') link_parts.append(
f'<{_link_with_params(base_url, **base_params, cursor=body["next_cursor"])}>; rel="next"'
)
if body["prev_cursor"]: if body["prev_cursor"]:
links.append(f'<{request.base_url}?cursor={body["prev_cursor"]}&limit={window.limit}>; rel="prev"') link_parts.append(
if links: f'<{_link_with_params(base_url, **base_params, cursor=body["prev_cursor"])}>; rel="prev"'
resp.headers["Link"] = ", ".join(links) )
if link_parts:
resp.headers["Link"] = ", ".join(link_parts)
return resp return resp
@bp.get('/<int:id>') @bp.get("/<int:id>")
def get_item(id): def get_item(id):
item = service.get(id, request.args)
try: try:
item = service.get(id, request.args)
if item is None:
abort(404)
return jsonify(item.as_dict()) return jsonify(item.as_dict())
except Exception as e: except Exception as e:
return jsonify({"status": "error", "error": str(e)}) # Could be validation, auth, or just you forgetting an index again
return jsonify({"status": "error", "error": str(e)}), 400
@bp.post('/') @bp.post("/")
def create_item(): def create_item():
obj = service.create(request.json) payload = request.get_json(silent=True) or {}
try: try:
return jsonify(obj.as_dict()) obj = service.create(payload)
return jsonify(obj.as_dict()), 201
except Exception as e: except Exception as e:
return jsonify({"status": "error", "error": str(e)}) return jsonify({"status": "error", "error": str(e)}), 400
@bp.patch('/<int:id>') @bp.patch("/<int:id>")
def update_item(id): def update_item(id):
obj = service.update(id, request.json) payload = request.get_json(silent=True) or {}
try: try:
obj = service.update(id, payload)
return jsonify(obj.as_dict()) return jsonify(obj.as_dict())
except Exception as e: except Exception as e:
return jsonify({"status": "error", "error": str(e)}) # 404 if not found, 400 if validation. Your service can throw specific exceptions if you ever feel like being professional.
return jsonify({"status": "error", "error": str(e)}), 400
@bp.delete('/<int:id>') @bp.delete("/<int:id>")
def delete_item(id): def delete_item(id):
service.delete(id)
try: try:
return jsonify({"status": "success"}), 204 service.delete(id)
# 204 means "no content" so don't send any.
return ("", 204)
except Exception as e: except Exception as e:
return jsonify({"status": "error", "error": str(e)}) return jsonify({"status": "error", "error": str(e)}), 400
return bp return bp

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
from sqlalchemy import and_, func, inspect, or_, text from sqlalchemy import and_, func, inspect, or_, text
from sqlalchemy.engine import Engine, Connection from sqlalchemy.engine import Engine, Connection
@ -10,6 +12,9 @@ from crudkit.core.types import OrderSpec, SeekWindow
from crudkit.backend import BackendInfo, make_backend_info from crudkit.backend import BackendInfo, make_backend_info
from crudkit.projection import compile_projection from crudkit.projection import compile_projection
import logging
log = logging.getLogger("crudkit.service")
def _is_rel(model_cls, name: str) -> bool: def _is_rel(model_cls, name: str) -> bool:
try: try:
prop = model_cls.__mapper__.relationships.get(name) prop = model_cls.__mapper__.relationships.get(name)
@ -56,7 +61,7 @@ class CRUDService(Generic[T]):
self.polymorphic = polymorphic self.polymorphic = polymorphic
self.supports_soft_delete = hasattr(model, 'is_deleted') self.supports_soft_delete = hasattr(model, 'is_deleted')
# Cache backend info once. If not provided, derive from session bind. # Cache backend info once. If not provided, derive from session bind.
bind = self.session.get_bind() bind = session_factory().get_bind()
eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind) eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind)
self.backend = backend or make_backend_info(eng) self.backend = backend or make_backend_info(eng)
@ -70,6 +75,11 @@ class CRUDService(Generic[T]):
return self.session.query(poly), poly return self.session.query(poly), poly
return self.session.query(self.model), self.model return self.session.query(self.model), self.model
def _apply_not_deleted(self, query, root_alias, params) -> Any:
if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")):
return query.filter(getattr(root_alias, "is_deleted") == False)
return query
def _extract_order_spec(self, root_alias, given_order_by): def _extract_order_spec(self, root_alias, given_order_by):
""" """
SQLAlchemy 2.x only: SQLAlchemy 2.x only:
@ -85,7 +95,7 @@ class CRUDService(Generic[T]):
for ob in given: for ob in given:
# Unwrap column if this is a UnaryExpression produced by .asc()/.desc() # Unwrap column if this is a UnaryExpression produced by .asc()/.desc()
elem = getattr(ob, "element", None) elem = getattr(ob, "element", None)
col = elem if elem is not None else ob # don't use "or" with SA expressions col = elem if elem is not None else ob
# Detect direction in SA 2.x # Detect direction in SA 2.x
is_desc = False is_desc = False
@ -103,27 +113,30 @@ class CRUDService(Generic[T]):
return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags)) return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags))
def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool): def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool):
"""
Build lexicographic predicate for keyset seek.
For backward traversal, import comparisons.
"""
if not key_vals: if not key_vals:
return None return None
conds = [] conds = []
for i, col in enumerate(spec.cols): for i, col in enumerate(spec.cols):
# If NULLs possible, normalize for comparison. Example using coalesce to a sentinel:
# sent_col = func.coalesce(col, literal("-∞"))
sent_col = col
ties = [spec.cols[j] == key_vals[j] for j in range(i)] ties = [spec.cols[j] == key_vals[j] for j in range(i)]
is_desc = spec.desc[i] is_desc = spec.desc[i]
if not backward: if not backward:
op = col < key_vals[i] if is_desc else col > key_vals[i] op = (sent_col < key_vals[i]) if is_desc else (sent_col > key_vals[i])
else: else:
op = col > key_vals[i] if is_desc else col < key_vals[i] op = (sent_col > key_vals[i]) if is_desc else (sent_col < key_vals[i])
conds.append(and_(*ties, op)) conds.append(and_(*ties, op))
return or_(*conds) return or_(*conds)
def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]: def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]:
out = [] out = []
for c in spec.cols: for c in spec.cols:
# Only simple mapped columns supported for key pluck
key = getattr(c, "key", None) or getattr(c, "name", None) key = getattr(c, "key", None) or getattr(c, "name", None)
if key is None or not hasattr(obj, key):
raise ValueError("Order includes non-mapped or related column; cannot pluck cursor key from row object.")
out.append(getattr(obj, key)) out.append(getattr(obj, key))
return out return out
@ -142,6 +155,7 @@ class CRUDService(Generic[T]):
- forward/backward seek via `key` and `backward` - forward/backward seek via `key` and `backward`
Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total. Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total.
""" """
session = self.session
fields = list((params or {}).get("fields", [])) fields = list((params or {}).get("fields", []))
expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], []) expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], [])
query, root_alias = self.get_query() query, root_alias = self.get_query()
@ -156,8 +170,9 @@ class CRUDService(Generic[T]):
root_fields, rel_field_names, root_field_names = spec.parse_fields() root_fields, rel_field_names, root_field_names = spec.parse_fields()
# Soft delete filter # Soft delete filter
if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")): # if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")):
query = query.filter(getattr(root_alias, "is_deleted") == False) # query = query.filter(getattr(root_alias, "is_deleted") == False)
query = self._apply_not_deleted(query, root_alias, params)
# Parse filters first # Parse filters first
if filters: if filters:
@ -165,6 +180,8 @@ class CRUDService(Generic[T]):
# Includes + joins (so relationship fields like brand.name, location.label work) # Includes + joins (so relationship fields like brand.name, location.label work)
spec.parse_includes() spec.parse_includes()
join_paths = tuple(spec.get_join_paths())
for _, relationship_attr, target_alias in spec.get_join_paths(): for _, relationship_attr, target_alias in spec.get_join_paths():
rel_attr = cast(InstrumentedAttribute, relationship_attr) rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias) target = cast(Any, target_alias)
@ -178,8 +195,12 @@ class CRUDService(Generic[T]):
# Order + limit # Order + limit
order_spec = self._extract_order_spec(root_alias, order_by) # SA2-only helper order_spec = self._extract_order_spec(root_alias, order_by) # SA2-only helper
limit, _ = spec.parse_pagination() limit, _ = spec.parse_pagination()
if not limit or limit <= 0: if limit is None:
limit = 50 # sensible default effective_limit = 50
elif limit == 0:
effective_limit = None
else:
effective_limit = limit
# Keyset predicate # Keyset predicate
if key: if key:
@ -189,18 +210,19 @@ class CRUDService(Generic[T]):
# Apply ordering. For backward, invert SQL order then reverse in-memory for display. # Apply ordering. For backward, invert SQL order then reverse in-memory for display.
if not backward: if not backward:
clauses = [] clauses = [(c.desc() if is_desc else c.asc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)]
for col, is_desc in zip(order_spec.cols, order_spec.desc): query = query.order_by(*clauses)
clauses.append(col.desc() if is_desc else col.asc()) if effective_limit is not None:
query = query.order_by(*clauses).limit(limit) query = query.limit(effective_limit)
items = query.all() items = query.all()
else: else:
inv_clauses = [] inv_clauses = [(c.asc() if is_desc else c.desc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)]
for col, is_desc in zip(order_spec.cols, order_spec.desc): query = query.order_by(*inv_clauses)
inv_clauses.append(col.asc() if is_desc else col.desc()) if effective_limit is not None:
query = query.order_by(*inv_clauses).limit(limit) query = query.limit(effective_limit)
items = list(reversed(query.all())) items = list(reversed(query.all()))
# Tag projection so your renderer knows what fields were requested # Tag projection so your renderer knows what fields were requested
if expanded_fields: if expanded_fields:
proj = list(expanded_fields) proj = list(expanded_fields)
@ -231,23 +253,27 @@ class CRUDService(Generic[T]):
# Optional total thats safe under JOINs (COUNT DISTINCT ids) # Optional total thats safe under JOINs (COUNT DISTINCT ids)
total = None total = None
if include_total: if include_total:
base = self.session.query(getattr(root_alias, "id")) base = session.query(getattr(root_alias, "id"))
if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")): base = self._apply_not_deleted(base, root_alias, params)
base = base.filter(getattr(root_alias, "is_deleted") == False)
if filters: if filters:
base = base.filter(*filters) base = base.filter(*filters)
# replicate the same joins used above for _, relationship_attr, target_alias in join_paths: # reuse
for _, relationship_attr, target_alias in spec.get_join_paths():
rel_attr = cast(InstrumentedAttribute, relationship_attr) rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias) target = cast(Any, target_alias)
base = base.join(target, rel_attr.of_type(target), isouter=True) base = base.join(target, rel_attr.of_type(target), isouter=True)
total = self.session.query(func.count()).select_from(base.order_by(None).distinct().subquery()).scalar() or 0 total = session.query(func.count()).select_from(
print(f"!!! QUERY !!! -> {str(query)}") base.order_by(None).distinct().subquery()
).scalar() or 0
window_limit_for_body = 0 if effective_limit is None and limit == 0 else (effective_limit or 50)
if log.isEnabledFor(logging.DEBUG):
log.debug("QUERY: %s", str(query))
from crudkit.core.types import SeekWindow # avoid circulars at module top from crudkit.core.types import SeekWindow # avoid circulars at module top
return SeekWindow( return SeekWindow(
items=items, items=items,
limit=limit, limit=window_limit_for_body,
first_key=first_key, first_key=first_key,
last_key=last_key, last_key=last_key,
order=order_spec, order=order_spec,
@ -342,7 +368,9 @@ class CRUDService(Generic[T]):
except Exception: except Exception:
pass pass
print(f"!!! QUERY !!! -> {str(query)}") if log.isEnabledFor(logging.DEBUG):
log.debug("QUERY: %s", str(query))
return obj or None return obj or None
def list(self, params=None) -> list[T]: def list(self, params=None) -> list[T]:
@ -422,42 +450,51 @@ class CRUDService(Generic[T]):
except Exception: except Exception:
pass pass
print(f"!!! QUERY !!! -> {str(query)}") if log.isEnabledFor(logging.DEBUG):
log.debug("QUERY: %s", str(query))
return rows return rows
def create(self, data: dict, actor=None) -> T: def create(self, data: dict, actor=None) -> T:
session = self.session
obj = self.model(**data) obj = self.model(**data)
self.session.add(obj) session.add(obj)
self.session.commit() session.commit()
self._log_version("create", obj, actor) self._log_version("create", obj, actor)
return obj return obj
def update(self, id: int, data: dict, actor=None) -> T: def update(self, id: int, data: dict, actor=None) -> T:
session = self.session
obj = self.get(id) obj = self.get(id)
if not obj: if not obj:
raise ValueError(f"{self.model.__name__} with ID {id} not found.") raise ValueError(f"{self.model.__name__} with ID {id} not found.")
valid_fields = {c.name for c in self.model.__table__.columns} valid_fields = {c.name for c in self.model.__table__.columns}
unknown = set(data) - valid_fields
if unknown:
raise ValueError(f"Unknown fields: {', '.join(sorted(unknown))}")
for k, v in data.items(): for k, v in data.items():
if k in valid_fields: if k in valid_fields:
setattr(obj, k, v) setattr(obj, k, v)
self.session.commit() session.commit()
self._log_version("update", obj, actor) self._log_version("update", obj, actor)
return obj return obj
def delete(self, id: int, hard: bool = False, actor = False): def delete(self, id: int, hard: bool = False, actor = None):
obj = self.session.get(self.model, id) session = self.session
obj = session.get(self.model, id)
if not obj: if not obj:
return None return None
if hard or not self.supports_soft_delete: if hard or not self.supports_soft_delete:
self.session.delete(obj) session.delete(obj)
else: else:
soft = cast(_SoftDeletable, obj) soft = cast(_SoftDeletable, obj)
soft.is_deleted = True soft.is_deleted = True
self.session.commit() session.commit()
self._log_version("delete", obj, actor) self._log_version("delete", obj, actor)
return obj return obj
def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict = {}): def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict | None = None):
session = self.session
try: try:
data = obj.as_dict() data = obj.as_dict()
except Exception: except Exception:
@ -470,5 +507,5 @@ class CRUDService(Generic[T]):
actor=str(actor) if actor else None, actor=str(actor) if actor else None,
meta=metadata meta=metadata
) )
self.session.add(version) session.add(version)
self.session.commit() session.commit()

View file

@ -91,7 +91,7 @@ def init_listing_routes(app):
] ]
limit = int(request.args.get("limit", 15)) limit = int(request.args.get("limit", 15))
cursor = request.args.get("cursor") cursor = request.args.get("cursor")
key, backward = decode_cursor(cursor) key, _desc, backward = decode_cursor(cursor)
service = crudkit.crud.get_service(cls) service = crudkit.crud.get_service(cls)
window = service.seek_window(spec | {"limit": limit}, key=key, backward=backward, include_total=True) window = service.seek_window(spec | {"limit": limit}, key=key, backward=backward, include_total=True)