Hooray for pager working now.

This commit is contained in:
Yaro Kasear 2025-08-28 14:41:16 -05:00
parent 52bd0d4c91
commit 8100d221a1
5 changed files with 186 additions and 119 deletions

View file

@ -1,8 +1,21 @@
from typing import List
from __future__ import annotations
from typing import Iterable, List, Sequence, Set
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import Load, joinedload, selectinload
from sqlalchemy.orm import Load, joinedload, selectinload, RelationshipProperty
def default_eager_policy(Model, expand: List[str]) -> List[Load]:
class EagerConfig:
def __init__(self, strict: bool = False, max_depth: int = 4):
self.strict = strict
self.max_depth = max_depth
def _rel(cls, name: str) -> RelationshipProperty | None:
return inspect(cls).relationships.get(name)
def _is_expandable(rel: RelationshipProperty) -> bool:
# Skip dynamic or viewonly collections; they dont support eagerload
return rel.lazy != "dynamic"
def default_eager_policy(Model, expand: Sequence[str], cfg: EagerConfig | None = None) -> List[Load]:
"""
Heuristic:
- many-to-one / one-to-one: joinedload
@ -12,31 +25,51 @@ def default_eager_policy(Model, expand: List[str]) -> List[Load]:
if not expand:
return []
cfg = cfg or EagerConfig()
# normalize, dedupe, and prefer longer paths over their prefixes
raw: Set[str] = {p.strip() for p in expand if p and p.strip()}
# drop prefixes if a longer path exists (author, author.publisher -> keep only author.publisher)
pruned: Set[str] = set(raw)
for p in raw:
parts = p.split(".")
for i in range(1, len(parts)):
pruned.discard(".".join(parts[:i]))
opts: List[Load] = []
seen: Set[tuple] = set()
for path in expand:
for path in sorted(pruned):
parts = path.split(".")
if len(parts) > cfg.max_depth:
if cfg.strict:
raise ValueError(f"expand path too deep: {path} (max {cfg.max_depth})")
continue
current_model = Model
current_inspect = inspect(current_model)
# build the chain incrementally
loader: Load | None = None
ok = True
# first hop
rel = current_inspect.relationships.get(parts[0])
if not rel:
continue # silently skip bad names
attr = getattr(current_model, parts[0])
loader: Load = selectinload(attr) if rel.uselist else joinedload(attr)
current_model = rel.mapper.class_
# nested hops, if any
for name in parts[1:]:
current_inspect = inspect(current_model)
rel = current_inspect.relationships.get(name)
if not rel:
for i, name in enumerate(parts):
rel = _rel(current_model, name)
if not rel or not _is_expandable(rel):
ok = False
break
attr = getattr(current_model, name)
loader = loader.selectinload(attr) if rel.uselist else loader.joinedload(attr)
if loader is None:
loader = selectinload(attr) if rel.uselist else joinedload(attr)
else:
loader = loader.selectinload(attr) if rel.uselist else loader.joinedload(attr)
current_model = rel.mapper.class_
opts.append(loader)
if not ok:
if cfg.strict:
raise ValueError(f"unknown or non-expandable relationship in expand path: {path}")
continue
key = (tuple(parts),)
if loader is not None and key not in seen:
opts.append(loader)
seen.add(key)
return opts