Make the CRUDService more reliable.

This commit is contained in:
Yaro Kasear 2025-09-24 08:28:02 -05:00
parent 515eb27fe0
commit 1c5fa29943

View file

@ -1,50 +1,14 @@
from typing import Any, Callable, Dict, Iterable, List, Tuple, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
from sqlalchemy import and_, func, inspect, or_, text from sqlalchemy import and_, func, inspect, or_, text
from sqlalchemy.engine import Engine, Connection from sqlalchemy.engine import Engine, Connection
from sqlalchemy.orm import Load, Session, raiseload, selectinload, with_polymorphic, Mapper, RelationshipProperty, class_mapper, ColumnProperty from sqlalchemy.orm import Load, Session, selectinload, with_polymorphic, Mapper, RelationshipProperty, ColumnProperty
from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.base import NO_VALUE
from sqlalchemy.orm.util import AliasedClass
from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import UnaryExpression
from crudkit.core.base import Version from crudkit.core.base import Version
from crudkit.core.spec import CRUDSpec from crudkit.core.spec import CRUDSpec
from crudkit.core.types import OrderSpec, SeekWindow from crudkit.core.types import OrderSpec, SeekWindow
from crudkit.backend import BackendInfo, make_backend_info from crudkit.backend import BackendInfo, make_backend_info
from crudkit.projection import compile_projection
def _expand_requires(model_cls, fields):
out, seen = [], set()
def add(f):
if f not in seen:
seen.add(f); out.append(f)
for f in fields:
add(f)
parts = f.split(".")
cur_cls = model_cls
prefix = []
for p in parts[:-1]:
rel = getattr(cur_cls.__mapper__.relationships, 'get', lambda _: None)(p)
if not rel:
cur_cls = None
break
cur_cls = rel.mapper.class_
prefix.append(p)
if cur_cls is None:
continue
leaf = parts[-1]
deps = (getattr(cur_cls, "__crudkit_field_requires__", {}) or {}).get(leaf)
if not deps:
continue
pre = ".".join(prefix)
for dep in deps:
add(f"{pre + '.' if pre else ''}{dep}")
return out
def _is_rel(model_cls, name: str) -> bool: def _is_rel(model_cls, name: str) -> bool:
try: try:
@ -53,41 +17,6 @@ def _is_rel(model_cls, name: str) -> bool:
except Exception: except Exception:
return False return False
def _is_instrumented_column(attr) -> bool:
try:
return hasattr(attr, "property") and isinstance(attr.property, ColumnProperty)
except Exception:
return False
def _loader_options_for_fields(root_alias, model_cls, fields: list[str]) -> list[Load]:
"""
For bare MANYTOONE names in fields (e.g. "location"), selectinload the relationship
and only fetch the related PK. This is enough for preselecting <select> inputs
without projecting the FK column on the root model.
"""
opts: list[Load] = []
if not fields:
return opts
mapper = class_mapper(model_cls)
for name in fields:
prop = mapper.relationships.get(name)
if not isinstance(prop, RelationshipProperty):
continue
if prop.direction.name != "MANYTOONE":
continue
rel_attr = getattr(root_alias, name)
target_cls = prop.mapper.class_
# load_only PK if present; else just selectinload
id_attr = getattr(target_cls, "id", None)
if id_attr is not None:
opts.append(selectinload(rel_attr).load_only(id_attr))
else:
opts.append(selectinload(rel_attr))
return opts
@runtime_checkable @runtime_checkable
class _HasID(Protocol): class _HasID(Protocol):
id: int id: int
@ -141,59 +70,6 @@ class CRUDService(Generic[T]):
return self.session.query(poly), poly return self.session.query(poly), poly
return self.session.query(self.model), self.model return self.session.query(self.model), self.model
def _resolve_required_includes(self, root_alias: Any, rel_field_names: Dict[Tuple[str, ...], List[str]]) -> List[Any]:
"""
For each dotted path like ("location"), -> ["label"], look up the target
model's __crudkit_field_requires__ for the terminal field and produce
selectinload options prefixed with the relationship path, e.g.:
Room.__crudkit_field_requires__['label'] = ['room_function']
=> selectinload(root.location).selectinload(Room.room_function)
"""
opts: List[Any] = []
root_mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
for path, names in (rel_field_names or {}).items():
if not path:
continue
current_mapper = root_mapper
rel_props: List[RelationshipProperty] = []
valid = True
for step in path:
rel = current_mapper.relationships.get(step)
if not isinstance(rel, RelationshipProperty):
valid = False
break
rel_props.append(rel)
current_mapper = cast(Mapper[Any], inspect(rel.entity.entity))
if not valid or not rel_props:
continue
first = rel_props[0]
base_loader = selectinload(getattr(root_alias, first.key))
for i in range(1, len(rel_props)):
prev_target_cls = rel_props[i - 1].mapper.class_
hop_attr = getattr(prev_target_cls, rel_props[i].key)
base_loader = base_loader.selectinload(hop_attr)
target_cls = rel_props[-1].mapper.class_
requires = getattr(target_cls, "__crudkit_field_requires__", None)
if not isinstance(requires, dict):
continue
for field_name in names:
needed: Iterable[str] = requires.get(field_name, []) or []
for rel_need in needed:
rel_prop2 = target_cls.__mapper__.relationships.get(rel_need)
if not isinstance(rel_prop2, RelationshipProperty):
continue
dep_attr = getattr(target_cls, rel_prop2.key)
opts.append(base_loader.selectinload(dep_attr))
return opts
def _extract_order_spec(self, root_alias, given_order_by): def _extract_order_spec(self, root_alias, given_order_by):
""" """
SQLAlchemy 2.x only: SQLAlchemy 2.x only:
@ -266,11 +142,11 @@ class CRUDService(Generic[T]):
- forward/backward seek via `key` and `backward` - forward/backward seek via `key` and `backward`
Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total. Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total.
""" """
fields = list(params.get("fields", [])) fields = list((params or {}).get("fields", []))
if fields: expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], [])
fields = _expand_requires(self.model, fields)
params = {**params, "fields": fields}
query, root_alias = self.get_query() query, root_alias = self.get_query()
if proj_opts:
query = query.options(*proj_opts)
spec = CRUDSpec(self.model, params or {}, root_alias) spec = CRUDSpec(self.model, params or {}, root_alias)
@ -279,21 +155,8 @@ class CRUDService(Generic[T]):
root_fields, rel_field_names, root_field_names = spec.parse_fields() root_fields, rel_field_names, root_field_names = spec.parse_fields()
seen_rel_roots = set()
for path, names in (rel_field_names or {}).items():
if not path:
continue
rel_name = path[0]
if rel_name in seen_rel_roots:
continue
if _is_rel(self.model, rel_name):
rel_attr = getattr(root_alias, rel_name, None)
if rel_attr is not None:
query = query.options(selectinload(rel_attr))
seen_rel_roots.add(rel_name)
# Soft delete filter # Soft delete filter
if self.supports_soft_delete and not _is_truthy(params.get("include_deleted")): if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")):
query = query.filter(getattr(root_alias, "is_deleted") == False) query = query.filter(getattr(root_alias, "is_deleted") == False)
# Parse filters first # Parse filters first
@ -302,7 +165,7 @@ class CRUDService(Generic[T]):
# Includes + joins (so relationship fields like brand.name, location.label work) # Includes + joins (so relationship fields like brand.name, location.label work)
spec.parse_includes() spec.parse_includes()
for parent_alias, relationship_attr, target_alias in spec.get_join_paths(): for _, relationship_attr, target_alias in spec.get_join_paths():
rel_attr = cast(InstrumentedAttribute, relationship_attr) rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias) target = cast(Any, target_alias)
query = query.join(target, rel_attr.of_type(target), isouter=True) query = query.join(target, rel_attr.of_type(target), isouter=True)
@ -311,11 +174,6 @@ class CRUDService(Generic[T]):
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols: if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols)) query = query.options(Load(root_alias).load_only(*only_cols))
# for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
# query = query.options(eager)
for opt in self._resolve_required_includes(root_alias, rel_field_names):
query = query.options(opt)
# Order + limit # Order + limit
order_spec = self._extract_order_spec(root_alias, order_by) # SA2-only helper order_spec = self._extract_order_spec(root_alias, order_by) # SA2-only helper
@ -344,17 +202,21 @@ class CRUDService(Generic[T]):
items = list(reversed(query.all())) items = list(reversed(query.all()))
# Tag projection so your renderer knows what fields were requested # Tag projection so your renderer knows what fields were requested
proj = [] if expanded_fields:
if root_field_names: proj = list(expanded_fields)
proj.extend(root_field_names) else:
if root_fields: proj = []
proj.extend(c.key for c in root_fields) if root_field_names:
for path, names in (rel_field_names or {}).items(): proj.extend(root_field_names)
prefix = ".".join(path) if root_fields:
for n in names: proj.extend(c.key for c in root_fields)
proj.append(f"{prefix}.{n}") for path, names in (rel_field_names or {}).items():
if proj and "id" not in proj and hasattr(self.model, "id"): prefix = ".".join(path)
proj.insert(0, "id") for n in names:
proj.append(f"{prefix}.{n}")
if proj and "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
if proj: if proj:
for obj in items: for obj in items:
try: try:
@ -375,11 +237,12 @@ class CRUDService(Generic[T]):
if filters: if filters:
base = base.filter(*filters) base = base.filter(*filters)
# replicate the same joins used above # replicate the same joins used above
for parent_alias, relationship_attr, target_alias in spec.get_join_paths(): for _, relationship_attr, target_alias in spec.get_join_paths():
rel_attr = cast(InstrumentedAttribute, relationship_attr) rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias) target = cast(Any, target_alias)
base = base.join(target, rel_attr.of_type(target), isouter=True) base = base.join(target, rel_attr.of_type(target), isouter=True)
total = self.session.query(func.count()).select_from(base.order_by(None).distinct().subquery()).scalar() or 0 total = self.session.query(func.count()).select_from(base.order_by(None).distinct().subquery()).scalar() or 0
print(f"!!! QUERY !!! -> {str(query)}")
from crudkit.core.types import SeekWindow # avoid circulars at module top from crudkit.core.types import SeekWindow # avoid circulars at module top
return SeekWindow( return SeekWindow(
@ -439,7 +302,7 @@ class CRUDService(Generic[T]):
spec.parse_includes() spec.parse_includes()
for parent_alias, relationship_attr, target_alias in spec.get_join_paths(): for _, relationship_attr, target_alias in spec.get_join_paths():
rel_attr = cast(InstrumentedAttribute, relationship_attr) rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias) target = cast(Any, target_alias)
query = query.join(target, rel_attr.of_type(target), isouter=True) query = query.join(target, rel_attr.of_type(target), isouter=True)
@ -447,58 +310,39 @@ class CRUDService(Generic[T]):
if params: if params:
root_fields, rel_field_names, root_field_names = spec.parse_fields() root_fields, rel_field_names, root_field_names = spec.parse_fields()
if rel_field_names: req_fields = list((params or {}).get("fields", []))
seen_rel_roots = set() expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
for path, names in rel_field_names.items(): if proj_opts:
if not path: query = query.options(*proj_opts)
continue
rel_name = path[0]
if rel_name in seen_rel_roots:
continue
if _is_rel(self.model, rel_name):
rel_attr = getattr(root_alias, rel_name, None)
if rel_attr is not None:
query = query.options(selectinload(rel_attr))
seen_rel_roots.add(rel_name)
fields = (params or {}).get("fields") if isinstance(params, dict) else None
if fields:
for opt in _loader_options_for_fields(root_alias, self.model, fields):
query = query.options(opt)
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols: if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols)) query = query.options(Load(root_alias).load_only(*only_cols))
# for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
# query = query.options(eager)
if params:
fields = params.get("fields") or []
for opt in _loader_options_for_fields(root_alias, self.model, fields):
query = query.options(opt)
obj = query.first() obj = query.first()
proj = [] if expanded_fields:
if root_field_names: proj = list(expanded_fields)
proj.extend(root_field_names) else:
if root_fields: proj = []
proj.extend(c.key for c in root_fields) if root_field_names:
for path, names in (rel_field_names or {}).items(): proj.extend(root_field_names)
prefix = ".".join(path) if root_fields:
for n in names: proj.extend(c.key for c in root_fields)
proj.append(f"{prefix}.{n}") for path, names in (rel_field_names or {}).items():
prefix = ".".join(path)
for n in names:
proj.append(f"{prefix}.{n}")
if proj and "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
if proj and "id" not in proj and hasattr(self.model, "id"): if proj and obj is not None:
proj.insert(0, "id")
if proj:
try: try:
setattr(obj, "__crudkit_projection__", tuple(proj)) setattr(obj, "__crudkit_projection__", tuple(proj))
except Exception: except Exception:
pass pass
print(f"!!! QUERY !!! -> {str(query)}")
return obj or None return obj or None
def list(self, params=None) -> list[T]: def list(self, params=None) -> list[T]:
@ -520,7 +364,7 @@ class CRUDService(Generic[T]):
limit, offset = spec.parse_pagination() limit, offset = spec.parse_pagination()
spec.parse_includes() spec.parse_includes()
for parent_alias, relationship_attr, target_alias in spec.get_join_paths(): for _, relationship_attr, target_alias in spec.get_join_paths():
rel_attr = cast(InstrumentedAttribute, relationship_attr) rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias) target = cast(Any, target_alias)
query = query.join(target, rel_attr.of_type(target), isouter=True) query = query.join(target, rel_attr.of_type(target), isouter=True)
@ -528,32 +372,10 @@ class CRUDService(Generic[T]):
if params: if params:
root_fields, rel_field_names, root_field_names = spec.parse_fields() root_fields, rel_field_names, root_field_names = spec.parse_fields()
if rel_field_names:
seen_rel_roots = set()
for path, names in rel_field_names.items():
if not path:
continue
rel_name = path[0]
if rel_name in seen_rel_roots:
continue
if _is_rel(self.model, rel_name):
rel_attr = getattr(root_alias, rel_name, None)
if rel_attr is not None:
query = query.options(selectinload(rel_attr))
seen_rel_roots.add(rel_name)
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols: if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols)) query = query.options(Load(root_alias).load_only(*only_cols))
# for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
# query = query.options(eager)
if params:
fields = params.get("fields") or []
for opt in _loader_options_for_fields(root_alias, self.model, fields):
query = query.options(opt)
if filters: if filters:
query = query.filter(*filters) query = query.filter(*filters)
@ -571,20 +393,27 @@ class CRUDService(Generic[T]):
if limit is not None and limit > 0: if limit is not None and limit > 0:
query = query.limit(limit) query = query.limit(limit)
req_fields = list((params or {}).get("fields", []))
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
if proj_opts:
query = query.options(*proj_opts)
rows = query.all() rows = query.all()
proj = [] if expanded_fields:
if root_field_names: proj = list(expanded_fields)
proj.extend(root_field_names) else:
if root_fields: proj = []
proj.extend(c.key for c in root_fields) if root_field_names:
for path, names in (rel_field_names or {}).items(): proj.extend(root_field_names)
prefix = ".".join(path) if root_fields:
for n in names: proj.extend(c.key for c in root_fields)
proj.append(f"{prefix}.{n}") for path, names in (rel_field_names or {}).items():
prefix = ".".join(path)
if proj and "id" not in proj and hasattr(self.model, "id"): for n in names:
proj.insert(0, "id") proj.append(f"{prefix}.{n}")
if proj and "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
if proj: if proj:
for obj in rows: for obj in rows:
@ -593,6 +422,7 @@ class CRUDService(Generic[T]):
except Exception: except Exception:
pass pass
print(f"!!! QUERY !!! -> {str(query)}")
return rows return rows
def create(self, data: dict, actor=None) -> T: def create(self, data: dict, actor=None) -> T: