Pagination/cursor feature added.
This commit is contained in:
parent
daf0684ebe
commit
7aefefdec6
4 changed files with 349 additions and 8 deletions
21
crudkit/api/_cursor.py
Normal file
21
crudkit/api/_cursor.py
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
import base64, json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
def encode_cursor(values: list[Any] | None, desc_flags: list[bool], backward: bool) -> str | None:
|
||||||
|
if not values:
|
||||||
|
return None
|
||||||
|
payload = {"v": values, "d": desc_flags, "b": backward}
|
||||||
|
return base64.urlsafe_b64encode(json.dumps(payload).encode()).decode()
|
||||||
|
|
||||||
|
def decode_cursor(token: str | None) -> tuple[list[Any] | None, bool] | tuple[None, bool]:
|
||||||
|
if not token:
|
||||||
|
return None, False
|
||||||
|
try:
|
||||||
|
obj = json.loads(base64.urlsafe_b64decode(token.encode()).decode())
|
||||||
|
vals = obj.get("v")
|
||||||
|
backward = bool(obj.get("b", False))
|
||||||
|
if isinstance(vals, list):
|
||||||
|
return vals, backward
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None, False
|
||||||
|
|
@ -1,15 +1,59 @@
|
||||||
from flask import Blueprint, jsonify, request
|
from flask import Blueprint, jsonify, request
|
||||||
|
|
||||||
|
from crudkit.api._cursor import encode_cursor, decode_cursor
|
||||||
|
from crudkit.core.service import _is_truthy
|
||||||
|
|
||||||
def generate_crud_blueprint(model, service):
|
def generate_crud_blueprint(model, service):
|
||||||
bp = Blueprint(model.__name__.lower(), __name__)
|
bp = Blueprint(model.__name__.lower(), __name__)
|
||||||
|
|
||||||
@bp.get('/')
|
@bp.get('/')
|
||||||
def list_items():
|
def list_items():
|
||||||
items = service.list(request.args)
|
args = request.args.to_dict(flat=True)
|
||||||
|
|
||||||
|
# legacy detection
|
||||||
|
legacy_offset = "offset" in args or "page" in args
|
||||||
|
|
||||||
|
# sane limit default
|
||||||
try:
|
try:
|
||||||
return jsonify([item.as_dict() for item in items])
|
limit = int(args.get("limit", 50))
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return jsonify({"status": "error", "error": str(e)})
|
limit = 50
|
||||||
|
args["limit"] = limit
|
||||||
|
|
||||||
|
if legacy_offset:
|
||||||
|
# Old behavior: honor limit/offset, same CRUDSpec goodies
|
||||||
|
items = service.list(args)
|
||||||
|
return jsonify([obj.as_dict() for obj in items])
|
||||||
|
|
||||||
|
# New behavior: keyset seek with cursors
|
||||||
|
key, backward = decode_cursor(args.get("cursor"))
|
||||||
|
|
||||||
|
window = service.seek_window(
|
||||||
|
args,
|
||||||
|
key=key,
|
||||||
|
backward=backward,
|
||||||
|
include_total=_is_truthy(args.get("include_total", "1")),
|
||||||
|
)
|
||||||
|
|
||||||
|
desc_flags = list(window.order.desc)
|
||||||
|
body = {
|
||||||
|
"items": [obj.as_dict() for obj in window.items],
|
||||||
|
"limit": window.limit,
|
||||||
|
"next_cursor": encode_cursor(window.last_key, desc_flags, backward=False),
|
||||||
|
"prev_cursor": encode_cursor(window.first_key, desc_flags, backward=True),
|
||||||
|
"total": window.total,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = jsonify(body)
|
||||||
|
# Optional Link header
|
||||||
|
links = []
|
||||||
|
if body["next_cursor"]:
|
||||||
|
links.append(f'<{request.base_url}?cursor={body["next_cursor"]}&limit={window.limit}>; rel="next"')
|
||||||
|
if body["prev_cursor"]:
|
||||||
|
links.append(f'<{request.base_url}?cursor={body["prev_cursor"]}&limit={window.limit}>; rel="prev"')
|
||||||
|
if links:
|
||||||
|
resp.headers["Link"] = ", ".join(links)
|
||||||
|
return resp
|
||||||
|
|
||||||
@bp.get('/<int:id>')
|
@bp.get('/<int:id>')
|
||||||
def get_item(id):
|
def get_item(id):
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,15 @@
|
||||||
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
|
from typing import Any, Callable, Dict, Iterable, List, Tuple, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
|
||||||
from sqlalchemy.orm import Load, Session, raiseload, with_polymorphic, Mapper
|
from sqlalchemy import and_, func, inspect, or_, text
|
||||||
|
from sqlalchemy.engine import Engine, Connection
|
||||||
|
from sqlalchemy.orm import Load, Session, raiseload, selectinload, with_polymorphic, Mapper, RelationshipProperty
|
||||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||||
from sqlalchemy.orm.util import AliasedClass
|
from sqlalchemy.orm.util import AliasedClass
|
||||||
from sqlalchemy.engine import Engine, Connection
|
from sqlalchemy.sql import operators
|
||||||
from sqlalchemy import inspect, text
|
from sqlalchemy.sql.elements import UnaryExpression
|
||||||
|
|
||||||
from crudkit.core.base import Version
|
from crudkit.core.base import Version
|
||||||
from crudkit.core.spec import CRUDSpec
|
from crudkit.core.spec import CRUDSpec
|
||||||
|
from crudkit.core.types import OrderSpec, SeekWindow
|
||||||
from crudkit.backend import BackendInfo, make_backend_info
|
from crudkit.backend import BackendInfo, make_backend_info
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
|
@ -61,6 +65,243 @@ class CRUDService(Generic[T]):
|
||||||
return self.session.query(poly), poly
|
return self.session.query(poly), poly
|
||||||
return self.session.query(self.model), self.model
|
return self.session.query(self.model), self.model
|
||||||
|
|
||||||
|
def _resolve_required_includes(self, root_alias: Any, rel_field_names: Dict[Tuple[str, ...], List[str]]) -> List[Any]:
|
||||||
|
"""
|
||||||
|
For each dotted path like ("location"), -> ["label"], look up the target
|
||||||
|
model's __crudkit_field_requires__ for the terminal field and produce
|
||||||
|
selectinload options prefixed with the relationship path, e.g.:
|
||||||
|
Room.__crudkit_field_requires__['label'] = ['room_function']
|
||||||
|
=> selectinload(root.location).selectinload(Room.room_function)
|
||||||
|
"""
|
||||||
|
opts: List[Any] = []
|
||||||
|
|
||||||
|
root_mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
|
||||||
|
|
||||||
|
for path, names in (rel_field_names or {}).items():
|
||||||
|
if not path:
|
||||||
|
continue
|
||||||
|
|
||||||
|
current_alias = root_alias
|
||||||
|
current_mapper = root_mapper
|
||||||
|
rel_props: List[RelationshipProperty] = []
|
||||||
|
|
||||||
|
valid = True
|
||||||
|
for step in path:
|
||||||
|
rel = current_mapper.relationships.get(step)
|
||||||
|
if rel is None:
|
||||||
|
valid = False
|
||||||
|
break
|
||||||
|
rel_props.append(rel)
|
||||||
|
current_mapper = cast(Mapper[Any], inspect(rel.entity.entity))
|
||||||
|
if not valid:
|
||||||
|
continue
|
||||||
|
|
||||||
|
target_cls = current_mapper.class_
|
||||||
|
|
||||||
|
requires = getattr(target_cls, "__crudkit_field_requires__", None)
|
||||||
|
if not isinstance(requires, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for field_name in names:
|
||||||
|
needed: Iterable[str] = requires.get(field_name, [])
|
||||||
|
for rel_need in needed:
|
||||||
|
loader = selectinload(getattr(root_alias, rel_props[0].key))
|
||||||
|
for rp in rel_props[1:]:
|
||||||
|
loader = loader.selectinload(getattr(getattr(root_alias, rp.parent.class_.__name__.lower(), None) or rp.parent.class_, rp.key))
|
||||||
|
|
||||||
|
loader = loader.selectinload(getattr(target_cls, rel_need))
|
||||||
|
opts.append(loader)
|
||||||
|
|
||||||
|
return opts
|
||||||
|
|
||||||
|
def _extract_order_spec(self, root_alias, given_order_by):
|
||||||
|
"""
|
||||||
|
SQLAlchemy 2.x only:
|
||||||
|
Normalize order_by into (cols, desc_flags). Supports plain columns and
|
||||||
|
col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses.
|
||||||
|
"""
|
||||||
|
from sqlalchemy.sql import operators
|
||||||
|
from sqlalchemy.sql.elements import UnaryExpression
|
||||||
|
|
||||||
|
given = self._stable_order_by(root_alias, given_order_by)
|
||||||
|
|
||||||
|
cols, desc_flags = [], []
|
||||||
|
for ob in given:
|
||||||
|
# Unwrap column if this is a UnaryExpression produced by .asc()/.desc()
|
||||||
|
elem = getattr(ob, "element", None)
|
||||||
|
col = elem if elem is not None else ob # don't use "or" with SA expressions
|
||||||
|
|
||||||
|
# Detect direction in SA 2.x
|
||||||
|
is_desc = False
|
||||||
|
dir_attr = getattr(ob, "_direction", None)
|
||||||
|
if dir_attr is not None:
|
||||||
|
is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC")
|
||||||
|
elif isinstance(ob, UnaryExpression):
|
||||||
|
op = getattr(ob, "operator", None)
|
||||||
|
is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC")
|
||||||
|
|
||||||
|
cols.append(col)
|
||||||
|
desc_flags.append(bool(is_desc))
|
||||||
|
|
||||||
|
from crudkit.core.types import OrderSpec
|
||||||
|
return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags))
|
||||||
|
|
||||||
|
def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool):
|
||||||
|
"""
|
||||||
|
Build lexicographic predicate for keyset seek.
|
||||||
|
For backward traversal, import comparisons.
|
||||||
|
"""
|
||||||
|
if not key_vals:
|
||||||
|
return None
|
||||||
|
conds = []
|
||||||
|
for i, col in enumerate(spec.cols):
|
||||||
|
ties = [spec.cols[j] == key_vals[j] for j in range(i)]
|
||||||
|
is_desc = spec.desc[i]
|
||||||
|
if not backward:
|
||||||
|
op = col < key_vals[i] if is_desc else col > key_vals[i]
|
||||||
|
else:
|
||||||
|
op = col > key_vals[i] if is_desc else col < key_vals[i]
|
||||||
|
conds.append(and_(*ties, op))
|
||||||
|
return or_(*conds)
|
||||||
|
|
||||||
|
def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]:
|
||||||
|
out = []
|
||||||
|
for c in spec.cols:
|
||||||
|
key = getattr(c, "key", None) or getattr(c, "name", None)
|
||||||
|
out.append(getattr(obj, key))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def seek_window(
|
||||||
|
self,
|
||||||
|
params: dict | None = None,
|
||||||
|
*,
|
||||||
|
key: list[Any] | None = None,
|
||||||
|
backward: bool = False,
|
||||||
|
include_total: bool = True,
|
||||||
|
) -> "SeekWindow[T]":
|
||||||
|
"""
|
||||||
|
Transport-agnostic keyset pagination that preserves all the goodies from `list()`:
|
||||||
|
- filters, includes, joins, field projection, eager loading, soft-delete
|
||||||
|
- deterministic ordering (user sort + PK tiebreakers)
|
||||||
|
- forward/backward seek via `key` and `backward`
|
||||||
|
Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total.
|
||||||
|
"""
|
||||||
|
session = self.session
|
||||||
|
query, root_alias = self.get_query()
|
||||||
|
|
||||||
|
spec = CRUDSpec(self.model, params or {}, root_alias)
|
||||||
|
|
||||||
|
filters = spec.parse_filters()
|
||||||
|
order_by = spec.parse_sort()
|
||||||
|
|
||||||
|
root_fields, rel_field_names, root_field_names = spec.parse_fields()
|
||||||
|
|
||||||
|
for path, names in (rel_field_names or {}).items():
|
||||||
|
if "label" in names:
|
||||||
|
rel_name = path[0]
|
||||||
|
rel_attr = getattr(root_alias, rel_name, None)
|
||||||
|
if rel_attr is not None:
|
||||||
|
query = query.options(selectinload(rel_attr))
|
||||||
|
|
||||||
|
# Soft delete filter
|
||||||
|
if self.supports_soft_delete and not _is_truthy(params.get("include_deleted")):
|
||||||
|
query = query.filter(getattr(root_alias, "is_deleted") == False)
|
||||||
|
|
||||||
|
# Parse filters first
|
||||||
|
if filters:
|
||||||
|
query = query.filter(*filters)
|
||||||
|
|
||||||
|
# Includes + joins (so relationship fields like brand.name, location.label work)
|
||||||
|
spec.parse_includes()
|
||||||
|
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
|
||||||
|
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
||||||
|
target = cast(Any, target_alias)
|
||||||
|
query = query.join(target, rel_attr.of_type(target), isouter=True)
|
||||||
|
|
||||||
|
# Fields/projection: load_only for root columns, eager loads for relationships
|
||||||
|
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
|
||||||
|
if only_cols:
|
||||||
|
query = query.options(Load(root_alias).load_only(*only_cols))
|
||||||
|
for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
|
||||||
|
query = query.options(eager)
|
||||||
|
|
||||||
|
for opt in self._resolve_required_includes(root_alias, rel_field_names):
|
||||||
|
query = query.options(opt)
|
||||||
|
|
||||||
|
# Order + limit
|
||||||
|
order_spec = self._extract_order_spec(root_alias, order_by) # SA2-only helper
|
||||||
|
limit, _ = spec.parse_pagination()
|
||||||
|
if not limit or limit <= 0:
|
||||||
|
limit = 50 # sensible default
|
||||||
|
|
||||||
|
# Keyset predicate
|
||||||
|
if key:
|
||||||
|
pred = self._key_predicate(order_spec, key, backward)
|
||||||
|
if pred is not None:
|
||||||
|
query = query.filter(pred)
|
||||||
|
|
||||||
|
# Apply ordering. For backward, invert SQL order then reverse in-memory for display.
|
||||||
|
if not backward:
|
||||||
|
clauses = []
|
||||||
|
for col, is_desc in zip(order_spec.cols, order_spec.desc):
|
||||||
|
clauses.append(col.desc() if is_desc else col.asc())
|
||||||
|
query = query.order_by(*clauses).limit(limit)
|
||||||
|
items = query.all()
|
||||||
|
else:
|
||||||
|
inv_clauses = []
|
||||||
|
for col, is_desc in zip(order_spec.cols, order_spec.desc):
|
||||||
|
inv_clauses.append(col.asc() if is_desc else col.desc())
|
||||||
|
query = query.order_by(*inv_clauses).limit(limit)
|
||||||
|
items = list(reversed(query.all()))
|
||||||
|
|
||||||
|
# Tag projection so your renderer knows what fields were requested
|
||||||
|
proj = []
|
||||||
|
if root_field_names:
|
||||||
|
proj.extend(root_field_names)
|
||||||
|
if root_fields:
|
||||||
|
proj.extend(c.key for c in root_fields)
|
||||||
|
for path, names in (rel_field_names or {}).items():
|
||||||
|
prefix = ".".join(path)
|
||||||
|
for n in names:
|
||||||
|
proj.append(f"{prefix}.{n}")
|
||||||
|
if proj and "id" not in proj and hasattr(self.model, "id"):
|
||||||
|
proj.insert(0, "id")
|
||||||
|
if proj:
|
||||||
|
for obj in items:
|
||||||
|
try:
|
||||||
|
setattr(obj, "__crudkit_projection__", tuple(proj))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Boundary keys for cursor encoding in the API layer
|
||||||
|
first_key = self._pluck_key(items[0], order_spec) if items else None
|
||||||
|
last_key = self._pluck_key(items[-1], order_spec) if items else None
|
||||||
|
|
||||||
|
# Optional total that’s safe under JOINs (COUNT DISTINCT ids)
|
||||||
|
total = None
|
||||||
|
if include_total:
|
||||||
|
base = self.session.query(getattr(root_alias, "id"))
|
||||||
|
if self.supports_soft_delete and not _is_truthy(params.get("include_deleted")):
|
||||||
|
base = base.filter(getattr(root_alias, "is_deleted") == False)
|
||||||
|
if filters:
|
||||||
|
base = base.filter(*filters)
|
||||||
|
# replicate the same joins used above
|
||||||
|
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
|
||||||
|
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
||||||
|
target = cast(Any, target_alias)
|
||||||
|
base = base.join(target, rel_attr.of_type(target), isouter=True)
|
||||||
|
total = self.session.query(func.count()).select_from(base.order_by(None).distinct().subquery()).scalar() or 0
|
||||||
|
|
||||||
|
from crudkit.core.types import SeekWindow # avoid circulars at module top
|
||||||
|
return SeekWindow(
|
||||||
|
items=items,
|
||||||
|
limit=limit,
|
||||||
|
first_key=first_key,
|
||||||
|
last_key=last_key,
|
||||||
|
order=order_spec,
|
||||||
|
total=total,
|
||||||
|
)
|
||||||
|
|
||||||
# Helper: default ORDER BY for MSSQL when paginating without explicit order
|
# Helper: default ORDER BY for MSSQL when paginating without explicit order
|
||||||
def _default_order_by(self, root_alias):
|
def _default_order_by(self, root_alias):
|
||||||
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
|
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
|
||||||
|
|
@ -72,6 +313,25 @@ class CRUDService(Generic[T]):
|
||||||
cols.append(col)
|
cols.append(col)
|
||||||
return cols or [text("1")]
|
return cols or [text("1")]
|
||||||
|
|
||||||
|
def _stable_order_by(self, root_alias, given_order_by):
|
||||||
|
"""
|
||||||
|
Ensure deterministic ordering by appending PK columns as tiebreakers.
|
||||||
|
If no order is provided, fall back to default primary-key order.
|
||||||
|
"""
|
||||||
|
order_by = list(given_order_by or [])
|
||||||
|
if not order_by:
|
||||||
|
return self._default_order_by(root_alias)
|
||||||
|
|
||||||
|
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
|
||||||
|
pk_cols = []
|
||||||
|
for col in mapper.primary_key:
|
||||||
|
try:
|
||||||
|
pk_cols.append(getattr(root_alias, col.key))
|
||||||
|
except ArithmeticError:
|
||||||
|
pk_cols.append(col)
|
||||||
|
|
||||||
|
return [*order_by, *pk_cols]
|
||||||
|
|
||||||
def get(self, id: int, params=None) -> T | None:
|
def get(self, id: int, params=None) -> T | None:
|
||||||
query, root_alias = self.get_query()
|
query, root_alias = self.get_query()
|
||||||
|
|
||||||
|
|
|
||||||
16
crudkit/core/types.py
Normal file
16
crudkit/core/types.py
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Sequence
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class OrderSpec:
|
||||||
|
cols: Sequence[Any]
|
||||||
|
desc: Sequence[bool]
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SeekWindow:
|
||||||
|
items: list[Any]
|
||||||
|
limit: int
|
||||||
|
first_key: list[Any] | None
|
||||||
|
last_key: list[Any] | None
|
||||||
|
order: OrderSpec
|
||||||
|
total: int | None = None
|
||||||
Loading…
Add table
Add a link
Reference in a new issue