inventory/crudkit/eager.py
2025-08-28 14:41:16 -05:00

75 lines
2.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
from typing import Iterable, List, Sequence, Set
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import Load, joinedload, selectinload, RelationshipProperty
class EagerConfig:
def __init__(self, strict: bool = False, max_depth: int = 4):
self.strict = strict
self.max_depth = max_depth
def _rel(cls, name: str) -> RelationshipProperty | None:
return inspect(cls).relationships.get(name)
def _is_expandable(rel: RelationshipProperty) -> bool:
# Skip dynamic or viewonly collections; they dont support eagerload
return rel.lazy != "dynamic"
def default_eager_policy(Model, expand: Sequence[str], cfg: EagerConfig | None = None) -> List[Load]:
"""
Heuristic:
- many-to-one / one-to-one: joinedload
- one-to-many / many-to-many: selectinload
Accepts dotted paths like "author.publisher".
"""
if not expand:
return []
cfg = cfg or EagerConfig()
# normalize, dedupe, and prefer longer paths over their prefixes
raw: Set[str] = {p.strip() for p in expand if p and p.strip()}
# drop prefixes if a longer path exists (author, author.publisher -> keep only author.publisher)
pruned: Set[str] = set(raw)
for p in raw:
parts = p.split(".")
for i in range(1, len(parts)):
pruned.discard(".".join(parts[:i]))
opts: List[Load] = []
seen: Set[tuple] = set()
for path in sorted(pruned):
parts = path.split(".")
if len(parts) > cfg.max_depth:
if cfg.strict:
raise ValueError(f"expand path too deep: {path} (max {cfg.max_depth})")
continue
current_model = Model
# build the chain incrementally
loader: Load | None = None
ok = True
for i, name in enumerate(parts):
rel = _rel(current_model, name)
if not rel or not _is_expandable(rel):
ok = False
break
attr = getattr(current_model, name)
if loader is None:
loader = selectinload(attr) if rel.uselist else joinedload(attr)
else:
loader = loader.selectinload(attr) if rel.uselist else loader.joinedload(attr)
current_model = rel.mapper.class_
if not ok:
if cfg.strict:
raise ValueError(f"unknown or non-expandable relationship in expand path: {path}")
continue
key = (tuple(parts),)
if loader is not None and key not in seen:
opts.append(loader)
seen.add(key)
return opts