Added pagination.

This commit is contained in:
Yaro Kasear 2025-09-16 13:42:34 -05:00
parent a64c64e828
commit 27431a7150
3 changed files with 78 additions and 7 deletions

View file

@ -1,7 +1,7 @@
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
from typing import Any, Callable, Dict, Iterable, List, Tuple, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
from sqlalchemy import and_, func, inspect, or_, text
from sqlalchemy.engine import Engine, Connection
from sqlalchemy.orm import Load, Session, raiseload, with_polymorphic, Mapper
from sqlalchemy.orm import Load, Session, raiseload, selectinload, with_polymorphic, Mapper, RelationshipProperty
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.util import AliasedClass
from sqlalchemy.sql import operators
@ -65,6 +65,55 @@ class CRUDService(Generic[T]):
return self.session.query(poly), poly
return self.session.query(self.model), self.model
def _resolve_required_includes(self, root_alias: Any, rel_field_names: Dict[Tuple[str, ...], List[str]]) -> List[Any]:
"""
For each dotted path like ("location"), -> ["label"], look up the target
model's __crudkit_field_requires__ for the terminal field and produce
selectinload options prefixed with the relationship path, e.g.:
Room.__crudkit_field_requires__['label'] = ['room_function']
=> selectinload(root.location).selectinload(Room.room_function)
"""
opts: List[Any] = []
root_mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
for path, names in (rel_field_names or {}).items():
if not path:
continue
current_alias = root_alias
current_mapper = root_mapper
rel_props: List[RelationshipProperty] = []
valid = True
for step in path:
rel = current_mapper.relationships.get(step)
if rel is None:
valid = False
break
rel_props.append(rel)
current_mapper = cast(Mapper[Any], inspect(rel.entity.entity))
if not valid:
continue
target_cls = current_mapper.class_
requires = getattr(target_cls, "__crudkit_field_requires__", None)
if not isinstance(requires, dict):
continue
for field_name in names:
needed: Iterable[str] = requires.get(field_name, [])
for rel_need in needed:
loader = selectinload(getattr(root_alias, rel_props[0].key))
for rp in rel_props[1:]:
loader = loader.selectinload(getattr(getattr(root_alias, rp.parent.class_.__name__.lower(), None) or rp.parent.class_, rp.key))
loader = loader.selectinload(getattr(target_cls, rel_need))
opts.append(loader)
return opts
def _extract_order_spec(self, root_alias, given_order_by):
"""
SQLAlchemy 2.x only:
@ -137,17 +186,28 @@ class CRUDService(Generic[T]):
- forward/backward seek via `key` and `backward`
Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total.
"""
params = params or {}
session = self.session
query, root_alias = self.get_query()
spec = CRUDSpec(self.model, params, root_alias)
spec = CRUDSpec(self.model, params or {}, root_alias)
filters = spec.parse_filters()
order_by = spec.parse_sort()
root_fields, rel_field_names, root_field_names = spec.parse_fields()
for path, names in (rel_field_names or {}).items():
if "label" in names:
rel_name = path[0]
rel_attr = getattr(root_alias, rel_name, None)
if rel_attr is not None:
query = query.options(selectinload(rel_attr))
# Soft delete filter
if self.supports_soft_delete and not _is_truthy(params.get("include_deleted")):
query = query.filter(getattr(root_alias, "is_deleted") == False)
# Parse filters first
filters = spec.parse_filters()
if filters:
query = query.filter(*filters)
@ -159,15 +219,16 @@ class CRUDService(Generic[T]):
query = query.join(target, rel_attr.of_type(target), isouter=True)
# Fields/projection: load_only for root columns, eager loads for relationships
root_fields, rel_field_names, root_field_names = spec.parse_fields()
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols))
for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
query = query.options(eager)
for opt in self._resolve_required_includes(root_alias, rel_field_names):
query = query.options(opt)
# Order + limit
order_by = spec.parse_sort()
order_spec = self._extract_order_spec(root_alias, order_by) # SA2-only helper
limit, _ = spec.parse_pagination()
if not limit or limit <= 0: