from typing import Any, Dict, Iterable, List, Optional, Tuple, Set, cast from sqlalchemy import Column, Integer, DateTime, String, JSON, func, inspect from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE, Mapper from sqlalchemy.orm.state import InstanceState from crudkit.core.meta import column_names_for_model Base = declarative_base() def _sa_state(obj: Any) -> Optional[InstanceState[Any]]: """Safely get SQLAlchemy InstanceState (or None).""" try: st = inspect(obj) return cast(Optional[InstanceState[Any]], st) except Exception: return None def _sa_mapper(obj: Any) -> Optional[Mapper]: """Safely get Mapper for a maooed instance (or None).""" try: st = inspect(obj) mapper = getattr(st, "mapper", None) return cast(Optional[Mapper], mapper) except Exception: return None def _safe_get_loaded_attr(obj, name): st = _sa_state(obj) if st is None: return None try: st_dict = getattr(st, "dict", {}) if name in st_dict: return st_dict[name] attrs = getattr(st, "attrs", None) attr = None if attrs is not None: try: attr = attrs[name] except Exception: try: get = getattr(attrs, "get", None) if callable(get): attr = get(name) except Exception: attr = None if attr is not None: val = attr.loaded_value return None if val is NO_VALUE else val try: # In rare cases, state.dict may be stale; reject descriptors got = getattr(obj, name, None) from sqlalchemy.orm.attributes import InstrumentedAttribute as _Instr if got is not None and not isinstance(got, _Instr): # Do not trigger load; only return if it was already present in __dict__ if hasattr(obj, "__dict__") and name in obj.__dict__: return got except Exception: pass return None except Exception: return None def _identity_key(obj) -> Tuple[type, Any]: try: st = inspect(obj) return (type(obj), st.identity_key[1][0] if st.identity_key else id(obj)) except Exception: return (type(obj), id(obj)) def _serialize_simple_obj(obj) -> Dict[str, Any]: """Columns only (no relationships).""" out: Dict[str, Any] = {} for name in column_names_for_model(type(obj)): try: out[name] = getattr(obj, name) except Exception: out[name] = None return out def _serialize_loaded_rel(obj, name, *, depth: int, seen: Set[Tuple[type, Any]], embed: Set[str]) -> Any: """ Serialize relationship 'name' already loaded on obj. - If in 'embed' (or depth > 0 for depth-based walk), recurse. - Else, return None (don’t lazy-load). """ val = _safe_get_loaded_attr(obj, name) if val is None: return None # Decide whether to recurse into this relationship should_recurse = (depth > 0) or (name in embed) if isinstance(val, list): if not should_recurse: # Emit a light list of child primary data (id + a couple columns) without recursion. return [_serialize_simple_obj(child) for child in val] out = [] for child in val: ik = _identity_key(child) if ik in seen: # cycle guard out.append({"id": getattr(child, "id", None)}) continue seen.add(ik) out.append(child.as_tree(embed_depth=max(depth - 1, 0), embed=embed, _seen=seen)) return out # Scalar relationship child = val if not should_recurse: return _serialize_simple_obj(child) ik = _identity_key(child) if ik in seen: return {"id": getattr(child, "id", None)} seen.add(ik) return child.as_tree(embed_depth=max(depth - 1, 0), embed=embed, _seen=seen) def _split_field_tokens(fields: Iterable[str]) -> Tuple[List[str], Dict[str, List[str]]]: """ Split requested fields into: - scalars: ["label", "name"] - collections: {"updates": ["id", "timestamp","content"], "owner": ["label"]} Any dotted token "root.rest.of.path" becomes collections[root].append("rest.of.path"). Bare tokens ("foo") land in scalars. """ scalars: List[str] = [] groups: Dict[str, List[str]] = {} for raw in fields: f = str(raw).strip() if not f: continue # bare token -> scalar if "." not in f: scalars.append(f) continue # dotted token -> group under root root, tail = f.split(".", 1) if not root or not tail: continue groups.setdefault(root, []).append(tail) return scalars, groups def _deep_get_loaded(obj: Any, dotted: str) -> Any: """ Deep get with no lazy loads: - Walk intermediate hops via _safe_get_loaded_attr only. - Final hop: prefer _safe_get_loaded_attr; if that returns an ORM object or a collection of ORM objects, serialize to simple dicts; else return the plain value. """ parts = dotted.split(".") if not parts: return None cur = obj # Traverse up to parent of last token without lazy-loads for part in parts[:-1]: if cur is None: return None cur = _safe_get_loaded_attr(cur, part) if cur is None: return None last = parts[-1] val = _safe_get_loaded_attr(cur, last) if val is None: # Do NOT lazy load. If it isn't loaded, treat as absent. return None # If the final hop is an ORM object or collection, serialize it if _sa_mapper(val) is not None: return _serialize_simple_obj(val) if isinstance(val, (list, tuple)): out = [] for v in val: if _sa_mapper(v) is not None: out.append(_serialize_simple_obj(v)) else: out.append(v) return out # Plain scalar/computed value return val def _serialize_leaf(obj: Any) -> Any: """ Leaf serialization for non-dotted scalar fields: - If it's an ORM object with as_dict(), use it. - Else if it's an ORM object, serialize columns only. - Else return the value as-is. """ if obj is None: return None if _sa_mapper(obj) is not None: ad = getattr(obj, "as_dict", None) if callable(ad): try: return ad() # no args, your default handles fields selection except Exception: pass return _serialize_simple_obj(obj) return obj def _serialize_collection(items: Iterable[Any], requested_tails: List[str]) -> List[Dict[str, Any]]: """ Turn a collection of ORM objects into list[dict] with exactly requested_tails, where each tail can be dotted again (e.g., "author.label"). We do NOT lazy-load. """ out: List[Dict[str, Any]] = [] # Deduplicate while preserving order uniq_tails = list(dict.fromkeys(requested_tails)) for child in (items or []): row: Dict[str, Any] = {} for tail in uniq_tails: row[tail] = _deep_get_loaded(child, tail) # ensure id present if exists and not already requested try: if "id" not in row and hasattr(child, "id"): row["id"] = getattr(child, "id") except Exception: pass out.append(row) return out @declarative_mixin class CRUDMixin: id = Column(Integer, primary_key=True) created_at = Column(DateTime, default=func.now(), nullable=False) updated_at = Column(DateTime, default=func.now(), nullable=False, onupdate=func.now()) def as_tree( self, *, embed_depth: int = 0, embed: Iterable[str] | None = None, _seen: Set[Tuple[type, Any]] | None = None, ) -> Dict[str, Any]: """ Recursive, NON-LAZY serializer. - Always includes mapped columns. - For relationships: only serializes those ALREADY LOADED. - Recurses either up to embed_depth or for specific names in 'embed'. - Keeps *_id columns alongside embedded objects. - Cycle-safe via _seen. """ seen = _seen or set() ik = _identity_key(self) if ik in seen: return {"id": getattr(self, "id", None)} seen.add(ik) data = _serialize_simple_obj(self) # Determine which relationships to consider try: mapper = _sa_mapper(self) embed_set = set(str(x).split(".", 1)[0] for x in (embed or [])) if mapper is None: return data st = _sa_state(self) if st is None: return data for name, prop in mapper.relationships.items(): # Only touch relationships that are already loaded; never lazy-load here. rel_loaded = getattr(st, "attrs", {}).get(name) if rel_loaded is None or rel_loaded.loaded_value is NO_VALUE: continue data[name] = _serialize_loaded_rel( self, name, depth=embed_depth, seen=seen, embed=embed_set ) except Exception: # If inspection fails, we just return columns. pass return data def as_dict(self, fields: list[str] | None = None): """ Serialize the instance. Behavior: - If 'fields' (possibly dotted) is provided, emit exactly those keys. * Bare tokens (e.g., "label", "owner") return the current loaded value. * Dotted tokens for one-to-many (e.g., "updates.id","updates.timestamp") produce a single "updates" key containing a list of dicts with the requested child keys. * Dotted tokens for many-to-one/one-to-one (e.g., "owner.label") emit the scalar under "owner.label". - Else, if '__crudkit_projection__' is set on the instance, use that. - Else, fall back to all mapped columns on this class hierarchy. Always includes 'id' when present unless explicitly excluded (i.e., fields explicitly provided without id). """ req = fields if fields is not None else getattr(self, "__crudkit_projection__", None) if req: # Normalize and split into (scalars, groups of dotted by root) req_list = [p for p in (str(x).strip() for x in req) if p] scalars, groups = _split_field_tokens(req_list) out: Dict[str, Any] = {} # Always include id unless the caller explicitly listed fields containing id if "id" not in req_list and hasattr(self, "id"): try: out["id"] = getattr(self, "id") except Exception: pass # Handle scalar tokens (may be columns, hybrids/properties, or relationships) for name in scalars: # Try loaded value first (never lazy-load) val = _safe_get_loaded_attr(self, name) # Final-hop getattr for root scalars (hybrids/@property) so they can compute. if val is None: try: val = getattr(self, name) except Exception: val = None # If it's a scalar ORM object (relationship), serialize its columns mapper = _sa_mapper(val) if mapper is not None: out[name] = _serialize_simple_obj(val) continue if isinstance(val, (list, tuple)): out[name] = [_serialize_leaf(v) for v in val] else: out[name] = val # Handle dotted groups: root -> [tails] for root, tails in groups.items(): root_val = _safe_get_loaded_attr(self, root) if isinstance(root_val, (list, tuple)): # one-to-many collection → list of dicts with the requested tails out[root] = _serialize_collection(root_val, tails) else: # many-to-one or scalar dotted; place each full dotted path as key for tail in tails: dotted = f"{root}.{tail}" out[dotted] = _deep_get_loaded(self, dotted) # ← This was the placeholder before. We return the dict we just built. return out # Fallback: all mapped columns on this class hierarchy result: Dict[str, Any] = {} for cls in self.__class__.__mro__: if hasattr(cls, "__table__"): for column in cls.__table__.columns: name = column.name try: result[name] = getattr(self, name) except Exception: result[name] = None return result class Version(Base): __tablename__ = "versions" id = Column(Integer, primary_key=True) model_name = Column(String, nullable=False) object_id = Column(Integer, nullable=False) change_type = Column(String, nullable=False) data = Column(JSON, nullable=True) timestamp = Column(DateTime, default=func.now()) actor = Column(String, nullable=True) meta = Column('metadata', JSON, nullable=True)