diff --git a/crudkit/core/service.py b/crudkit/core/service.py index 86a4542..9eb6d8b 100644 --- a/crudkit/core/service.py +++ b/crudkit/core/service.py @@ -1,10 +1,13 @@ from __future__ import annotations +from collections.abc import Iterable from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast -from sqlalchemy import and_, func, inspect, or_, text +from sqlalchemy import and_, func, inspect, or_, text, UnaryExpression from sqlalchemy.engine import Engine, Connection -from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, RelationshipProperty +from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager, selectinload from sqlalchemy.orm.attributes import InstrumentedAttribute +from sqlalchemy.sql import operators +from sqlalchemy.sql.elements import ColumnElement from crudkit.core.base import Version from crudkit.core.spec import CRUDSpec @@ -15,13 +18,6 @@ from crudkit.projection import compile_projection import logging log = logging.getLogger("crudkit.service") -def _is_rel(model_cls, name: str) -> bool: - try: - prop = model_cls.__mapper__.relationships.get(name) - return isinstance(prop, RelationshipProperty) - except Exception: - return False - @runtime_checkable class _HasID(Protocol): id: int @@ -44,9 +40,65 @@ class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol): T = TypeVar("T", bound=_CRUDModelProto) +def _belongs_to_alias(col: Any, alias: Any) -> bool: + # Try to detect if a column/expression ultimately comes from this alias. + # Works for most ORM columns; complex expressions may need more. + t = getattr(col, "table", None) + selectable = getattr(alias, "selectable", None) + return t is not None and selectable is not None and t is selectable + +def _paths_needed_for_sql(order_by: Iterable[Any], filters: Iterable[Any], join_paths: tuple) -> set[tuple[str, ...]]: + paths: set[tuple[str, ...]] = set() + # Sort columns + for ob in order_by or []: + col = getattr(ob, "element", ob) # unwrap UnaryExpression + for path, _rel_attr, target_alias in join_paths: + if _belongs_to_alias(col, target_alias): + paths.add(tuple(path)) + # Filter columns (best-effort) + # Walk simple binary expressions + def _extract_cols(expr: Any) -> Iterable[Any]: + if isinstance(expr, ColumnElement): + yield expr + for ch in getattr(expr, "get_children", lambda: [])(): + yield from _extract_cols(ch) + elif hasattr(expr, "clauses"): + for ch in expr.clauses: + yield from _extract_cols(ch) + + for flt in filters or []: + for col in _extract_cols(flt): + for path, _rel_attr, target_alias in join_paths: + if _belongs_to_alias(col, target_alias): + paths.add(tuple[path]) + return paths + +def _paths_from_fields(req_fields: list[str]) -> set[tuple[str, ...]]: + out: set[tuple[str, ...]] = set() + for f in req_fields: + if "." in f: + parts = tuple(f.split(".")[:-1]) + if parts: + out.add(parts) + return out + def _is_truthy(val): return str(val).lower() in ('1', 'true', 'yes', 'on') +def _normalize_fields_param(params: dict | None) -> list[str]: + if not params: + return [] + raw = params.get("fields") + if isinstance(raw, (list, tuple)): + out: list[str] = [] + for item in raw: + if isinstance(item, str): + out.extend([p for p in (s.strip() for s in item.split(",")) if p]) + return out + if isinstance(raw, str): + return [p for p in (s.strip() for s in raw.split(",")) if p] + return [] + class CRUDService(Generic[T]): def __init__( self, @@ -86,8 +138,6 @@ class CRUDService(Generic[T]): Normalize order_by into (cols, desc_flags). Supports plain columns and col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses. """ - from sqlalchemy.sql import operators - from sqlalchemy.sql.elements import UnaryExpression given = self._stable_order_by(root_alias, given_order_by) @@ -109,7 +159,6 @@ class CRUDService(Generic[T]): cols.append(col) desc_flags.append(bool(is_desc)) - from crudkit.core.types import OrderSpec return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags)) def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool): @@ -156,49 +205,100 @@ class CRUDService(Generic[T]): Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total. """ session = self.session - fields = list((params or {}).get("fields", [])) - expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], []) query, root_alias = self.get_query() - if proj_opts: - query = query.options(*proj_opts) + + # Normalize requested fields and compile projection (may skip later to avoid conflicts) + fields = _normalize_fields_param(params) + expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], []) spec = CRUDSpec(self.model, params or {}, root_alias) filters = spec.parse_filters() order_by = spec.parse_sort() + # Field parsing for root load_only fallback root_fields, rel_field_names, root_field_names = spec.parse_fields() # Soft delete filter - # if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")): - # query = query.filter(getattr(root_alias, "is_deleted") == False) query = self._apply_not_deleted(query, root_alias, params) - # Parse filters first + # Apply filters first if filters: query = query.filter(*filters) - # Includes + joins (so relationship fields like brand.name, location.label work) + # Includes + join paths (dotted fields etc.) spec.parse_includes() - join_paths = tuple(spec.get_join_paths()) + join_paths = tuple(spec.get_join_paths()) # iterable of (path, relationship_attr, target_alias) - for _, relationship_attr, target_alias in spec.get_join_paths(): - rel_attr = cast(InstrumentedAttribute, relationship_attr) - target = cast(Any, target_alias) - query = query.join(target, rel_attr.of_type(target), isouter=True) + # Decide which relationship *names* are required for SQL (filters/sort) vs display-only + def _belongs_to_alias(col: Any, alias: Any) -> bool: + t = getattr(col, "table", None) + selectable = getattr(alias, "selectable", None) + return t is not None and selectable is not None and t is selectable - # Fields/projection: load_only for root columns, eager loads for relationships + # 1) which relationship aliases are referenced by sort/filter + sql_hops: set[str] = set() + for path, relationship_attr, target_alias in join_paths: + # If any ORDER BY column comes from this alias, mark it + for ob in (order_by or []): + col = getattr(ob, "element", ob) # unwrap UnaryExpression + if _belongs_to_alias(col, target_alias): + sql_hops.add(relationship_attr.key) + break + # If any filter expr touches this alias, mark it (best effort) + if relationship_attr.key not in sql_hops: + def _walk_cols(expr: Any): + # Primitive walker for ColumnElement trees + from sqlalchemy.sql.elements import ColumnElement + if isinstance(expr, ColumnElement): + yield expr + for ch in getattr(expr, "get_children", lambda: [])(): + yield from _walk_cols(ch) + elif hasattr(expr, "clauses"): + for ch in expr.clauses: + yield from _walk_cols(ch) + for flt in (filters or []): + if any(_belongs_to_alias(c, target_alias) for c in _walk_cols(flt)): + sql_hops.add(relationship_attr.key) + break + + # 2) first-hop relationship names implied by dotted projection fields + proj_hops: set[str] = {f.split(".", 1)[0] for f in fields if "." in f} + + # Root column projection + from sqlalchemy.orm import Load # local import to match your style only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) + # Relationship handling per path (avoid loader strategy conflicts) + used_contains_eager = False + for path, relationship_attr, target_alias in join_paths: + rel_attr = cast(InstrumentedAttribute, relationship_attr) + name = relationship_attr.key + if name in sql_hops: + # Needed for WHERE/ORDER BY: join + hydrate from that join + query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) + query = query.options(contains_eager(rel_attr, alias=target_alias)) + used_contains_eager = True + elif name in proj_hops: + # Display-only: bulk-load efficiently, no join + query = query.options(selectinload(rel_attr)) + else: + # Not needed + pass + + # Apply projection loader options only if they won't conflict with contains_eager + if proj_opts and not used_contains_eager: + query = query.options(*proj_opts) + # Order + limit - order_spec = self._extract_order_spec(root_alias, order_by) # SA2-only helper + order_spec = self._extract_order_spec(root_alias, order_by) # SA 2.x helper limit, _ = spec.parse_pagination() if limit is None: effective_limit = 50 elif limit == 0: - effective_limit = None + effective_limit = None # unlimited else: effective_limit = limit @@ -222,16 +322,17 @@ class CRUDService(Generic[T]): query = query.limit(effective_limit) items = list(reversed(query.all())) - # Tag projection so your renderer knows what fields were requested - if expanded_fields: - proj = list(expanded_fields) + if fields: + proj = list(dict.fromkeys(fields)) # dedupe, preserve order + if "id" not in proj and hasattr(self.model, "id"): + proj.insert(0, "id") else: proj = [] if root_field_names: proj.extend(root_field_names) if root_fields: - proj.extend(c.key for c in root_fields) + proj.extend(c.key for c in root_fields if hasattr(c, "key")) for path, names in (rel_field_names or {}).items(): prefix = ".".join(path) for n in names: @@ -248,7 +349,7 @@ class CRUDService(Generic[T]): # Boundary keys for cursor encoding in the API layer first_key = self._pluck_key(items[0], order_spec) if items else None - last_key = self._pluck_key(items[-1], order_spec) if items else None + last_key = self._pluck_key(items[-1], order_spec) if items else None # Optional total that’s safe under JOINs (COUNT DISTINCT ids) total = None @@ -257,10 +358,11 @@ class CRUDService(Generic[T]): base = self._apply_not_deleted(base, root_alias, params) if filters: base = base.filter(*filters) - for _, relationship_attr, target_alias in join_paths: # reuse - rel_attr = cast(InstrumentedAttribute, relationship_attr) - target = cast(Any, target_alias) - base = base.join(target, rel_attr.of_type(target), isouter=True) + # Mirror join structure for any SQL-needed relationships + for path, relationship_attr, target_alias in join_paths: + if relationship_attr.key in sql_hops: + rel_attr = cast(InstrumentedAttribute, relationship_attr) + base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True) total = session.query(func.count()).select_from( base.order_by(None).distinct().subquery() ).scalar() or 0 @@ -270,7 +372,6 @@ class CRUDService(Generic[T]): if log.isEnabledFor(logging.DEBUG): log.debug("QUERY: %s", str(query)) - from crudkit.core.types import SeekWindow # avoid circulars at module top return SeekWindow( items=items, limit=window_limit_for_body, @@ -311,50 +412,81 @@ class CRUDService(Generic[T]): return [*order_by, *pk_cols] def get(self, id: int, params=None) -> T | None: + """Fetch a single row by id with conflict-free eager loading and clean projection.""" query, root_alias = self.get_query() - include_deleted = False - root_fields = [] - root_field_names = {} - rel_field_names = {} + # Defaults so we can build a projection even if params is None + root_fields: list[Any] = [] + root_field_names: dict[str, str] = {} + rel_field_names: dict[tuple[str, ...], list[str]] = {} + req_fields: list[str] = _normalize_fields_param(params) + + # Soft-delete guard + query = self._apply_not_deleted(query, root_alias, params) spec = CRUDSpec(self.model, params or {}, root_alias) - if params: - if self.supports_soft_delete: - include_deleted = _is_truthy(params.get('include_deleted')) - if self.supports_soft_delete and not include_deleted: - query = query.filter(getattr(root_alias, "is_deleted") == False) + + # Optional extra filters (in addition to id); keep parity with list() + filters = spec.parse_filters() + if filters: + query = query.filter(*filters) + + # Always filter by id query = query.filter(getattr(root_alias, "id") == id) + # Includes + join paths we may need spec.parse_includes() + join_paths = tuple(spec.get_join_paths()) - for _, relationship_attr, target_alias in spec.get_join_paths(): - rel_attr = cast(InstrumentedAttribute, relationship_attr) - target = cast(Any, target_alias) - query = query.join(target, rel_attr.of_type(target), isouter=True) - + # Field parsing to enable root load_only if params: root_fields, rel_field_names, root_field_names = spec.parse_fields() - req_fields = list((params or {}).get("fields", [])) - expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) - if proj_opts: - query = query.options(*proj_opts) + # Decide which relationship paths are needed for SQL vs display-only + # For get(), there is no ORDER BY; only filters might force SQL use. + sql_paths = _paths_needed_for_sql(order_by=None, filters=filters, join_paths=join_paths) + proj_paths = _paths_from_fields(req_fields) + # Root column projection only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) + # Relationship handling per path: avoid loader strategy conflicts + used_contains_eager = False + for path, relationship_attr, target_alias in join_paths: + rel_attr = cast(InstrumentedAttribute, relationship_attr) + ptuple = tuple(path) + if ptuple in sql_paths: + # Needed in WHERE: join + hydrate from the join + query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) + query = query.options(contains_eager(rel_attr, alias=target_alias)) + used_contains_eager = True + elif ptuple in proj_paths: + # Display-only: bulk-load efficiently + query = query.options(selectinload(rel_attr)) + else: + pass + + # Projection loader options compiled from requested fields. + # Skip if we used contains_eager to avoid strategy conflicts. + expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) + if proj_opts and not used_contains_eager: + query = query.options(*proj_opts) + obj = query.first() - if expanded_fields: - proj = list(expanded_fields) + # Emit exactly what the client requested (plus id), or a reasonable fallback + if req_fields: + proj = list(dict.fromkeys(req_fields)) # dedupe, preserve order + if "id" not in proj and hasattr(self.model, "id"): + proj.insert(0, "id") else: proj = [] if root_field_names: proj.extend(root_field_names) if root_fields: - proj.extend(c.key for c in root_fields) + proj.extend(c.key for c in root_fields if hasattr(c, "key")) for path, names in (rel_field_names or {}).items(): prefix = ".".join(path) for n in names: @@ -374,40 +506,60 @@ class CRUDService(Generic[T]): return obj or None def list(self, params=None) -> list[T]: + """Offset/limit listing with smart relationship loading and clean projection.""" query, root_alias = self.get_query() - root_fields = [] - root_field_names = {} - rel_field_names = {} + # Defaults so we can reference them later even if params is None + root_fields: list[Any] = [] + root_field_names: dict[str, str] = {} + rel_field_names: dict[tuple[str, ...], list[str]] = {} + req_fields: list[str] = _normalize_fields_param(params) if params: - if self.supports_soft_delete: - include_deleted = _is_truthy(params.get('include_deleted')) - if not include_deleted: - query = query.filter(getattr(root_alias, "is_deleted") == False) + query = self._apply_not_deleted(query, root_alias, params) spec = CRUDSpec(self.model, params or {}, root_alias) filters = spec.parse_filters() order_by = spec.parse_sort() limit, offset = spec.parse_pagination() + + # Includes + join paths we might need spec.parse_includes() + join_paths = tuple(spec.get_join_paths()) - for _, relationship_attr, target_alias in spec.get_join_paths(): - rel_attr = cast(InstrumentedAttribute, relationship_attr) - target = cast(Any, target_alias) - query = query.join(target, rel_attr.of_type(target), isouter=True) - - if params: - root_fields, rel_field_names, root_field_names = spec.parse_fields() - - only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] - if only_cols: - query = query.options(Load(root_alias).load_only(*only_cols)) + # Field parsing for load_only on root columns + root_fields, rel_field_names, root_field_names = spec.parse_fields() if filters: query = query.filter(*filters) - # MSSQL: requires ORDER BY when using OFFSET (and SQLA will use OFFSET for limit+offset). + # Determine which relationship paths are needed for SQL vs display-only + sql_paths = _paths_needed_for_sql(order_by, filters, join_paths) + proj_paths = _paths_from_fields(req_fields) + + # Root column projection + only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] + if only_cols: + query = query.options(Load(root_alias).load_only(*only_cols)) + + # Relationship handling per path + used_contains_eager = False + for path, relationship_attr, target_alias in join_paths: + rel_attr = cast(InstrumentedAttribute, relationship_attr) + ptuple = tuple(path) + if ptuple in sql_paths: + # Needed for WHERE/ORDER BY: join + hydrate from the join + query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) + query = query.options(contains_eager(rel_attr, alias=target_alias)) + used_contains_eager = True + elif ptuple in proj_paths: + # Display-only: no join, bulk-load efficiently + query = query.options(selectinload(rel_attr)) + else: + # Not needed at all; do nothing + pass + + # MSSQL requires ORDER BY when OFFSET is used (SQLA uses OFFSET for limit/offset) paginating = (limit is not None) or (offset is not None and offset != 0) if paginating and not order_by and self.backend.requires_order_by_for_offset: order_by = self._default_order_by(root_alias) @@ -415,27 +567,37 @@ class CRUDService(Generic[T]): if order_by: query = query.order_by(*order_by) - # Only apply offset/limit when not None. + # Only apply offset/limit when not None and not zero if offset is not None and offset != 0: query = query.offset(offset) if limit is not None and limit > 0: query = query.limit(limit) - req_fields = list((params or {}).get("fields", [])) - expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) - if proj_opts: - query = query.options(*proj_opts) + # Projection loader options compiled from requested fields. + # Skip if we used contains_eager to avoid loader-strategy conflicts. + expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) + if proj_opts and not used_contains_eager: + query = query.options(*proj_opts) + + else: + # No params means no filters/sorts/limits; still honor projection loaders if any + expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) + if proj_opts: + query = query.options(*proj_opts) rows = query.all() - if expanded_fields: - proj = list(expanded_fields) + # Emit exactly what the client requested (plus id), or a reasonable fallback + if req_fields: + proj = list(dict.fromkeys(req_fields)) # dedupe while preserving order + if "id" not in proj and hasattr(self.model, "id"): + proj.insert(0, "id") else: proj = [] if root_field_names: proj.extend(root_field_names) if root_fields: - proj.extend(c.key for c in root_fields) + proj.extend(c.key for c in root_fields if hasattr(c, "key")) for path, names in (rel_field_names or {}).items(): prefix = ".".join(path) for n in names: