diff --git a/crudkit/core/service.py b/crudkit/core/service.py index 653d1f9..ad6338f 100644 --- a/crudkit/core/service.py +++ b/crudkit/core/service.py @@ -5,7 +5,7 @@ from flask import current_app from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast from sqlalchemy import and_, func, inspect, or_, text from sqlalchemy.engine import Engine, Connection -from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager +from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager, selectinload from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.sql import operators from sqlalchemy.sql.elements import UnaryExpression, ColumnElement @@ -49,7 +49,7 @@ def _unwrap_ob(ob): is_desc = False dir_attr = getattr(ob, "_direction", None) if dir_attr is not None: - is_desc = (dir_attr is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") + is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC") elif isinstance(ob, UnaryExpression): op = getattr(ob, "operator", None) is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") @@ -231,7 +231,7 @@ class CRUDService(Generic[T]): # Parse all inputs so join_paths are populated filters = spec.parse_filters() order_by = spec.parse_sort() - root_fields, rel_field_names, root_field_names = spec.parse_fields() + root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() spec.parse_includes() join_paths = tuple(spec.get_join_paths()) @@ -243,12 +243,25 @@ class CRUDService(Generic[T]): if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) - # JOIN all resolved paths, hydrate from the join + # JOIN all resolved paths; for collections use selectinload (never join) used_contains_eager = False - for _base_alias, rel_attr, target_alias in join_paths: - query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) - query = query.options(contains_eager(rel_attr, alias=target_alias)) - used_contains_eager = True + for base_alias, rel_attr, target_alias in join_paths: + is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False)) + if is_collection: + opt = selectinload(rel_attr) + # narroe child columns it requested (e.g., updates.id,updates.timestamp) + child_names = (collection_field_names or {}).get(rel_attr.key, []) + if child_names: + target_cls = rel_attr.property.mapper.class_ + cols = [getattr(target_cls, n, None) for n in child_names] + cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] + if cols: + opt = opt.load_only(*cols) + query = query.options(opt) + else: + query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) + query = query.options(contains_eager(rel_attr, alias=target_alias)) + used_contains_eager = True # Filters AFTER joins → no cartesian products if filters: @@ -346,8 +359,10 @@ class CRUDService(Generic[T]): base = session.query(getattr(root_alias, "id")) base = self._apply_not_deleted(base, root_alias, params) # same joins as above for correctness - for _base_alias, rel_attr, target_alias in join_paths: - base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True) + for base_alias, rel_attr, target_alias in join_paths: + # do not join collections for COUNT mirror + if not bool(getattr(getattr(rel_attr, "property", None), "uselist", False)): + base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True) if filters: base = base.filter(*filters) total = session.query(func.count()).select_from( @@ -428,7 +443,7 @@ class CRUDService(Generic[T]): filters = spec.parse_filters() # no ORDER BY for get() if params: - root_fields, rel_field_names, root_field_names = spec.parse_fields() + root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() spec.parse_includes() join_paths = tuple(spec.get_join_paths()) @@ -438,12 +453,24 @@ class CRUDService(Generic[T]): if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) - # JOIN all discovered paths up front; hydrate via contains_eager + # JOIN non-collections only; collections via selectinload used_contains_eager = False - for _base_alias, rel_attr, target_alias in join_paths: - query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) - query = query.options(contains_eager(rel_attr, alias=target_alias)) - used_contains_eager = True + for base_alias, rel_attr, target_alias in join_paths: + is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False)) + if is_collection: + opt = selectinload(rel_attr) + child_names = (collection_field_names or {}).get(rel_attr.key, []) + if child_names: + target_cls = rel_attr.property.mapper.class_ + cols = [getattr(target_cls, n, None) for n in child_names] + cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] + if cols: + opt = opt.load_only(*cols) + query = query.options(opt) + else: + query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) + query = query.options(contains_eager(rel_attr, alias=target_alias)) + used_contains_eager = True # Apply filters (joins are in place → no cartesian products) if filters: diff --git a/crudkit/core/spec.py b/crudkit/core/spec.py index 9c0e53b..d5c2480 100644 --- a/crudkit/core/spec.py +++ b/crudkit/core/spec.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Set, Dict, Optional +from typing import List, Tuple, Set, Dict, Optional, Iterable from sqlalchemy import asc, desc from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import aliased, selectinload @@ -20,10 +20,14 @@ class CRUDSpec: self.params = params self.root_alias = root_alias self.eager_paths: Set[Tuple[str, ...]] = set() + # (parent_alias. relationship_attr, alias_for_target) self.join_paths: List[Tuple[object, InstrumentedAttribute, object]] = [] self.alias_map: Dict[Tuple[str, ...], object] = {} self._root_fields: List[InstrumentedAttribute] = [] - self._rel_field_names: Dict[Tuple[str, ...], object] = {} + # dotted non-collection fields (MANYTOONE etc) + self._rel_field_names: Dict[Tuple[str, ...], List[str]] = {} + # dotted collection fields (ONETOMANY) + self._collection_field_names: Dict[str, List[str]] = {} self.include_paths: Set[Tuple[str, ...]] = set() def _resolve_column(self, path: str): @@ -117,11 +121,12 @@ class CRUDSpec: Parse ?fields=colA,colB,rel1.colC,rel1.rel2.colD - Root fields become InstrumentedAttributes bound to root_alias. - Related fields store attribute NAMES; we'll resolve them on the target class when building loader options. - Returns (root_fields, rel_field_names). + - Collection (uselist=True) relationships record child names by relationship key. + Returns (root_fields, rel_field_names, root_field_names, collection_field_names_by_rel). """ raw = self.params.get('fields') if not raw: - return [], {}, {} + return [], {}, {}, {} if isinstance(raw, list): tokens = [] @@ -133,14 +138,36 @@ class CRUDSpec: root_fields: List[InstrumentedAttribute] = [] root_field_names: list[str] = [] rel_field_names: Dict[Tuple[str, ...], List[str]] = {} + collection_field_names: Dict[str, List[str]] = {} for token in tokens: col, join_path = self._resolve_column(token) if not col: continue if join_path: - rel_field_names.setdefault(join_path, []).append(col.key) - self.eager_paths.add(join_path) + # rel_field_names.setdefault(join_path, []).append(col.key) + # self.eager_paths.add(join_path) + try: + cur_cls = self.model + names = list(join_path) + last_name = names[-1] + for nm in names: + rel_attr = getattr(cur_cls, nm) + cur_cls = rel_attr.property.mapper.class_ + is_collection = bool(getattr(getattr(self.model, last_name), "property", None) and getattr(getattr(self.model, last_name).property, "uselist", False)) + except Exception: + # Fallback: inspect the InstrumentedAttribute we recorded on join_paths + is_collection = False + for _pa, rel_attr, _al in self.join_paths: + if rel_attr.key == (join_path[-1] if join_path else ""): + is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False)) + break + + if is_collection: + collection_field_names.setdefault(join_path[-1], []).append(col.key) + else: + rel_field_names.setdefault(join_path, []).append(col.key) + self.eager_paths.add(join_path) else: root_fields.append(col) root_field_names.append(getattr(col, "key", token)) @@ -153,7 +180,11 @@ class CRUDSpec: self._root_fields = root_fields self._rel_field_names = rel_field_names - return root_fields, rel_field_names, root_field_names + # return root_fields, rel_field_names, root_field_names + for r, names in collection_field_names.items(): + seen3 = set() + collection_field_names[r] = [n for n in names if not (n in seen3 or seen3.add(n))] + return root_field_names, rel_field_names, root_field_names, collection_field_names def get_eager_loads(self, root_alias, *, fields_map=None): loads = [] diff --git a/inventory/routes/entry.py b/inventory/routes/entry.py index 986a44f..dec7553 100644 --- a/inventory/routes/entry.py +++ b/inventory/routes/entry.py @@ -88,7 +88,7 @@ def init_entry_routes(app): {"name": "label", "order": 0}, {"name": "name", "order": 10, "attrs": {"class": "row"}}, {"name": "details", "order": 20, "attrs": {"class": "row mt-2"}}, - {"name": "checkboxes", "order": 30, "parent": "name", + {"name": "checkboxes", "order": 30, "parent": "details", "attrs": {"class": "col d-flex flex-column justify-content-end"}}, ]