diff --git a/crudkit/api/_cursor.py b/crudkit/api/_cursor.py index a2a80c2..3b63cfd 100644 --- a/crudkit/api/_cursor.py +++ b/crudkit/api/_cursor.py @@ -1,135 +1,21 @@ -# crudkit/api/_cursor.py +import base64, json +from typing import Any -from __future__ import annotations - -import base64 -import dataclasses -import datetime as _dt -import decimal as _dec -import hmac -import json -import typing as _t -from hashlib import sha256 - -Any = _t.Any - -@dataclasses.dataclass(frozen=True) -class Cursor: - values: list[Any] - desc_flags: list[bool] - backward: bool - version: int = 1 - - -def _json_default(o: Any) -> Any: - # Keep it boring and predictable. - if isinstance(o, (_dt.datetime, _dt.date, _dt.time)): - # ISO is good enough; assume UTC-aware datetimes are already normalized upstream. - return o.isoformat() - if isinstance(o, _dec.Decimal): - return str(o) - # UUIDs, Enums, etc. - if hasattr(o, "__str__"): - return str(o) - raise TypeError(f"Unsupported cursor value type: {type(o)!r}") - - -def _b64url_nopad_encode(b: bytes) -> str: - return base64.urlsafe_b64encode(b).rstrip(b"=").decode("ascii") - - -def _b64url_nopad_decode(s: str) -> bytes: - # Restore padding for strict decode - pad = "=" * (-len(s) % 4) - return base64.urlsafe_b64decode((s + pad).encode("ascii")) - - -def encode_cursor( - values: list[Any] | None, - desc_flags: list[bool], - backward: bool, - *, - secret: bytes | None = None, -) -> str | None: - """ - Create an opaque, optionally signed cursor token. - - - values: keyset values for the last/first row - - desc_flags: per-key descending flags (same length/order as sort keys) - - backward: whether the window direction is backward - - secret: if set, an HMAC-SHA256 signature is appended (JWT-ish 'payload.sig') - """ +def encode_cursor(values: list[Any] | None, desc_flags: list[bool], backward: bool) -> str | None: if not values: return None + payload = {"v": values, "d": desc_flags, "b": backward} + return base64.urlsafe_b64encode(json.dumps(payload).encode()).decode() - payload = { - "ver": 1, - "v": values, - "d": desc_flags, - "b": bool(backward), - } - - body = json.dumps(payload, default=_json_default, separators=(",", ":"), sort_keys=True).encode("utf-8") - token = _b64url_nopad_encode(body) - - if secret: - sig = hmac.new(secret, body, sha256).digest() - token = f"{token}.{_b64url_nopad_encode(sig)}" - - return token - - -def decode_cursor( - token: str | None, - *, - secret: bytes | None = None, -) -> tuple[list[Any] | None, list[bool] | None, bool]: - """ - Parse a cursor token. Returns (values, desc_flags, backward). - - - Accepts legacy tokens lacking 'd' and returns desc_flags=None in that case. - - If secret is provided, verifies HMAC and rejects tampered tokens. - - On any parse failure, returns (None, None, False). - """ +def decode_cursor(token: str | None) -> tuple[list[Any] | None, bool] | tuple[None, bool]: if not token: - return None, None, False - + return None, False try: - # Split payload.sig if signed - if "." in token: - body_b64, sig_b64 = token.split(".", 1) - body = _b64url_nopad_decode(body_b64) - if secret is None: - # Caller didn’t ask for verification; still parse but don’t trust. - pass - else: - expected = hmac.new(secret, body, sha256).digest() - actual = _b64url_nopad_decode(sig_b64) - if not hmac.compare_digest(expected, actual): - return None, None, False - else: - body = _b64url_nopad_decode(token) - - obj = json.loads(body.decode("utf-8")) - - # Versioning. If we ever change fields, branch here. - ver = int(obj.get("ver", 0)) - if ver not in (0, 1): - return None, None, False - + obj = json.loads(base64.urlsafe_b64decode(token.encode()).decode()) vals = obj.get("v") backward = bool(obj.get("b", False)) - - # desc_flags may be absent in legacy payloads (ver 0) - desc = obj.get("d") - if not isinstance(vals, list): - return None, None, False - if desc is not None and not (isinstance(desc, list) and all(isinstance(x, bool) for x in desc)): - # Ignore weird 'd' types rather than crashing - desc = None - - return vals, desc, backward - + if isinstance(vals, list): + return vals, backward except Exception: - # Be tolerant on decode: treat as no-cursor. - return None, None, False + pass + return None, False diff --git a/crudkit/api/flask_api.py b/crudkit/api/flask_api.py index 3b061e2..bf505b2 100644 --- a/crudkit/api/flask_api.py +++ b/crudkit/api/flask_api.py @@ -1,40 +1,23 @@ -from __future__ import annotations - -from flask import Blueprint, jsonify, request, abort -from urllib.parse import urlencode +from flask import Blueprint, jsonify, request from crudkit.api._cursor import encode_cursor, decode_cursor from crudkit.core.service import _is_truthy - -def _bool_param(d: dict[str, str], key: str, default: bool) -> bool: - return _is_truthy(d.get(key, "1" if default else "0")) - - -def _safe_int(value: str | None, default: int) -> int: - try: - return int(value) if value is not None else default - except Exception: - return default - - -def _link_with_params(base_url: str, **params) -> str: - # Filter out None, encode safely - q = {k: v for k, v in params.items() if v is not None} - return f"{base_url}?{urlencode(q)}" - - def generate_crud_blueprint(model, service): bp = Blueprint(model.__name__.lower(), __name__) - @bp.get("/") + @bp.get('/') def list_items(): - # Work from a copy so we don't mutate request.args args = request.args.to_dict(flat=True) + # legacy detection legacy_offset = "offset" in args or "page" in args - limit = _safe_int(args.get("limit"), 50) + # sane limit default + try: + limit = int(args.get("limit", 50)) + except Exception: + limit = 50 args["limit"] = limit if legacy_offset: @@ -42,23 +25,17 @@ def generate_crud_blueprint(model, service): items = service.list(args) return jsonify([obj.as_dict() for obj in items]) - # New behavior: keyset pagination with cursors - cursor_token = args.get("cursor") - key, desc_from_cursor, backward = decode_cursor(cursor_token) + # New behavior: keyset seek with cursors + key, backward = decode_cursor(args.get("cursor")) window = service.seek_window( args, key=key, backward=backward, - include_total=_bool_param(args, "include_total", True), + include_total=_is_truthy(args.get("include_total", "1")), ) - # Prefer the order actually used by the window; fall back to desc_from_cursor if needed. - try: - desc_flags = list(window.order.desc) - except Exception: - desc_flags = desc_from_cursor or [] - + desc_flags = list(window.order.desc) body = { "items": [obj.as_dict() for obj in window.items], "limit": window.limit, @@ -68,60 +45,46 @@ def generate_crud_blueprint(model, service): } resp = jsonify(body) - - # Preserve user’s other query params like include_total, filters, sorts, etc. - base_url = request.base_url - base_params = {k: v for k, v in args.items() if k not in {"cursor"}} - link_parts = [] + # Optional Link header + links = [] if body["next_cursor"]: - link_parts.append( - f'<{_link_with_params(base_url, **base_params, cursor=body["next_cursor"])}>; rel="next"' - ) + links.append(f'<{request.base_url}?cursor={body["next_cursor"]}&limit={window.limit}>; rel="next"') if body["prev_cursor"]: - link_parts.append( - f'<{_link_with_params(base_url, **base_params, cursor=body["prev_cursor"])}>; rel="prev"' - ) - if link_parts: - resp.headers["Link"] = ", ".join(link_parts) + links.append(f'<{request.base_url}?cursor={body["prev_cursor"]}&limit={window.limit}>; rel="prev"') + if links: + resp.headers["Link"] = ", ".join(links) return resp - @bp.get("/") + @bp.get('/') def get_item(id): + item = service.get(id, request.args) try: - item = service.get(id, request.args) - if item is None: - abort(404) return jsonify(item.as_dict()) except Exception as e: - # Could be validation, auth, or just you forgetting an index again - return jsonify({"status": "error", "error": str(e)}), 400 + return jsonify({"status": "error", "error": str(e)}) - @bp.post("/") + @bp.post('/') def create_item(): - payload = request.get_json(silent=True) or {} + obj = service.create(request.json) try: - obj = service.create(payload) - return jsonify(obj.as_dict()), 201 - except Exception as e: - return jsonify({"status": "error", "error": str(e)}), 400 - - @bp.patch("/") - def update_item(id): - payload = request.get_json(silent=True) or {} - try: - obj = service.update(id, payload) return jsonify(obj.as_dict()) except Exception as e: - # 404 if not found, 400 if validation. Your service can throw specific exceptions if you ever feel like being professional. - return jsonify({"status": "error", "error": str(e)}), 400 + return jsonify({"status": "error", "error": str(e)}) - @bp.delete("/") - def delete_item(id): + @bp.patch('/') + def update_item(id): + obj = service.update(id, request.json) try: - service.delete(id) - # 204 means "no content" so don't send any. - return ("", 204) + return jsonify(obj.as_dict()) except Exception as e: - return jsonify({"status": "error", "error": str(e)}), 400 + return jsonify({"status": "error", "error": str(e)}) + + @bp.delete('/') + def delete_item(id): + service.delete(id) + try: + return jsonify({"status": "success"}), 204 + except Exception as e: + return jsonify({"status": "error", "error": str(e)}) return bp diff --git a/crudkit/core/service.py b/crudkit/core/service.py index 9eb6d8b..b84ebc6 100644 --- a/crudkit/core/service.py +++ b/crudkit/core/service.py @@ -1,22 +1,92 @@ -from __future__ import annotations - -from collections.abc import Iterable -from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast -from sqlalchemy import and_, func, inspect, or_, text, UnaryExpression +from typing import Any, Callable, Dict, Iterable, List, Tuple, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast +from sqlalchemy import and_, func, inspect, or_, text from sqlalchemy.engine import Engine, Connection -from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager, selectinload +from sqlalchemy.orm import Load, Session, raiseload, selectinload, with_polymorphic, Mapper, RelationshipProperty, class_mapper, ColumnProperty from sqlalchemy.orm.attributes import InstrumentedAttribute +from sqlalchemy.orm.base import NO_VALUE +from sqlalchemy.orm.util import AliasedClass from sqlalchemy.sql import operators -from sqlalchemy.sql.elements import ColumnElement +from sqlalchemy.sql.elements import UnaryExpression from crudkit.core.base import Version from crudkit.core.spec import CRUDSpec from crudkit.core.types import OrderSpec, SeekWindow from crudkit.backend import BackendInfo, make_backend_info -from crudkit.projection import compile_projection -import logging -log = logging.getLogger("crudkit.service") +def _expand_requires(model_cls, fields): + out, seen = [], set() + def add(f): + if f not in seen: + seen.add(f); out.append(f) + + for f in fields: + add(f) + parts = f.split(".") + cur_cls = model_cls + prefix = [] + + for p in parts[:-1]: + rel = getattr(cur_cls.__mapper__.relationships, 'get', lambda _: None)(p) + if not rel: + cur_cls = None + break + cur_cls = rel.mapper.class_ + prefix.append(p) + + if cur_cls is None: + continue + + leaf = parts[-1] + deps = (getattr(cur_cls, "__crudkit_field_requires__", {}) or {}).get(leaf) + if not deps: + continue + + pre = ".".join(prefix) + for dep in deps: + add(f"{pre + '.' if pre else ''}{dep}") + return out + +def _is_rel(model_cls, name: str) -> bool: + try: + prop = model_cls.__mapper__.relationships.get(name) + return isinstance(prop, RelationshipProperty) + except Exception: + return False + +def _is_instrumented_column(attr) -> bool: + try: + return hasattr(attr, "property") and isinstance(attr.property, ColumnProperty) + except Exception: + return False + +def _loader_options_for_fields(root_alias, model_cls, fields: list[str]) -> list[Load]: + """ + For bare MANYTOONE names in fields (e.g. "location"), selectinload the relationship + and only fetch the related PK. This is enough for preselecting