Pagination support!!!

This commit is contained in:
Yaro Kasear 2025-09-16 09:57:40 -05:00
parent 3f677fceee
commit a64c64e828
5 changed files with 298 additions and 12 deletions

21
crudkit/api/_cursor.py Normal file
View file

@ -0,0 +1,21 @@
import base64, json
from typing import Any
def encode_cursor(values: list[Any] | None, desc_flags: list[bool], backward: bool) -> str | None:
if not values:
return None
payload = {"v": values, "d": desc_flags, "b": backward}
return base64.urlsafe_b64encode(json.dumps(payload).encode()).decode()
def decode_cursor(token: str | None) -> tuple[list[Any] | None, bool] | tuple[None, bool]:
if not token:
return None, False
try:
obj = json.loads(base64.urlsafe_b64decode(token.encode()).decode())
vals = obj.get("v")
backward = bool(obj.get("b", False))
if isinstance(vals, list):
return vals, backward
except Exception:
pass
return None, False

View file

@ -1,15 +1,59 @@
from flask import Blueprint, jsonify, request
from crudkit.api._cursor import encode_cursor, decode_cursor
from crudkit.core.service import _is_truthy
def generate_crud_blueprint(model, service):
bp = Blueprint(model.__name__.lower(), __name__)
@bp.get('/')
def list_items():
items = service.list(request.args)
args = request.args.to_dict(flat=True)
# legacy detection
legacy_offset = "offset" in args or "page" in args
# sane limit default
try:
return jsonify([item.as_dict() for item in items])
except Exception as e:
return jsonify({"status": "error", "error": str(e)})
limit = int(args.get("limit", 50))
except Exception:
limit = 50
args["limit"] = limit
if legacy_offset:
# Old behavior: honor limit/offset, same CRUDSpec goodies
items = service.list(args)
return jsonify([obj.as_dict() for obj in items])
# New behavior: keyset seek with cursors
key, backward = decode_cursor(args.get("cursor"))
window = service.seek_window(
args,
key=key,
backward=backward,
include_total=_is_truthy(args.get("include_total", "1")),
)
desc_flags = list(window.order.desc)
body = {
"items": [obj.as_dict() for obj in window.items],
"limit": window.limit,
"next_cursor": encode_cursor(window.last_key, desc_flags, backward=False),
"prev_cursor": encode_cursor(window.first_key, desc_flags, backward=True),
"total": window.total,
}
resp = jsonify(body)
# Optional Link header
links = []
if body["next_cursor"]:
links.append(f'<{request.base_url}?cursor={body["next_cursor"]}&limit={window.limit}>; rel="next"')
if body["prev_cursor"]:
links.append(f'<{request.base_url}?cursor={body["prev_cursor"]}&limit={window.limit}>; rel="prev"')
if links:
resp.headers["Link"] = ", ".join(links)
return resp
@bp.get('/<int:id>')
def get_item(id):

View file

@ -1,11 +1,15 @@
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
from sqlalchemy import and_, func, inspect, or_, text
from sqlalchemy.engine import Engine, Connection
from sqlalchemy.orm import Load, Session, raiseload, with_polymorphic, Mapper
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.util import AliasedClass
from sqlalchemy.engine import Engine, Connection
from sqlalchemy import inspect, text
from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import UnaryExpression
from crudkit.core.base import Version
from crudkit.core.spec import CRUDSpec
from crudkit.core.types import OrderSpec, SeekWindow
from crudkit.backend import BackendInfo, make_backend_info
@runtime_checkable
@ -61,6 +65,182 @@ class CRUDService(Generic[T]):
return self.session.query(poly), poly
return self.session.query(self.model), self.model
def _extract_order_spec(self, root_alias, given_order_by):
"""
SQLAlchemy 2.x only:
Normalize order_by into (cols, desc_flags). Supports plain columns and
col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses.
"""
from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import UnaryExpression
given = self._stable_order_by(root_alias, given_order_by)
cols, desc_flags = [], []
for ob in given:
# Unwrap column if this is a UnaryExpression produced by .asc()/.desc()
elem = getattr(ob, "element", None)
col = elem if elem is not None else ob # don't use "or" with SA expressions
# Detect direction in SA 2.x
is_desc = False
dir_attr = getattr(ob, "_direction", None)
if dir_attr is not None:
is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC")
elif isinstance(ob, UnaryExpression):
op = getattr(ob, "operator", None)
is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC")
cols.append(col)
desc_flags.append(bool(is_desc))
from crudkit.core.types import OrderSpec
return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags))
def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool):
"""
Build lexicographic predicate for keyset seek.
For backward traversal, import comparisons.
"""
if not key_vals:
return None
conds = []
for i, col in enumerate(spec.cols):
ties = [spec.cols[j] == key_vals[j] for j in range(i)]
is_desc = spec.desc[i]
if not backward:
op = col < key_vals[i] if is_desc else col > key_vals[i]
else:
op = col > key_vals[i] if is_desc else col < key_vals[i]
conds.append(and_(*ties, op))
return or_(*conds)
def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]:
out = []
for c in spec.cols:
key = getattr(c, "key", None) or getattr(c, "name", None)
out.append(getattr(obj, key))
return out
def seek_window(
self,
params: dict | None = None,
*,
key: list[Any] | None = None,
backward: bool = False,
include_total: bool = True,
) -> "SeekWindow[T]":
"""
Transport-agnostic keyset pagination that preserves all the goodies from `list()`:
- filters, includes, joins, field projection, eager loading, soft-delete
- deterministic ordering (user sort + PK tiebreakers)
- forward/backward seek via `key` and `backward`
Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total.
"""
params = params or {}
query, root_alias = self.get_query()
spec = CRUDSpec(self.model, params, root_alias)
# Soft delete filter
if self.supports_soft_delete and not _is_truthy(params.get("include_deleted")):
query = query.filter(getattr(root_alias, "is_deleted") == False)
# Parse filters first
filters = spec.parse_filters()
if filters:
query = query.filter(*filters)
# Includes + joins (so relationship fields like brand.name, location.label work)
spec.parse_includes()
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias)
query = query.join(target, rel_attr.of_type(target), isouter=True)
# Fields/projection: load_only for root columns, eager loads for relationships
root_fields, rel_field_names, root_field_names = spec.parse_fields()
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols))
for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
query = query.options(eager)
# Order + limit
order_by = spec.parse_sort()
order_spec = self._extract_order_spec(root_alias, order_by) # SA2-only helper
limit, _ = spec.parse_pagination()
if not limit or limit <= 0:
limit = 50 # sensible default
# Keyset predicate
if key:
pred = self._key_predicate(order_spec, key, backward)
if pred is not None:
query = query.filter(pred)
# Apply ordering. For backward, invert SQL order then reverse in-memory for display.
if not backward:
clauses = []
for col, is_desc in zip(order_spec.cols, order_spec.desc):
clauses.append(col.desc() if is_desc else col.asc())
query = query.order_by(*clauses).limit(limit)
items = query.all()
else:
inv_clauses = []
for col, is_desc in zip(order_spec.cols, order_spec.desc):
inv_clauses.append(col.asc() if is_desc else col.desc())
query = query.order_by(*inv_clauses).limit(limit)
items = list(reversed(query.all()))
# Tag projection so your renderer knows what fields were requested
proj = []
if root_field_names:
proj.extend(root_field_names)
if root_fields:
proj.extend(c.key for c in root_fields)
for path, names in (rel_field_names or {}).items():
prefix = ".".join(path)
for n in names:
proj.append(f"{prefix}.{n}")
if proj and "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
if proj:
for obj in items:
try:
setattr(obj, "__crudkit_projection__", tuple(proj))
except Exception:
pass
# Boundary keys for cursor encoding in the API layer
first_key = self._pluck_key(items[0], order_spec) if items else None
last_key = self._pluck_key(items[-1], order_spec) if items else None
# Optional total thats safe under JOINs (COUNT DISTINCT ids)
total = None
if include_total:
base = self.session.query(getattr(root_alias, "id"))
if self.supports_soft_delete and not _is_truthy(params.get("include_deleted")):
base = base.filter(getattr(root_alias, "is_deleted") == False)
if filters:
base = base.filter(*filters)
# replicate the same joins used above
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias)
base = base.join(target, rel_attr.of_type(target), isouter=True)
total = self.session.query(func.count()).select_from(base.order_by(None).distinct().subquery()).scalar() or 0
from crudkit.core.types import SeekWindow # avoid circulars at module top
return SeekWindow(
items=items,
limit=limit,
first_key=first_key,
last_key=last_key,
order=order_spec,
total=total,
)
# Helper: default ORDER BY for MSSQL when paginating without explicit order
def _default_order_by(self, root_alias):
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
@ -72,6 +252,25 @@ class CRUDService(Generic[T]):
cols.append(col)
return cols or [text("1")]
def _stable_order_by(self, root_alias, given_order_by):
"""
Ensure deterministic ordering by appending PK columns as tiebreakers.
If no order is provided, fall back to default primary-key order.
"""
order_by = list(given_order_by or [])
if not order_by:
return self._default_order_by(root_alias)
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
pk_cols = []
for col in mapper.primary_key:
try:
pk_cols.append(getattr(root_alias, col.key))
except ArithmeticError:
pk_cols.append(col)
return [*order_by, *pk_cols]
def get(self, id: int, params=None) -> T | None:
query, root_alias = self.get_query()

16
crudkit/core/types.py Normal file
View file

@ -0,0 +1,16 @@
from dataclasses import dataclass
from typing import Any, Sequence
@dataclass(frozen=True)
class OrderSpec:
cols: Sequence[Any]
desc: Sequence[bool]
@dataclass
class SeekWindow:
items: list[Any]
limit: int
first_key: list[Any] | None
last_key: list[Any] | None
order: OrderSpec
total: int | None = None

View file

@ -2,6 +2,7 @@ from flask import Blueprint, render_template, abort, request
import crudkit
from crudkit.api._cursor import decode_cursor, encode_cursor
from crudkit.ui.fragments import render_table
bp_listing = Blueprint("listing", __name__)
@ -89,14 +90,19 @@ def init_listing_routes(app):
{"when": {"field": "complete", "is": True}, "class": "table-success"},
{"when": {"field": "complete", "is": False}, "class": "table-danger"}
]
spec["limit"] = 15
spec["offset"] = (page_num - 1) * 15
limit = int(request.args.get("limit", 15))
cursor = request.args.get("cursor")
key, backward = decode_cursor(cursor)
service = crudkit.crud.get_service(cls)
rows = service.list(spec)
window = service.seek_window(spec | {"limit": limit}, key=key, backward=backward, include_total=True)
table = render_table(rows, columns=columns, opts={"object_class": model, "row_classes": row_classes})
return render_template("listing.html", model=model, table=table)
table = render_table(window.items, columns=columns, opts={"object_class": model, "row_classes": row_classes})
return render_template("listing.html", model=model, table=table, pagination={
"limit": window.limit,
"total": window.total,
"next_cursor": encode_cursor(window.last_key, list(window.order.desc), backward=False),
"prev_cursor": encode_cursor(window.first_key, list(window.order.desc), backward=True),
})
app.register_blueprint(bp_listing)