From 7aefefdec6d4e59b24c543156c4544a9d418b1b1 Mon Sep 17 00:00:00 2001 From: Yaro Kasear Date: Tue, 16 Sep 2025 13:41:55 -0500 Subject: [PATCH] Pagination/cursor feature added. --- crudkit/api/_cursor.py | 21 +++ crudkit/api/flask_api.py | 52 +++++++- crudkit/core/service.py | 268 ++++++++++++++++++++++++++++++++++++++- crudkit/core/types.py | 16 +++ 4 files changed, 349 insertions(+), 8 deletions(-) create mode 100644 crudkit/api/_cursor.py create mode 100644 crudkit/core/types.py diff --git a/crudkit/api/_cursor.py b/crudkit/api/_cursor.py new file mode 100644 index 0000000..3b63cfd --- /dev/null +++ b/crudkit/api/_cursor.py @@ -0,0 +1,21 @@ +import base64, json +from typing import Any + +def encode_cursor(values: list[Any] | None, desc_flags: list[bool], backward: bool) -> str | None: + if not values: + return None + payload = {"v": values, "d": desc_flags, "b": backward} + return base64.urlsafe_b64encode(json.dumps(payload).encode()).decode() + +def decode_cursor(token: str | None) -> tuple[list[Any] | None, bool] | tuple[None, bool]: + if not token: + return None, False + try: + obj = json.loads(base64.urlsafe_b64decode(token.encode()).decode()) + vals = obj.get("v") + backward = bool(obj.get("b", False)) + if isinstance(vals, list): + return vals, backward + except Exception: + pass + return None, False diff --git a/crudkit/api/flask_api.py b/crudkit/api/flask_api.py index d238678..bf505b2 100644 --- a/crudkit/api/flask_api.py +++ b/crudkit/api/flask_api.py @@ -1,15 +1,59 @@ from flask import Blueprint, jsonify, request +from crudkit.api._cursor import encode_cursor, decode_cursor +from crudkit.core.service import _is_truthy + def generate_crud_blueprint(model, service): bp = Blueprint(model.__name__.lower(), __name__) @bp.get('/') def list_items(): - items = service.list(request.args) + args = request.args.to_dict(flat=True) + + # legacy detection + legacy_offset = "offset" in args or "page" in args + + # sane limit default try: - return jsonify([item.as_dict() for item in items]) - except Exception as e: - return jsonify({"status": "error", "error": str(e)}) + limit = int(args.get("limit", 50)) + except Exception: + limit = 50 + args["limit"] = limit + + if legacy_offset: + # Old behavior: honor limit/offset, same CRUDSpec goodies + items = service.list(args) + return jsonify([obj.as_dict() for obj in items]) + + # New behavior: keyset seek with cursors + key, backward = decode_cursor(args.get("cursor")) + + window = service.seek_window( + args, + key=key, + backward=backward, + include_total=_is_truthy(args.get("include_total", "1")), + ) + + desc_flags = list(window.order.desc) + body = { + "items": [obj.as_dict() for obj in window.items], + "limit": window.limit, + "next_cursor": encode_cursor(window.last_key, desc_flags, backward=False), + "prev_cursor": encode_cursor(window.first_key, desc_flags, backward=True), + "total": window.total, + } + + resp = jsonify(body) + # Optional Link header + links = [] + if body["next_cursor"]: + links.append(f'<{request.base_url}?cursor={body["next_cursor"]}&limit={window.limit}>; rel="next"') + if body["prev_cursor"]: + links.append(f'<{request.base_url}?cursor={body["prev_cursor"]}&limit={window.limit}>; rel="prev"') + if links: + resp.headers["Link"] = ", ".join(links) + return resp @bp.get('/') def get_item(id): diff --git a/crudkit/core/service.py b/crudkit/core/service.py index 6c0afd6..e7acf1d 100644 --- a/crudkit/core/service.py +++ b/crudkit/core/service.py @@ -1,11 +1,15 @@ -from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast -from sqlalchemy.orm import Load, Session, raiseload, with_polymorphic, Mapper +from typing import Any, Callable, Dict, Iterable, List, Tuple, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast +from sqlalchemy import and_, func, inspect, or_, text +from sqlalchemy.engine import Engine, Connection +from sqlalchemy.orm import Load, Session, raiseload, selectinload, with_polymorphic, Mapper, RelationshipProperty from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.util import AliasedClass -from sqlalchemy.engine import Engine, Connection -from sqlalchemy import inspect, text +from sqlalchemy.sql import operators +from sqlalchemy.sql.elements import UnaryExpression + from crudkit.core.base import Version from crudkit.core.spec import CRUDSpec +from crudkit.core.types import OrderSpec, SeekWindow from crudkit.backend import BackendInfo, make_backend_info @runtime_checkable @@ -61,6 +65,243 @@ class CRUDService(Generic[T]): return self.session.query(poly), poly return self.session.query(self.model), self.model + def _resolve_required_includes(self, root_alias: Any, rel_field_names: Dict[Tuple[str, ...], List[str]]) -> List[Any]: + """ + For each dotted path like ("location"), -> ["label"], look up the target + model's __crudkit_field_requires__ for the terminal field and produce + selectinload options prefixed with the relationship path, e.g.: + Room.__crudkit_field_requires__['label'] = ['room_function'] + => selectinload(root.location).selectinload(Room.room_function) + """ + opts: List[Any] = [] + + root_mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) + + for path, names in (rel_field_names or {}).items(): + if not path: + continue + + current_alias = root_alias + current_mapper = root_mapper + rel_props: List[RelationshipProperty] = [] + + valid = True + for step in path: + rel = current_mapper.relationships.get(step) + if rel is None: + valid = False + break + rel_props.append(rel) + current_mapper = cast(Mapper[Any], inspect(rel.entity.entity)) + if not valid: + continue + + target_cls = current_mapper.class_ + + requires = getattr(target_cls, "__crudkit_field_requires__", None) + if not isinstance(requires, dict): + continue + + for field_name in names: + needed: Iterable[str] = requires.get(field_name, []) + for rel_need in needed: + loader = selectinload(getattr(root_alias, rel_props[0].key)) + for rp in rel_props[1:]: + loader = loader.selectinload(getattr(getattr(root_alias, rp.parent.class_.__name__.lower(), None) or rp.parent.class_, rp.key)) + + loader = loader.selectinload(getattr(target_cls, rel_need)) + opts.append(loader) + + return opts + + def _extract_order_spec(self, root_alias, given_order_by): + """ + SQLAlchemy 2.x only: + Normalize order_by into (cols, desc_flags). Supports plain columns and + col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses. + """ + from sqlalchemy.sql import operators + from sqlalchemy.sql.elements import UnaryExpression + + given = self._stable_order_by(root_alias, given_order_by) + + cols, desc_flags = [], [] + for ob in given: + # Unwrap column if this is a UnaryExpression produced by .asc()/.desc() + elem = getattr(ob, "element", None) + col = elem if elem is not None else ob # don't use "or" with SA expressions + + # Detect direction in SA 2.x + is_desc = False + dir_attr = getattr(ob, "_direction", None) + if dir_attr is not None: + is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC") + elif isinstance(ob, UnaryExpression): + op = getattr(ob, "operator", None) + is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") + + cols.append(col) + desc_flags.append(bool(is_desc)) + + from crudkit.core.types import OrderSpec + return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags)) + + def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool): + """ + Build lexicographic predicate for keyset seek. + For backward traversal, import comparisons. + """ + if not key_vals: + return None + conds = [] + for i, col in enumerate(spec.cols): + ties = [spec.cols[j] == key_vals[j] for j in range(i)] + is_desc = spec.desc[i] + if not backward: + op = col < key_vals[i] if is_desc else col > key_vals[i] + else: + op = col > key_vals[i] if is_desc else col < key_vals[i] + conds.append(and_(*ties, op)) + return or_(*conds) + + def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]: + out = [] + for c in spec.cols: + key = getattr(c, "key", None) or getattr(c, "name", None) + out.append(getattr(obj, key)) + return out + + def seek_window( + self, + params: dict | None = None, + *, + key: list[Any] | None = None, + backward: bool = False, + include_total: bool = True, + ) -> "SeekWindow[T]": + """ + Transport-agnostic keyset pagination that preserves all the goodies from `list()`: + - filters, includes, joins, field projection, eager loading, soft-delete + - deterministic ordering (user sort + PK tiebreakers) + - forward/backward seek via `key` and `backward` + Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total. + """ + session = self.session + query, root_alias = self.get_query() + + spec = CRUDSpec(self.model, params or {}, root_alias) + + filters = spec.parse_filters() + order_by = spec.parse_sort() + + root_fields, rel_field_names, root_field_names = spec.parse_fields() + + for path, names in (rel_field_names or {}).items(): + if "label" in names: + rel_name = path[0] + rel_attr = getattr(root_alias, rel_name, None) + if rel_attr is not None: + query = query.options(selectinload(rel_attr)) + + # Soft delete filter + if self.supports_soft_delete and not _is_truthy(params.get("include_deleted")): + query = query.filter(getattr(root_alias, "is_deleted") == False) + + # Parse filters first + if filters: + query = query.filter(*filters) + + # Includes + joins (so relationship fields like brand.name, location.label work) + spec.parse_includes() + for parent_alias, relationship_attr, target_alias in spec.get_join_paths(): + rel_attr = cast(InstrumentedAttribute, relationship_attr) + target = cast(Any, target_alias) + query = query.join(target, rel_attr.of_type(target), isouter=True) + + # Fields/projection: load_only for root columns, eager loads for relationships + only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] + if only_cols: + query = query.options(Load(root_alias).load_only(*only_cols)) + for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names): + query = query.options(eager) + + for opt in self._resolve_required_includes(root_alias, rel_field_names): + query = query.options(opt) + + # Order + limit + order_spec = self._extract_order_spec(root_alias, order_by) # SA2-only helper + limit, _ = spec.parse_pagination() + if not limit or limit <= 0: + limit = 50 # sensible default + + # Keyset predicate + if key: + pred = self._key_predicate(order_spec, key, backward) + if pred is not None: + query = query.filter(pred) + + # Apply ordering. For backward, invert SQL order then reverse in-memory for display. + if not backward: + clauses = [] + for col, is_desc in zip(order_spec.cols, order_spec.desc): + clauses.append(col.desc() if is_desc else col.asc()) + query = query.order_by(*clauses).limit(limit) + items = query.all() + else: + inv_clauses = [] + for col, is_desc in zip(order_spec.cols, order_spec.desc): + inv_clauses.append(col.asc() if is_desc else col.desc()) + query = query.order_by(*inv_clauses).limit(limit) + items = list(reversed(query.all())) + + # Tag projection so your renderer knows what fields were requested + proj = [] + if root_field_names: + proj.extend(root_field_names) + if root_fields: + proj.extend(c.key for c in root_fields) + for path, names in (rel_field_names or {}).items(): + prefix = ".".join(path) + for n in names: + proj.append(f"{prefix}.{n}") + if proj and "id" not in proj and hasattr(self.model, "id"): + proj.insert(0, "id") + if proj: + for obj in items: + try: + setattr(obj, "__crudkit_projection__", tuple(proj)) + except Exception: + pass + + # Boundary keys for cursor encoding in the API layer + first_key = self._pluck_key(items[0], order_spec) if items else None + last_key = self._pluck_key(items[-1], order_spec) if items else None + + # Optional total that’s safe under JOINs (COUNT DISTINCT ids) + total = None + if include_total: + base = self.session.query(getattr(root_alias, "id")) + if self.supports_soft_delete and not _is_truthy(params.get("include_deleted")): + base = base.filter(getattr(root_alias, "is_deleted") == False) + if filters: + base = base.filter(*filters) + # replicate the same joins used above + for parent_alias, relationship_attr, target_alias in spec.get_join_paths(): + rel_attr = cast(InstrumentedAttribute, relationship_attr) + target = cast(Any, target_alias) + base = base.join(target, rel_attr.of_type(target), isouter=True) + total = self.session.query(func.count()).select_from(base.order_by(None).distinct().subquery()).scalar() or 0 + + from crudkit.core.types import SeekWindow # avoid circulars at module top + return SeekWindow( + items=items, + limit=limit, + first_key=first_key, + last_key=last_key, + order=order_spec, + total=total, + ) + # Helper: default ORDER BY for MSSQL when paginating without explicit order def _default_order_by(self, root_alias): mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) @@ -72,6 +313,25 @@ class CRUDService(Generic[T]): cols.append(col) return cols or [text("1")] + def _stable_order_by(self, root_alias, given_order_by): + """ + Ensure deterministic ordering by appending PK columns as tiebreakers. + If no order is provided, fall back to default primary-key order. + """ + order_by = list(given_order_by or []) + if not order_by: + return self._default_order_by(root_alias) + + mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) + pk_cols = [] + for col in mapper.primary_key: + try: + pk_cols.append(getattr(root_alias, col.key)) + except ArithmeticError: + pk_cols.append(col) + + return [*order_by, *pk_cols] + def get(self, id: int, params=None) -> T | None: query, root_alias = self.get_query() diff --git a/crudkit/core/types.py b/crudkit/core/types.py new file mode 100644 index 0000000..aade874 --- /dev/null +++ b/crudkit/core/types.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass +from typing import Any, Sequence + +@dataclass(frozen=True) +class OrderSpec: + cols: Sequence[Any] + desc: Sequence[bool] + +@dataclass +class SeekWindow: + items: list[Any] + limit: int + first_key: list[Any] | None + last_key: list[Any] | None + order: OrderSpec + total: int | None = None