Make the CRUDService more reliable.
This commit is contained in:
parent
515eb27fe0
commit
1c5fa29943
1 changed files with 68 additions and 238 deletions
|
|
@ -1,50 +1,14 @@
|
|||
from typing import Any, Callable, Dict, Iterable, List, Tuple, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
|
||||
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
|
||||
from sqlalchemy import and_, func, inspect, or_, text
|
||||
from sqlalchemy.engine import Engine, Connection
|
||||
from sqlalchemy.orm import Load, Session, raiseload, selectinload, with_polymorphic, Mapper, RelationshipProperty, class_mapper, ColumnProperty
|
||||
from sqlalchemy.orm import Load, Session, selectinload, with_polymorphic, Mapper, RelationshipProperty, ColumnProperty
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
from sqlalchemy.orm.base import NO_VALUE
|
||||
from sqlalchemy.orm.util import AliasedClass
|
||||
from sqlalchemy.sql import operators
|
||||
from sqlalchemy.sql.elements import UnaryExpression
|
||||
|
||||
from crudkit.core.base import Version
|
||||
from crudkit.core.spec import CRUDSpec
|
||||
from crudkit.core.types import OrderSpec, SeekWindow
|
||||
from crudkit.backend import BackendInfo, make_backend_info
|
||||
|
||||
def _expand_requires(model_cls, fields):
|
||||
out, seen = [], set()
|
||||
def add(f):
|
||||
if f not in seen:
|
||||
seen.add(f); out.append(f)
|
||||
|
||||
for f in fields:
|
||||
add(f)
|
||||
parts = f.split(".")
|
||||
cur_cls = model_cls
|
||||
prefix = []
|
||||
|
||||
for p in parts[:-1]:
|
||||
rel = getattr(cur_cls.__mapper__.relationships, 'get', lambda _: None)(p)
|
||||
if not rel:
|
||||
cur_cls = None
|
||||
break
|
||||
cur_cls = rel.mapper.class_
|
||||
prefix.append(p)
|
||||
|
||||
if cur_cls is None:
|
||||
continue
|
||||
|
||||
leaf = parts[-1]
|
||||
deps = (getattr(cur_cls, "__crudkit_field_requires__", {}) or {}).get(leaf)
|
||||
if not deps:
|
||||
continue
|
||||
|
||||
pre = ".".join(prefix)
|
||||
for dep in deps:
|
||||
add(f"{pre + '.' if pre else ''}{dep}")
|
||||
return out
|
||||
from crudkit.projection import compile_projection
|
||||
|
||||
def _is_rel(model_cls, name: str) -> bool:
|
||||
try:
|
||||
|
|
@ -53,41 +17,6 @@ def _is_rel(model_cls, name: str) -> bool:
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
def _is_instrumented_column(attr) -> bool:
|
||||
try:
|
||||
return hasattr(attr, "property") and isinstance(attr.property, ColumnProperty)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _loader_options_for_fields(root_alias, model_cls, fields: list[str]) -> list[Load]:
|
||||
"""
|
||||
For bare MANYTOONE names in fields (e.g. "location"), selectinload the relationship
|
||||
and only fetch the related PK. This is enough for preselecting <select> inputs
|
||||
without projecting the FK column on the root model.
|
||||
"""
|
||||
opts: list[Load] = []
|
||||
if not fields:
|
||||
return opts
|
||||
|
||||
mapper = class_mapper(model_cls)
|
||||
for name in fields:
|
||||
prop = mapper.relationships.get(name)
|
||||
if not isinstance(prop, RelationshipProperty):
|
||||
continue
|
||||
if prop.direction.name != "MANYTOONE":
|
||||
continue
|
||||
|
||||
rel_attr = getattr(root_alias, name)
|
||||
target_cls = prop.mapper.class_
|
||||
# load_only PK if present; else just selectinload
|
||||
id_attr = getattr(target_cls, "id", None)
|
||||
if id_attr is not None:
|
||||
opts.append(selectinload(rel_attr).load_only(id_attr))
|
||||
else:
|
||||
opts.append(selectinload(rel_attr))
|
||||
|
||||
return opts
|
||||
|
||||
@runtime_checkable
|
||||
class _HasID(Protocol):
|
||||
id: int
|
||||
|
|
@ -141,59 +70,6 @@ class CRUDService(Generic[T]):
|
|||
return self.session.query(poly), poly
|
||||
return self.session.query(self.model), self.model
|
||||
|
||||
def _resolve_required_includes(self, root_alias: Any, rel_field_names: Dict[Tuple[str, ...], List[str]]) -> List[Any]:
|
||||
"""
|
||||
For each dotted path like ("location"), -> ["label"], look up the target
|
||||
model's __crudkit_field_requires__ for the terminal field and produce
|
||||
selectinload options prefixed with the relationship path, e.g.:
|
||||
Room.__crudkit_field_requires__['label'] = ['room_function']
|
||||
=> selectinload(root.location).selectinload(Room.room_function)
|
||||
"""
|
||||
opts: List[Any] = []
|
||||
root_mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
|
||||
|
||||
for path, names in (rel_field_names or {}).items():
|
||||
if not path:
|
||||
continue
|
||||
|
||||
current_mapper = root_mapper
|
||||
rel_props: List[RelationshipProperty] = []
|
||||
|
||||
valid = True
|
||||
for step in path:
|
||||
rel = current_mapper.relationships.get(step)
|
||||
if not isinstance(rel, RelationshipProperty):
|
||||
valid = False
|
||||
break
|
||||
rel_props.append(rel)
|
||||
current_mapper = cast(Mapper[Any], inspect(rel.entity.entity))
|
||||
if not valid or not rel_props:
|
||||
continue
|
||||
|
||||
first = rel_props[0]
|
||||
base_loader = selectinload(getattr(root_alias, first.key))
|
||||
for i in range(1, len(rel_props)):
|
||||
prev_target_cls = rel_props[i - 1].mapper.class_
|
||||
hop_attr = getattr(prev_target_cls, rel_props[i].key)
|
||||
base_loader = base_loader.selectinload(hop_attr)
|
||||
|
||||
target_cls = rel_props[-1].mapper.class_
|
||||
|
||||
requires = getattr(target_cls, "__crudkit_field_requires__", None)
|
||||
if not isinstance(requires, dict):
|
||||
continue
|
||||
|
||||
for field_name in names:
|
||||
needed: Iterable[str] = requires.get(field_name, []) or []
|
||||
for rel_need in needed:
|
||||
rel_prop2 = target_cls.__mapper__.relationships.get(rel_need)
|
||||
if not isinstance(rel_prop2, RelationshipProperty):
|
||||
continue
|
||||
dep_attr = getattr(target_cls, rel_prop2.key)
|
||||
opts.append(base_loader.selectinload(dep_attr))
|
||||
|
||||
return opts
|
||||
|
||||
def _extract_order_spec(self, root_alias, given_order_by):
|
||||
"""
|
||||
SQLAlchemy 2.x only:
|
||||
|
|
@ -266,11 +142,11 @@ class CRUDService(Generic[T]):
|
|||
- forward/backward seek via `key` and `backward`
|
||||
Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total.
|
||||
"""
|
||||
fields = list(params.get("fields", []))
|
||||
if fields:
|
||||
fields = _expand_requires(self.model, fields)
|
||||
params = {**params, "fields": fields}
|
||||
fields = list((params or {}).get("fields", []))
|
||||
expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], [])
|
||||
query, root_alias = self.get_query()
|
||||
if proj_opts:
|
||||
query = query.options(*proj_opts)
|
||||
|
||||
spec = CRUDSpec(self.model, params or {}, root_alias)
|
||||
|
||||
|
|
@ -279,21 +155,8 @@ class CRUDService(Generic[T]):
|
|||
|
||||
root_fields, rel_field_names, root_field_names = spec.parse_fields()
|
||||
|
||||
seen_rel_roots = set()
|
||||
for path, names in (rel_field_names or {}).items():
|
||||
if not path:
|
||||
continue
|
||||
rel_name = path[0]
|
||||
if rel_name in seen_rel_roots:
|
||||
continue
|
||||
if _is_rel(self.model, rel_name):
|
||||
rel_attr = getattr(root_alias, rel_name, None)
|
||||
if rel_attr is not None:
|
||||
query = query.options(selectinload(rel_attr))
|
||||
seen_rel_roots.add(rel_name)
|
||||
|
||||
# Soft delete filter
|
||||
if self.supports_soft_delete and not _is_truthy(params.get("include_deleted")):
|
||||
if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")):
|
||||
query = query.filter(getattr(root_alias, "is_deleted") == False)
|
||||
|
||||
# Parse filters first
|
||||
|
|
@ -302,7 +165,7 @@ class CRUDService(Generic[T]):
|
|||
|
||||
# Includes + joins (so relationship fields like brand.name, location.label work)
|
||||
spec.parse_includes()
|
||||
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
|
||||
for _, relationship_attr, target_alias in spec.get_join_paths():
|
||||
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
||||
target = cast(Any, target_alias)
|
||||
query = query.join(target, rel_attr.of_type(target), isouter=True)
|
||||
|
|
@ -311,11 +174,6 @@ class CRUDService(Generic[T]):
|
|||
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
|
||||
if only_cols:
|
||||
query = query.options(Load(root_alias).load_only(*only_cols))
|
||||
# for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
|
||||
# query = query.options(eager)
|
||||
|
||||
for opt in self._resolve_required_includes(root_alias, rel_field_names):
|
||||
query = query.options(opt)
|
||||
|
||||
# Order + limit
|
||||
order_spec = self._extract_order_spec(root_alias, order_by) # SA2-only helper
|
||||
|
|
@ -344,6 +202,9 @@ class CRUDService(Generic[T]):
|
|||
items = list(reversed(query.all()))
|
||||
|
||||
# Tag projection so your renderer knows what fields were requested
|
||||
if expanded_fields:
|
||||
proj = list(expanded_fields)
|
||||
else:
|
||||
proj = []
|
||||
if root_field_names:
|
||||
proj.extend(root_field_names)
|
||||
|
|
@ -355,6 +216,7 @@ class CRUDService(Generic[T]):
|
|||
proj.append(f"{prefix}.{n}")
|
||||
if proj and "id" not in proj and hasattr(self.model, "id"):
|
||||
proj.insert(0, "id")
|
||||
|
||||
if proj:
|
||||
for obj in items:
|
||||
try:
|
||||
|
|
@ -375,11 +237,12 @@ class CRUDService(Generic[T]):
|
|||
if filters:
|
||||
base = base.filter(*filters)
|
||||
# replicate the same joins used above
|
||||
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
|
||||
for _, relationship_attr, target_alias in spec.get_join_paths():
|
||||
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
||||
target = cast(Any, target_alias)
|
||||
base = base.join(target, rel_attr.of_type(target), isouter=True)
|
||||
total = self.session.query(func.count()).select_from(base.order_by(None).distinct().subquery()).scalar() or 0
|
||||
print(f"!!! QUERY !!! -> {str(query)}")
|
||||
|
||||
from crudkit.core.types import SeekWindow # avoid circulars at module top
|
||||
return SeekWindow(
|
||||
|
|
@ -439,7 +302,7 @@ class CRUDService(Generic[T]):
|
|||
|
||||
spec.parse_includes()
|
||||
|
||||
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
|
||||
for _, relationship_attr, target_alias in spec.get_join_paths():
|
||||
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
||||
target = cast(Any, target_alias)
|
||||
query = query.join(target, rel_attr.of_type(target), isouter=True)
|
||||
|
|
@ -447,39 +310,20 @@ class CRUDService(Generic[T]):
|
|||
if params:
|
||||
root_fields, rel_field_names, root_field_names = spec.parse_fields()
|
||||
|
||||
if rel_field_names:
|
||||
seen_rel_roots = set()
|
||||
for path, names in rel_field_names.items():
|
||||
if not path:
|
||||
continue
|
||||
rel_name = path[0]
|
||||
if rel_name in seen_rel_roots:
|
||||
continue
|
||||
if _is_rel(self.model, rel_name):
|
||||
rel_attr = getattr(root_alias, rel_name, None)
|
||||
if rel_attr is not None:
|
||||
query = query.options(selectinload(rel_attr))
|
||||
seen_rel_roots.add(rel_name)
|
||||
|
||||
fields = (params or {}).get("fields") if isinstance(params, dict) else None
|
||||
if fields:
|
||||
for opt in _loader_options_for_fields(root_alias, self.model, fields):
|
||||
query = query.options(opt)
|
||||
req_fields = list((params or {}).get("fields", []))
|
||||
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
|
||||
if proj_opts:
|
||||
query = query.options(*proj_opts)
|
||||
|
||||
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
|
||||
if only_cols:
|
||||
query = query.options(Load(root_alias).load_only(*only_cols))
|
||||
|
||||
# for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
|
||||
# query = query.options(eager)
|
||||
|
||||
if params:
|
||||
fields = params.get("fields") or []
|
||||
for opt in _loader_options_for_fields(root_alias, self.model, fields):
|
||||
query = query.options(opt)
|
||||
|
||||
obj = query.first()
|
||||
|
||||
if expanded_fields:
|
||||
proj = list(expanded_fields)
|
||||
else:
|
||||
proj = []
|
||||
if root_field_names:
|
||||
proj.extend(root_field_names)
|
||||
|
|
@ -489,16 +333,16 @@ class CRUDService(Generic[T]):
|
|||
prefix = ".".join(path)
|
||||
for n in names:
|
||||
proj.append(f"{prefix}.{n}")
|
||||
|
||||
if proj and "id" not in proj and hasattr(self.model, "id"):
|
||||
proj.insert(0, "id")
|
||||
|
||||
if proj:
|
||||
if proj and obj is not None:
|
||||
try:
|
||||
setattr(obj, "__crudkit_projection__", tuple(proj))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print(f"!!! QUERY !!! -> {str(query)}")
|
||||
return obj or None
|
||||
|
||||
def list(self, params=None) -> list[T]:
|
||||
|
|
@ -520,7 +364,7 @@ class CRUDService(Generic[T]):
|
|||
limit, offset = spec.parse_pagination()
|
||||
spec.parse_includes()
|
||||
|
||||
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
|
||||
for _, relationship_attr, target_alias in spec.get_join_paths():
|
||||
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
||||
target = cast(Any, target_alias)
|
||||
query = query.join(target, rel_attr.of_type(target), isouter=True)
|
||||
|
|
@ -528,32 +372,10 @@ class CRUDService(Generic[T]):
|
|||
if params:
|
||||
root_fields, rel_field_names, root_field_names = spec.parse_fields()
|
||||
|
||||
if rel_field_names:
|
||||
seen_rel_roots = set()
|
||||
for path, names in rel_field_names.items():
|
||||
if not path:
|
||||
continue
|
||||
rel_name = path[0]
|
||||
if rel_name in seen_rel_roots:
|
||||
continue
|
||||
if _is_rel(self.model, rel_name):
|
||||
rel_attr = getattr(root_alias, rel_name, None)
|
||||
if rel_attr is not None:
|
||||
query = query.options(selectinload(rel_attr))
|
||||
seen_rel_roots.add(rel_name)
|
||||
|
||||
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
|
||||
if only_cols:
|
||||
query = query.options(Load(root_alias).load_only(*only_cols))
|
||||
|
||||
# for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
|
||||
# query = query.options(eager)
|
||||
|
||||
if params:
|
||||
fields = params.get("fields") or []
|
||||
for opt in _loader_options_for_fields(root_alias, self.model, fields):
|
||||
query = query.options(opt)
|
||||
|
||||
if filters:
|
||||
query = query.filter(*filters)
|
||||
|
||||
|
|
@ -571,8 +393,16 @@ class CRUDService(Generic[T]):
|
|||
if limit is not None and limit > 0:
|
||||
query = query.limit(limit)
|
||||
|
||||
req_fields = list((params or {}).get("fields", []))
|
||||
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
|
||||
if proj_opts:
|
||||
query = query.options(*proj_opts)
|
||||
|
||||
rows = query.all()
|
||||
|
||||
if expanded_fields:
|
||||
proj = list(expanded_fields)
|
||||
else:
|
||||
proj = []
|
||||
if root_field_names:
|
||||
proj.extend(root_field_names)
|
||||
|
|
@ -582,7 +412,6 @@ class CRUDService(Generic[T]):
|
|||
prefix = ".".join(path)
|
||||
for n in names:
|
||||
proj.append(f"{prefix}.{n}")
|
||||
|
||||
if proj and "id" not in proj and hasattr(self.model, "id"):
|
||||
proj.insert(0, "id")
|
||||
|
||||
|
|
@ -593,6 +422,7 @@ class CRUDService(Generic[T]):
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
print(f"!!! QUERY !!! -> {str(query)}")
|
||||
return rows
|
||||
|
||||
def create(self, data: dict, actor=None) -> T:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue