diff --git a/crudkit/core/base.py b/crudkit/core/base.py index 5501e19..c42b90e 100644 --- a/crudkit/core/base.py +++ b/crudkit/core/base.py @@ -1,5 +1,6 @@ +from typing import Any, Dict, Iterable, List, Tuple, Set from sqlalchemy import Column, Integer, DateTime, Boolean, String, JSON, func, inspect -from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE +from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE, RelationshipProperty Base = declarative_base() @@ -16,45 +17,296 @@ def _safe_get_loaded_attr(obj, name): except Exception: return None +def _identity_key(obj) -> Tuple[type, Any]: + try: + st = inspect(obj) + return (type(obj), st.identity_key[1][0] if st.identity_key else id(obj)) + except Exception: + return (type(obj), id(obj)) + +def _is_collection_rel(prop: RelationshipProperty) -> bool: + try: + return prop.uselist is True + except Exception: + return False + +def _serialize_simple_obj(obj) -> Dict[str, Any]: + """Columns only (no relationships).""" + out: Dict[str, Any] = {} + for cls in obj.__class__.__mro__: + if hasattr(cls, "__table__"): + for col in cls.__table__.columns: + name = col.name + try: + out[name] = getattr(obj, name) + except Exception: + out[name] = None + return out + +def _serialize_loaded_rel(obj, name, *, depth: int, seen: Set[Tuple[type, Any]], embed: Set[str]) -> Any: + """ + Serialize relationship 'name' already loaded on obj. + - If in 'embed' (or depth > 0 for depth-based walk), recurse. + - Else, return None (don’t lazy-load). + """ + val = _safe_get_loaded_attr(obj, name) + if val is None: + return None + + # Decide whether to recurse into this relationship + should_recurse = (depth > 0) or (name in embed) + + if isinstance(val, list): + if not should_recurse: + # Emit a light list of child primary data (id + a couple columns) without recursion. + return [_serialize_simple_obj(child) for child in val] + out = [] + for child in val: + ik = _identity_key(child) + if ik in seen: # cycle guard + out.append({"id": getattr(child, "id", None)}) + continue + seen.add(ik) + out.append(child.as_tree(embed_depth=max(depth - 1, 0), embed=embed, _seen=seen)) + return out + + # Scalar relationship + child = val + if not should_recurse: + return _serialize_simple_obj(child) + ik = _identity_key(child) + if ik in seen: + return {"id": getattr(child, "id", None)} + seen.add(ik) + return child.as_tree(embed_depth=max(depth - 1, 0), embed=embed, _seen=seen) + +def _split_field_tokens(fields: Iterable[str]) -> Tuple[List[str], Dict[str, List[str]]]: + """ + Split requested fields into: + - scalars: ["label", "name"] + - collections: {"updates": ["id", "timestamp","content"], "owner": ["label"]} + Any dotted token "root.rest.of.path" becomes collections[root].append("rest.of.path"). + Bare tokens ("foo") land in scalars. + """ + scalars: List[str] = [] + groups: Dict[str, List[str]] = {} + for raw in fields: + f = str(raw).strip() + if not f: + continue + # bare token -> scalar + if "." not in f: + scalars.append(f) + continue + # dotted token -> group under root + root, tail = f.split(".", 1) + if not root or not tail: + continue + groups.setdefault(root, []).append(tail) + return scalars, groups + +def _deep_get_loaded(obj: Any, dotted: str) -> Any: + """ + Deep get with no lazy loads: + - For all but the final hop, use _safe_get_loaded_attr (mapped-only, no getattr). + - For the final hop, try _safe_get_loaded_attr first; if None, fall back to getattr() + to allow computed properties/hybrids that rely on already-loaded columns. + """ + parts = dotted.split(".") + if not parts: + return None + + cur = obj + # Traverse up to the parent of the last token safely + for part in parts[:-1]: + if cur is None: + return None + cur = _safe_get_loaded_attr(cur, part) + if cur is None: + return None + + last = parts[-1] + # Try safe fetch on the last hop first + val = _safe_get_loaded_attr(cur, last) + if val is not None: + return val + # Fall back to getattr for computed/hybrid attributes on an already-loaded object + try: + return getattr(cur, last, None) + except Exception: + return None + +def _serialize_leaf(obj: Any) -> Any: + """ + Lead serialization for values we put into as_dict(): + - If object has as_dict(), call as_dict() with no args (caller controls field shapes). + - Else return value as-is (Flask/JSON encoder will handle datetimes, etc., via app config). + """ + if obj is None: + return None + ad = getattr(obj, "as_dict", None) + if callable(ad): + try: + return ad(None) + except Exception: + return str(obj) + return obj + +def _serialize_collection(items: Iterable[Any], requested_tails: List[str]) -> List[Dict[str, Any]]: + """ + Turn a collection of ORM objects into list[dict] with exactly requested_tails, + where each tail can be dotted again (e.g., "author.label"). We do NOT lazy-load. + """ + out: List[Dict[str, Any]] = [] + # Deduplicate while preserving order + uniq_tails = list(dict.fromkeys(requested_tails)) + for child in (items or []): + row: Dict[str, Any] = {} + for tail in uniq_tails: + row[tail] = _deep_get_loaded(child, tail) + # ensure id present if exists and not already requested + try: + if "id" not in row and hasattr(child, "id"): + row["id"] = getattr(child, "id") + except Exception: + pass + out.append(row) + return out + @declarative_mixin class CRUDMixin: id = Column(Integer, primary_key=True) created_at = Column(DateTime, default=func.now(), nullable=False) updated_at = Column(DateTime, default=func.now(), nullable=False, onupdate=func.now()) + def as_tree( + self, + *, + embed_depth: int = 0, + embed: Iterable[str] | None = None, + _seen: Set[Tuple[type, Any]] | None = None, + ) -> Dict[str, Any]: + """ + Recursive, NON-LAZY serializer. + - Always includes mapped columns. + - For relationships: only serializes those ALREADY LOADED. + - Recurses either up to embed_depth or for specific names in 'embed'. + - Keeps *_id columns alongside embedded objects. + - Cycle-safe via _seen. + """ + seen = _seen or set() + ik = _identity_key(self) + if ik in seen: + return {"id": getattr(self, "id", None)} + seen.add(ik) + + data = _serialize_simple_obj(self) + + # Determine which relationships to consider + try: + st = inspect(self) + mapper = st.mapper + embed_set = set(str(x).split(".", 1)[0] for x in (embed or [])) # top-level names + for name, prop in mapper.relationships.items(): + # Only touch relationships that are already loaded; never lazy-load here. + rel_loaded = st.attrs.get(name) + if rel_loaded is None or rel_loaded.loaded_value is NO_VALUE: + continue + + data[name] = _serialize_loaded_rel( + self, name, depth=embed_depth, seen=seen, embed=embed_set + ) + except Exception: + # If inspection fails, we just return columns. + pass + + return data + def as_dict(self, fields: list[str] | None = None): """ Serialize the instance. - - If 'fields' (possibly dotted) is provided, emit exactly those keys. - - Else, if '__crudkit_projection__' is set on the instance, emit those keys. - - Else, fall back to all mapped columns on this class hierarchy. - Always includes 'id' when present unless explicitly excluded. - """ - if fields is None: - fields = getattr(self, "__crudkit_projection__", None) - if fields: - out = {} - if "id" not in fields and hasattr(self, "id"): - out["id"] = getattr(self, "id") - for f in fields: - cur = self - for part in f.split("."): - if cur is None: - break - cur = getattr(cur, part, None) - out[f] = cur + Behavior: + - If 'fields' (possibly dotted) is provided, emit exactly those keys. + * Bare tokens (e.g., "label", "owner") return the current loaded value. + * Dotted tokens for one-to-many (e.g., "updates.id","updates.timestamp") + produce a single "updates" key containing a list of dicts with the requested child keys. + * Dotted tokens for many-to-one/one-to-one (e.g., "owner.label") emit the scalar under "owner.label". + - Else, if '__crudkit_projection__' is set on the instance, use that. + - Else, fall back to all mapped columns on this class hierarchy. + + Always includes 'id' when present unless explicitly excluded (i.e., fields explicitly provided without id). + """ + req = fields if fields is not None else getattr(self, "__crudkit_projection__", None) + + if req: + # Normalize and split into (scalars, groups of dotted by root) + req_list = [p for p in (str(x).strip() for x in req) if p] + scalars, groups = _split_field_tokens(req_list) + + out: Dict[str, Any] = {} + + # Always include id unless the caller explicitly listed fields containing id + if "id" not in req_list and hasattr(self, "id"): + try: + out["id"] = getattr(self, "id") + except Exception: + pass + + # Handle scalar tokens (may be columns, hybrids/properties, or relationships) + for name in scalars: + # Try loaded value first (never lazy-load) + val = _safe_get_loaded_attr(self, name) + + # Final-hop getattr for root scalars (hybrids/@property) so they can compute. + if val is None: + try: + val = getattr(self, name) + except Exception: + val = None + + # If it's a scalar ORM object (relationship), serialize its columns + try: + st = inspect(val) # will raise if not an ORM object + if getattr(st, "mapper", None) is not None: + out[name] = _serialize_simple_obj(val) + continue + except Exception: + pass + + # If it's a collection and no subfields were requested, emit a light list + if isinstance(val, (list, tuple)): + out[name] = [_serialize_leaf(v) for v in val] + else: + out[name] = val + + # Handle dotted groups: root -> [tails] + for root, tails in groups.items(): + root_val = _safe_get_loaded_attr(self, root) + if isinstance(root_val, (list, tuple)): + # one-to-many collection → list of dicts with the requested tails + out[root] = _serialize_collection(root_val, tails) + else: + # many-to-one or scalar dotted; place each full dotted path as key + for tail in tails: + dotted = f"{root}.{tail}" + out[dotted] = _deep_get_loaded(self, dotted) + + # ← This was the placeholder before. We return the dict we just built. return out - result = {} + # Fallback: all mapped columns on this class hierarchy + result: Dict[str, Any] = {} for cls in self.__class__.__mro__: if hasattr(cls, "__table__"): for column in cls.__table__.columns: name = column.name - result[name] = getattr(self, name) + try: + result[name] = getattr(self, name) + except Exception: + result[name] = None return result - class Version(Base): __tablename__ = "versions" diff --git a/crudkit/core/service.py b/crudkit/core/service.py index 5f713ed..b4d7036 100644 --- a/crudkit/core/service.py +++ b/crudkit/core/service.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Iterable +from flask import current_app from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast from sqlalchemy import and_, func, inspect, or_, text from sqlalchemy.engine import Engine, Connection @@ -40,66 +41,43 @@ class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol): T = TypeVar("T", bound=_CRUDModelProto) -def _hops_from_sort(params: dict | None) -> set[str]: - """Extract first-hop relationship names from a sort spec like 'owner.first_name,-brand.name'.""" - if not params: - return set() - raw = params.get("sort") - tokens: list[str] = [] - if isinstance(raw, str): - tokens = [t.strip() for t in raw.split(",") if t.strip()] - elif isinstance(raw, (list, tuple)): - for item in raw: - if isinstance(item, str): - tokens.extend([t.strip() for t in item.split(",") if t.strip()]) - hops: set[str] = set() - for tok in tokens: - tok = tok.lstrip("+-") - if "." in tok: - hops.add(tok.split(".", 1)[0]) - return hops +def _unwrap_ob(ob): + """Return (col, is_desc) from an ORDER BY element (handles .asc()/.desc()).""" + col = getattr(ob, "element", None) + if col is None: + col = ob + is_desc = False + dir_attr = getattr(ob, "_direction", None) + if dir_attr is not None: + is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC") + elif isinstance(ob, UnaryExpression): + op = getattr(ob, "operator", None) + is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") + return col, bool(is_desc) -def _belongs_to_alias(col: Any, alias: Any) -> bool: - # Try to detect if a column/expression ultimately comes from this alias. - # Works for most ORM columns; complex expressions may need more. - t = getattr(col, "table", None) - selectable = getattr(alias, "selectable", None) - return t is not None and selectable is not None and t is selectable +def _order_identity(col: ColumnElement): + """ + Build a stable identity for a column suitable for deduping. + We ignore direction here. Duplicates are duplicates regardless of ASC/DESC. + """ + table = getattr(col, "table", None) + table_key = getattr(table, "key", None) or id(table) + col_key = getattr(col, "key", None) or getattr(col, "name", None) + return (table_key, col_key) -def _paths_needed_for_sql(order_by: Iterable[Any], filters: Iterable[Any], join_paths: tuple) -> set[str]: - hops: set[str] = set() - paths: set[tuple[str, ...]] = set() - # Sort columns - for ob in order_by or []: - col = getattr(ob, "element", ob) # unwrap UnaryExpression - for _path, rel_attr, target_alias in join_paths: - if _belongs_to_alias(col, target_alias): - hops.add(rel_attr.key) - # Filter columns (best-effort) - # Walk simple binary expressions - def _extract_cols(expr: Any) -> Iterable[Any]: - if isinstance(expr, ColumnElement): - yield expr - for ch in getattr(expr, "get_children", lambda: [])(): - yield from _extract_cols(ch) - elif hasattr(expr, "clauses"): - for ch in expr.clauses: - yield from _extract_cols(ch) - - for flt in filters or []: - for col in _extract_cols(flt): - for _path, rel_attr, target_alias in join_paths: - if _belongs_to_alias(col, target_alias): - hops.add(rel_attr.key) - return hops - -def _paths_from_fields(req_fields: list[str]) -> set[str]: - out: set[str] = set() - for f in req_fields: - if "." in f: - parent = f.split(".", 1)[0] - if parent: - out.add(parent) +def _dedupe_order_by(order_by): + """Remove duplicate ORDER BY entries (by column identity, ignoring direction).""" + if not order_by: + return [] + seen = set() + out = [] + for ob in order_by: + col, _ = _unwrap_ob(ob) + ident = _order_identity(col) + if ident in seen: + continue + seen.add(ident) + out.append(ob) return out def _is_truthy(val): @@ -133,24 +111,25 @@ class CRUDService(Generic[T]): self.polymorphic = polymorphic self.supports_soft_delete = hasattr(model, 'is_deleted') - # Derive engine WITHOUT leaking a session/connection - bind = getattr(session_factory, "bind", None) - if bind is None: - tmp_sess = session_factory() - try: - bind = tmp_sess.get_bind() - finally: - try: - tmp_sess.close() - except Exception: - pass - - eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind) - self.backend = backend or make_backend_info(eng) + self._backend: Optional[BackendInfo] = backend @property def session(self) -> Session: - return self._session_factory() + """Always return the Flask-scoped Session if available; otherwise the provided factory.""" + try: + sess = current_app.extensions["crudkit"]["Session"] + return sess + except Exception: + return self._session_factory() + + @property + def backend(self) -> BackendInfo: + """Resolve backend info lazily against the active session's engine.""" + if self._backend is None: + bind = self.session.get_bind() + eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind) + self._backend = make_backend_info(eng) + return self._backend def get_query(self): if self.polymorphic: @@ -237,85 +216,69 @@ class CRUDService(Generic[T]): include_total: bool = True, ) -> "SeekWindow[T]": """ - Transport-agnostic keyset pagination that preserves all the goodies from `list()`: - - filters, includes, joins, field projection, eager loading, soft-delete - - deterministic ordering (user sort + PK tiebreakers) - - forward/backward seek via `key` and `backward` - Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total. + Keyset pagination with relationship-safe filtering/sorting. + Always JOIN all CRUDSpec-discovered paths first; then apply filters, sort, seek. """ - self._debug_bind("seek_window") session = self.session query, root_alias = self.get_query() - # Normalize requested fields and compile projection (may skip later to avoid conflicts) + # Requested fields → projection + optional loaders fields = _normalize_fields_param(params) expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], []) spec = CRUDSpec(self.model, params or {}, root_alias) + # Parse all inputs so join_paths are populated filters = spec.parse_filters() order_by = spec.parse_sort() + root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() + spec.parse_includes() + join_paths = tuple(spec.get_join_paths()) - # Field parsing for root load_only fallback - root_fields, rel_field_names, root_field_names = spec.parse_fields() - - # Soft delete filter + # Soft delete query = self._apply_not_deleted(query, root_alias, params) - # Apply filters first - if filters: - query = query.filter(*filters) - - # Includes + join paths (dotted fields etc.) - spec.parse_includes() - join_paths = tuple(spec.get_join_paths()) # iterable of (path, relationship_attr, target_alias) - - # Relationship names required by ORDER BY / WHERE - sql_hops: set[str] = _paths_needed_for_sql(order_by, filters, join_paths) - # Also include relationships mentioned directly in the sort spec - sql_hops |= _hops_from_sort(params) - - # First-hop relationship names implied by dotted projection fields - proj_hops: set[str] = _paths_from_fields(fields) - - # Root column projection + # Root column projection (load_only) only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) - # Relationship handling per path (avoid loader strategy conflicts) - used_contains_eager = False - joined_names: set[str] = set() + # Detect first hops that have deeper, nested tails requested (e.g. "contact.supervisor") + nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } - for _path, relationship_attr, target_alias in join_paths: - rel_attr = cast(InstrumentedAttribute, relationship_attr) - name = relationship_attr.key - if name in sql_hops: - # Needed for WHERE/ORDER BY: join + hydrate from that join + used_contains_eager = False + for base_alias, rel_attr, target_alias in join_paths: + is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False)) + is_nested_firsthop = rel_attr.key in nested_first_hops + + if is_collection or is_nested_firsthop: + # Use selectinload so deeper hops can chain cleanly (and to avoid + # contains_eager/loader conflicts on nested paths). + opt = selectinload(rel_attr) + + # Narrow columns for collections if we know child scalar names + if is_collection: + child_names = (collection_field_names or {}).get(rel_attr.key, []) + if child_names: + target_cls = rel_attr.property.mapper.class_ + cols = [getattr(target_cls, n, None) for n in child_names] + cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] + if cols: + opt = opt.load_only(*cols) + + query = query.options(opt) + else: + # Simple first-hop scalar rel with no deeper tails: safe to join + contains_eager query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.options(contains_eager(rel_attr, alias=target_alias)) used_contains_eager = True - joined_names.add(name) - elif name in proj_hops: - # Display-only: bulk-load efficiently, no join - query = query.options(selectinload(rel_attr)) - joined_names.add(name) - # Force-join any SQL-needed relationships that weren't in join_paths - missing_sql = sql_hops - joined_names - for name in missing_sql: - rel_attr = cast(InstrumentedAttribute, getattr(root_alias, name)) - query = query.join(rel_attr, isouter=True) - query = query.options(contains_eager(rel_attr)) - used_contains_eager = True - joined_names.add(name) + # Filters AFTER joins → no cartesian products + if filters: + query = query.filter(*filters) - # Apply projection loader options only if they won't conflict with contains_eager - if proj_opts and not used_contains_eager: - query = query.options(*proj_opts) - - # Order + limit - order_spec = self._extract_order_spec(root_alias, order_by) # SA 2.x helper + # Order spec (with PK tie-breakers for stability) + order_spec = self._extract_order_spec(root_alias, order_by) limit, _ = spec.parse_pagination() if limit is None: effective_limit = 50 @@ -324,13 +287,13 @@ class CRUDService(Generic[T]): else: effective_limit = limit - # Keyset predicate + # Seek predicate from cursor key (if any) if key: pred = self._key_predicate(order_spec, key, backward) if pred is not None: query = query.filter(pred) - # Apply ordering. For backward, invert SQL order then reverse in-memory for display. + # Apply ORDER and LIMIT. Backward is SQL-inverted + reverse in-memory. if not backward: clauses = [(c.desc() if is_desc else c.asc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)] query = query.order_by(*clauses) @@ -344,9 +307,9 @@ class CRUDService(Generic[T]): query = query.limit(effective_limit) items = list(reversed(query.all())) - # Tag projection so your renderer knows what fields were requested + # Projection meta tag for renderers if fields: - proj = list(dict.fromkeys(fields)) # dedupe, preserve order + proj = list(dict.fromkeys(fields)) if "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") else: @@ -369,12 +332,9 @@ class CRUDService(Generic[T]): except Exception: pass - # Boundary keys for cursor encoding in the API layer - # When ORDER BY includes related columns (e.g., owner.first_name), - # pluck values from the related object we hydrated with contains_eager/selectinload. + # Cursor key pluck: support related columns we hydrated via contains_eager def _pluck_key_from_obj(obj: Any) -> list[Any]: vals: list[Any] = [] - # Build a quick map: selectable -> relationship name alias_to_rel: dict[Any, str] = {} for _p, relationship_attr, target_alias in join_paths: sel = getattr(target_alias, "selectable", None) @@ -382,20 +342,17 @@ class CRUDService(Generic[T]): alias_to_rel[sel] = relationship_attr.key for col in order_spec.cols: - key = getattr(col, "key", None) or getattr(col, "name", None) - # Try root attribute first - if key and hasattr(obj, key): - vals.append(getattr(obj, key)) + keyname = getattr(col, "key", None) or getattr(col, "name", None) + if keyname and hasattr(obj, keyname): + vals.append(getattr(obj, keyname)) continue - # Try relationship hop by matching the column's table/selectable table = getattr(col, "table", None) relname = alias_to_rel.get(table) - if relname and key: + if relname and keyname: relobj = getattr(obj, relname, None) - if relobj is not None and hasattr(relobj, key): - vals.append(getattr(relobj, key)) + if relobj is not None and hasattr(relobj, keyname): + vals.append(getattr(relobj, keyname)) continue - # Give up: unsupported expression for cursor purposes raise ValueError("unpluckable") return vals @@ -403,33 +360,26 @@ class CRUDService(Generic[T]): first_key = _pluck_key_from_obj(items[0]) if items else None last_key = _pluck_key_from_obj(items[-1]) if items else None except Exception: - # If we can't derive cursor keys (e.g., ORDER BY expression/aggregate), - # disable cursors for this response rather than exploding. first_key = None last_key = None - # Optional total that’s safe under JOINs (COUNT DISTINCT ids) + # Count DISTINCT ids with mirrored joins total = None if include_total: base = session.query(getattr(root_alias, "id")) base = self._apply_not_deleted(base, root_alias, params) + # same joins as above for correctness + for base_alias, rel_attr, target_alias in join_paths: + # do not join collections for COUNT mirror + if not bool(getattr(getattr(rel_attr, "property", None), "uselist", False)): + base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True) if filters: base = base.filter(*filters) - # Mirror join structure for any SQL-needed relationships - for _path, relationship_attr, target_alias in join_paths: - if relationship_attr.key in sql_hops: - rel_attr = cast(InstrumentedAttribute, relationship_attr) - base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True) - # Also mirror any forced joins - for name in (sql_hops - {ra.key for _p, ra, _a in join_paths}): - rel_attr = cast(InstrumentedAttribute, getattr(root_alias, name)) - base = base.join(rel_attr, isouter=True) - total = session.query(func.count()).select_from( base.order_by(None).distinct().subquery() ).scalar() or 0 - window_limit_for_body = 0 if effective_limit is None and limit == 0 else (effective_limit or 50) + window_limit_for_body = 0 if effective_limit is None and (limit == 0) else (effective_limit or 50) if log.isEnabledFor(logging.DEBUG): log.debug("QUERY: %s", str(query)) @@ -458,81 +408,93 @@ class CRUDService(Generic[T]): """ Ensure deterministic ordering by appending PK columns as tiebreakers. If no order is provided, fall back to default primary-key order. + Never duplicates columns in ORDER BY (SQL Server requires uniqueness). """ order_by = list(given_order_by or []) if not order_by: - return self._default_order_by(root_alias) + return _dedupe_order_by(self._default_order_by(root_alias)) + + order_by = _dedupe_order_by(order_by) mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) - pk_cols = [] - for col in mapper.primary_key: - try: - pk_cols.append(getattr(root_alias, col.key)) - except AttributeError: - pk_cols.append(col) + present = {_order_identity(_unwrap_ob(ob)[0]) for ob in order_by} - return [*order_by, *pk_cols] + for pk in mapper.primary_key: + try: + pk_col = getattr(root_alias, pk.key) + except AttributeError: + pk_col = pk + if _order_identity(pk_col) not in present: + order_by.append(pk_col.asc()) + present.add(_order_identity(pk_col)) + + return order_by def get(self, id: int, params=None) -> T | None: - """Fetch a single row by id with conflict-free eager loading and clean projection.""" - self._debug_bind("get") + """ + Fetch a single row by id with conflict-free eager loading and clean projection. + Always JOIN any paths that CRUDSpec resolved for filters/fields/includes so + related-column filters never create cartesian products. + """ query, root_alias = self.get_query() # Defaults so we can build a projection even if params is None + req_fields: list[str] = _normalize_fields_param(params) root_fields: list[Any] = [] root_field_names: dict[str, str] = {} rel_field_names: dict[tuple[str, ...], list[str]] = {} - req_fields: list[str] = _normalize_fields_param(params) - # Soft-delete guard + # Soft-delete guard first query = self._apply_not_deleted(query, root_alias, params) spec = CRUDSpec(self.model, params or {}, root_alias) - # Optional extra filters (in addition to id); keep parity with list() + # Parse everything so CRUDSpec records any join paths it needed to resolve filters = spec.parse_filters() - if filters: - query = query.filter(*filters) - - # Always filter by id - query = query.filter(getattr(root_alias, "id") == id) - - # Includes + join paths we may need + # no ORDER BY for get() + if params: + root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() spec.parse_includes() + join_paths = tuple(spec.get_join_paths()) - # Field parsing to enable root load_only - if params: - root_fields, rel_field_names, root_field_names = spec.parse_fields() - - # Decide which relationship paths are needed for SQL vs display-only - # For get(), there is no ORDER BY; only filters might force SQL use. - sql_hops = _paths_needed_for_sql(order_by=None, filters=filters, join_paths=join_paths) - proj_hops = _paths_from_fields(req_fields) - - # Root column projection + # Root-column projection (load_only) only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) - # Relationship handling per path: avoid loader strategy conflicts + nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } + used_contains_eager = False - for _path, relationship_attr, target_alias in join_paths: - rel_attr = cast(InstrumentedAttribute, relationship_attr) - name = relationship_attr.key - if name in sql_hops: - # Needed in WHERE: join + hydrate from the join + for base_alias, rel_attr, target_alias in join_paths: + is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False)) + is_nested_firsthop = rel_attr.key in nested_first_hops + + if is_collection or is_nested_firsthop: + opt = selectinload(rel_attr) + if is_collection: + child_names = (collection_field_names or {}).get(rel_attr.key, []) + if child_names: + target_cls = rel_attr.property.mapper.class_ + cols = [getattr(target_cls, n, None) for n in child_names] + cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] + if cols: + opt = opt.load_only(*cols) + query = query.options(opt) + else: query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.options(contains_eager(rel_attr, alias=target_alias)) used_contains_eager = True - elif name in proj_hops: - # Display-only: bulk-load efficiently - query = query.options(selectinload(rel_attr)) - else: - pass + + # Apply filters (joins are in place → no cartesian products) + if filters: + query = query.filter(*filters) + + # And the id filter + query = query.filter(getattr(root_alias, "id") == id) # Projection loader options compiled from requested fields. - # Skip if we used contains_eager to avoid strategy conflicts. + # Skip if we used contains_eager to avoid loader-strategy conflicts. expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) if proj_opts and not used_contains_eager: query = query.options(*proj_opts) @@ -541,7 +503,7 @@ class CRUDService(Generic[T]): # Emit exactly what the client requested (plus id), or a reasonable fallback if req_fields: - proj = list(dict.fromkeys(req_fields)) # dedupe, preserve order + proj = list(dict.fromkeys(req_fields)) if "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") else: @@ -569,17 +531,20 @@ class CRUDService(Generic[T]): return obj or None def list(self, params=None) -> list[T]: - """Offset/limit listing with smart relationship loading and clean projection.""" - self._debug_bind("list") + """ + Offset/limit listing with relationship-safe filtering. + We always JOIN every CRUDSpec-discovered path before applying filters/sorts. + """ query, root_alias = self.get_query() # Defaults so we can reference them later even if params is None + req_fields: list[str] = _normalize_fields_param(params) root_fields: list[Any] = [] root_field_names: dict[str, str] = {} rel_field_names: dict[tuple[str, ...], list[str]] = {} - req_fields: list[str] = _normalize_fields_param(params) if params: + # Soft delete query = self._apply_not_deleted(query, root_alias, params) spec = CRUDSpec(self.model, params or {}, root_alias) @@ -587,84 +552,75 @@ class CRUDService(Generic[T]): order_by = spec.parse_sort() limit, offset = spec.parse_pagination() - # Includes + join paths we might need + # Includes / fields (populates join_paths) + root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() spec.parse_includes() join_paths = tuple(spec.get_join_paths()) - # Field parsing for load_only on root columns - root_fields, rel_field_names, root_field_names = spec.parse_fields() - - if filters: - query = query.filter(*filters) - - # Determine which relationship paths are needed for SQL vs display-only - sql_hops = _paths_needed_for_sql(order_by, filters, join_paths) - sql_hops |= _hops_from_sort(params) # ensure sort-driven joins exist - proj_hops = _paths_from_fields(req_fields) - - # Root column projection + # Root column projection (load_only) only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] if only_cols: query = query.options(Load(root_alias).load_only(*only_cols)) - # Relationship handling per path - used_contains_eager = False - joined_names: set[str] = set() + nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } - for _path, relationship_attr, target_alias in join_paths: - rel_attr = cast(InstrumentedAttribute, relationship_attr) - name = relationship_attr.key - if name in sql_hops: - # Needed for WHERE/ORDER BY: join + hydrate from the join + used_contains_eager = False + for _base_alias, rel_attr, target_alias in join_paths: + is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False)) + is_nested_firsthop = rel_attr.key in nested_first_hops + + if is_collection or is_nested_firsthop: + opt = selectinload(rel_attr) + if is_collection: + child_names = (collection_field_names or {}).get(rel_attr.key, []) + if child_names: + target_cls = rel_attr.property.mapper.class_ + cols = [getattr(target_cls, n, None) for n in child_names] + cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] + if cols: + opt = opt.load_only(*cols) + query = query.options(opt) + else: query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.options(contains_eager(rel_attr, alias=target_alias)) used_contains_eager = True - joined_names.add(name) - elif name in proj_hops: - # Display-only: no join, bulk-load efficiently - query = query.options(selectinload(rel_attr)) - joined_names.add(name) - # Force-join any SQL-needed relationships that weren't in join_paths - missing_sql = sql_hops - joined_names - for name in missing_sql: - rel_attr = cast(InstrumentedAttribute, getattr(root_alias, name)) - query = query.join(rel_attr, isouter=True) - query = query.options(contains_eager(rel_attr)) - used_contains_eager = True - joined_names.add(name) + # Filters AFTER joins → no cartesian products + if filters: + query = query.filter(*filters) - # MSSQL requires ORDER BY when OFFSET is used (SQLA uses OFFSET for limit/offset) + # MSSQL requires ORDER BY when OFFSET is used; ensure stable PK tie-breakers paginating = (limit is not None) or (offset is not None and offset != 0) if paginating and not order_by and self.backend.requires_order_by_for_offset: order_by = self._default_order_by(root_alias) + order_by = _dedupe_order_by(order_by) + if order_by: query = query.order_by(*order_by) - # Only apply offset/limit when not None and not zero + # Offset/limit if offset is not None and offset != 0: query = query.offset(offset) if limit is not None and limit > 0: query = query.limit(limit) - # Projection loader options compiled from requested fields. - # Skip if we used contains_eager to avoid loader-strategy conflicts. + # Projection loaders only if we didn’t use contains_eager expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) if proj_opts and not used_contains_eager: query = query.options(*proj_opts) else: - # No params means no filters/sorts/limits; still honor projection loaders if any + # No params; still honor projection loaders if any expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) if proj_opts: query = query.options(*proj_opts) rows = query.all() - # Emit exactly what the client requested (plus id), or a reasonable fallback + # Build projection meta for renderers if req_fields: - proj = list(dict.fromkeys(req_fields)) # dedupe while preserving order + proj = list(dict.fromkeys(req_fields)) if "id" not in proj and hasattr(self.model, "id"): proj.insert(0, "id") else: @@ -692,7 +648,6 @@ class CRUDService(Generic[T]): return rows - def create(self, data: dict, actor=None) -> T: session = self.session obj = self.model(**data) diff --git a/crudkit/core/spec.py b/crudkit/core/spec.py index 9c0e53b..d5c2480 100644 --- a/crudkit/core/spec.py +++ b/crudkit/core/spec.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Set, Dict, Optional +from typing import List, Tuple, Set, Dict, Optional, Iterable from sqlalchemy import asc, desc from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import aliased, selectinload @@ -20,10 +20,14 @@ class CRUDSpec: self.params = params self.root_alias = root_alias self.eager_paths: Set[Tuple[str, ...]] = set() + # (parent_alias. relationship_attr, alias_for_target) self.join_paths: List[Tuple[object, InstrumentedAttribute, object]] = [] self.alias_map: Dict[Tuple[str, ...], object] = {} self._root_fields: List[InstrumentedAttribute] = [] - self._rel_field_names: Dict[Tuple[str, ...], object] = {} + # dotted non-collection fields (MANYTOONE etc) + self._rel_field_names: Dict[Tuple[str, ...], List[str]] = {} + # dotted collection fields (ONETOMANY) + self._collection_field_names: Dict[str, List[str]] = {} self.include_paths: Set[Tuple[str, ...]] = set() def _resolve_column(self, path: str): @@ -117,11 +121,12 @@ class CRUDSpec: Parse ?fields=colA,colB,rel1.colC,rel1.rel2.colD - Root fields become InstrumentedAttributes bound to root_alias. - Related fields store attribute NAMES; we'll resolve them on the target class when building loader options. - Returns (root_fields, rel_field_names). + - Collection (uselist=True) relationships record child names by relationship key. + Returns (root_fields, rel_field_names, root_field_names, collection_field_names_by_rel). """ raw = self.params.get('fields') if not raw: - return [], {}, {} + return [], {}, {}, {} if isinstance(raw, list): tokens = [] @@ -133,14 +138,36 @@ class CRUDSpec: root_fields: List[InstrumentedAttribute] = [] root_field_names: list[str] = [] rel_field_names: Dict[Tuple[str, ...], List[str]] = {} + collection_field_names: Dict[str, List[str]] = {} for token in tokens: col, join_path = self._resolve_column(token) if not col: continue if join_path: - rel_field_names.setdefault(join_path, []).append(col.key) - self.eager_paths.add(join_path) + # rel_field_names.setdefault(join_path, []).append(col.key) + # self.eager_paths.add(join_path) + try: + cur_cls = self.model + names = list(join_path) + last_name = names[-1] + for nm in names: + rel_attr = getattr(cur_cls, nm) + cur_cls = rel_attr.property.mapper.class_ + is_collection = bool(getattr(getattr(self.model, last_name), "property", None) and getattr(getattr(self.model, last_name).property, "uselist", False)) + except Exception: + # Fallback: inspect the InstrumentedAttribute we recorded on join_paths + is_collection = False + for _pa, rel_attr, _al in self.join_paths: + if rel_attr.key == (join_path[-1] if join_path else ""): + is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False)) + break + + if is_collection: + collection_field_names.setdefault(join_path[-1], []).append(col.key) + else: + rel_field_names.setdefault(join_path, []).append(col.key) + self.eager_paths.add(join_path) else: root_fields.append(col) root_field_names.append(getattr(col, "key", token)) @@ -153,7 +180,11 @@ class CRUDSpec: self._root_fields = root_fields self._rel_field_names = rel_field_names - return root_fields, rel_field_names, root_field_names + # return root_fields, rel_field_names, root_field_names + for r, names in collection_field_names.items(): + seen3 = set() + collection_field_names[r] = [n for n in names if not (n in seen3 or seen3.add(n))] + return root_field_names, rel_field_names, root_field_names, collection_field_names def get_eager_loads(self, root_alias, *, fields_map=None): loads = [] diff --git a/crudkit/integrations/flask.py b/crudkit/integrations/flask.py index 433f6a5..feb262b 100644 --- a/crudkit/integrations/flask.py +++ b/crudkit/integrations/flask.py @@ -26,32 +26,6 @@ def init_app(app: Flask, *, runtime: CRUDKitRuntime | None = None, config: type[ try: bound_engine = getattr(SessionFactory, "bind", None) or getattr(SessionFactory, "kw", {}).get("bind") or engine pool = bound_engine.pool - - from sqlalchemy import event - - @event.listens_for(pool, "checkout") - def _on_checkout(dbapi_conn, conn_record, conn_proxy): - sz = pool.size() - chk = pool.checkedout() - try: - conns_in_pool = pool.checkedin() - except Exception: - conns_in_pool = "?" - print(f"POOL CHECKOUT: Pool size: {sz} Connections in pool: {conns_in_pool} " - f"Current Overflow: {pool.overflow()} Current Checked out connections: {chk} " - f"engine id= {id(bound_engine)}") - - @event.listens_for(pool, "checkin") - def _on_checkin(dbapi_conn, conn_record): - sz = pool.size() - chk = pool.checkedout() - try: - conns_in_pool = pool.checkedin() - except Exception: - conns_in_pool = "?" - print(f"POOL CHECKIN: Pool size: {sz} Connections in pool: {conns_in_pool} " - f"Current Overflow: {pool.overflow()} Current Checked out connections: {chk} " - f"engine id= {id(bound_engine)}") except Exception as e: print(f"[crudkit.init_app] Failed to attach pool listeners: {e}") diff --git a/crudkit/ui/fragments.py b/crudkit/ui/fragments.py index 6c62c45..e60bd1a 100644 --- a/crudkit/ui/fragments.py +++ b/crudkit/ui/fragments.py @@ -1153,15 +1153,15 @@ def render_form( field["wrap"] = _sanitize_attrs(field["wrap"]) fields.append(field) - if submit_attrs: + if submit_attrs: submit_attrs = _sanitize_attrs(submit_attrs) - common_ctx = {"values": values_map, "instance": instance, "model_cls": model_cls, "session": session} - for f in fields: - if f.get("type") == "template": - base = dict(common_ctx) - base.update(f.get("template_ctx") or {}) - f["template_ctx"] = base + common_ctx = {"values": values_map, "instance": instance, "model_cls": model_cls, "session": session} + for f in fields: + if f.get("type") == "template": + base = dict(common_ctx) + base.update(f.get("template_ctx") or {}) + f["template_ctx"] = base for f in fields: # existing FK label resolution