Pagination support!!!
This commit is contained in:
parent
3f677fceee
commit
a64c64e828
5 changed files with 298 additions and 12 deletions
|
|
@ -1,11 +1,15 @@
|
|||
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
|
||||
from sqlalchemy import and_, func, inspect, or_, text
|
||||
from sqlalchemy.engine import Engine, Connection
|
||||
from sqlalchemy.orm import Load, Session, raiseload, with_polymorphic, Mapper
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
from sqlalchemy.orm.util import AliasedClass
|
||||
from sqlalchemy.engine import Engine, Connection
|
||||
from sqlalchemy import inspect, text
|
||||
from sqlalchemy.sql import operators
|
||||
from sqlalchemy.sql.elements import UnaryExpression
|
||||
|
||||
from crudkit.core.base import Version
|
||||
from crudkit.core.spec import CRUDSpec
|
||||
from crudkit.core.types import OrderSpec, SeekWindow
|
||||
from crudkit.backend import BackendInfo, make_backend_info
|
||||
|
||||
@runtime_checkable
|
||||
|
|
@ -61,6 +65,182 @@ class CRUDService(Generic[T]):
|
|||
return self.session.query(poly), poly
|
||||
return self.session.query(self.model), self.model
|
||||
|
||||
def _extract_order_spec(self, root_alias, given_order_by):
|
||||
"""
|
||||
SQLAlchemy 2.x only:
|
||||
Normalize order_by into (cols, desc_flags). Supports plain columns and
|
||||
col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses.
|
||||
"""
|
||||
from sqlalchemy.sql import operators
|
||||
from sqlalchemy.sql.elements import UnaryExpression
|
||||
|
||||
given = self._stable_order_by(root_alias, given_order_by)
|
||||
|
||||
cols, desc_flags = [], []
|
||||
for ob in given:
|
||||
# Unwrap column if this is a UnaryExpression produced by .asc()/.desc()
|
||||
elem = getattr(ob, "element", None)
|
||||
col = elem if elem is not None else ob # don't use "or" with SA expressions
|
||||
|
||||
# Detect direction in SA 2.x
|
||||
is_desc = False
|
||||
dir_attr = getattr(ob, "_direction", None)
|
||||
if dir_attr is not None:
|
||||
is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC")
|
||||
elif isinstance(ob, UnaryExpression):
|
||||
op = getattr(ob, "operator", None)
|
||||
is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC")
|
||||
|
||||
cols.append(col)
|
||||
desc_flags.append(bool(is_desc))
|
||||
|
||||
from crudkit.core.types import OrderSpec
|
||||
return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags))
|
||||
|
||||
def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool):
|
||||
"""
|
||||
Build lexicographic predicate for keyset seek.
|
||||
For backward traversal, import comparisons.
|
||||
"""
|
||||
if not key_vals:
|
||||
return None
|
||||
conds = []
|
||||
for i, col in enumerate(spec.cols):
|
||||
ties = [spec.cols[j] == key_vals[j] for j in range(i)]
|
||||
is_desc = spec.desc[i]
|
||||
if not backward:
|
||||
op = col < key_vals[i] if is_desc else col > key_vals[i]
|
||||
else:
|
||||
op = col > key_vals[i] if is_desc else col < key_vals[i]
|
||||
conds.append(and_(*ties, op))
|
||||
return or_(*conds)
|
||||
|
||||
def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]:
|
||||
out = []
|
||||
for c in spec.cols:
|
||||
key = getattr(c, "key", None) or getattr(c, "name", None)
|
||||
out.append(getattr(obj, key))
|
||||
return out
|
||||
|
||||
def seek_window(
|
||||
self,
|
||||
params: dict | None = None,
|
||||
*,
|
||||
key: list[Any] | None = None,
|
||||
backward: bool = False,
|
||||
include_total: bool = True,
|
||||
) -> "SeekWindow[T]":
|
||||
"""
|
||||
Transport-agnostic keyset pagination that preserves all the goodies from `list()`:
|
||||
- filters, includes, joins, field projection, eager loading, soft-delete
|
||||
- deterministic ordering (user sort + PK tiebreakers)
|
||||
- forward/backward seek via `key` and `backward`
|
||||
Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total.
|
||||
"""
|
||||
params = params or {}
|
||||
query, root_alias = self.get_query()
|
||||
|
||||
spec = CRUDSpec(self.model, params, root_alias)
|
||||
|
||||
# Soft delete filter
|
||||
if self.supports_soft_delete and not _is_truthy(params.get("include_deleted")):
|
||||
query = query.filter(getattr(root_alias, "is_deleted") == False)
|
||||
|
||||
# Parse filters first
|
||||
filters = spec.parse_filters()
|
||||
if filters:
|
||||
query = query.filter(*filters)
|
||||
|
||||
# Includes + joins (so relationship fields like brand.name, location.label work)
|
||||
spec.parse_includes()
|
||||
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
|
||||
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
||||
target = cast(Any, target_alias)
|
||||
query = query.join(target, rel_attr.of_type(target), isouter=True)
|
||||
|
||||
# Fields/projection: load_only for root columns, eager loads for relationships
|
||||
root_fields, rel_field_names, root_field_names = spec.parse_fields()
|
||||
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
|
||||
if only_cols:
|
||||
query = query.options(Load(root_alias).load_only(*only_cols))
|
||||
for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
|
||||
query = query.options(eager)
|
||||
|
||||
# Order + limit
|
||||
order_by = spec.parse_sort()
|
||||
order_spec = self._extract_order_spec(root_alias, order_by) # SA2-only helper
|
||||
limit, _ = spec.parse_pagination()
|
||||
if not limit or limit <= 0:
|
||||
limit = 50 # sensible default
|
||||
|
||||
# Keyset predicate
|
||||
if key:
|
||||
pred = self._key_predicate(order_spec, key, backward)
|
||||
if pred is not None:
|
||||
query = query.filter(pred)
|
||||
|
||||
# Apply ordering. For backward, invert SQL order then reverse in-memory for display.
|
||||
if not backward:
|
||||
clauses = []
|
||||
for col, is_desc in zip(order_spec.cols, order_spec.desc):
|
||||
clauses.append(col.desc() if is_desc else col.asc())
|
||||
query = query.order_by(*clauses).limit(limit)
|
||||
items = query.all()
|
||||
else:
|
||||
inv_clauses = []
|
||||
for col, is_desc in zip(order_spec.cols, order_spec.desc):
|
||||
inv_clauses.append(col.asc() if is_desc else col.desc())
|
||||
query = query.order_by(*inv_clauses).limit(limit)
|
||||
items = list(reversed(query.all()))
|
||||
|
||||
# Tag projection so your renderer knows what fields were requested
|
||||
proj = []
|
||||
if root_field_names:
|
||||
proj.extend(root_field_names)
|
||||
if root_fields:
|
||||
proj.extend(c.key for c in root_fields)
|
||||
for path, names in (rel_field_names or {}).items():
|
||||
prefix = ".".join(path)
|
||||
for n in names:
|
||||
proj.append(f"{prefix}.{n}")
|
||||
if proj and "id" not in proj and hasattr(self.model, "id"):
|
||||
proj.insert(0, "id")
|
||||
if proj:
|
||||
for obj in items:
|
||||
try:
|
||||
setattr(obj, "__crudkit_projection__", tuple(proj))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Boundary keys for cursor encoding in the API layer
|
||||
first_key = self._pluck_key(items[0], order_spec) if items else None
|
||||
last_key = self._pluck_key(items[-1], order_spec) if items else None
|
||||
|
||||
# Optional total that’s safe under JOINs (COUNT DISTINCT ids)
|
||||
total = None
|
||||
if include_total:
|
||||
base = self.session.query(getattr(root_alias, "id"))
|
||||
if self.supports_soft_delete and not _is_truthy(params.get("include_deleted")):
|
||||
base = base.filter(getattr(root_alias, "is_deleted") == False)
|
||||
if filters:
|
||||
base = base.filter(*filters)
|
||||
# replicate the same joins used above
|
||||
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
|
||||
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
||||
target = cast(Any, target_alias)
|
||||
base = base.join(target, rel_attr.of_type(target), isouter=True)
|
||||
total = self.session.query(func.count()).select_from(base.order_by(None).distinct().subquery()).scalar() or 0
|
||||
|
||||
from crudkit.core.types import SeekWindow # avoid circulars at module top
|
||||
return SeekWindow(
|
||||
items=items,
|
||||
limit=limit,
|
||||
first_key=first_key,
|
||||
last_key=last_key,
|
||||
order=order_spec,
|
||||
total=total,
|
||||
)
|
||||
|
||||
# Helper: default ORDER BY for MSSQL when paginating without explicit order
|
||||
def _default_order_by(self, root_alias):
|
||||
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
|
||||
|
|
@ -72,6 +252,25 @@ class CRUDService(Generic[T]):
|
|||
cols.append(col)
|
||||
return cols or [text("1")]
|
||||
|
||||
def _stable_order_by(self, root_alias, given_order_by):
|
||||
"""
|
||||
Ensure deterministic ordering by appending PK columns as tiebreakers.
|
||||
If no order is provided, fall back to default primary-key order.
|
||||
"""
|
||||
order_by = list(given_order_by or [])
|
||||
if not order_by:
|
||||
return self._default_order_by(root_alias)
|
||||
|
||||
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
|
||||
pk_cols = []
|
||||
for col in mapper.primary_key:
|
||||
try:
|
||||
pk_cols.append(getattr(root_alias, col.key))
|
||||
except ArithmeticError:
|
||||
pk_cols.append(col)
|
||||
|
||||
return [*order_by, *pk_cols]
|
||||
|
||||
def get(self, id: int, params=None) -> T | None:
|
||||
query, root_alias = self.get_query()
|
||||
|
||||
|
|
|
|||
16
crudkit/core/types.py
Normal file
16
crudkit/core/types.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, Sequence
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OrderSpec:
|
||||
cols: Sequence[Any]
|
||||
desc: Sequence[bool]
|
||||
|
||||
@dataclass
|
||||
class SeekWindow:
|
||||
items: list[Any]
|
||||
limit: int
|
||||
first_key: list[Any] | None
|
||||
last_key: list[Any] | None
|
||||
order: OrderSpec
|
||||
total: int | None = None
|
||||
Loading…
Add table
Add a link
Reference in a new issue