From 4c56149f1b32f311af6785f4941bb8608ed9d6ea Mon Sep 17 00:00:00 2001 From: Yaro Kasear Date: Fri, 26 Sep 2025 15:50:35 -0500 Subject: [PATCH] Starting to get nested loading working. Still a WIP. --- crudkit/core/base.py | 128 ++++++++++++++++++++++++++++++++++++++-- crudkit/core/service.py | 74 ++++++++++++++--------- 2 files changed, 168 insertions(+), 34 deletions(-) diff --git a/crudkit/core/base.py b/crudkit/core/base.py index 51d9b71..c42b90e 100644 --- a/crudkit/core/base.py +++ b/crudkit/core/base.py @@ -1,6 +1,6 @@ -from typing import Any, Dict, Iterable, List, Tuple +from typing import Any, Dict, Iterable, List, Tuple, Set from sqlalchemy import Column, Integer, DateTime, Boolean, String, JSON, func, inspect -from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE +from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE, RelationshipProperty Base = declarative_base() @@ -17,6 +17,69 @@ def _safe_get_loaded_attr(obj, name): except Exception: return None +def _identity_key(obj) -> Tuple[type, Any]: + try: + st = inspect(obj) + return (type(obj), st.identity_key[1][0] if st.identity_key else id(obj)) + except Exception: + return (type(obj), id(obj)) + +def _is_collection_rel(prop: RelationshipProperty) -> bool: + try: + return prop.uselist is True + except Exception: + return False + +def _serialize_simple_obj(obj) -> Dict[str, Any]: + """Columns only (no relationships).""" + out: Dict[str, Any] = {} + for cls in obj.__class__.__mro__: + if hasattr(cls, "__table__"): + for col in cls.__table__.columns: + name = col.name + try: + out[name] = getattr(obj, name) + except Exception: + out[name] = None + return out + +def _serialize_loaded_rel(obj, name, *, depth: int, seen: Set[Tuple[type, Any]], embed: Set[str]) -> Any: + """ + Serialize relationship 'name' already loaded on obj. + - If in 'embed' (or depth > 0 for depth-based walk), recurse. + - Else, return None (don’t lazy-load). + """ + val = _safe_get_loaded_attr(obj, name) + if val is None: + return None + + # Decide whether to recurse into this relationship + should_recurse = (depth > 0) or (name in embed) + + if isinstance(val, list): + if not should_recurse: + # Emit a light list of child primary data (id + a couple columns) without recursion. + return [_serialize_simple_obj(child) for child in val] + out = [] + for child in val: + ik = _identity_key(child) + if ik in seen: # cycle guard + out.append({"id": getattr(child, "id", None)}) + continue + seen.add(ik) + out.append(child.as_tree(embed_depth=max(depth - 1, 0), embed=embed, _seen=seen)) + return out + + # Scalar relationship + child = val + if not should_recurse: + return _serialize_simple_obj(child) + ik = _identity_key(child) + if ik in seen: + return {"id": getattr(child, "id", None)} + seen.add(ik) + return child.as_tree(embed_depth=max(depth - 1, 0), embed=embed, _seen=seen) + def _split_field_tokens(fields: Iterable[str]) -> Tuple[List[str], Dict[str, List[str]]]: """ Split requested fields into: @@ -116,6 +179,49 @@ class CRUDMixin: created_at = Column(DateTime, default=func.now(), nullable=False) updated_at = Column(DateTime, default=func.now(), nullable=False, onupdate=func.now()) + def as_tree( + self, + *, + embed_depth: int = 0, + embed: Iterable[str] | None = None, + _seen: Set[Tuple[type, Any]] | None = None, + ) -> Dict[str, Any]: + """ + Recursive, NON-LAZY serializer. + - Always includes mapped columns. + - For relationships: only serializes those ALREADY LOADED. + - Recurses either up to embed_depth or for specific names in 'embed'. + - Keeps *_id columns alongside embedded objects. + - Cycle-safe via _seen. + """ + seen = _seen or set() + ik = _identity_key(self) + if ik in seen: + return {"id": getattr(self, "id", None)} + seen.add(ik) + + data = _serialize_simple_obj(self) + + # Determine which relationships to consider + try: + st = inspect(self) + mapper = st.mapper + embed_set = set(str(x).split(".", 1)[0] for x in (embed or [])) # top-level names + for name, prop in mapper.relationships.items(): + # Only touch relationships that are already loaded; never lazy-load here. + rel_loaded = st.attrs.get(name) + if rel_loaded is None or rel_loaded.loaded_value is NO_VALUE: + continue + + data[name] = _serialize_loaded_rel( + self, name, depth=embed_depth, seen=seen, embed=embed_set + ) + except Exception: + # If inspection fails, we just return columns. + pass + + return data + def as_dict(self, fields: list[str] | None = None): """ Serialize the instance. @@ -140,25 +246,34 @@ class CRUDMixin: out: Dict[str, Any] = {} - # Always include id unless user explicitly listed fields and included id already + # Always include id unless the caller explicitly listed fields containing id if "id" not in req_list and hasattr(self, "id"): try: out["id"] = getattr(self, "id") except Exception: pass + # Handle scalar tokens (may be columns, hybrids/properties, or relationships) for name in scalars: # Try loaded value first (never lazy-load) val = _safe_get_loaded_attr(self, name) - # if still None, allow a final-hop getattr for root scalars - # so hybrids / @property can compute (they won't traverse relationships). + # Final-hop getattr for root scalars (hybrids/@property) so they can compute. if val is None: try: val = getattr(self, name) except Exception: val = None + # If it's a scalar ORM object (relationship), serialize its columns + try: + st = inspect(val) # will raise if not an ORM object + if getattr(st, "mapper", None) is not None: + out[name] = _serialize_simple_obj(val) + continue + except Exception: + pass + # If it's a collection and no subfields were requested, emit a light list if isinstance(val, (list, tuple)): out[name] = [_serialize_leaf(v) for v in val] @@ -169,7 +284,7 @@ class CRUDMixin: for root, tails in groups.items(): root_val = _safe_get_loaded_attr(self, root) if isinstance(root_val, (list, tuple)): - # one-to-many collection + # one-to-many collection → list of dicts with the requested tails out[root] = _serialize_collection(root_val, tails) else: # many-to-one or scalar dotted; place each full dotted path as key @@ -177,6 +292,7 @@ class CRUDMixin: dotted = f"{root}.{tail}" out[dotted] = _deep_get_loaded(self, dotted) + # ← This was the placeholder before. We return the dict we just built. return out # Fallback: all mapped columns on this class hierarchy diff --git a/crudkit/core/service.py b/crudkit/core/service.py index e2e1ab4..b4d7036 100644 --- a/crudkit/core/service.py +++ b/crudkit/core/service.py @@ -243,22 +243,32 @@ class CRUDService(Generic[T]): if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) - # JOIN all resolved paths; for collections use selectinload (never join) + # Detect first hops that have deeper, nested tails requested (e.g. "contact.supervisor") + nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } + used_contains_eager = False for base_alias, rel_attr, target_alias in join_paths: is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False)) - if is_collection: + is_nested_firsthop = rel_attr.key in nested_first_hops + + if is_collection or is_nested_firsthop: + # Use selectinload so deeper hops can chain cleanly (and to avoid + # contains_eager/loader conflicts on nested paths). opt = selectinload(rel_attr) - # narroe child columns it requested (e.g., updates.id,updates.timestamp) - child_names = (collection_field_names or {}).get(rel_attr.key, []) - if child_names: - target_cls = rel_attr.property.mapper.class_ - cols = [getattr(target_cls, n, None) for n in child_names] - cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] - if cols: - opt = opt.load_only(*cols) + + # Narrow columns for collections if we know child scalar names + if is_collection: + child_names = (collection_field_names or {}).get(rel_attr.key, []) + if child_names: + target_cls = rel_attr.property.mapper.class_ + cols = [getattr(target_cls, n, None) for n in child_names] + cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] + if cols: + opt = opt.load_only(*cols) + query = query.options(opt) else: + # Simple first-hop scalar rel with no deeper tails: safe to join + contains_eager query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.options(contains_eager(rel_attr, alias=target_alias)) used_contains_eager = True @@ -453,19 +463,23 @@ class CRUDService(Generic[T]): if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) - # JOIN non-collections only; collections via selectinload + nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } + used_contains_eager = False for base_alias, rel_attr, target_alias in join_paths: is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False)) - if is_collection: + is_nested_firsthop = rel_attr.key in nested_first_hops + + if is_collection or is_nested_firsthop: opt = selectinload(rel_attr) - child_names = (collection_field_names or {}).get(rel_attr.key, []) - if child_names: - target_cls = rel_attr.property.mapper.class_ - cols = [getattr(target_cls, n, None) for n in child_names] - cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] - if cols: - opt = opt.load_only(*cols) + if is_collection: + child_names = (collection_field_names or {}).get(rel_attr.key, []) + if child_names: + target_cls = rel_attr.property.mapper.class_ + cols = [getattr(target_cls, n, None) for n in child_names] + cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] + if cols: + opt = opt.load_only(*cols) query = query.options(opt) else: query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) @@ -548,19 +562,23 @@ class CRUDService(Generic[T]): if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) - # JOIN non-collection paths; selectinload for collections + nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } + used_contains_eager = False for _base_alias, rel_attr, target_alias in join_paths: is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False)) - if is_collection: + is_nested_firsthop = rel_attr.key in nested_first_hops + + if is_collection or is_nested_firsthop: opt = selectinload(rel_attr) - child_names = (collection_field_names or {}).get(rel_attr.key, []) - if child_names: - target_cls = rel_attr.property.mapper.class_ - cols = [getattr(target_cls, n, None) for n in child_names] - cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] - if cols: - opt = opt.load_only(*cols) + if is_collection: + child_names = (collection_field_names or {}).get(rel_attr.key, []) + if child_names: + target_cls = rel_attr.property.mapper.class_ + cols = [getattr(target_cls, n, None) for n in child_names] + cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] + if cols: + opt = opt.load_only(*cols) query = query.options(opt) else: query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)