from __future__ import annotations from typing import Iterable, List, Sequence, Set from sqlalchemy.inspection import inspect from sqlalchemy.orm import Load, joinedload, selectinload, RelationshipProperty class EagerConfig: def __init__(self, strict: bool = False, max_depth: int = 4): self.strict = strict self.max_depth = max_depth def _rel(cls, name: str) -> RelationshipProperty | None: return inspect(cls).relationships.get(name) def _is_expandable(rel: RelationshipProperty) -> bool: # Skip dynamic or viewonly collections; they don’t support eagerload return rel.lazy != "dynamic" def default_eager_policy(Model, expand: Sequence[str], cfg: EagerConfig | None = None) -> List[Load]: """ Heuristic: - many-to-one / one-to-one: joinedload - one-to-many / many-to-many: selectinload Accepts dotted paths like "author.publisher". """ if not expand: return [] cfg = cfg or EagerConfig() # normalize, dedupe, and prefer longer paths over their prefixes raw: Set[str] = {p.strip() for p in expand if p and p.strip()} # drop prefixes if a longer path exists (author, author.publisher -> keep only author.publisher) pruned: Set[str] = set(raw) for p in raw: parts = p.split(".") for i in range(1, len(parts)): pruned.discard(".".join(parts[:i])) opts: List[Load] = [] seen: Set[tuple] = set() for path in sorted(pruned): parts = path.split(".") if len(parts) > cfg.max_depth: if cfg.strict: raise ValueError(f"expand path too deep: {path} (max {cfg.max_depth})") continue current_model = Model # build the chain incrementally loader: Load | None = None ok = True for i, name in enumerate(parts): rel = _rel(current_model, name) if not rel or not _is_expandable(rel): ok = False break attr = getattr(current_model, name) if loader is None: loader = selectinload(attr) if rel.uselist else joinedload(attr) else: loader = loader.selectinload(attr) if rel.uselist else loader.joinedload(attr) current_model = rel.mapper.class_ if not ok: if cfg.strict: raise ValueError(f"unknown or non-expandable relationship in expand path: {path}") continue key = (tuple(parts),) if loader is not None and key not in seen: opts.append(loader) seen.add(key) return opts