354 lines
13 KiB
Python
354 lines
13 KiB
Python
from functools import lru_cache
|
||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Set, cast
|
||
from sqlalchemy import Column, Integer, DateTime, Boolean, String, JSON, func, inspect
|
||
from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE, RelationshipProperty, Mapper
|
||
from sqlalchemy.orm.state import InstanceState
|
||
|
||
Base = declarative_base()
|
||
|
||
@lru_cache(maxsize=512)
|
||
def _column_names_for_model(cls: type) -> tuple[str, ...]:
|
||
try:
|
||
mapper = inspect(cls)
|
||
return tuple(prop.key for prop in mapper.column_attrs)
|
||
except Exception:
|
||
names: list[str] = []
|
||
for c in cls.__mro__:
|
||
if hasattr(c, "__table__"):
|
||
names.extend(col.name for col in c.__table__.columns)
|
||
return tuple(dict.fromkeys(names))
|
||
|
||
def _sa_state(obj: Any) -> Optional[InstanceState[Any]]:
|
||
"""Safely get SQLAlchemy InstanceState (or None)."""
|
||
try:
|
||
st = inspect(obj)
|
||
return cast(Optional[InstanceState[Any]], st)
|
||
except Exception:
|
||
return None
|
||
|
||
def _sa_mapper(obj: Any) -> Optional[Mapper]:
|
||
"""Safely get Mapper for a maooed instance (or None)."""
|
||
try:
|
||
st = inspect(obj)
|
||
mapper = getattr(st, "mapper", None)
|
||
return cast(Optional[Mapper], mapper)
|
||
except Exception:
|
||
return None
|
||
|
||
def _safe_get_loaded_attr(obj, name):
|
||
st = _sa_state(obj)
|
||
if st is None:
|
||
return None
|
||
try:
|
||
attrs = getattr(st, "attrs", {}).get(name)
|
||
if attrs is not None and name in attrs:
|
||
attr = attrs[name]
|
||
val = attr.loaded_value
|
||
return None if val is NO_VALUE else val
|
||
st_dict = getattr(st, "dict", {})
|
||
if name in st_dict:
|
||
return st_dict.get(name)
|
||
return None
|
||
except Exception:
|
||
return None
|
||
|
||
def _identity_key(obj) -> Tuple[type, Any]:
|
||
try:
|
||
st = inspect(obj)
|
||
return (type(obj), st.identity_key[1][0] if st.identity_key else id(obj))
|
||
except Exception:
|
||
return (type(obj), id(obj))
|
||
|
||
def _is_collection_rel(prop: RelationshipProperty) -> bool:
|
||
try:
|
||
return prop.uselist is True
|
||
except Exception:
|
||
return False
|
||
|
||
def _serialize_simple_obj(obj) -> Dict[str, Any]:
|
||
"""Columns only (no relationships)."""
|
||
out: Dict[str, Any] = {}
|
||
for name in _column_names_for_model(type(obj)):
|
||
try:
|
||
out[name] = getattr(obj, name)
|
||
except Exception:
|
||
out[name] = None
|
||
return out
|
||
|
||
def _serialize_loaded_rel(obj, name, *, depth: int, seen: Set[Tuple[type, Any]], embed: Set[str]) -> Any:
|
||
"""
|
||
Serialize relationship 'name' already loaded on obj.
|
||
- If in 'embed' (or depth > 0 for depth-based walk), recurse.
|
||
- Else, return None (don’t lazy-load).
|
||
"""
|
||
val = _safe_get_loaded_attr(obj, name)
|
||
if val is None:
|
||
return None
|
||
|
||
# Decide whether to recurse into this relationship
|
||
should_recurse = (depth > 0) or (name in embed)
|
||
|
||
if isinstance(val, list):
|
||
if not should_recurse:
|
||
# Emit a light list of child primary data (id + a couple columns) without recursion.
|
||
return [_serialize_simple_obj(child) for child in val]
|
||
out = []
|
||
for child in val:
|
||
ik = _identity_key(child)
|
||
if ik in seen: # cycle guard
|
||
out.append({"id": getattr(child, "id", None)})
|
||
continue
|
||
seen.add(ik)
|
||
out.append(child.as_tree(embed_depth=max(depth - 1, 0), embed=embed, _seen=seen))
|
||
return out
|
||
|
||
# Scalar relationship
|
||
child = val
|
||
if not should_recurse:
|
||
return _serialize_simple_obj(child)
|
||
ik = _identity_key(child)
|
||
if ik in seen:
|
||
return {"id": getattr(child, "id", None)}
|
||
seen.add(ik)
|
||
return child.as_tree(embed_depth=max(depth - 1, 0), embed=embed, _seen=seen)
|
||
|
||
def _split_field_tokens(fields: Iterable[str]) -> Tuple[List[str], Dict[str, List[str]]]:
|
||
"""
|
||
Split requested fields into:
|
||
- scalars: ["label", "name"]
|
||
- collections: {"updates": ["id", "timestamp","content"], "owner": ["label"]}
|
||
Any dotted token "root.rest.of.path" becomes collections[root].append("rest.of.path").
|
||
Bare tokens ("foo") land in scalars.
|
||
"""
|
||
scalars: List[str] = []
|
||
groups: Dict[str, List[str]] = {}
|
||
for raw in fields:
|
||
f = str(raw).strip()
|
||
if not f:
|
||
continue
|
||
# bare token -> scalar
|
||
if "." not in f:
|
||
scalars.append(f)
|
||
continue
|
||
# dotted token -> group under root
|
||
root, tail = f.split(".", 1)
|
||
if not root or not tail:
|
||
continue
|
||
groups.setdefault(root, []).append(tail)
|
||
return scalars, groups
|
||
|
||
def _deep_get_loaded(obj: Any, dotted: str) -> Any:
|
||
"""
|
||
Deep get with no lazy loads:
|
||
- For all but the final hop, use _safe_get_loaded_attr (mapped-only, no getattr).
|
||
- For the final hop, try _safe_get_loaded_attr first; if None, fall back to getattr()
|
||
to allow computed properties/hybrids that rely on already-loaded columns.
|
||
"""
|
||
parts = dotted.split(".")
|
||
if not parts:
|
||
return None
|
||
|
||
cur = obj
|
||
# Traverse up to the parent of the last token safely
|
||
for part in parts[:-1]:
|
||
if cur is None:
|
||
return None
|
||
cur = _safe_get_loaded_attr(cur, part)
|
||
if cur is None:
|
||
return None
|
||
|
||
last = parts[-1]
|
||
# Try safe fetch on the last hop first
|
||
val = _safe_get_loaded_attr(cur, last)
|
||
if val is not None:
|
||
return val
|
||
# Fall back to getattr for computed/hybrid attributes on an already-loaded object
|
||
try:
|
||
return getattr(cur, last, None)
|
||
except Exception:
|
||
return None
|
||
|
||
def _serialize_leaf(obj: Any) -> Any:
|
||
"""
|
||
Lead serialization for values we put into as_dict():
|
||
- If object has as_dict(), call as_dict() with no args (caller controls field shapes).
|
||
- Else return value as-is (Flask/JSON encoder will handle datetimes, etc., via app config).
|
||
"""
|
||
if obj is None:
|
||
return None
|
||
ad = getattr(obj, "as_dict", None)
|
||
if callable(ad):
|
||
try:
|
||
return ad(None)
|
||
except Exception:
|
||
return str(obj)
|
||
return obj
|
||
|
||
def _serialize_collection(items: Iterable[Any], requested_tails: List[str]) -> List[Dict[str, Any]]:
|
||
"""
|
||
Turn a collection of ORM objects into list[dict] with exactly requested_tails,
|
||
where each tail can be dotted again (e.g., "author.label"). We do NOT lazy-load.
|
||
"""
|
||
out: List[Dict[str, Any]] = []
|
||
# Deduplicate while preserving order
|
||
uniq_tails = list(dict.fromkeys(requested_tails))
|
||
for child in (items or []):
|
||
row: Dict[str, Any] = {}
|
||
for tail in uniq_tails:
|
||
row[tail] = _deep_get_loaded(child, tail)
|
||
# ensure id present if exists and not already requested
|
||
try:
|
||
if "id" not in row and hasattr(child, "id"):
|
||
row["id"] = getattr(child, "id")
|
||
except Exception:
|
||
pass
|
||
out.append(row)
|
||
return out
|
||
|
||
@declarative_mixin
|
||
class CRUDMixin:
|
||
id = Column(Integer, primary_key=True)
|
||
created_at = Column(DateTime, default=func.now(), nullable=False)
|
||
updated_at = Column(DateTime, default=func.now(), nullable=False, onupdate=func.now())
|
||
|
||
def as_tree(
|
||
self,
|
||
*,
|
||
embed_depth: int = 0,
|
||
embed: Iterable[str] | None = None,
|
||
_seen: Set[Tuple[type, Any]] | None = None,
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
Recursive, NON-LAZY serializer.
|
||
- Always includes mapped columns.
|
||
- For relationships: only serializes those ALREADY LOADED.
|
||
- Recurses either up to embed_depth or for specific names in 'embed'.
|
||
- Keeps *_id columns alongside embedded objects.
|
||
- Cycle-safe via _seen.
|
||
"""
|
||
seen = _seen or set()
|
||
ik = _identity_key(self)
|
||
if ik in seen:
|
||
return {"id": getattr(self, "id", None)}
|
||
seen.add(ik)
|
||
|
||
data = _serialize_simple_obj(self)
|
||
|
||
# Determine which relationships to consider
|
||
try:
|
||
mapper = _sa_mapper(self)
|
||
embed_set = set(str(x).split(".", 1)[0] for x in (embed or []))
|
||
if mapper is None:
|
||
return data
|
||
st = _sa_state(self)
|
||
if st is None:
|
||
return data
|
||
for name, prop in mapper.relationships.items():
|
||
# Only touch relationships that are already loaded; never lazy-load here.
|
||
rel_loaded = getattr(st, "attrs", {}).get(name)
|
||
if rel_loaded is None or rel_loaded.loaded_value is NO_VALUE:
|
||
continue
|
||
|
||
data[name] = _serialize_loaded_rel(
|
||
self, name, depth=embed_depth, seen=seen, embed=embed_set
|
||
)
|
||
except Exception:
|
||
# If inspection fails, we just return columns.
|
||
pass
|
||
|
||
return data
|
||
|
||
def as_dict(self, fields: list[str] | None = None):
|
||
"""
|
||
Serialize the instance.
|
||
|
||
Behavior:
|
||
- If 'fields' (possibly dotted) is provided, emit exactly those keys.
|
||
* Bare tokens (e.g., "label", "owner") return the current loaded value.
|
||
* Dotted tokens for one-to-many (e.g., "updates.id","updates.timestamp")
|
||
produce a single "updates" key containing a list of dicts with the requested child keys.
|
||
* Dotted tokens for many-to-one/one-to-one (e.g., "owner.label") emit the scalar under "owner.label".
|
||
- Else, if '__crudkit_projection__' is set on the instance, use that.
|
||
- Else, fall back to all mapped columns on this class hierarchy.
|
||
|
||
Always includes 'id' when present unless explicitly excluded (i.e., fields explicitly provided without id).
|
||
"""
|
||
req = fields if fields is not None else getattr(self, "__crudkit_projection__", None)
|
||
|
||
if req:
|
||
# Normalize and split into (scalars, groups of dotted by root)
|
||
req_list = [p for p in (str(x).strip() for x in req) if p]
|
||
scalars, groups = _split_field_tokens(req_list)
|
||
|
||
out: Dict[str, Any] = {}
|
||
|
||
# Always include id unless the caller explicitly listed fields containing id
|
||
if "id" not in req_list and hasattr(self, "id"):
|
||
try:
|
||
out["id"] = getattr(self, "id")
|
||
except Exception:
|
||
pass
|
||
|
||
# Handle scalar tokens (may be columns, hybrids/properties, or relationships)
|
||
for name in scalars:
|
||
# Try loaded value first (never lazy-load)
|
||
val = _safe_get_loaded_attr(self, name)
|
||
|
||
# Final-hop getattr for root scalars (hybrids/@property) so they can compute.
|
||
if val is None:
|
||
try:
|
||
val = getattr(self, name)
|
||
except Exception:
|
||
val = None
|
||
|
||
# If it's a scalar ORM object (relationship), serialize its columns
|
||
mapper = _sa_mapper(val)
|
||
if mapper is not None:
|
||
out[name] = _serialize_simple_obj(val)
|
||
continue
|
||
|
||
# If it's a collection and no subfields were requested, emit a light list
|
||
if isinstance(val, (list, tuple)):
|
||
out[name] = [_serialize_leaf(v) for v in val]
|
||
else:
|
||
out[name] = val
|
||
|
||
# Handle dotted groups: root -> [tails]
|
||
for root, tails in groups.items():
|
||
root_val = _safe_get_loaded_attr(self, root)
|
||
if isinstance(root_val, (list, tuple)):
|
||
# one-to-many collection → list of dicts with the requested tails
|
||
out[root] = _serialize_collection(root_val, tails)
|
||
else:
|
||
# many-to-one or scalar dotted; place each full dotted path as key
|
||
for tail in tails:
|
||
dotted = f"{root}.{tail}"
|
||
out[dotted] = _deep_get_loaded(self, dotted)
|
||
|
||
# ← This was the placeholder before. We return the dict we just built.
|
||
return out
|
||
|
||
# Fallback: all mapped columns on this class hierarchy
|
||
result: Dict[str, Any] = {}
|
||
for cls in self.__class__.__mro__:
|
||
if hasattr(cls, "__table__"):
|
||
for column in cls.__table__.columns:
|
||
name = column.name
|
||
try:
|
||
result[name] = getattr(self, name)
|
||
except Exception:
|
||
result[name] = None
|
||
return result
|
||
|
||
class Version(Base):
|
||
__tablename__ = "versions"
|
||
|
||
id = Column(Integer, primary_key=True)
|
||
model_name = Column(String, nullable=False)
|
||
object_id = Column(Integer, nullable=False)
|
||
change_type = Column(String, nullable=False)
|
||
data = Column(JSON, nullable=True)
|
||
timestamp = Column(DateTime, default=func.now())
|
||
|
||
actor = Column(String, nullable=True)
|
||
meta = Column('metadata', JSON, nullable=True)
|