Starting to get nested loading working. Still a WIP.
This commit is contained in:
parent
97891961e1
commit
4c56149f1b
2 changed files with 168 additions and 34 deletions
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Any, Dict, Iterable, List, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Tuple, Set
|
||||
from sqlalchemy import Column, Integer, DateTime, Boolean, String, JSON, func, inspect
|
||||
from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE
|
||||
from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE, RelationshipProperty
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
|
@ -17,6 +17,69 @@ def _safe_get_loaded_attr(obj, name):
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def _identity_key(obj) -> Tuple[type, Any]:
|
||||
try:
|
||||
st = inspect(obj)
|
||||
return (type(obj), st.identity_key[1][0] if st.identity_key else id(obj))
|
||||
except Exception:
|
||||
return (type(obj), id(obj))
|
||||
|
||||
def _is_collection_rel(prop: RelationshipProperty) -> bool:
|
||||
try:
|
||||
return prop.uselist is True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _serialize_simple_obj(obj) -> Dict[str, Any]:
|
||||
"""Columns only (no relationships)."""
|
||||
out: Dict[str, Any] = {}
|
||||
for cls in obj.__class__.__mro__:
|
||||
if hasattr(cls, "__table__"):
|
||||
for col in cls.__table__.columns:
|
||||
name = col.name
|
||||
try:
|
||||
out[name] = getattr(obj, name)
|
||||
except Exception:
|
||||
out[name] = None
|
||||
return out
|
||||
|
||||
def _serialize_loaded_rel(obj, name, *, depth: int, seen: Set[Tuple[type, Any]], embed: Set[str]) -> Any:
|
||||
"""
|
||||
Serialize relationship 'name' already loaded on obj.
|
||||
- If in 'embed' (or depth > 0 for depth-based walk), recurse.
|
||||
- Else, return None (don’t lazy-load).
|
||||
"""
|
||||
val = _safe_get_loaded_attr(obj, name)
|
||||
if val is None:
|
||||
return None
|
||||
|
||||
# Decide whether to recurse into this relationship
|
||||
should_recurse = (depth > 0) or (name in embed)
|
||||
|
||||
if isinstance(val, list):
|
||||
if not should_recurse:
|
||||
# Emit a light list of child primary data (id + a couple columns) without recursion.
|
||||
return [_serialize_simple_obj(child) for child in val]
|
||||
out = []
|
||||
for child in val:
|
||||
ik = _identity_key(child)
|
||||
if ik in seen: # cycle guard
|
||||
out.append({"id": getattr(child, "id", None)})
|
||||
continue
|
||||
seen.add(ik)
|
||||
out.append(child.as_tree(embed_depth=max(depth - 1, 0), embed=embed, _seen=seen))
|
||||
return out
|
||||
|
||||
# Scalar relationship
|
||||
child = val
|
||||
if not should_recurse:
|
||||
return _serialize_simple_obj(child)
|
||||
ik = _identity_key(child)
|
||||
if ik in seen:
|
||||
return {"id": getattr(child, "id", None)}
|
||||
seen.add(ik)
|
||||
return child.as_tree(embed_depth=max(depth - 1, 0), embed=embed, _seen=seen)
|
||||
|
||||
def _split_field_tokens(fields: Iterable[str]) -> Tuple[List[str], Dict[str, List[str]]]:
|
||||
"""
|
||||
Split requested fields into:
|
||||
|
|
@ -116,6 +179,49 @@ class CRUDMixin:
|
|||
created_at = Column(DateTime, default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime, default=func.now(), nullable=False, onupdate=func.now())
|
||||
|
||||
def as_tree(
|
||||
self,
|
||||
*,
|
||||
embed_depth: int = 0,
|
||||
embed: Iterable[str] | None = None,
|
||||
_seen: Set[Tuple[type, Any]] | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Recursive, NON-LAZY serializer.
|
||||
- Always includes mapped columns.
|
||||
- For relationships: only serializes those ALREADY LOADED.
|
||||
- Recurses either up to embed_depth or for specific names in 'embed'.
|
||||
- Keeps *_id columns alongside embedded objects.
|
||||
- Cycle-safe via _seen.
|
||||
"""
|
||||
seen = _seen or set()
|
||||
ik = _identity_key(self)
|
||||
if ik in seen:
|
||||
return {"id": getattr(self, "id", None)}
|
||||
seen.add(ik)
|
||||
|
||||
data = _serialize_simple_obj(self)
|
||||
|
||||
# Determine which relationships to consider
|
||||
try:
|
||||
st = inspect(self)
|
||||
mapper = st.mapper
|
||||
embed_set = set(str(x).split(".", 1)[0] for x in (embed or [])) # top-level names
|
||||
for name, prop in mapper.relationships.items():
|
||||
# Only touch relationships that are already loaded; never lazy-load here.
|
||||
rel_loaded = st.attrs.get(name)
|
||||
if rel_loaded is None or rel_loaded.loaded_value is NO_VALUE:
|
||||
continue
|
||||
|
||||
data[name] = _serialize_loaded_rel(
|
||||
self, name, depth=embed_depth, seen=seen, embed=embed_set
|
||||
)
|
||||
except Exception:
|
||||
# If inspection fails, we just return columns.
|
||||
pass
|
||||
|
||||
return data
|
||||
|
||||
def as_dict(self, fields: list[str] | None = None):
|
||||
"""
|
||||
Serialize the instance.
|
||||
|
|
@ -140,25 +246,34 @@ class CRUDMixin:
|
|||
|
||||
out: Dict[str, Any] = {}
|
||||
|
||||
# Always include id unless user explicitly listed fields and included id already
|
||||
# Always include id unless the caller explicitly listed fields containing id
|
||||
if "id" not in req_list and hasattr(self, "id"):
|
||||
try:
|
||||
out["id"] = getattr(self, "id")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Handle scalar tokens (may be columns, hybrids/properties, or relationships)
|
||||
for name in scalars:
|
||||
# Try loaded value first (never lazy-load)
|
||||
val = _safe_get_loaded_attr(self, name)
|
||||
|
||||
# if still None, allow a final-hop getattr for root scalars
|
||||
# so hybrids / @property can compute (they won't traverse relationships).
|
||||
# Final-hop getattr for root scalars (hybrids/@property) so they can compute.
|
||||
if val is None:
|
||||
try:
|
||||
val = getattr(self, name)
|
||||
except Exception:
|
||||
val = None
|
||||
|
||||
# If it's a scalar ORM object (relationship), serialize its columns
|
||||
try:
|
||||
st = inspect(val) # will raise if not an ORM object
|
||||
if getattr(st, "mapper", None) is not None:
|
||||
out[name] = _serialize_simple_obj(val)
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If it's a collection and no subfields were requested, emit a light list
|
||||
if isinstance(val, (list, tuple)):
|
||||
out[name] = [_serialize_leaf(v) for v in val]
|
||||
|
|
@ -169,7 +284,7 @@ class CRUDMixin:
|
|||
for root, tails in groups.items():
|
||||
root_val = _safe_get_loaded_attr(self, root)
|
||||
if isinstance(root_val, (list, tuple)):
|
||||
# one-to-many collection
|
||||
# one-to-many collection → list of dicts with the requested tails
|
||||
out[root] = _serialize_collection(root_val, tails)
|
||||
else:
|
||||
# many-to-one or scalar dotted; place each full dotted path as key
|
||||
|
|
@ -177,6 +292,7 @@ class CRUDMixin:
|
|||
dotted = f"{root}.{tail}"
|
||||
out[dotted] = _deep_get_loaded(self, dotted)
|
||||
|
||||
# ← This was the placeholder before. We return the dict we just built.
|
||||
return out
|
||||
|
||||
# Fallback: all mapped columns on this class hierarchy
|
||||
|
|
|
|||
|
|
@ -243,22 +243,32 @@ class CRUDService(Generic[T]):
|
|||
if only_cols:
|
||||
query = query.options(Load(root_alias).load_only(*only_cols))
|
||||
|
||||
# JOIN all resolved paths; for collections use selectinload (never join)
|
||||
# Detect first hops that have deeper, nested tails requested (e.g. "contact.supervisor")
|
||||
nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 }
|
||||
|
||||
used_contains_eager = False
|
||||
for base_alias, rel_attr, target_alias in join_paths:
|
||||
is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False))
|
||||
if is_collection:
|
||||
is_nested_firsthop = rel_attr.key in nested_first_hops
|
||||
|
||||
if is_collection or is_nested_firsthop:
|
||||
# Use selectinload so deeper hops can chain cleanly (and to avoid
|
||||
# contains_eager/loader conflicts on nested paths).
|
||||
opt = selectinload(rel_attr)
|
||||
# narroe child columns it requested (e.g., updates.id,updates.timestamp)
|
||||
child_names = (collection_field_names or {}).get(rel_attr.key, [])
|
||||
if child_names:
|
||||
target_cls = rel_attr.property.mapper.class_
|
||||
cols = [getattr(target_cls, n, None) for n in child_names]
|
||||
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
|
||||
if cols:
|
||||
opt = opt.load_only(*cols)
|
||||
|
||||
# Narrow columns for collections if we know child scalar names
|
||||
if is_collection:
|
||||
child_names = (collection_field_names or {}).get(rel_attr.key, [])
|
||||
if child_names:
|
||||
target_cls = rel_attr.property.mapper.class_
|
||||
cols = [getattr(target_cls, n, None) for n in child_names]
|
||||
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
|
||||
if cols:
|
||||
opt = opt.load_only(*cols)
|
||||
|
||||
query = query.options(opt)
|
||||
else:
|
||||
# Simple first-hop scalar rel with no deeper tails: safe to join + contains_eager
|
||||
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
|
||||
query = query.options(contains_eager(rel_attr, alias=target_alias))
|
||||
used_contains_eager = True
|
||||
|
|
@ -453,19 +463,23 @@ class CRUDService(Generic[T]):
|
|||
if only_cols:
|
||||
query = query.options(Load(root_alias).load_only(*only_cols))
|
||||
|
||||
# JOIN non-collections only; collections via selectinload
|
||||
nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 }
|
||||
|
||||
used_contains_eager = False
|
||||
for base_alias, rel_attr, target_alias in join_paths:
|
||||
is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False))
|
||||
if is_collection:
|
||||
is_nested_firsthop = rel_attr.key in nested_first_hops
|
||||
|
||||
if is_collection or is_nested_firsthop:
|
||||
opt = selectinload(rel_attr)
|
||||
child_names = (collection_field_names or {}).get(rel_attr.key, [])
|
||||
if child_names:
|
||||
target_cls = rel_attr.property.mapper.class_
|
||||
cols = [getattr(target_cls, n, None) for n in child_names]
|
||||
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
|
||||
if cols:
|
||||
opt = opt.load_only(*cols)
|
||||
if is_collection:
|
||||
child_names = (collection_field_names or {}).get(rel_attr.key, [])
|
||||
if child_names:
|
||||
target_cls = rel_attr.property.mapper.class_
|
||||
cols = [getattr(target_cls, n, None) for n in child_names]
|
||||
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
|
||||
if cols:
|
||||
opt = opt.load_only(*cols)
|
||||
query = query.options(opt)
|
||||
else:
|
||||
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
|
||||
|
|
@ -548,19 +562,23 @@ class CRUDService(Generic[T]):
|
|||
if only_cols:
|
||||
query = query.options(Load(root_alias).load_only(*only_cols))
|
||||
|
||||
# JOIN non-collection paths; selectinload for collections
|
||||
nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 }
|
||||
|
||||
used_contains_eager = False
|
||||
for _base_alias, rel_attr, target_alias in join_paths:
|
||||
is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False))
|
||||
if is_collection:
|
||||
is_nested_firsthop = rel_attr.key in nested_first_hops
|
||||
|
||||
if is_collection or is_nested_firsthop:
|
||||
opt = selectinload(rel_attr)
|
||||
child_names = (collection_field_names or {}).get(rel_attr.key, [])
|
||||
if child_names:
|
||||
target_cls = rel_attr.property.mapper.class_
|
||||
cols = [getattr(target_cls, n, None) for n in child_names]
|
||||
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
|
||||
if cols:
|
||||
opt = opt.load_only(*cols)
|
||||
if is_collection:
|
||||
child_names = (collection_field_names or {}).get(rel_attr.key, [])
|
||||
if child_names:
|
||||
target_cls = rel_attr.property.mapper.class_
|
||||
cols = [getattr(target_cls, n, None) for n in child_names]
|
||||
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
|
||||
if cols:
|
||||
opt = opt.load_only(*cols)
|
||||
query = query.options(opt)
|
||||
else:
|
||||
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue