Starting to get nested loading working. Still a WIP.

This commit is contained in:
Yaro Kasear 2025-09-26 15:50:35 -05:00
parent 97891961e1
commit 4c56149f1b
2 changed files with 168 additions and 34 deletions

View file

@ -1,6 +1,6 @@
from typing import Any, Dict, Iterable, List, Tuple from typing import Any, Dict, Iterable, List, Tuple, Set
from sqlalchemy import Column, Integer, DateTime, Boolean, String, JSON, func, inspect from sqlalchemy import Column, Integer, DateTime, Boolean, String, JSON, func, inspect
from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE, RelationshipProperty
Base = declarative_base() Base = declarative_base()
@ -17,6 +17,69 @@ def _safe_get_loaded_attr(obj, name):
except Exception: except Exception:
return None return None
def _identity_key(obj) -> Tuple[type, Any]:
try:
st = inspect(obj)
return (type(obj), st.identity_key[1][0] if st.identity_key else id(obj))
except Exception:
return (type(obj), id(obj))
def _is_collection_rel(prop: RelationshipProperty) -> bool:
try:
return prop.uselist is True
except Exception:
return False
def _serialize_simple_obj(obj) -> Dict[str, Any]:
"""Columns only (no relationships)."""
out: Dict[str, Any] = {}
for cls in obj.__class__.__mro__:
if hasattr(cls, "__table__"):
for col in cls.__table__.columns:
name = col.name
try:
out[name] = getattr(obj, name)
except Exception:
out[name] = None
return out
def _serialize_loaded_rel(obj, name, *, depth: int, seen: Set[Tuple[type, Any]], embed: Set[str]) -> Any:
"""
Serialize relationship 'name' already loaded on obj.
- If in 'embed' (or depth > 0 for depth-based walk), recurse.
- Else, return None (dont lazy-load).
"""
val = _safe_get_loaded_attr(obj, name)
if val is None:
return None
# Decide whether to recurse into this relationship
should_recurse = (depth > 0) or (name in embed)
if isinstance(val, list):
if not should_recurse:
# Emit a light list of child primary data (id + a couple columns) without recursion.
return [_serialize_simple_obj(child) for child in val]
out = []
for child in val:
ik = _identity_key(child)
if ik in seen: # cycle guard
out.append({"id": getattr(child, "id", None)})
continue
seen.add(ik)
out.append(child.as_tree(embed_depth=max(depth - 1, 0), embed=embed, _seen=seen))
return out
# Scalar relationship
child = val
if not should_recurse:
return _serialize_simple_obj(child)
ik = _identity_key(child)
if ik in seen:
return {"id": getattr(child, "id", None)}
seen.add(ik)
return child.as_tree(embed_depth=max(depth - 1, 0), embed=embed, _seen=seen)
def _split_field_tokens(fields: Iterable[str]) -> Tuple[List[str], Dict[str, List[str]]]: def _split_field_tokens(fields: Iterable[str]) -> Tuple[List[str], Dict[str, List[str]]]:
""" """
Split requested fields into: Split requested fields into:
@ -116,6 +179,49 @@ class CRUDMixin:
created_at = Column(DateTime, default=func.now(), nullable=False) created_at = Column(DateTime, default=func.now(), nullable=False)
updated_at = Column(DateTime, default=func.now(), nullable=False, onupdate=func.now()) updated_at = Column(DateTime, default=func.now(), nullable=False, onupdate=func.now())
def as_tree(
self,
*,
embed_depth: int = 0,
embed: Iterable[str] | None = None,
_seen: Set[Tuple[type, Any]] | None = None,
) -> Dict[str, Any]:
"""
Recursive, NON-LAZY serializer.
- Always includes mapped columns.
- For relationships: only serializes those ALREADY LOADED.
- Recurses either up to embed_depth or for specific names in 'embed'.
- Keeps *_id columns alongside embedded objects.
- Cycle-safe via _seen.
"""
seen = _seen or set()
ik = _identity_key(self)
if ik in seen:
return {"id": getattr(self, "id", None)}
seen.add(ik)
data = _serialize_simple_obj(self)
# Determine which relationships to consider
try:
st = inspect(self)
mapper = st.mapper
embed_set = set(str(x).split(".", 1)[0] for x in (embed or [])) # top-level names
for name, prop in mapper.relationships.items():
# Only touch relationships that are already loaded; never lazy-load here.
rel_loaded = st.attrs.get(name)
if rel_loaded is None or rel_loaded.loaded_value is NO_VALUE:
continue
data[name] = _serialize_loaded_rel(
self, name, depth=embed_depth, seen=seen, embed=embed_set
)
except Exception:
# If inspection fails, we just return columns.
pass
return data
def as_dict(self, fields: list[str] | None = None): def as_dict(self, fields: list[str] | None = None):
""" """
Serialize the instance. Serialize the instance.
@ -140,25 +246,34 @@ class CRUDMixin:
out: Dict[str, Any] = {} out: Dict[str, Any] = {}
# Always include id unless user explicitly listed fields and included id already # Always include id unless the caller explicitly listed fields containing id
if "id" not in req_list and hasattr(self, "id"): if "id" not in req_list and hasattr(self, "id"):
try: try:
out["id"] = getattr(self, "id") out["id"] = getattr(self, "id")
except Exception: except Exception:
pass pass
# Handle scalar tokens (may be columns, hybrids/properties, or relationships)
for name in scalars: for name in scalars:
# Try loaded value first (never lazy-load) # Try loaded value first (never lazy-load)
val = _safe_get_loaded_attr(self, name) val = _safe_get_loaded_attr(self, name)
# if still None, allow a final-hop getattr for root scalars # Final-hop getattr for root scalars (hybrids/@property) so they can compute.
# so hybrids / @property can compute (they won't traverse relationships).
if val is None: if val is None:
try: try:
val = getattr(self, name) val = getattr(self, name)
except Exception: except Exception:
val = None val = None
# If it's a scalar ORM object (relationship), serialize its columns
try:
st = inspect(val) # will raise if not an ORM object
if getattr(st, "mapper", None) is not None:
out[name] = _serialize_simple_obj(val)
continue
except Exception:
pass
# If it's a collection and no subfields were requested, emit a light list # If it's a collection and no subfields were requested, emit a light list
if isinstance(val, (list, tuple)): if isinstance(val, (list, tuple)):
out[name] = [_serialize_leaf(v) for v in val] out[name] = [_serialize_leaf(v) for v in val]
@ -169,7 +284,7 @@ class CRUDMixin:
for root, tails in groups.items(): for root, tails in groups.items():
root_val = _safe_get_loaded_attr(self, root) root_val = _safe_get_loaded_attr(self, root)
if isinstance(root_val, (list, tuple)): if isinstance(root_val, (list, tuple)):
# one-to-many collection # one-to-many collection → list of dicts with the requested tails
out[root] = _serialize_collection(root_val, tails) out[root] = _serialize_collection(root_val, tails)
else: else:
# many-to-one or scalar dotted; place each full dotted path as key # many-to-one or scalar dotted; place each full dotted path as key
@ -177,6 +292,7 @@ class CRUDMixin:
dotted = f"{root}.{tail}" dotted = f"{root}.{tail}"
out[dotted] = _deep_get_loaded(self, dotted) out[dotted] = _deep_get_loaded(self, dotted)
# ← This was the placeholder before. We return the dict we just built.
return out return out
# Fallback: all mapped columns on this class hierarchy # Fallback: all mapped columns on this class hierarchy

View file

@ -243,22 +243,32 @@ class CRUDService(Generic[T]):
if only_cols: if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols)) query = query.options(Load(root_alias).load_only(*only_cols))
# JOIN all resolved paths; for collections use selectinload (never join) # Detect first hops that have deeper, nested tails requested (e.g. "contact.supervisor")
nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 }
used_contains_eager = False used_contains_eager = False
for base_alias, rel_attr, target_alias in join_paths: for base_alias, rel_attr, target_alias in join_paths:
is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False)) is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False))
if is_collection: is_nested_firsthop = rel_attr.key in nested_first_hops
if is_collection or is_nested_firsthop:
# Use selectinload so deeper hops can chain cleanly (and to avoid
# contains_eager/loader conflicts on nested paths).
opt = selectinload(rel_attr) opt = selectinload(rel_attr)
# narroe child columns it requested (e.g., updates.id,updates.timestamp)
child_names = (collection_field_names or {}).get(rel_attr.key, []) # Narrow columns for collections if we know child scalar names
if child_names: if is_collection:
target_cls = rel_attr.property.mapper.class_ child_names = (collection_field_names or {}).get(rel_attr.key, [])
cols = [getattr(target_cls, n, None) for n in child_names] if child_names:
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] target_cls = rel_attr.property.mapper.class_
if cols: cols = [getattr(target_cls, n, None) for n in child_names]
opt = opt.load_only(*cols) cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
if cols:
opt = opt.load_only(*cols)
query = query.options(opt) query = query.options(opt)
else: else:
# Simple first-hop scalar rel with no deeper tails: safe to join + contains_eager
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias)) query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True used_contains_eager = True
@ -453,19 +463,23 @@ class CRUDService(Generic[T]):
if only_cols: if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols)) query = query.options(Load(root_alias).load_only(*only_cols))
# JOIN non-collections only; collections via selectinload nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 }
used_contains_eager = False used_contains_eager = False
for base_alias, rel_attr, target_alias in join_paths: for base_alias, rel_attr, target_alias in join_paths:
is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False)) is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False))
if is_collection: is_nested_firsthop = rel_attr.key in nested_first_hops
if is_collection or is_nested_firsthop:
opt = selectinload(rel_attr) opt = selectinload(rel_attr)
child_names = (collection_field_names or {}).get(rel_attr.key, []) if is_collection:
if child_names: child_names = (collection_field_names or {}).get(rel_attr.key, [])
target_cls = rel_attr.property.mapper.class_ if child_names:
cols = [getattr(target_cls, n, None) for n in child_names] target_cls = rel_attr.property.mapper.class_
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] cols = [getattr(target_cls, n, None) for n in child_names]
if cols: cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
opt = opt.load_only(*cols) if cols:
opt = opt.load_only(*cols)
query = query.options(opt) query = query.options(opt)
else: else:
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
@ -548,19 +562,23 @@ class CRUDService(Generic[T]):
if only_cols: if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols)) query = query.options(Load(root_alias).load_only(*only_cols))
# JOIN non-collection paths; selectinload for collections nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 }
used_contains_eager = False used_contains_eager = False
for _base_alias, rel_attr, target_alias in join_paths: for _base_alias, rel_attr, target_alias in join_paths:
is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False)) is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False))
if is_collection: is_nested_firsthop = rel_attr.key in nested_first_hops
if is_collection or is_nested_firsthop:
opt = selectinload(rel_attr) opt = selectinload(rel_attr)
child_names = (collection_field_names or {}).get(rel_attr.key, []) if is_collection:
if child_names: child_names = (collection_field_names or {}).get(rel_attr.key, [])
target_cls = rel_attr.property.mapper.class_ if child_names:
cols = [getattr(target_cls, n, None) for n in child_names] target_cls = rel_attr.property.mapper.class_
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] cols = [getattr(target_cls, n, None) for n in child_names]
if cols: cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
opt = opt.load_only(*cols) if cols:
opt = opt.load_only(*cols)
query = query.options(opt) query = query.options(opt)
else: else:
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)