diff --git a/crudkit/core/service.py b/crudkit/core/service.py index 60e3aef..48e0e56 100644 --- a/crudkit/core/service.py +++ b/crudkit/core/service.py @@ -9,6 +9,7 @@ from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eag from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.sql import operators from sqlalchemy.sql.elements import UnaryExpression, ColumnElement +from sqlalchemy.sql.visitors import Visitable from crudkit.core import to_jsonable, deep_diff, diff_to_patch, filter_to_columns, normalize_payload from crudkit.core.base import Version @@ -42,6 +43,45 @@ class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol): T = TypeVar("T", bound=_CRUDModelProto) +def _collect_tables_from_filters(filters) -> set: + """ + Walk SQLAlchemy filter expressions and collect the Table/Alias objects + that appear, so we can detect which relationships are *actually used* + by filters and must be JOINed (not just selectinloaded). + """ + seen = set() + def visit(node): + if node is None: + return + # record table / selectable if present + tbl = getattr(node, "table", None) + if tbl is not None: + # include the selectable and its base element (alias -> base table) + cur = tbl + while cur is not None: + seen.add(cur) + cur = getattr(cur, "element", None) + # generic children walker + try: + children = list(node.get_children()) + except Exception: + children = [] + for ch in children: + visit(ch) + # also inspect common attributes if present + for attr in ("left", "right", "element", "clause", "clauses"): + val = getattr(node, attr, None) + if val is None: + continue + if isinstance(val, (list, tuple)): + for v in val: + visit(v) + else: + visit(val) + for f in (filters or []): + visit(f) + return seen + def _unwrap_ob(ob): """Return (col, is_desc) from an ORDER BY element (handles .asc()/.desc()).""" col = getattr(ob, "element", None) @@ -239,6 +279,9 @@ class CRUDService(Generic[T]): # Soft delete query = self._apply_not_deleted(query, root_alias, params) + # Which related tables are referenced by filters? + filter_tables = _collect_tables_from_filters(filters) + # Root column projection (load_only) only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: @@ -247,35 +290,47 @@ class CRUDService(Generic[T]): # Detect first hops that have deeper, nested tails requested (e.g. "contact.supervisor") nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } - # IMPORTANT: - # - Only attach loader options for first-hop relations from the root. - # - Always use selectinload here (avoid contains_eager joins). - # - Let compile_projections() supply deep chained options. + # First-hop strategy: + # - If a non-collection relation's target table is used in filters -> plain JOIN (no explicit alias) + # - Otherwise -> selectinload (keeps row counts sane) + used_contains_eager = False # we purposely avoid contains_eager here for base_alias, rel_attr, target_alias in join_paths: is_firsthop_from_root = (base_alias is root_alias) if not is_firsthop_from_root: - # Deeper hops are handled by proj_opts below + # deeper hops handled separately (via proj_opts) continue prop = getattr(rel_attr, "property", None) is_collection = bool(getattr(prop, "uselist", False)) - is_nested_firsthop = rel_attr.key in nested_first_hops + _is_nested_firsthop = rel_attr.key in nested_first_hops - opt = selectinload(rel_attr) - # Optional narrowng for collections - if is_collection: - child_names = (collection_field_names or {}).get(rel_attr.key, []) - if child_names: - target_cls = prop.mapper.class_ - cols = [getattr(target_cls, n, None) for n in child_names] - cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] - if cols: - opt = opt.load_only(*cols) - query = query.options(opt) + target_selectable = getattr(target_alias, "selectable", None) + target_base = getattr(target_selectable, "element", None) or target_selectable + needed_for_filter = (target_selectable in filter_tables) or (target_base in filter_tables) - # Filters AFTER joins → no cartesian products + if needed_for_filter and not is_collection: + # Join via the relationship attribute so SA reuses the same FROM + # that filter expressions reference (avoids cartesian products). + query = query.join(rel_attr, isouter=True) + else: + opt = selectinload(rel_attr) + if is_collection: + child_names = (collection_field_names or {}).get(rel_attr.key, []) + if child_names: + target_cls = prop.mapper.class_ + cols = [getattr(target_cls, n, None) for n in child_names] + cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] + if cols: + opt = opt.load_only(*cols) + query = query.options(opt) + + # Filters AFTER relations wiring → no cartesian products if filters: query = query.filter(*filters) + # Apply deep projection loader options (only when we didn't use contains_eager) + if proj_opts and not used_contains_eager: + query = query.options(*proj_opts) + # Order spec (with PK tie-breakers for stability) order_spec = self._extract_order_spec(root_alias, order_by) limit, _ = spec.parse_pagination() @@ -363,10 +418,6 @@ class CRUDService(Generic[T]): last_key = None # Count DISTINCT ids with mirrored joins - - # Apply deep projection loader options (safe: we avoided contains_eager) - if proj_opts: - query = query.options(*proj_opts) total = None if include_total: base = session.query(getattr(root_alias, "id")) @@ -375,7 +426,7 @@ class CRUDService(Generic[T]): for base_alias, rel_attr, target_alias in join_paths: # do not join collections for COUNT mirror if not bool(getattr(getattr(rel_attr, "property", None), "uselist", False)): - base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True) + base = base.join(rel_attr, isouter=True) if filters: base = base.filter(*filters) total = session.query(func.count()).select_from( @@ -468,7 +519,8 @@ class CRUDService(Generic[T]): nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } - # First-hop only; use selectinload (no contains_eager) + # First-hop strategy (same as peek_window): join if filter needs it, else selectinload + use_contains_eager = False # we avoid contains_eager here, too for base_alias, rel_attr, target_alias in join_paths: is_firsthop_from_root = (base_alias is root_alias) if not is_firsthop_from_root: @@ -477,16 +529,23 @@ class CRUDService(Generic[T]): is_collection = bool(getattr(prop, "uselist", False)) _is_nested_firsthop = rel_attr.key in nested_first_hops - opt = selectinload(rel_attr) - if is_collection: - child_names = (collection_field_names or {}).get(rel_attr.key, []) - if child_names: - target_cls = prop.mapper.class_ - cols = [getattr(target_cls, n, None) for n in child_names] - cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] - if cols: - opt = opt.load_only(*cols) - query = query.options(opt) + target_selectable = getattr(target_alias, "selectable", None) + target_base = getattr(target_selectable, "element", None) or target_selectable + needed_for_filter = target_selectable in _collect_tables_from_filters(filters) or (target_base in _collect_tables_from_filters(filters)) + + if needed_for_filter and not is_collection: + query = query.join(rel_attr, isouter=True) + else: + opt = selectinload(rel_attr) + if is_collection: + child_names = (collection_field_names or {}).get(rel_attr.key, []) + if child_names: + target_cls = prop.mapper.class_ + cols = [getattr(target_cls, n, None) for n in child_names] + cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] + if cols: + opt = opt.load_only(*cols) + query = query.options(opt) # Apply filters (joins are in place → no cartesian products) if filters: @@ -498,7 +557,7 @@ class CRUDService(Generic[T]): # Projection loader options compiled from requested fields. # Skip if we used contains_eager to avoid loader-strategy conflicts. expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) - if proj_opts: + if proj_opts and not use_contains_eager: query = query.options(*proj_opts) obj = query.first() @@ -566,7 +625,8 @@ class CRUDService(Generic[T]): nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } - # First-hop only; use selectinload + # First-hop strategy (same as seek_window): join if filter needs it, else selectinload + used_contains_eager = False # We avoid contains_eager here as well for base_alias, rel_attr, target_alias in join_paths: is_firsthop_from_root = (base_alias is root_alias) if not is_firsthop_from_root: @@ -575,16 +635,24 @@ class CRUDService(Generic[T]): is_collection = bool(getattr(prop, "uselist", False)) _is_nested_firsthop = rel_attr.key in nested_first_hops - opt = selectinload(rel_attr) - if is_collection: - child_names = (collection_field_names or {}).get(rel_attr.key, []) - if child_names: - target_cls = prop.mapper.class_ - cols = [getattr(target_cls, n, None) for n in child_names] - cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] - if cols: - opt = opt.load_only(*cols) - query = query.options(opt) + target_selectable = getattr(target_alias, "selectable", None) + target_base = getattr(target_selectable, "element", None) or target_selectable + ftables = _collect_tables_from_filters(filters) + needed_for_filter = (target_selectable in ftables) or (target_base in ftables) + + if needed_for_filter and not is_collection: + query = query.join(rel_attr, isouter=True) + else: + opt = selectinload(rel_attr) + if is_collection: + child_names = (collection_field_names or {}).get(rel_attr.key, []) + if child_names: + target_cls = prop.mapper.class_ + cols = [getattr(target_cls, n, None) for n in child_names] + cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] + if cols: + opt = opt.load_only(*cols) + query = query.options(opt) # Filters AFTER joins → no cartesian products if filters: @@ -608,7 +676,7 @@ class CRUDService(Generic[T]): # Projection loaders only if we didn’t use contains_eager expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) - if proj_opts: + if proj_opts and not used_contains_eager: query = query.options(*proj_opts) else: diff --git a/inventory/templates/update_list.html b/inventory/templates/update_list.html index 92f04b8..e0e2848 100644 --- a/inventory/templates/update_list.html +++ b/inventory/templates/update_list.html @@ -25,6 +25,15 @@ {% endfor %} +