inventory/crudkit/core/base.py
2025-10-20 13:53:27 -05:00

367 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from functools import lru_cache
from typing import Any, Dict, Iterable, List, Optional, Tuple, Set, cast
from sqlalchemy import Column, Integer, DateTime, Boolean, String, JSON, func, inspect
from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE, RelationshipProperty, Mapper
from sqlalchemy.orm.state import InstanceState
Base = declarative_base()
@lru_cache(maxsize=512)
def _column_names_for_model(cls: type) -> tuple[str, ...]:
try:
mapper = inspect(cls)
return tuple(prop.key for prop in mapper.column_attrs)
except Exception:
names: list[str] = []
for c in cls.__mro__:
if hasattr(c, "__table__"):
names.extend(col.name for col in c.__table__.columns)
return tuple(dict.fromkeys(names))
def _sa_state(obj: Any) -> Optional[InstanceState[Any]]:
"""Safely get SQLAlchemy InstanceState (or None)."""
try:
st = inspect(obj)
return cast(Optional[InstanceState[Any]], st)
except Exception:
return None
def _sa_mapper(obj: Any) -> Optional[Mapper]:
"""Safely get Mapper for a maooed instance (or None)."""
try:
st = inspect(obj)
mapper = getattr(st, "mapper", None)
return cast(Optional[Mapper], mapper)
except Exception:
return None
def _safe_get_loaded_attr(obj, name):
st = _sa_state(obj)
if st is None:
return None
try:
st_dict = getattr(st, "dict", {})
if name in st_dict:
return st_dict[name]
attrs = getattr(st, "attrs", None)
attr = None
if attrs is not None:
try:
attr = attrs[name]
except Exception:
try:
get = getattr(attrs, "get", None)
if callable(get):
attr = get(name)
except Exception:
attr = None
if attr is not None:
val = attr.loaded_value
return None if val is NO_VALUE else val
return None
except Exception:
return None
def _identity_key(obj) -> Tuple[type, Any]:
try:
st = inspect(obj)
return (type(obj), st.identity_key[1][0] if st.identity_key else id(obj))
except Exception:
return (type(obj), id(obj))
def _is_collection_rel(prop: RelationshipProperty) -> bool:
try:
return prop.uselist is True
except Exception:
return False
def _serialize_simple_obj(obj) -> Dict[str, Any]:
"""Columns only (no relationships)."""
out: Dict[str, Any] = {}
for name in _column_names_for_model(type(obj)):
try:
out[name] = getattr(obj, name)
except Exception:
out[name] = None
return out
def _serialize_loaded_rel(obj, name, *, depth: int, seen: Set[Tuple[type, Any]], embed: Set[str]) -> Any:
"""
Serialize relationship 'name' already loaded on obj.
- If in 'embed' (or depth > 0 for depth-based walk), recurse.
- Else, return None (dont lazy-load).
"""
val = _safe_get_loaded_attr(obj, name)
if val is None:
return None
# Decide whether to recurse into this relationship
should_recurse = (depth > 0) or (name in embed)
if isinstance(val, list):
if not should_recurse:
# Emit a light list of child primary data (id + a couple columns) without recursion.
return [_serialize_simple_obj(child) for child in val]
out = []
for child in val:
ik = _identity_key(child)
if ik in seen: # cycle guard
out.append({"id": getattr(child, "id", None)})
continue
seen.add(ik)
out.append(child.as_tree(embed_depth=max(depth - 1, 0), embed=embed, _seen=seen))
return out
# Scalar relationship
child = val
if not should_recurse:
return _serialize_simple_obj(child)
ik = _identity_key(child)
if ik in seen:
return {"id": getattr(child, "id", None)}
seen.add(ik)
return child.as_tree(embed_depth=max(depth - 1, 0), embed=embed, _seen=seen)
def _split_field_tokens(fields: Iterable[str]) -> Tuple[List[str], Dict[str, List[str]]]:
"""
Split requested fields into:
- scalars: ["label", "name"]
- collections: {"updates": ["id", "timestamp","content"], "owner": ["label"]}
Any dotted token "root.rest.of.path" becomes collections[root].append("rest.of.path").
Bare tokens ("foo") land in scalars.
"""
scalars: List[str] = []
groups: Dict[str, List[str]] = {}
for raw in fields:
f = str(raw).strip()
if not f:
continue
# bare token -> scalar
if "." not in f:
scalars.append(f)
continue
# dotted token -> group under root
root, tail = f.split(".", 1)
if not root or not tail:
continue
groups.setdefault(root, []).append(tail)
return scalars, groups
def _deep_get_loaded(obj: Any, dotted: str) -> Any:
"""
Deep get with no lazy loads:
- For all but the final hop, use _safe_get_loaded_attr (mapped-only, no getattr).
- For the final hop, try _safe_get_loaded_attr first; if None, fall back to getattr()
to allow computed properties/hybrids that rely on already-loaded columns.
"""
parts = dotted.split(".")
if not parts:
return None
cur = obj
# Traverse up to the parent of the last token safely
for part in parts[:-1]:
if cur is None:
return None
cur = _safe_get_loaded_attr(cur, part)
if cur is None:
return None
last = parts[-1]
# Try safe fetch on the last hop first
val = _safe_get_loaded_attr(cur, last)
if val is not None:
return val
# Fall back to getattr for computed/hybrid attributes on an already-loaded object
try:
return getattr(cur, last, None)
except Exception:
return None
def _serialize_leaf(obj: Any) -> Any:
"""
Lead serialization for values we put into as_dict():
- If object has as_dict(), call as_dict() with no args (caller controls field shapes).
- Else return value as-is (Flask/JSON encoder will handle datetimes, etc., via app config).
"""
if obj is None:
return None
ad = getattr(obj, "as_dict", None)
if callable(ad):
try:
return ad(None)
except Exception:
return str(obj)
return obj
def _serialize_collection(items: Iterable[Any], requested_tails: List[str]) -> List[Dict[str, Any]]:
"""
Turn a collection of ORM objects into list[dict] with exactly requested_tails,
where each tail can be dotted again (e.g., "author.label"). We do NOT lazy-load.
"""
out: List[Dict[str, Any]] = []
# Deduplicate while preserving order
uniq_tails = list(dict.fromkeys(requested_tails))
for child in (items or []):
row: Dict[str, Any] = {}
for tail in uniq_tails:
row[tail] = _deep_get_loaded(child, tail)
# ensure id present if exists and not already requested
try:
if "id" not in row and hasattr(child, "id"):
row["id"] = getattr(child, "id")
except Exception:
pass
out.append(row)
return out
@declarative_mixin
class CRUDMixin:
id = Column(Integer, primary_key=True)
created_at = Column(DateTime, default=func.now(), nullable=False)
updated_at = Column(DateTime, default=func.now(), nullable=False, onupdate=func.now())
def as_tree(
self,
*,
embed_depth: int = 0,
embed: Iterable[str] | None = None,
_seen: Set[Tuple[type, Any]] | None = None,
) -> Dict[str, Any]:
"""
Recursive, NON-LAZY serializer.
- Always includes mapped columns.
- For relationships: only serializes those ALREADY LOADED.
- Recurses either up to embed_depth or for specific names in 'embed'.
- Keeps *_id columns alongside embedded objects.
- Cycle-safe via _seen.
"""
seen = _seen or set()
ik = _identity_key(self)
if ik in seen:
return {"id": getattr(self, "id", None)}
seen.add(ik)
data = _serialize_simple_obj(self)
# Determine which relationships to consider
try:
mapper = _sa_mapper(self)
embed_set = set(str(x).split(".", 1)[0] for x in (embed or []))
if mapper is None:
return data
st = _sa_state(self)
if st is None:
return data
for name, prop in mapper.relationships.items():
# Only touch relationships that are already loaded; never lazy-load here.
rel_loaded = getattr(st, "attrs", {}).get(name)
if rel_loaded is None or rel_loaded.loaded_value is NO_VALUE:
continue
data[name] = _serialize_loaded_rel(
self, name, depth=embed_depth, seen=seen, embed=embed_set
)
except Exception:
# If inspection fails, we just return columns.
pass
return data
def as_dict(self, fields: list[str] | None = None):
"""
Serialize the instance.
Behavior:
- If 'fields' (possibly dotted) is provided, emit exactly those keys.
* Bare tokens (e.g., "label", "owner") return the current loaded value.
* Dotted tokens for one-to-many (e.g., "updates.id","updates.timestamp")
produce a single "updates" key containing a list of dicts with the requested child keys.
* Dotted tokens for many-to-one/one-to-one (e.g., "owner.label") emit the scalar under "owner.label".
- Else, if '__crudkit_projection__' is set on the instance, use that.
- Else, fall back to all mapped columns on this class hierarchy.
Always includes 'id' when present unless explicitly excluded (i.e., fields explicitly provided without id).
"""
req = fields if fields is not None else getattr(self, "__crudkit_projection__", None)
if req:
# Normalize and split into (scalars, groups of dotted by root)
req_list = [p for p in (str(x).strip() for x in req) if p]
scalars, groups = _split_field_tokens(req_list)
out: Dict[str, Any] = {}
# Always include id unless the caller explicitly listed fields containing id
if "id" not in req_list and hasattr(self, "id"):
try:
out["id"] = getattr(self, "id")
except Exception:
pass
# Handle scalar tokens (may be columns, hybrids/properties, or relationships)
for name in scalars:
# Try loaded value first (never lazy-load)
val = _safe_get_loaded_attr(self, name)
# Final-hop getattr for root scalars (hybrids/@property) so they can compute.
if val is None:
try:
val = getattr(self, name)
except Exception:
val = None
# If it's a scalar ORM object (relationship), serialize its columns
mapper = _sa_mapper(val)
if mapper is not None:
out[name] = _serialize_simple_obj(val)
continue
# If it's a collection and no subfields were requested, emit a light list
if isinstance(val, (list, tuple)):
out[name] = [_serialize_leaf(v) for v in val]
else:
out[name] = val
# Handle dotted groups: root -> [tails]
for root, tails in groups.items():
root_val = _safe_get_loaded_attr(self, root)
if isinstance(root_val, (list, tuple)):
# one-to-many collection → list of dicts with the requested tails
out[root] = _serialize_collection(root_val, tails)
else:
# many-to-one or scalar dotted; place each full dotted path as key
for tail in tails:
dotted = f"{root}.{tail}"
out[dotted] = _deep_get_loaded(self, dotted)
# ← This was the placeholder before. We return the dict we just built.
return out
# Fallback: all mapped columns on this class hierarchy
result: Dict[str, Any] = {}
for cls in self.__class__.__mro__:
if hasattr(cls, "__table__"):
for column in cls.__table__.columns:
name = column.name
try:
result[name] = getattr(self, name)
except Exception:
result[name] = None
return result
class Version(Base):
__tablename__ = "versions"
id = Column(Integer, primary_key=True)
model_name = Column(String, nullable=False)
object_id = Column(Integer, nullable=False)
change_type = Column(String, nullable=False)
data = Column(JSON, nullable=True)
timestamp = Column(DateTime, default=func.now())
actor = Column(String, nullable=True)
meta = Column('metadata', JSON, nullable=True)