Optimization and refactoring pass.

This commit is contained in:
Yaro Kasear 2025-10-20 11:03:03 -05:00
parent 15ae0caf27
commit e829de9792
2 changed files with 130 additions and 54 deletions

View file

@ -1,18 +1,53 @@
from typing import Any, Dict, Iterable, List, Tuple, Set
from functools import lru_cache
from typing import Any, Dict, Iterable, List, Optional, Tuple, Set, cast
from sqlalchemy import Column, Integer, DateTime, Boolean, String, JSON, func, inspect
from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE, RelationshipProperty
from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE, RelationshipProperty, Mapper
from sqlalchemy.orm.state import InstanceState
Base = declarative_base()
def _safe_get_loaded_attr(obj, name):
@lru_cache(maxsize=512)
def _column_names_for_model(cls: type) -> tuple[str, ...]:
try:
mapper = inspect(cls)
return tuple(prop.key for prop in mapper.column_attrs)
except Exception:
names: list[str] = []
for c in cls.__mro__:
if hasattr(c, "__table__"):
names.extend(col.name for col in c.__table__.columns)
return tuple(dict.fromkeys(names))
def _sa_state(obj: Any) -> Optional[InstanceState[Any]]:
"""Safely get SQLAlchemy InstanceState (or None)."""
try:
st = inspect(obj)
attr = st.attrs.get(name)
if attr is not None:
return cast(Optional[InstanceState[Any]], st)
except Exception:
return None
def _sa_mapper(obj: Any) -> Optional[Mapper]:
"""Safely get Mapper for a maooed instance (or None)."""
try:
st = inspect(obj)
mapper = getattr(st, "mapper", None)
return cast(Optional[Mapper], mapper)
except Exception:
return None
def _safe_get_loaded_attr(obj, name):
st = _sa_state(obj)
if st is None:
return None
try:
attrs = getattr(st, "attrs", {}).get(name)
if attrs is not None and name in attrs:
attr = attrs[name]
val = attr.loaded_value
return None if val is NO_VALUE else val
if name in st.dict:
return st.dict.get(name)
st_dict = getattr(st, "dict", {})
if name in st_dict:
return st_dict.get(name)
return None
except Exception:
return None
@ -33,14 +68,11 @@ def _is_collection_rel(prop: RelationshipProperty) -> bool:
def _serialize_simple_obj(obj) -> Dict[str, Any]:
"""Columns only (no relationships)."""
out: Dict[str, Any] = {}
for cls in obj.__class__.__mro__:
if hasattr(cls, "__table__"):
for col in cls.__table__.columns:
name = col.name
try:
out[name] = getattr(obj, name)
except Exception:
out[name] = None
for name in _column_names_for_model(type(obj)):
try:
out[name] = getattr(obj, name)
except Exception:
out[name] = None
return out
def _serialize_loaded_rel(obj, name, *, depth: int, seen: Set[Tuple[type, Any]], embed: Set[str]) -> Any:
@ -204,12 +236,16 @@ class CRUDMixin:
# Determine which relationships to consider
try:
st = inspect(self)
mapper = st.mapper
embed_set = set(str(x).split(".", 1)[0] for x in (embed or [])) # top-level names
mapper = _sa_mapper(self)
embed_set = set(str(x).split(".", 1)[0] for x in (embed or []))
if mapper is None:
return data
st = _sa_state(self)
if st is None:
return data
for name, prop in mapper.relationships.items():
# Only touch relationships that are already loaded; never lazy-load here.
rel_loaded = st.attrs.get(name)
rel_loaded = getattr(st, "attrs", {}).get(name)
if rel_loaded is None or rel_loaded.loaded_value is NO_VALUE:
continue
@ -266,13 +302,10 @@ class CRUDMixin:
val = None
# If it's a scalar ORM object (relationship), serialize its columns
try:
st = inspect(val) # will raise if not an ORM object
if getattr(st, "mapper", None) is not None:
out[name] = _serialize_simple_obj(val)
continue
except Exception:
pass
mapper = _sa_mapper(val)
if mapper is not None:
out[name] = _serialize_simple_obj(val)
continue
# If it's a collection and no subfields were requested, emit a light list
if isinstance(val, (list, tuple)):