Work begins on proper one to many support.

This commit is contained in:
Yaro Kasear 2025-09-26 11:24:49 -05:00
parent 811b534b89
commit ffa49f13e9
3 changed files with 82 additions and 24 deletions

View file

@ -5,7 +5,7 @@ from flask import current_app
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
from sqlalchemy import and_, func, inspect, or_, text from sqlalchemy import and_, func, inspect, or_, text
from sqlalchemy.engine import Engine, Connection from sqlalchemy.engine import Engine, Connection
from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager, selectinload
from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql import operators from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import UnaryExpression, ColumnElement from sqlalchemy.sql.elements import UnaryExpression, ColumnElement
@ -49,7 +49,7 @@ def _unwrap_ob(ob):
is_desc = False is_desc = False
dir_attr = getattr(ob, "_direction", None) dir_attr = getattr(ob, "_direction", None)
if dir_attr is not None: if dir_attr is not None:
is_desc = (dir_attr is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC")
elif isinstance(ob, UnaryExpression): elif isinstance(ob, UnaryExpression):
op = getattr(ob, "operator", None) op = getattr(ob, "operator", None)
is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC")
@ -231,7 +231,7 @@ class CRUDService(Generic[T]):
# Parse all inputs so join_paths are populated # Parse all inputs so join_paths are populated
filters = spec.parse_filters() filters = spec.parse_filters()
order_by = spec.parse_sort() order_by = spec.parse_sort()
root_fields, rel_field_names, root_field_names = spec.parse_fields() root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields()
spec.parse_includes() spec.parse_includes()
join_paths = tuple(spec.get_join_paths()) join_paths = tuple(spec.get_join_paths())
@ -243,12 +243,25 @@ class CRUDService(Generic[T]):
if only_cols: if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols)) query = query.options(Load(root_alias).load_only(*only_cols))
# JOIN all resolved paths, hydrate from the join # JOIN all resolved paths; for collections use selectinload (never join)
used_contains_eager = False used_contains_eager = False
for _base_alias, rel_attr, target_alias in join_paths: for base_alias, rel_attr, target_alias in join_paths:
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False))
query = query.options(contains_eager(rel_attr, alias=target_alias)) if is_collection:
used_contains_eager = True opt = selectinload(rel_attr)
# narroe child columns it requested (e.g., updates.id,updates.timestamp)
child_names = (collection_field_names or {}).get(rel_attr.key, [])
if child_names:
target_cls = rel_attr.property.mapper.class_
cols = [getattr(target_cls, n, None) for n in child_names]
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
if cols:
opt = opt.load_only(*cols)
query = query.options(opt)
else:
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True
# Filters AFTER joins → no cartesian products # Filters AFTER joins → no cartesian products
if filters: if filters:
@ -346,8 +359,10 @@ class CRUDService(Generic[T]):
base = session.query(getattr(root_alias, "id")) base = session.query(getattr(root_alias, "id"))
base = self._apply_not_deleted(base, root_alias, params) base = self._apply_not_deleted(base, root_alias, params)
# same joins as above for correctness # same joins as above for correctness
for _base_alias, rel_attr, target_alias in join_paths: for base_alias, rel_attr, target_alias in join_paths:
base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True) # do not join collections for COUNT mirror
if not bool(getattr(getattr(rel_attr, "property", None), "uselist", False)):
base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
if filters: if filters:
base = base.filter(*filters) base = base.filter(*filters)
total = session.query(func.count()).select_from( total = session.query(func.count()).select_from(
@ -428,7 +443,7 @@ class CRUDService(Generic[T]):
filters = spec.parse_filters() filters = spec.parse_filters()
# no ORDER BY for get() # no ORDER BY for get()
if params: if params:
root_fields, rel_field_names, root_field_names = spec.parse_fields() root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields()
spec.parse_includes() spec.parse_includes()
join_paths = tuple(spec.get_join_paths()) join_paths = tuple(spec.get_join_paths())
@ -438,12 +453,24 @@ class CRUDService(Generic[T]):
if only_cols: if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols)) query = query.options(Load(root_alias).load_only(*only_cols))
# JOIN all discovered paths up front; hydrate via contains_eager # JOIN non-collections only; collections via selectinload
used_contains_eager = False used_contains_eager = False
for _base_alias, rel_attr, target_alias in join_paths: for base_alias, rel_attr, target_alias in join_paths:
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False))
query = query.options(contains_eager(rel_attr, alias=target_alias)) if is_collection:
used_contains_eager = True opt = selectinload(rel_attr)
child_names = (collection_field_names or {}).get(rel_attr.key, [])
if child_names:
target_cls = rel_attr.property.mapper.class_
cols = [getattr(target_cls, n, None) for n in child_names]
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
if cols:
opt = opt.load_only(*cols)
query = query.options(opt)
else:
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True
# Apply filters (joins are in place → no cartesian products) # Apply filters (joins are in place → no cartesian products)
if filters: if filters:

View file

@ -1,4 +1,4 @@
from typing import List, Tuple, Set, Dict, Optional from typing import List, Tuple, Set, Dict, Optional, Iterable
from sqlalchemy import asc, desc from sqlalchemy import asc, desc
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import aliased, selectinload from sqlalchemy.orm import aliased, selectinload
@ -20,10 +20,14 @@ class CRUDSpec:
self.params = params self.params = params
self.root_alias = root_alias self.root_alias = root_alias
self.eager_paths: Set[Tuple[str, ...]] = set() self.eager_paths: Set[Tuple[str, ...]] = set()
# (parent_alias. relationship_attr, alias_for_target)
self.join_paths: List[Tuple[object, InstrumentedAttribute, object]] = [] self.join_paths: List[Tuple[object, InstrumentedAttribute, object]] = []
self.alias_map: Dict[Tuple[str, ...], object] = {} self.alias_map: Dict[Tuple[str, ...], object] = {}
self._root_fields: List[InstrumentedAttribute] = [] self._root_fields: List[InstrumentedAttribute] = []
self._rel_field_names: Dict[Tuple[str, ...], object] = {} # dotted non-collection fields (MANYTOONE etc)
self._rel_field_names: Dict[Tuple[str, ...], List[str]] = {}
# dotted collection fields (ONETOMANY)
self._collection_field_names: Dict[str, List[str]] = {}
self.include_paths: Set[Tuple[str, ...]] = set() self.include_paths: Set[Tuple[str, ...]] = set()
def _resolve_column(self, path: str): def _resolve_column(self, path: str):
@ -117,11 +121,12 @@ class CRUDSpec:
Parse ?fields=colA,colB,rel1.colC,rel1.rel2.colD Parse ?fields=colA,colB,rel1.colC,rel1.rel2.colD
- Root fields become InstrumentedAttributes bound to root_alias. - Root fields become InstrumentedAttributes bound to root_alias.
- Related fields store attribute NAMES; we'll resolve them on the target class when building loader options. - Related fields store attribute NAMES; we'll resolve them on the target class when building loader options.
Returns (root_fields, rel_field_names). - Collection (uselist=True) relationships record child names by relationship key.
Returns (root_fields, rel_field_names, root_field_names, collection_field_names_by_rel).
""" """
raw = self.params.get('fields') raw = self.params.get('fields')
if not raw: if not raw:
return [], {}, {} return [], {}, {}, {}
if isinstance(raw, list): if isinstance(raw, list):
tokens = [] tokens = []
@ -133,14 +138,36 @@ class CRUDSpec:
root_fields: List[InstrumentedAttribute] = [] root_fields: List[InstrumentedAttribute] = []
root_field_names: list[str] = [] root_field_names: list[str] = []
rel_field_names: Dict[Tuple[str, ...], List[str]] = {} rel_field_names: Dict[Tuple[str, ...], List[str]] = {}
collection_field_names: Dict[str, List[str]] = {}
for token in tokens: for token in tokens:
col, join_path = self._resolve_column(token) col, join_path = self._resolve_column(token)
if not col: if not col:
continue continue
if join_path: if join_path:
rel_field_names.setdefault(join_path, []).append(col.key) # rel_field_names.setdefault(join_path, []).append(col.key)
self.eager_paths.add(join_path) # self.eager_paths.add(join_path)
try:
cur_cls = self.model
names = list(join_path)
last_name = names[-1]
for nm in names:
rel_attr = getattr(cur_cls, nm)
cur_cls = rel_attr.property.mapper.class_
is_collection = bool(getattr(getattr(self.model, last_name), "property", None) and getattr(getattr(self.model, last_name).property, "uselist", False))
except Exception:
# Fallback: inspect the InstrumentedAttribute we recorded on join_paths
is_collection = False
for _pa, rel_attr, _al in self.join_paths:
if rel_attr.key == (join_path[-1] if join_path else ""):
is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False))
break
if is_collection:
collection_field_names.setdefault(join_path[-1], []).append(col.key)
else:
rel_field_names.setdefault(join_path, []).append(col.key)
self.eager_paths.add(join_path)
else: else:
root_fields.append(col) root_fields.append(col)
root_field_names.append(getattr(col, "key", token)) root_field_names.append(getattr(col, "key", token))
@ -153,7 +180,11 @@ class CRUDSpec:
self._root_fields = root_fields self._root_fields = root_fields
self._rel_field_names = rel_field_names self._rel_field_names = rel_field_names
return root_fields, rel_field_names, root_field_names # return root_fields, rel_field_names, root_field_names
for r, names in collection_field_names.items():
seen3 = set()
collection_field_names[r] = [n for n in names if not (n in seen3 or seen3.add(n))]
return root_field_names, rel_field_names, root_field_names, collection_field_names
def get_eager_loads(self, root_alias, *, fields_map=None): def get_eager_loads(self, root_alias, *, fields_map=None):
loads = [] loads = []

View file

@ -88,7 +88,7 @@ def init_entry_routes(app):
{"name": "label", "order": 0}, {"name": "label", "order": 0},
{"name": "name", "order": 10, "attrs": {"class": "row"}}, {"name": "name", "order": 10, "attrs": {"class": "row"}},
{"name": "details", "order": 20, "attrs": {"class": "row mt-2"}}, {"name": "details", "order": 20, "attrs": {"class": "row mt-2"}},
{"name": "checkboxes", "order": 30, "parent": "name", {"name": "checkboxes", "order": 30, "parent": "details",
"attrs": {"class": "col d-flex flex-column justify-content-end"}}, "attrs": {"class": "col d-flex flex-column justify-content-end"}},
] ]