Work begins on proper one to many support.
This commit is contained in:
parent
811b534b89
commit
ffa49f13e9
3 changed files with 82 additions and 24 deletions
|
|
@ -5,7 +5,7 @@ from flask import current_app
|
||||||
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
|
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
|
||||||
from sqlalchemy import and_, func, inspect, or_, text
|
from sqlalchemy import and_, func, inspect, or_, text
|
||||||
from sqlalchemy.engine import Engine, Connection
|
from sqlalchemy.engine import Engine, Connection
|
||||||
from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager
|
from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager, selectinload
|
||||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||||
from sqlalchemy.sql import operators
|
from sqlalchemy.sql import operators
|
||||||
from sqlalchemy.sql.elements import UnaryExpression, ColumnElement
|
from sqlalchemy.sql.elements import UnaryExpression, ColumnElement
|
||||||
|
|
@ -49,7 +49,7 @@ def _unwrap_ob(ob):
|
||||||
is_desc = False
|
is_desc = False
|
||||||
dir_attr = getattr(ob, "_direction", None)
|
dir_attr = getattr(ob, "_direction", None)
|
||||||
if dir_attr is not None:
|
if dir_attr is not None:
|
||||||
is_desc = (dir_attr is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC")
|
is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC")
|
||||||
elif isinstance(ob, UnaryExpression):
|
elif isinstance(ob, UnaryExpression):
|
||||||
op = getattr(ob, "operator", None)
|
op = getattr(ob, "operator", None)
|
||||||
is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC")
|
is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC")
|
||||||
|
|
@ -231,7 +231,7 @@ class CRUDService(Generic[T]):
|
||||||
# Parse all inputs so join_paths are populated
|
# Parse all inputs so join_paths are populated
|
||||||
filters = spec.parse_filters()
|
filters = spec.parse_filters()
|
||||||
order_by = spec.parse_sort()
|
order_by = spec.parse_sort()
|
||||||
root_fields, rel_field_names, root_field_names = spec.parse_fields()
|
root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields()
|
||||||
spec.parse_includes()
|
spec.parse_includes()
|
||||||
join_paths = tuple(spec.get_join_paths())
|
join_paths = tuple(spec.get_join_paths())
|
||||||
|
|
||||||
|
|
@ -243,9 +243,22 @@ class CRUDService(Generic[T]):
|
||||||
if only_cols:
|
if only_cols:
|
||||||
query = query.options(Load(root_alias).load_only(*only_cols))
|
query = query.options(Load(root_alias).load_only(*only_cols))
|
||||||
|
|
||||||
# JOIN all resolved paths, hydrate from the join
|
# JOIN all resolved paths; for collections use selectinload (never join)
|
||||||
used_contains_eager = False
|
used_contains_eager = False
|
||||||
for _base_alias, rel_attr, target_alias in join_paths:
|
for base_alias, rel_attr, target_alias in join_paths:
|
||||||
|
is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False))
|
||||||
|
if is_collection:
|
||||||
|
opt = selectinload(rel_attr)
|
||||||
|
# narroe child columns it requested (e.g., updates.id,updates.timestamp)
|
||||||
|
child_names = (collection_field_names or {}).get(rel_attr.key, [])
|
||||||
|
if child_names:
|
||||||
|
target_cls = rel_attr.property.mapper.class_
|
||||||
|
cols = [getattr(target_cls, n, None) for n in child_names]
|
||||||
|
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
|
||||||
|
if cols:
|
||||||
|
opt = opt.load_only(*cols)
|
||||||
|
query = query.options(opt)
|
||||||
|
else:
|
||||||
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
|
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
|
||||||
query = query.options(contains_eager(rel_attr, alias=target_alias))
|
query = query.options(contains_eager(rel_attr, alias=target_alias))
|
||||||
used_contains_eager = True
|
used_contains_eager = True
|
||||||
|
|
@ -346,7 +359,9 @@ class CRUDService(Generic[T]):
|
||||||
base = session.query(getattr(root_alias, "id"))
|
base = session.query(getattr(root_alias, "id"))
|
||||||
base = self._apply_not_deleted(base, root_alias, params)
|
base = self._apply_not_deleted(base, root_alias, params)
|
||||||
# same joins as above for correctness
|
# same joins as above for correctness
|
||||||
for _base_alias, rel_attr, target_alias in join_paths:
|
for base_alias, rel_attr, target_alias in join_paths:
|
||||||
|
# do not join collections for COUNT mirror
|
||||||
|
if not bool(getattr(getattr(rel_attr, "property", None), "uselist", False)):
|
||||||
base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
|
base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
|
||||||
if filters:
|
if filters:
|
||||||
base = base.filter(*filters)
|
base = base.filter(*filters)
|
||||||
|
|
@ -428,7 +443,7 @@ class CRUDService(Generic[T]):
|
||||||
filters = spec.parse_filters()
|
filters = spec.parse_filters()
|
||||||
# no ORDER BY for get()
|
# no ORDER BY for get()
|
||||||
if params:
|
if params:
|
||||||
root_fields, rel_field_names, root_field_names = spec.parse_fields()
|
root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields()
|
||||||
spec.parse_includes()
|
spec.parse_includes()
|
||||||
|
|
||||||
join_paths = tuple(spec.get_join_paths())
|
join_paths = tuple(spec.get_join_paths())
|
||||||
|
|
@ -438,9 +453,21 @@ class CRUDService(Generic[T]):
|
||||||
if only_cols:
|
if only_cols:
|
||||||
query = query.options(Load(root_alias).load_only(*only_cols))
|
query = query.options(Load(root_alias).load_only(*only_cols))
|
||||||
|
|
||||||
# JOIN all discovered paths up front; hydrate via contains_eager
|
# JOIN non-collections only; collections via selectinload
|
||||||
used_contains_eager = False
|
used_contains_eager = False
|
||||||
for _base_alias, rel_attr, target_alias in join_paths:
|
for base_alias, rel_attr, target_alias in join_paths:
|
||||||
|
is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False))
|
||||||
|
if is_collection:
|
||||||
|
opt = selectinload(rel_attr)
|
||||||
|
child_names = (collection_field_names or {}).get(rel_attr.key, [])
|
||||||
|
if child_names:
|
||||||
|
target_cls = rel_attr.property.mapper.class_
|
||||||
|
cols = [getattr(target_cls, n, None) for n in child_names]
|
||||||
|
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
|
||||||
|
if cols:
|
||||||
|
opt = opt.load_only(*cols)
|
||||||
|
query = query.options(opt)
|
||||||
|
else:
|
||||||
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
|
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
|
||||||
query = query.options(contains_eager(rel_attr, alias=target_alias))
|
query = query.options(contains_eager(rel_attr, alias=target_alias))
|
||||||
used_contains_eager = True
|
used_contains_eager = True
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import List, Tuple, Set, Dict, Optional
|
from typing import List, Tuple, Set, Dict, Optional, Iterable
|
||||||
from sqlalchemy import asc, desc
|
from sqlalchemy import asc, desc
|
||||||
from sqlalchemy.ext.hybrid import hybrid_property
|
from sqlalchemy.ext.hybrid import hybrid_property
|
||||||
from sqlalchemy.orm import aliased, selectinload
|
from sqlalchemy.orm import aliased, selectinload
|
||||||
|
|
@ -20,10 +20,14 @@ class CRUDSpec:
|
||||||
self.params = params
|
self.params = params
|
||||||
self.root_alias = root_alias
|
self.root_alias = root_alias
|
||||||
self.eager_paths: Set[Tuple[str, ...]] = set()
|
self.eager_paths: Set[Tuple[str, ...]] = set()
|
||||||
|
# (parent_alias. relationship_attr, alias_for_target)
|
||||||
self.join_paths: List[Tuple[object, InstrumentedAttribute, object]] = []
|
self.join_paths: List[Tuple[object, InstrumentedAttribute, object]] = []
|
||||||
self.alias_map: Dict[Tuple[str, ...], object] = {}
|
self.alias_map: Dict[Tuple[str, ...], object] = {}
|
||||||
self._root_fields: List[InstrumentedAttribute] = []
|
self._root_fields: List[InstrumentedAttribute] = []
|
||||||
self._rel_field_names: Dict[Tuple[str, ...], object] = {}
|
# dotted non-collection fields (MANYTOONE etc)
|
||||||
|
self._rel_field_names: Dict[Tuple[str, ...], List[str]] = {}
|
||||||
|
# dotted collection fields (ONETOMANY)
|
||||||
|
self._collection_field_names: Dict[str, List[str]] = {}
|
||||||
self.include_paths: Set[Tuple[str, ...]] = set()
|
self.include_paths: Set[Tuple[str, ...]] = set()
|
||||||
|
|
||||||
def _resolve_column(self, path: str):
|
def _resolve_column(self, path: str):
|
||||||
|
|
@ -117,11 +121,12 @@ class CRUDSpec:
|
||||||
Parse ?fields=colA,colB,rel1.colC,rel1.rel2.colD
|
Parse ?fields=colA,colB,rel1.colC,rel1.rel2.colD
|
||||||
- Root fields become InstrumentedAttributes bound to root_alias.
|
- Root fields become InstrumentedAttributes bound to root_alias.
|
||||||
- Related fields store attribute NAMES; we'll resolve them on the target class when building loader options.
|
- Related fields store attribute NAMES; we'll resolve them on the target class when building loader options.
|
||||||
Returns (root_fields, rel_field_names).
|
- Collection (uselist=True) relationships record child names by relationship key.
|
||||||
|
Returns (root_fields, rel_field_names, root_field_names, collection_field_names_by_rel).
|
||||||
"""
|
"""
|
||||||
raw = self.params.get('fields')
|
raw = self.params.get('fields')
|
||||||
if not raw:
|
if not raw:
|
||||||
return [], {}, {}
|
return [], {}, {}, {}
|
||||||
|
|
||||||
if isinstance(raw, list):
|
if isinstance(raw, list):
|
||||||
tokens = []
|
tokens = []
|
||||||
|
|
@ -133,12 +138,34 @@ class CRUDSpec:
|
||||||
root_fields: List[InstrumentedAttribute] = []
|
root_fields: List[InstrumentedAttribute] = []
|
||||||
root_field_names: list[str] = []
|
root_field_names: list[str] = []
|
||||||
rel_field_names: Dict[Tuple[str, ...], List[str]] = {}
|
rel_field_names: Dict[Tuple[str, ...], List[str]] = {}
|
||||||
|
collection_field_names: Dict[str, List[str]] = {}
|
||||||
|
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
col, join_path = self._resolve_column(token)
|
col, join_path = self._resolve_column(token)
|
||||||
if not col:
|
if not col:
|
||||||
continue
|
continue
|
||||||
if join_path:
|
if join_path:
|
||||||
|
# rel_field_names.setdefault(join_path, []).append(col.key)
|
||||||
|
# self.eager_paths.add(join_path)
|
||||||
|
try:
|
||||||
|
cur_cls = self.model
|
||||||
|
names = list(join_path)
|
||||||
|
last_name = names[-1]
|
||||||
|
for nm in names:
|
||||||
|
rel_attr = getattr(cur_cls, nm)
|
||||||
|
cur_cls = rel_attr.property.mapper.class_
|
||||||
|
is_collection = bool(getattr(getattr(self.model, last_name), "property", None) and getattr(getattr(self.model, last_name).property, "uselist", False))
|
||||||
|
except Exception:
|
||||||
|
# Fallback: inspect the InstrumentedAttribute we recorded on join_paths
|
||||||
|
is_collection = False
|
||||||
|
for _pa, rel_attr, _al in self.join_paths:
|
||||||
|
if rel_attr.key == (join_path[-1] if join_path else ""):
|
||||||
|
is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False))
|
||||||
|
break
|
||||||
|
|
||||||
|
if is_collection:
|
||||||
|
collection_field_names.setdefault(join_path[-1], []).append(col.key)
|
||||||
|
else:
|
||||||
rel_field_names.setdefault(join_path, []).append(col.key)
|
rel_field_names.setdefault(join_path, []).append(col.key)
|
||||||
self.eager_paths.add(join_path)
|
self.eager_paths.add(join_path)
|
||||||
else:
|
else:
|
||||||
|
|
@ -153,7 +180,11 @@ class CRUDSpec:
|
||||||
|
|
||||||
self._root_fields = root_fields
|
self._root_fields = root_fields
|
||||||
self._rel_field_names = rel_field_names
|
self._rel_field_names = rel_field_names
|
||||||
return root_fields, rel_field_names, root_field_names
|
# return root_fields, rel_field_names, root_field_names
|
||||||
|
for r, names in collection_field_names.items():
|
||||||
|
seen3 = set()
|
||||||
|
collection_field_names[r] = [n for n in names if not (n in seen3 or seen3.add(n))]
|
||||||
|
return root_field_names, rel_field_names, root_field_names, collection_field_names
|
||||||
|
|
||||||
def get_eager_loads(self, root_alias, *, fields_map=None):
|
def get_eager_loads(self, root_alias, *, fields_map=None):
|
||||||
loads = []
|
loads = []
|
||||||
|
|
|
||||||
|
|
@ -88,7 +88,7 @@ def init_entry_routes(app):
|
||||||
{"name": "label", "order": 0},
|
{"name": "label", "order": 0},
|
||||||
{"name": "name", "order": 10, "attrs": {"class": "row"}},
|
{"name": "name", "order": 10, "attrs": {"class": "row"}},
|
||||||
{"name": "details", "order": 20, "attrs": {"class": "row mt-2"}},
|
{"name": "details", "order": 20, "attrs": {"class": "row mt-2"}},
|
||||||
{"name": "checkboxes", "order": 30, "parent": "name",
|
{"name": "checkboxes", "order": 30, "parent": "details",
|
||||||
"attrs": {"class": "col d-flex flex-column justify-content-end"}},
|
"attrs": {"class": "col d-flex flex-column justify-content-end"}},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue