From a0ee1caeb71fb321924aeb22c8271fe953e5dc94 Mon Sep 17 00:00:00 2001 From: Yaro Kasear Date: Wed, 24 Sep 2025 09:53:25 -0500 Subject: [PATCH] Optimizations and refactoring. --- crudkit/api/_cursor.py | 138 ++++++++++++++++++++++++++++++++---- crudkit/api/flask_api.py | 101 +++++++++++++++++--------- crudkit/core/service.py | 117 +++++++++++++++++++----------- inventory/routes/listing.py | 2 +- 4 files changed, 273 insertions(+), 85 deletions(-) diff --git a/crudkit/api/_cursor.py b/crudkit/api/_cursor.py index 3b63cfd..a2a80c2 100644 --- a/crudkit/api/_cursor.py +++ b/crudkit/api/_cursor.py @@ -1,21 +1,135 @@ -import base64, json -from typing import Any +# crudkit/api/_cursor.py -def encode_cursor(values: list[Any] | None, desc_flags: list[bool], backward: bool) -> str | None: +from __future__ import annotations + +import base64 +import dataclasses +import datetime as _dt +import decimal as _dec +import hmac +import json +import typing as _t +from hashlib import sha256 + +Any = _t.Any + +@dataclasses.dataclass(frozen=True) +class Cursor: + values: list[Any] + desc_flags: list[bool] + backward: bool + version: int = 1 + + +def _json_default(o: Any) -> Any: + # Keep it boring and predictable. + if isinstance(o, (_dt.datetime, _dt.date, _dt.time)): + # ISO is good enough; assume UTC-aware datetimes are already normalized upstream. + return o.isoformat() + if isinstance(o, _dec.Decimal): + return str(o) + # UUIDs, Enums, etc. + if hasattr(o, "__str__"): + return str(o) + raise TypeError(f"Unsupported cursor value type: {type(o)!r}") + + +def _b64url_nopad_encode(b: bytes) -> str: + return base64.urlsafe_b64encode(b).rstrip(b"=").decode("ascii") + + +def _b64url_nopad_decode(s: str) -> bytes: + # Restore padding for strict decode + pad = "=" * (-len(s) % 4) + return base64.urlsafe_b64decode((s + pad).encode("ascii")) + + +def encode_cursor( + values: list[Any] | None, + desc_flags: list[bool], + backward: bool, + *, + secret: bytes | None = None, +) -> str | None: + """ + Create an opaque, optionally signed cursor token. + + - values: keyset values for the last/first row + - desc_flags: per-key descending flags (same length/order as sort keys) + - backward: whether the window direction is backward + - secret: if set, an HMAC-SHA256 signature is appended (JWT-ish 'payload.sig') + """ if not values: return None - payload = {"v": values, "d": desc_flags, "b": backward} - return base64.urlsafe_b64encode(json.dumps(payload).encode()).decode() -def decode_cursor(token: str | None) -> tuple[list[Any] | None, bool] | tuple[None, bool]: + payload = { + "ver": 1, + "v": values, + "d": desc_flags, + "b": bool(backward), + } + + body = json.dumps(payload, default=_json_default, separators=(",", ":"), sort_keys=True).encode("utf-8") + token = _b64url_nopad_encode(body) + + if secret: + sig = hmac.new(secret, body, sha256).digest() + token = f"{token}.{_b64url_nopad_encode(sig)}" + + return token + + +def decode_cursor( + token: str | None, + *, + secret: bytes | None = None, +) -> tuple[list[Any] | None, list[bool] | None, bool]: + """ + Parse a cursor token. Returns (values, desc_flags, backward). + + - Accepts legacy tokens lacking 'd' and returns desc_flags=None in that case. + - If secret is provided, verifies HMAC and rejects tampered tokens. + - On any parse failure, returns (None, None, False). + """ if not token: - return None, False + return None, None, False + try: - obj = json.loads(base64.urlsafe_b64decode(token.encode()).decode()) + # Split payload.sig if signed + if "." in token: + body_b64, sig_b64 = token.split(".", 1) + body = _b64url_nopad_decode(body_b64) + if secret is None: + # Caller didn’t ask for verification; still parse but don’t trust. + pass + else: + expected = hmac.new(secret, body, sha256).digest() + actual = _b64url_nopad_decode(sig_b64) + if not hmac.compare_digest(expected, actual): + return None, None, False + else: + body = _b64url_nopad_decode(token) + + obj = json.loads(body.decode("utf-8")) + + # Versioning. If we ever change fields, branch here. + ver = int(obj.get("ver", 0)) + if ver not in (0, 1): + return None, None, False + vals = obj.get("v") backward = bool(obj.get("b", False)) - if isinstance(vals, list): - return vals, backward + + # desc_flags may be absent in legacy payloads (ver 0) + desc = obj.get("d") + if not isinstance(vals, list): + return None, None, False + if desc is not None and not (isinstance(desc, list) and all(isinstance(x, bool) for x in desc)): + # Ignore weird 'd' types rather than crashing + desc = None + + return vals, desc, backward + except Exception: - pass - return None, False + # Be tolerant on decode: treat as no-cursor. + return None, None, False diff --git a/crudkit/api/flask_api.py b/crudkit/api/flask_api.py index bf505b2..3b061e2 100644 --- a/crudkit/api/flask_api.py +++ b/crudkit/api/flask_api.py @@ -1,23 +1,40 @@ -from flask import Blueprint, jsonify, request +from __future__ import annotations + +from flask import Blueprint, jsonify, request, abort +from urllib.parse import urlencode from crudkit.api._cursor import encode_cursor, decode_cursor from crudkit.core.service import _is_truthy + +def _bool_param(d: dict[str, str], key: str, default: bool) -> bool: + return _is_truthy(d.get(key, "1" if default else "0")) + + +def _safe_int(value: str | None, default: int) -> int: + try: + return int(value) if value is not None else default + except Exception: + return default + + +def _link_with_params(base_url: str, **params) -> str: + # Filter out None, encode safely + q = {k: v for k, v in params.items() if v is not None} + return f"{base_url}?{urlencode(q)}" + + def generate_crud_blueprint(model, service): bp = Blueprint(model.__name__.lower(), __name__) - @bp.get('/') + @bp.get("/") def list_items(): + # Work from a copy so we don't mutate request.args args = request.args.to_dict(flat=True) - # legacy detection legacy_offset = "offset" in args or "page" in args - # sane limit default - try: - limit = int(args.get("limit", 50)) - except Exception: - limit = 50 + limit = _safe_int(args.get("limit"), 50) args["limit"] = limit if legacy_offset: @@ -25,17 +42,23 @@ def generate_crud_blueprint(model, service): items = service.list(args) return jsonify([obj.as_dict() for obj in items]) - # New behavior: keyset seek with cursors - key, backward = decode_cursor(args.get("cursor")) + # New behavior: keyset pagination with cursors + cursor_token = args.get("cursor") + key, desc_from_cursor, backward = decode_cursor(cursor_token) window = service.seek_window( args, key=key, backward=backward, - include_total=_is_truthy(args.get("include_total", "1")), + include_total=_bool_param(args, "include_total", True), ) - desc_flags = list(window.order.desc) + # Prefer the order actually used by the window; fall back to desc_from_cursor if needed. + try: + desc_flags = list(window.order.desc) + except Exception: + desc_flags = desc_from_cursor or [] + body = { "items": [obj.as_dict() for obj in window.items], "limit": window.limit, @@ -45,46 +68,60 @@ def generate_crud_blueprint(model, service): } resp = jsonify(body) - # Optional Link header - links = [] + + # Preserve user’s other query params like include_total, filters, sorts, etc. + base_url = request.base_url + base_params = {k: v for k, v in args.items() if k not in {"cursor"}} + link_parts = [] if body["next_cursor"]: - links.append(f'<{request.base_url}?cursor={body["next_cursor"]}&limit={window.limit}>; rel="next"') + link_parts.append( + f'<{_link_with_params(base_url, **base_params, cursor=body["next_cursor"])}>; rel="next"' + ) if body["prev_cursor"]: - links.append(f'<{request.base_url}?cursor={body["prev_cursor"]}&limit={window.limit}>; rel="prev"') - if links: - resp.headers["Link"] = ", ".join(links) + link_parts.append( + f'<{_link_with_params(base_url, **base_params, cursor=body["prev_cursor"])}>; rel="prev"' + ) + if link_parts: + resp.headers["Link"] = ", ".join(link_parts) return resp - @bp.get('/') + @bp.get("/") def get_item(id): - item = service.get(id, request.args) try: + item = service.get(id, request.args) + if item is None: + abort(404) return jsonify(item.as_dict()) except Exception as e: - return jsonify({"status": "error", "error": str(e)}) + # Could be validation, auth, or just you forgetting an index again + return jsonify({"status": "error", "error": str(e)}), 400 - @bp.post('/') + @bp.post("/") def create_item(): - obj = service.create(request.json) + payload = request.get_json(silent=True) or {} try: - return jsonify(obj.as_dict()) + obj = service.create(payload) + return jsonify(obj.as_dict()), 201 except Exception as e: - return jsonify({"status": "error", "error": str(e)}) + return jsonify({"status": "error", "error": str(e)}), 400 - @bp.patch('/') + @bp.patch("/") def update_item(id): - obj = service.update(id, request.json) + payload = request.get_json(silent=True) or {} try: + obj = service.update(id, payload) return jsonify(obj.as_dict()) except Exception as e: - return jsonify({"status": "error", "error": str(e)}) + # 404 if not found, 400 if validation. Your service can throw specific exceptions if you ever feel like being professional. + return jsonify({"status": "error", "error": str(e)}), 400 - @bp.delete('/') + @bp.delete("/") def delete_item(id): - service.delete(id) try: - return jsonify({"status": "success"}), 204 + service.delete(id) + # 204 means "no content" so don't send any. + return ("", 204) except Exception as e: - return jsonify({"status": "error", "error": str(e)}) + return jsonify({"status": "error", "error": str(e)}), 400 return bp diff --git a/crudkit/core/service.py b/crudkit/core/service.py index 7c40320..86a4542 100644 --- a/crudkit/core/service.py +++ b/crudkit/core/service.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast from sqlalchemy import and_, func, inspect, or_, text from sqlalchemy.engine import Engine, Connection @@ -10,6 +12,9 @@ from crudkit.core.types import OrderSpec, SeekWindow from crudkit.backend import BackendInfo, make_backend_info from crudkit.projection import compile_projection +import logging +log = logging.getLogger("crudkit.service") + def _is_rel(model_cls, name: str) -> bool: try: prop = model_cls.__mapper__.relationships.get(name) @@ -56,7 +61,7 @@ class CRUDService(Generic[T]): self.polymorphic = polymorphic self.supports_soft_delete = hasattr(model, 'is_deleted') # Cache backend info once. If not provided, derive from session bind. - bind = self.session.get_bind() + bind = session_factory().get_bind() eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind) self.backend = backend or make_backend_info(eng) @@ -70,6 +75,11 @@ class CRUDService(Generic[T]): return self.session.query(poly), poly return self.session.query(self.model), self.model + def _apply_not_deleted(self, query, root_alias, params) -> Any: + if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")): + return query.filter(getattr(root_alias, "is_deleted") == False) + return query + def _extract_order_spec(self, root_alias, given_order_by): """ SQLAlchemy 2.x only: @@ -85,7 +95,7 @@ class CRUDService(Generic[T]): for ob in given: # Unwrap column if this is a UnaryExpression produced by .asc()/.desc() elem = getattr(ob, "element", None) - col = elem if elem is not None else ob # don't use "or" with SA expressions + col = elem if elem is not None else ob # Detect direction in SA 2.x is_desc = False @@ -103,27 +113,30 @@ class CRUDService(Generic[T]): return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags)) def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool): - """ - Build lexicographic predicate for keyset seek. - For backward traversal, import comparisons. - """ if not key_vals: return None + conds = [] for i, col in enumerate(spec.cols): + # If NULLs possible, normalize for comparison. Example using coalesce to a sentinel: + # sent_col = func.coalesce(col, literal("-∞")) + sent_col = col ties = [spec.cols[j] == key_vals[j] for j in range(i)] is_desc = spec.desc[i] if not backward: - op = col < key_vals[i] if is_desc else col > key_vals[i] + op = (sent_col < key_vals[i]) if is_desc else (sent_col > key_vals[i]) else: - op = col > key_vals[i] if is_desc else col < key_vals[i] + op = (sent_col > key_vals[i]) if is_desc else (sent_col < key_vals[i]) conds.append(and_(*ties, op)) return or_(*conds) def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]: out = [] for c in spec.cols: + # Only simple mapped columns supported for key pluck key = getattr(c, "key", None) or getattr(c, "name", None) + if key is None or not hasattr(obj, key): + raise ValueError("Order includes non-mapped or related column; cannot pluck cursor key from row object.") out.append(getattr(obj, key)) return out @@ -142,6 +155,7 @@ class CRUDService(Generic[T]): - forward/backward seek via `key` and `backward` Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total. """ + session = self.session fields = list((params or {}).get("fields", [])) expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], []) query, root_alias = self.get_query() @@ -156,8 +170,9 @@ class CRUDService(Generic[T]): root_fields, rel_field_names, root_field_names = spec.parse_fields() # Soft delete filter - if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")): - query = query.filter(getattr(root_alias, "is_deleted") == False) + # if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")): + # query = query.filter(getattr(root_alias, "is_deleted") == False) + query = self._apply_not_deleted(query, root_alias, params) # Parse filters first if filters: @@ -165,6 +180,8 @@ class CRUDService(Generic[T]): # Includes + joins (so relationship fields like brand.name, location.label work) spec.parse_includes() + join_paths = tuple(spec.get_join_paths()) + for _, relationship_attr, target_alias in spec.get_join_paths(): rel_attr = cast(InstrumentedAttribute, relationship_attr) target = cast(Any, target_alias) @@ -178,8 +195,12 @@ class CRUDService(Generic[T]): # Order + limit order_spec = self._extract_order_spec(root_alias, order_by) # SA2-only helper limit, _ = spec.parse_pagination() - if not limit or limit <= 0: - limit = 50 # sensible default + if limit is None: + effective_limit = 50 + elif limit == 0: + effective_limit = None + else: + effective_limit = limit # Keyset predicate if key: @@ -189,18 +210,19 @@ class CRUDService(Generic[T]): # Apply ordering. For backward, invert SQL order then reverse in-memory for display. if not backward: - clauses = [] - for col, is_desc in zip(order_spec.cols, order_spec.desc): - clauses.append(col.desc() if is_desc else col.asc()) - query = query.order_by(*clauses).limit(limit) + clauses = [(c.desc() if is_desc else c.asc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)] + query = query.order_by(*clauses) + if effective_limit is not None: + query = query.limit(effective_limit) items = query.all() else: - inv_clauses = [] - for col, is_desc in zip(order_spec.cols, order_spec.desc): - inv_clauses.append(col.asc() if is_desc else col.desc()) - query = query.order_by(*inv_clauses).limit(limit) + inv_clauses = [(c.asc() if is_desc else c.desc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)] + query = query.order_by(*inv_clauses) + if effective_limit is not None: + query = query.limit(effective_limit) items = list(reversed(query.all())) + # Tag projection so your renderer knows what fields were requested if expanded_fields: proj = list(expanded_fields) @@ -231,23 +253,27 @@ class CRUDService(Generic[T]): # Optional total that’s safe under JOINs (COUNT DISTINCT ids) total = None if include_total: - base = self.session.query(getattr(root_alias, "id")) - if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")): - base = base.filter(getattr(root_alias, "is_deleted") == False) + base = session.query(getattr(root_alias, "id")) + base = self._apply_not_deleted(base, root_alias, params) if filters: base = base.filter(*filters) - # replicate the same joins used above - for _, relationship_attr, target_alias in spec.get_join_paths(): + for _, relationship_attr, target_alias in join_paths: # reuse rel_attr = cast(InstrumentedAttribute, relationship_attr) target = cast(Any, target_alias) base = base.join(target, rel_attr.of_type(target), isouter=True) - total = self.session.query(func.count()).select_from(base.order_by(None).distinct().subquery()).scalar() or 0 - print(f"!!! QUERY !!! -> {str(query)}") + total = session.query(func.count()).select_from( + base.order_by(None).distinct().subquery() + ).scalar() or 0 + + window_limit_for_body = 0 if effective_limit is None and limit == 0 else (effective_limit or 50) + + if log.isEnabledFor(logging.DEBUG): + log.debug("QUERY: %s", str(query)) from crudkit.core.types import SeekWindow # avoid circulars at module top return SeekWindow( items=items, - limit=limit, + limit=window_limit_for_body, first_key=first_key, last_key=last_key, order=order_spec, @@ -342,7 +368,9 @@ class CRUDService(Generic[T]): except Exception: pass - print(f"!!! QUERY !!! -> {str(query)}") + if log.isEnabledFor(logging.DEBUG): + log.debug("QUERY: %s", str(query)) + return obj or None def list(self, params=None) -> list[T]: @@ -422,42 +450,51 @@ class CRUDService(Generic[T]): except Exception: pass - print(f"!!! QUERY !!! -> {str(query)}") + if log.isEnabledFor(logging.DEBUG): + log.debug("QUERY: %s", str(query)) + return rows def create(self, data: dict, actor=None) -> T: + session = self.session obj = self.model(**data) - self.session.add(obj) - self.session.commit() + session.add(obj) + session.commit() self._log_version("create", obj, actor) return obj def update(self, id: int, data: dict, actor=None) -> T: + session = self.session obj = self.get(id) if not obj: raise ValueError(f"{self.model.__name__} with ID {id} not found.") valid_fields = {c.name for c in self.model.__table__.columns} + unknown = set(data) - valid_fields + if unknown: + raise ValueError(f"Unknown fields: {', '.join(sorted(unknown))}") for k, v in data.items(): if k in valid_fields: setattr(obj, k, v) - self.session.commit() + session.commit() self._log_version("update", obj, actor) return obj - def delete(self, id: int, hard: bool = False, actor = False): - obj = self.session.get(self.model, id) + def delete(self, id: int, hard: bool = False, actor = None): + session = self.session + obj = session.get(self.model, id) if not obj: return None if hard or not self.supports_soft_delete: - self.session.delete(obj) + session.delete(obj) else: soft = cast(_SoftDeletable, obj) soft.is_deleted = True - self.session.commit() + session.commit() self._log_version("delete", obj, actor) return obj - def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict = {}): + def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict | None = None): + session = self.session try: data = obj.as_dict() except Exception: @@ -470,5 +507,5 @@ class CRUDService(Generic[T]): actor=str(actor) if actor else None, meta=metadata ) - self.session.add(version) - self.session.commit() + session.add(version) + session.commit() diff --git a/inventory/routes/listing.py b/inventory/routes/listing.py index 56b1601..f5b474e 100644 --- a/inventory/routes/listing.py +++ b/inventory/routes/listing.py @@ -91,7 +91,7 @@ def init_listing_routes(app): ] limit = int(request.args.get("limit", 15)) cursor = request.args.get("cursor") - key, backward = decode_cursor(cursor) + key, _desc, backward = decode_cursor(cursor) service = crudkit.crud.get_service(cls) window = service.seek_window(spec | {"limit": limit}, key=key, backward=backward, include_total=True)