75 lines
2.6 KiB
Python
75 lines
2.6 KiB
Python
from __future__ import annotations
|
||
from typing import Iterable, List, Sequence, Set
|
||
from sqlalchemy.inspection import inspect
|
||
from sqlalchemy.orm import Load, joinedload, selectinload, RelationshipProperty
|
||
|
||
class EagerConfig:
|
||
def __init__(self, strict: bool = False, max_depth: int = 4):
|
||
self.strict = strict
|
||
self.max_depth = max_depth
|
||
|
||
def _rel(cls, name: str) -> RelationshipProperty | None:
|
||
return inspect(cls).relationships.get(name)
|
||
|
||
def _is_expandable(rel: RelationshipProperty) -> bool:
|
||
# Skip dynamic or viewonly collections; they don’t support eagerload
|
||
return rel.lazy != "dynamic"
|
||
|
||
def default_eager_policy(Model, expand: Sequence[str], cfg: EagerConfig | None = None) -> List[Load]:
|
||
"""
|
||
Heuristic:
|
||
- many-to-one / one-to-one: joinedload
|
||
- one-to-many / many-to-many: selectinload
|
||
Accepts dotted paths like "author.publisher".
|
||
"""
|
||
if not expand:
|
||
return []
|
||
|
||
cfg = cfg or EagerConfig()
|
||
# normalize, dedupe, and prefer longer paths over their prefixes
|
||
raw: Set[str] = {p.strip() for p in expand if p and p.strip()}
|
||
# drop prefixes if a longer path exists (author, author.publisher -> keep only author.publisher)
|
||
pruned: Set[str] = set(raw)
|
||
for p in raw:
|
||
parts = p.split(".")
|
||
for i in range(1, len(parts)):
|
||
pruned.discard(".".join(parts[:i]))
|
||
|
||
opts: List[Load] = []
|
||
seen: Set[tuple] = set()
|
||
|
||
for path in sorted(pruned):
|
||
parts = path.split(".")
|
||
if len(parts) > cfg.max_depth:
|
||
if cfg.strict:
|
||
raise ValueError(f"expand path too deep: {path} (max {cfg.max_depth})")
|
||
continue
|
||
|
||
current_model = Model
|
||
# build the chain incrementally
|
||
loader: Load | None = None
|
||
ok = True
|
||
|
||
for i, name in enumerate(parts):
|
||
rel = _rel(current_model, name)
|
||
if not rel or not _is_expandable(rel):
|
||
ok = False
|
||
break
|
||
attr = getattr(current_model, name)
|
||
if loader is None:
|
||
loader = selectinload(attr) if rel.uselist else joinedload(attr)
|
||
else:
|
||
loader = loader.selectinload(attr) if rel.uselist else loader.joinedload(attr)
|
||
current_model = rel.mapper.class_
|
||
|
||
if not ok:
|
||
if cfg.strict:
|
||
raise ValueError(f"unknown or non-expandable relationship in expand path: {path}")
|
||
continue
|
||
|
||
key = (tuple(parts),)
|
||
if loader is not None and key not in seen:
|
||
opts.append(loader)
|
||
seen.add(key)
|
||
|
||
return opts
|