diff --git a/crudkit/api/flask_api.py b/crudkit/api/flask_api.py index 3b061e2..0d6babe 100644 --- a/crudkit/api/flask_api.py +++ b/crudkit/api/flask_api.py @@ -1,3 +1,5 @@ +# crudkit/api/flask_api.py + from __future__ import annotations from flask import Blueprint, jsonify, request, abort @@ -19,85 +21,131 @@ def _safe_int(value: str | None, default: int) -> int: def _link_with_params(base_url: str, **params) -> str: - # Filter out None, encode safely q = {k: v for k, v in params.items() if v is not None} return f"{base_url}?{urlencode(q)}" -def generate_crud_blueprint(model, service): - bp = Blueprint(model.__name__.lower(), __name__) +def generate_crud_blueprint(model, service, *, base_prefix: str | None = None): + """ + RPC-ish blueprint that exposes CRUDService methods 1:1: - @bp.get("/") - def list_items(): - # Work from a copy so we don't mutate request.args - args = request.args.to_dict(flat=True) + GET /api//get?id=123&... -> service.get() + GET /api//list?... -> service.list() + GET /api//seek_window?... -> service.seek_window() + GET /api//page?page=2&per_page=50&... -> service.page() - legacy_offset = "offset" in args or "page" in args + POST /api//create -> service.create(payload) - limit = _safe_int(args.get("limit"), 50) - args["limit"] = limit + PATCH /api//update?id=123 -> service.update(id, payload) - if legacy_offset: - # Old behavior: honor limit/offset, same CRUDSpec goodies - items = service.list(args) - return jsonify([obj.as_dict() for obj in items]) + DELETE /api//delete?id=123[&hard=1] -> service.delete(id, hard) - # New behavior: keyset pagination with cursors - cursor_token = args.get("cursor") - key, desc_from_cursor, backward = decode_cursor(cursor_token) + Query params for filters/sorts/fields/includes all still pass straight through. + Cursor behavior for seek_window is preserved, with Link headers. + """ + name = (model.__name__ if base_prefix is None else base_prefix).lower() + bp = Blueprint(name, __name__, url_prefix=f"/api/{name}") - window = service.seek_window( - args, - key=key, - backward=backward, - include_total=_bool_param(args, "include_total", True), - ) + # -------- READS -------- - # Prefer the order actually used by the window; fall back to desc_from_cursor if needed. + @bp.get("/get") + def rpc_get(): + id_ = _safe_int(request.args.get("id"), 0) + if not id_: + return jsonify({"status": "error", "error": "missing required param: id"}), 400 try: - desc_flags = list(window.order.desc) - except Exception: - desc_flags = desc_from_cursor or [] - - body = { - "items": [obj.as_dict() for obj in window.items], - "limit": window.limit, - "next_cursor": encode_cursor(window.last_key, desc_flags, backward=False), - "prev_cursor": encode_cursor(window.first_key, desc_flags, backward=True), - "total": window.total, - } - - resp = jsonify(body) - - # Preserve user’s other query params like include_total, filters, sorts, etc. - base_url = request.base_url - base_params = {k: v for k, v in args.items() if k not in {"cursor"}} - link_parts = [] - if body["next_cursor"]: - link_parts.append( - f'<{_link_with_params(base_url, **base_params, cursor=body["next_cursor"])}>; rel="next"' - ) - if body["prev_cursor"]: - link_parts.append( - f'<{_link_with_params(base_url, **base_params, cursor=body["prev_cursor"])}>; rel="prev"' - ) - if link_parts: - resp.headers["Link"] = ", ".join(link_parts) - return resp - - @bp.get("/") - def get_item(id): - try: - item = service.get(id, request.args) + item = service.get(id_, request.args) if item is None: abort(404) return jsonify(item.as_dict()) except Exception as e: - # Could be validation, auth, or just you forgetting an index again return jsonify({"status": "error", "error": str(e)}), 400 - @bp.post("/") - def create_item(): + @bp.get("/list") + def rpc_list(): + # Keep legacy limit/offset behavior. Everything else passes through. + args = request.args.to_dict(flat=True) + # If the caller provides offset or page, honor normal list() pagination rules. + legacy_offset = ("offset" in args) or ("page" in args) + if not legacy_offset: + # We still allow limit to cap the result set if provided. + limit = _safe_int(args.get("limit"), 50) + args["limit"] = limit + try: + items = service.list(args) + return jsonify([obj.as_dict() for obj in items]) + except Exception as e: + return jsonify({"status": "error", "error": str(e)}), 400 + + @bp.get("/seek_window") + def rpc_seek_window(): + args = request.args.to_dict(flat=True) + + # Keep keyset & cursor mechanics intact + cursor_token = args.get("cursor") + key, desc_from_cursor, backward_from_cursor = decode_cursor(cursor_token) + + backward = _bool_param(args, "backward", backward_from_cursor if backward_from_cursor is not None else False) + include_total = _bool_param(args, "include_total", True) + + try: + window = service.seek_window( + args, + key=key, + backward=backward, + include_total=include_total, + ) + try: + desc_flags = list(window.order.desc) + except Exception: + desc_flags = desc_from_cursor or [] + + body = { + "items": [obj.as_dict() for obj in window.items], + "limit": window.limit, + "next_cursor": encode_cursor(window.last_key, desc_flags, backward=False), + "prev_cursor": encode_cursor(window.first_key, desc_flags, backward=True), + "total": window.total, + } + resp = jsonify(body) + + # Build Link headers preserving all non-cursor args + base_url = request.base_url + base_params = {k: v for k, v in args.items() if k not in {"cursor"}} + link_parts = [] + if body["next_cursor"]: + link_parts.append(f'<{_link_with_params(base_url, **base_params, cursor=body["next_cursor"])}>; rel="next"') + if body["prev_cursor"]: + link_parts.append(f'<{_link_with_params(base_url, **base_params, cursor=body["prev_cursor"])}>; rel="prev"') + if link_parts: + resp.headers["Link"] = ", ".join(link_parts) + return resp + except Exception as e: + return jsonify({"status": "error", "error": str(e)}), 400 + + @bp.get("/page") + def rpc_page(): + args = request.args.to_dict(flat=True) + page = _safe_int(args.get("page"), 1) + per_page = _safe_int(args.get("per_page"), 50) + include_total = _bool_param(args, "include_total", True) + + try: + result = service.page(args, page=page, per_page=per_page, include_total=include_total) + # Already includes: items, page, per_page, total, pages, order + # Items come back as model instances; serialize to dicts + result = { + **result, + "items": [obj.as_dict() for obj in result["items"]], + } + return jsonify(result) + except Exception as e: + return jsonify({"status": "error", "error": str(e)}), 400 + + # -------- WRITES -------- + + @bp.post("/create") + def rpc_create(): payload = request.get_json(silent=True) or {} try: obj = service.create(payload) @@ -105,22 +153,29 @@ def generate_crud_blueprint(model, service): except Exception as e: return jsonify({"status": "error", "error": str(e)}), 400 - @bp.patch("/") - def update_item(id): + @bp.patch("/update") + def rpc_update(): + id_ = _safe_int(request.args.get("id"), 0) + if not id_: + return jsonify({"status": "error", "error": "missing required param: id"}), 400 payload = request.get_json(silent=True) or {} try: - obj = service.update(id, payload) + obj = service.update(id_, payload) return jsonify(obj.as_dict()) except Exception as e: - # 404 if not found, 400 if validation. Your service can throw specific exceptions if you ever feel like being professional. + # If you ever decide to throw custom exceptions, map them here like an adult. return jsonify({"status": "error", "error": str(e)}), 400 - @bp.delete("/") - def delete_item(id): + @bp.delete("/delete") + def rpc_delete(): + id_ = _safe_int(request.args.get("id"), 0) + if not id_: + return jsonify({"status": "error", "error": "missing required param: id"}), 400 + hard = _bool_param(request.args, "hard", False) try: - service.delete(id) - # 204 means "no content" so don't send any. - return ("", 204) + obj = service.delete(id_, hard=hard) + # 204 if actually deleted or soft-deleted; return body if you feel chatty + return ("", 204) if obj is not None else abort(404) except Exception as e: return jsonify({"status": "error", "error": str(e)}), 400 diff --git a/crudkit/core/service.py b/crudkit/core/service.py index 60e3aef..db510cd 100644 --- a/crudkit/core/service.py +++ b/crudkit/core/service.py @@ -1,11 +1,12 @@ from __future__ import annotations from collections.abc import Iterable +from dataclasses import dataclass from flask import current_app from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast from sqlalchemy import and_, func, inspect, or_, text from sqlalchemy.engine import Engine, Connection -from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager, selectinload +from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, selectinload, with_loader_criteria from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.sql import operators from sqlalchemy.sql.elements import UnaryExpression, ColumnElement @@ -37,41 +38,86 @@ class _SoftDeletable(Protocol): is_deleted: bool class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol): - """Minimal surface that our CRUD service relies on. Soft-delete is optional.""" + """Surface expected by CRUDService.""" pass T = TypeVar("T", bound=_CRUDModelProto) +# ---------------------------- utilities ---------------------------- + +def _collect_tables_from_filters(filters) -> set: + seen = set() + stack = list(filters or []) + while stack: + node = stack.pop() + + tbl = getattr(node, "table", None) + if tbl is not None: + cur = tbl + while cur is not None and cur not in seen: + seen.add(cur) + cur = getattr(cur, "element", None) + + # follow only the common attributes; no generic visitor + left = getattr(node, "left", None) + if left is not None: + stack.append(left) + right = getattr(node, "right", None) + if right is not None: + stack.append(right) + elem = getattr(node, "element", None) + if elem is not None: + stack.append(elem) + clause = getattr(node, "clause", None) + if clause is not None: + stack.append(clause) + clauses = getattr(node, "clauses", None) + if clauses is not None: + try: + stack.extend(list(clauses)) + except TypeError: + pass + + return seen + +def _selectable_keys(sel) -> set[str]: + """ + Return a set of stable string keys for a selectable/alias and its base, + so we can match when when different alias objects are used. + """ + keys: set[str] = set() + cur = sel + while cur is not None: + k = getattr(cur, "key", None) or getattr(cur, "name", None) + if isinstance(k, str) and k: + keys.add(k) + cur = getattr(cur, "element", None) + return keys + def _unwrap_ob(ob): - """Return (col, is_desc) from an ORDER BY element (handles .asc()/.desc()).""" - col = getattr(ob, "element", None) - if col is None: - col = ob - is_desc = False - dir_attr = getattr(ob, "_direction", None) - if dir_attr is not None: - is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC") + elem = getattr(ob, "element", None) + col = elem if elem is not None else ob + + d = getattr(ob, "_direction", None) + if d is not None: + is_desc = (d is operators.desc_op) or (getattr(d, "name", "").upper() == "DESC") elif isinstance(ob, UnaryExpression): op = getattr(ob, "operator", None) is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") + else: + is_desc = False return col, bool(is_desc) def _order_identity(col: ColumnElement): - """ - Build a stable identity for a column suitable for deduping. - We ignore direction here. Duplicates are duplicates regardless of ASC/DESC. - """ table = getattr(col, "table", None) table_key = getattr(table, "key", None) or id(table) col_key = getattr(col, "key", None) or getattr(col, "name", None) return (table_key, col_key) def _dedupe_order_by(order_by): - """Remove duplicate ORDER BY entries (by column identity, ignoring direction).""" if not order_by: return [] - seen = set() - out = [] + seen, out = set(), [] for ob in order_by: col, _ = _unwrap_ob(ob) ident = _order_identity(col) @@ -98,6 +144,8 @@ def _normalize_fields_param(params: dict | None) -> list[str]: return [p for p in (s.strip() for s in raw.split(",")) if p] return [] +# ---------------------------- CRUD service ---------------------------- + class CRUDService(Generic[T]): def __init__( self, @@ -111,21 +159,19 @@ class CRUDService(Generic[T]): self._session_factory = session_factory self.polymorphic = polymorphic self.supports_soft_delete = hasattr(model, 'is_deleted') - self._backend: Optional[BackendInfo] = backend + # ---- infra + @property def session(self) -> Session: - """Always return the Flask-scoped Session if available; otherwise the provided factory.""" try: - sess = current_app.extensions["crudkit"]["Session"] - return sess + return current_app.extensions["crudkit"]["Session"] except Exception: return self._session_factory() @property def backend(self) -> BackendInfo: - """Resolve backend info lazily against the active session's engine.""" if self._backend is None: bind = self.session.get_bind() eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind) @@ -138,265 +184,94 @@ class CRUDService(Generic[T]): return self.session.query(poly), poly return self.session.query(self.model), self.model - def _debug_bind(self, where: str): - try: - bind = self.session.get_bind() - eng = getattr(bind, "engine", bind) - print(f"SERVICE BIND [{where}]: engine_id={id(eng)} url={getattr(eng, 'url', '?')} session={type(self.session).__name__}") - except Exception as e: - print(f"SERVICE BIND [{where}]: failed to introspect bind: {e}") + # ---- common building blocks - def _apply_not_deleted(self, query, root_alias, params) -> Any: + def _apply_soft_delete_criteria_for_children(self, query, plan: "CRUDService._Plan", params): + # Skip if caller explicitly asked for deleted + if _is_truthy((params or {}).get("include_deleted")): + return query + + seen = set() + for _base_alias, rel_attr, _target_alias in plan.join_paths: + prop = getattr(rel_attr, "property", None) + if not prop: + continue + target_cls = getattr(prop.mapper, "class_", None) + if not target_cls or target_cls in seen: + continue + seen.add(target_cls) + # Only apply to models that support soft delete + if hasattr(target_cls, "is_deleted"): + query = query.options( + with_loader_criteria( + target_cls, + lambda cls: cls.is_deleted == False, + include_aliases=True + ) + ) + return query + + def _order_clauses(self, order_spec, invert: bool = False): + clauses = [] + for c, is_desc in zip(order_spec.cols, order_spec.desc): + d = not is_desc if invert else is_desc + clauses.append(c.desc() if d else c.asc()) + return clauses + + def _anchor_key_for_page(self, params, per_page: int, page: int): + """Return the keyset tuple for the last row of the previous page, or None for page 1.""" + if page <= 1: + return None + + query, root_alias = self.get_query() + query = self._apply_not_deleted(query, root_alias, params) + + plan = self._plan(params, root_alias) + # Make sure joins/filters match the real query + query = self._apply_firsthop_strategies(query, root_alias, plan) + if plan.filters: + query = query.filter(*plan.filters) + + order_spec = self._extract_order_spec(root_alias, plan.order_by) + + # Inner subquery must be ordered exactly like the real query + inner = query.order_by(*self._order_clauses(order_spec, invert=False)) + + # IMPORTANT: Build subquery that actually exposes the order-by columns + # under predictable names, then select FROM that and reference subq.c[...] + subq = inner.with_entities(*order_spec.cols).subquery() + + # Map the order columns to the subquery columns by key/name + cols_on_subq = [] + for col in order_spec.cols: + key = getattr(col, "key", None) or getattr(col, "name", None) + if not key: + # Fallback, but frankly your order cols should have names + raise ValueError("Order-by column is missing a key/name") + cols_on_subq.append(getattr(subq.c, key)) + + # Now the outer anchor query orders and offsets on the subquery columns + anchor_q = ( + self.session + .query(*cols_on_subq) + .select_from(subq) + .order_by(*[ + (c.desc() if is_desc else c.asc()) + for c, is_desc in zip(cols_on_subq, order_spec.desc) + ]) + ) + + offset = max(0, (page - 1) * per_page - 1) + row = anchor_q.offset(offset).limit(1).first() + if not row: + return None + return list(row) # tuple-like -> list for _key_predicate + + def _apply_not_deleted(self, query, root_alias, params): if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")): return query.filter(getattr(root_alias, "is_deleted") == False) return query - def _extract_order_spec(self, root_alias, given_order_by): - """ - SQLAlchemy 2.x only: - Normalize order_by into (cols, desc_flags). Supports plain columns and - col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses. - """ - - given = self._stable_order_by(root_alias, given_order_by) - - cols, desc_flags = [], [] - for ob in given: - # Unwrap column if this is a UnaryExpression produced by .asc()/.desc() - elem = getattr(ob, "element", None) - col = elem if elem is not None else ob - - # Detect direction in SA 2.x - is_desc = False - dir_attr = getattr(ob, "_direction", None) - if dir_attr is not None: - is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC") - elif isinstance(ob, UnaryExpression): - op = getattr(ob, "operator", None) - is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") - - cols.append(col) - desc_flags.append(bool(is_desc)) - - return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags)) - - def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool): - if not key_vals: - return None - - conds = [] - for i, col in enumerate(spec.cols): - # If NULLs possible, normalize for comparison. Example using coalesce to a sentinel: - # sent_col = func.coalesce(col, literal("-∞")) - sent_col = col - ties = [spec.cols[j] == key_vals[j] for j in range(i)] - is_desc = spec.desc[i] - if not backward: - op = (sent_col < key_vals[i]) if is_desc else (sent_col > key_vals[i]) - else: - op = (sent_col > key_vals[i]) if is_desc else (sent_col < key_vals[i]) - conds.append(and_(*ties, op)) - return or_(*conds) - - def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]: - out = [] - for c in spec.cols: - # Only simple mapped columns supported for key pluck - key = getattr(c, "key", None) or getattr(c, "name", None) - if key is None or not hasattr(obj, key): - raise ValueError("Order includes non-mapped or related column; cannot pluck cursor key from row object.") - out.append(getattr(obj, key)) - return out - - def seek_window( - self, - params: dict | None = None, - *, - key: list[Any] | None = None, - backward: bool = False, - include_total: bool = True, - ) -> "SeekWindow[T]": - """ - Keyset pagination with relationship-safe filtering/sorting. - Always JOIN all CRUDSpec-discovered paths first; then apply filters, sort, seek. - """ - session = self.session - query, root_alias = self.get_query() - - # Requested fields → projection + optional loaders - fields = _normalize_fields_param(params) - expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], []) - - spec = CRUDSpec(self.model, params or {}, root_alias) - - # Parse all inputs so join_paths are populated - filters = spec.parse_filters() - order_by = spec.parse_sort() - root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() - spec.parse_includes() - join_paths = tuple(spec.get_join_paths()) - - # Soft delete - query = self._apply_not_deleted(query, root_alias, params) - - # Root column projection (load_only) - only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] - if only_cols: - query = query.options(Load(root_alias).load_only(*only_cols)) - - # Detect first hops that have deeper, nested tails requested (e.g. "contact.supervisor") - nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } - - # IMPORTANT: - # - Only attach loader options for first-hop relations from the root. - # - Always use selectinload here (avoid contains_eager joins). - # - Let compile_projections() supply deep chained options. - for base_alias, rel_attr, target_alias in join_paths: - is_firsthop_from_root = (base_alias is root_alias) - if not is_firsthop_from_root: - # Deeper hops are handled by proj_opts below - continue - prop = getattr(rel_attr, "property", None) - is_collection = bool(getattr(prop, "uselist", False)) - is_nested_firsthop = rel_attr.key in nested_first_hops - - opt = selectinload(rel_attr) - # Optional narrowng for collections - if is_collection: - child_names = (collection_field_names or {}).get(rel_attr.key, []) - if child_names: - target_cls = prop.mapper.class_ - cols = [getattr(target_cls, n, None) for n in child_names] - cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] - if cols: - opt = opt.load_only(*cols) - query = query.options(opt) - - # Filters AFTER joins → no cartesian products - if filters: - query = query.filter(*filters) - - # Order spec (with PK tie-breakers for stability) - order_spec = self._extract_order_spec(root_alias, order_by) - limit, _ = spec.parse_pagination() - if limit is None: - effective_limit = 50 - elif limit == 0: - effective_limit = None # unlimited - else: - effective_limit = limit - - # Seek predicate from cursor key (if any) - if key: - pred = self._key_predicate(order_spec, key, backward) - if pred is not None: - query = query.filter(pred) - - # Apply ORDER and LIMIT. Backward is SQL-inverted + reverse in-memory. - if not backward: - clauses = [(c.desc() if is_desc else c.asc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)] - query = query.order_by(*clauses) - if effective_limit is not None: - query = query.limit(effective_limit) - items = query.all() - else: - inv_clauses = [(c.asc() if is_desc else c.desc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)] - query = query.order_by(*inv_clauses) - if effective_limit is not None: - query = query.limit(effective_limit) - items = list(reversed(query.all())) - - # Projection meta tag for renderers - if fields: - proj = list(dict.fromkeys(fields)) - if "id" not in proj and hasattr(self.model, "id"): - proj.insert(0, "id") - else: - proj = [] - if root_field_names: - proj.extend(root_field_names) - if root_fields: - proj.extend(c.key for c in root_fields if hasattr(c, "key")) - for path, names in (rel_field_names or {}).items(): - prefix = ".".join(path) - for n in names: - proj.append(f"{prefix}.{n}") - if proj and "id" not in proj and hasattr(self.model, "id"): - proj.insert(0, "id") - - if proj: - for obj in items: - try: - setattr(obj, "__crudkit_projection__", tuple(proj)) - except Exception: - pass - - # Cursor key pluck: support related columns we hydrated via contains_eager - def _pluck_key_from_obj(obj: Any) -> list[Any]: - vals: list[Any] = [] - alias_to_rel: dict[Any, str] = {} - for _p, relationship_attr, target_alias in join_paths: - sel = getattr(target_alias, "selectable", None) - if sel is not None: - alias_to_rel[sel] = relationship_attr.key - - for col in order_spec.cols: - keyname = getattr(col, "key", None) or getattr(col, "name", None) - if keyname and hasattr(obj, keyname): - vals.append(getattr(obj, keyname)) - continue - table = getattr(col, "table", None) - relname = alias_to_rel.get(table) - if relname and keyname: - relobj = getattr(obj, relname, None) - if relobj is not None and hasattr(relobj, keyname): - vals.append(getattr(relobj, keyname)) - continue - raise ValueError("unpluckable") - return vals - - try: - first_key = _pluck_key_from_obj(items[0]) if items else None - last_key = _pluck_key_from_obj(items[-1]) if items else None - except Exception: - first_key = None - last_key = None - - # Count DISTINCT ids with mirrored joins - - # Apply deep projection loader options (safe: we avoided contains_eager) - if proj_opts: - query = query.options(*proj_opts) - total = None - if include_total: - base = session.query(getattr(root_alias, "id")) - base = self._apply_not_deleted(base, root_alias, params) - # same joins as above for correctness - for base_alias, rel_attr, target_alias in join_paths: - # do not join collections for COUNT mirror - if not bool(getattr(getattr(rel_attr, "property", None), "uselist", False)): - base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True) - if filters: - base = base.filter(*filters) - total = session.query(func.count()).select_from( - base.order_by(None).distinct().subquery() - ).scalar() or 0 - - window_limit_for_body = 0 if effective_limit is None and (limit == 0) else (effective_limit or 50) - - if log.isEnabledFor(logging.DEBUG): - log.debug("QUERY: %s", str(query)) - - return SeekWindow( - items=items, - limit=window_limit_for_body, - first_key=first_key, - last_key=last_key, - order=order_spec, - total=total, - ) - - # Helper: default ORDER BY for MSSQL when paginating without explicit order def _default_order_by(self, root_alias): mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) cols = [] @@ -408,176 +283,112 @@ class CRUDService(Generic[T]): return cols or [text("1")] def _stable_order_by(self, root_alias, given_order_by): - """ - Ensure deterministic ordering by appending PK columns as tiebreakers. - If no order is provided, fall back to default primary-key order. - Never duplicates columns in ORDER BY (SQL Server requires uniqueness). - """ order_by = list(given_order_by or []) if not order_by: return _dedupe_order_by(self._default_order_by(root_alias)) - order_by = _dedupe_order_by(order_by) - mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) present = {_order_identity(_unwrap_ob(ob)[0]) for ob in order_by} - for pk in mapper.primary_key: try: pk_col = getattr(root_alias, pk.key) except AttributeError: pk_col = pk - if _order_identity(pk_col) not in present: + ident = _order_identity(pk_col) + if ident not in present: order_by.append(pk_col.asc()) - present.add(_order_identity(pk_col)) - + present.add(ident) return order_by - def get(self, id: int, params=None) -> T | None: - """ - Fetch a single row by id with conflict-free eager loading and clean projection. - Always JOIN any paths that CRUDSpec resolved for filters/fields/includes so - related-column filters never create cartesian products. - """ - query, root_alias = self.get_query() + def _extract_order_spec(self, root_alias, given_order_by): + given = self._stable_order_by(root_alias, given_order_by) + cols, desc_flags = [], [] + for ob in given: + elem = getattr(ob, "element", None) + col = elem if elem is not None else ob + is_desc = False + dir_attr = getattr(ob, "_direction", None) + if dir_attr is not None: + is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC") + elif isinstance(ob, UnaryExpression): + op = getattr(ob, "operator", None) + is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC") + cols.append(col) + desc_flags.append(bool(is_desc)) + return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags)) - # Defaults so we can build a projection even if params is None - req_fields: list[str] = _normalize_fields_param(params) - root_fields: list[Any] = [] - root_field_names: dict[str, str] = {} - rel_field_names: dict[tuple[str, ...], list[str]] = {} + def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool): + if not key_vals: + return None + conds = [] + for i, col in enumerate(spec.cols): + ties = [spec.cols[j] == key_vals[j] for j in range(i)] + is_desc = spec.desc[i] + op = (col < key_vals[i]) if is_desc ^ backward else (col > key_vals[i]) + conds.append(and_(*ties, op)) + return or_(*conds) - # Soft-delete guard first - query = self._apply_not_deleted(query, root_alias, params) + # ---- planning and application + @dataclass(slots=True) + class _Plan: + spec: Any + filters: Any + order_by: Any + limit: Any + offset: Any + root_fields: Any + rel_field_names: Any + root_field_names: Any + collection_field_names: Any + join_paths: Any + filter_tables: Any + filter_table_keys: Any + req_fields: Any + proj_opts: Any + + def _plan(self, params, root_alias) -> _Plan: + req_fields = _normalize_fields_param(params) spec = CRUDSpec(self.model, params or {}, root_alias) - # Parse everything so CRUDSpec records any join paths it needed to resolve filters = spec.parse_filters() - # no ORDER BY for get() - if params: - root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() + order_by = spec.parse_sort() + limit, offset = spec.parse_pagination() + root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() if params else ([], {}, {}, {}) spec.parse_includes() - join_paths = tuple(spec.get_join_paths()) + filter_tables = _collect_tables_from_filters(filters) + _, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) - # Root-column projection (load_only) - only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] - if only_cols: - query = query.options(Load(root_alias).load_only(*only_cols)) + filter_tables = () + fkeys = set() - nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } + return self._Plan( + spec=spec, filters=filters, order_by=order_by, limit=limit, offset=offset, + root_fields=root_fields, rel_field_names=rel_field_names, + root_field_names=root_field_names, collection_field_names=collection_field_names, + join_paths=join_paths, filter_tables=filter_tables, filter_table_keys=fkeys, + req_fields=req_fields, proj_opts=proj_opts + ) - # First-hop only; use selectinload (no contains_eager) - for base_alias, rel_attr, target_alias in join_paths: - is_firsthop_from_root = (base_alias is root_alias) - if not is_firsthop_from_root: + def _apply_projection_load_only(self, query, root_alias, plan: _Plan): + only_cols = [c for c in plan.root_fields if isinstance(c, InstrumentedAttribute)] + return query.options(Load(root_alias).load_only(*only_cols)) if only_cols else query + + def _apply_firsthop_strategies(self, query, root_alias, plan: _Plan): + nested_first_hops = { p[0] for p in (plan.rel_field_names or {}).keys() if len(p) > 1 } + for base_alias, rel_attr, target_alias in plan.join_paths: + if base_alias is not root_alias: continue prop = getattr(rel_attr, "property", None) is_collection = bool(getattr(prop, "uselist", False)) - _is_nested_firsthop = rel_attr.key in nested_first_hops - - opt = selectinload(rel_attr) - if is_collection: - child_names = (collection_field_names or {}).get(rel_attr.key, []) - if child_names: - target_cls = prop.mapper.class_ - cols = [getattr(target_cls, n, None) for n in child_names] - cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] - if cols: - opt = opt.load_only(*cols) - query = query.options(opt) - - # Apply filters (joins are in place → no cartesian products) - if filters: - query = query.filter(*filters) - - # And the id filter - query = query.filter(getattr(root_alias, "id") == id) - - # Projection loader options compiled from requested fields. - # Skip if we used contains_eager to avoid loader-strategy conflicts. - expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) - if proj_opts: - query = query.options(*proj_opts) - - obj = query.first() - - # Emit exactly what the client requested (plus id), or a reasonable fallback - if req_fields: - proj = list(dict.fromkeys(req_fields)) - if "id" not in proj and hasattr(self.model, "id"): - proj.insert(0, "id") - else: - proj = [] - if root_field_names: - proj.extend(root_field_names) - if root_fields: - proj.extend(c.key for c in root_fields if hasattr(c, "key")) - for path, names in (rel_field_names or {}).items(): - prefix = ".".join(path) - for n in names: - proj.append(f"{prefix}.{n}") - if proj and "id" not in proj and hasattr(self.model, "id"): - proj.insert(0, "id") - - if proj and obj is not None: - try: - setattr(obj, "__crudkit_projection__", tuple(proj)) - except Exception: - pass - - if log.isEnabledFor(logging.DEBUG): - log.debug("QUERY: %s", str(query)) - - return obj or None - - def list(self, params=None) -> list[T]: - """ - Offset/limit listing with relationship-safe filtering. - We always JOIN every CRUDSpec-discovered path before applying filters/sorts. - """ - query, root_alias = self.get_query() - - # Defaults so we can reference them later even if params is None - req_fields: list[str] = _normalize_fields_param(params) - root_fields: list[Any] = [] - root_field_names: dict[str, str] = {} - rel_field_names: dict[tuple[str, ...], list[str]] = {} - - if params: - # Soft delete - query = self._apply_not_deleted(query, root_alias, params) - - spec = CRUDSpec(self.model, params or {}, root_alias) - filters = spec.parse_filters() - order_by = spec.parse_sort() - limit, offset = spec.parse_pagination() - - # Includes / fields (populates join_paths) - root_fields, rel_field_names, root_field_names, collection_field_names = spec.parse_fields() - spec.parse_includes() - join_paths = tuple(spec.get_join_paths()) - - # Root column projection (load_only) - only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] - if only_cols: - query = query.options(Load(root_alias).load_only(*only_cols)) - - nested_first_hops = { path[0] for path in (rel_field_names or {}).keys() if len(path) > 1 } - - # First-hop only; use selectinload - for base_alias, rel_attr, target_alias in join_paths: - is_firsthop_from_root = (base_alias is root_alias) - if not is_firsthop_from_root: - continue - prop = getattr(rel_attr, "property", None) - is_collection = bool(getattr(prop, "uselist", False)) - _is_nested_firsthop = rel_attr.key in nested_first_hops + if not is_collection: + query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) + else: opt = selectinload(rel_attr) if is_collection: - child_names = (collection_field_names or {}).get(rel_attr.key, []) + child_names = (plan.collection_field_names or {}).get(rel_attr.key, []) if child_names: target_cls = prop.mapper.class_ cols = [getattr(target_cls, n, None) for n in child_names] @@ -585,79 +396,223 @@ class CRUDService(Generic[T]): if cols: opt = opt.load_only(*cols) query = query.options(opt) + return query - # Filters AFTER joins → no cartesian products - if filters: - query = query.filter(*filters) + def _apply_proj_opts(self, query, plan: _Plan): + return query.options(*plan.proj_opts) if plan.proj_opts else query - # MSSQL requires ORDER BY when OFFSET is used; ensure stable PK tie-breakers - paginating = (limit is not None) or (offset is not None and offset != 0) - if paginating and not order_by and self.backend.requires_order_by_for_offset: - order_by = self._default_order_by(root_alias) + def _projection_meta(self, plan: _Plan): + if plan.req_fields: + proj = list(dict.fromkeys(plan.req_fields)) + return ["id"] + proj if "id" not in proj and hasattr(self.model, "id") else proj - order_by = _dedupe_order_by(order_by) + proj: list[str] = [] + if plan.root_field_names: + proj.extend(plan.root_field_names) + if plan.root_fields: + proj.extend(c.key for c in plan.root_fields if hasattr(c, "key")) + for path, names in (plan.rel_field_names or {}).items(): + prefix = ".".join(path) + for n in names: + proj.append(f"{prefix}.{n}") + if proj and "id" not in proj and hasattr(self.model, "id"): + proj.insert(0, "id") + return proj - if order_by: - query = query.order_by(*order_by) + def _tag_projection(self, items, proj): + if not proj: + return + for obj in items if isinstance(items, list) else [items]: + try: + setattr(obj, "__crudkit_projection__", tuple(proj)) + except Exception: + pass - # Offset/limit - if offset is not None and offset != 0: - query = query.offset(offset) - if limit is not None and limit > 0: - query = query.limit(limit) + # ---- public read ops - # Projection loaders only if we didn’t use contains_eager - expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) - if proj_opts: - query = query.options(*proj_opts) + def page(self, params=None, *, page: int = 1, per_page: int = 50, include_total: bool = True): + # Ensure seek_window uses `per_page` + params = dict(params or {}) + params["limit"] = per_page - else: - # No params; still honor projection loaders if any - expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) - if proj_opts: - query = query.options(*proj_opts) + anchor_key = self._anchor_key_for_page(params, per_page, page) + win = self.seek_window(params, key=anchor_key, backward=False, include_total=include_total) + pages = None + if include_total and win.total is not None and per_page: + # class ceil(total / per_page) // per_page + pages = (win.total + per_page - 1) // per_page + + return { + "items": win.items, + "page": page, + "per_page": per_page, + "total": win.total, + "pages": pages, + "order": [str(c) for c in win.order.cols], + } + + def seek_window( + self, + params: dict | None = None, + *, + key: list[Any] | None = None, + backward: bool = False, + include_total: bool = True, + ) -> "SeekWindow[T]": + session = self.session + query, root_alias = self.get_query() + query = self._apply_not_deleted(query, root_alias, params) + + plan = self._plan(params, root_alias) + query = self._apply_projection_load_only(query, root_alias, plan) + query = self._apply_firsthop_strategies(query, root_alias, plan) + query = self._apply_soft_delete_criteria_for_children(query, plan, params) + if plan.filters: + query = query.filter(*plan.filters) + + order_spec = self._extract_order_spec(root_alias, plan.order_by) + limit = 50 if plan.limit is None else (None if plan.limit == 0 else plan.limit) + + if key: + pred = self._key_predicate(order_spec, key, backward) + if pred is not None: + query = query.filter(pred) + + clauses = [(c.desc() if d else c.asc()) for c, d in zip(order_spec.cols, order_spec.desc)] + if backward: + clauses = [(c.asc() if d else c.desc()) for c, d in zip(order_spec.cols, order_spec.desc)] + query = query.order_by(*clauses) + if limit is not None: + query = query.limit(limit) + + query = self._apply_proj_opts(query, plan) rows = query.all() + items = list(reversed(rows)) if backward else rows - # Build projection meta for renderers - if req_fields: - proj = list(dict.fromkeys(req_fields)) - if "id" not in proj and hasattr(self.model, "id"): - proj.insert(0, "id") - else: - proj = [] - if root_field_names: - proj.extend(root_field_names) - if root_fields: - proj.extend(c.key for c in root_fields if hasattr(c, "key")) - for path, names in (rel_field_names or {}).items(): - prefix = ".".join(path) - for n in names: - proj.append(f"{prefix}.{n}") - if proj and "id" not in proj and hasattr(self.model, "id"): - proj.insert(0, "id") + proj = self._projection_meta(plan) + self._tag_projection(items, proj) - if proj: - for obj in rows: - try: - setattr(obj, "__crudkit_projection__", tuple(proj)) - except Exception: - pass + # cursor keys + def pluck(obj): + vals = [] + alias_to_rel = {} + for _p, rel_attr, target_alias in plan.join_paths: + sel = getattr(target_alias, "selectable", None) + if sel is not None: + alias_to_rel[sel] = rel_attr.key + + for col in order_spec.cols: + keyname = getattr(col, "key", None) or getattr(col, "name", None) + if keyname and hasattr(obj, keyname): + vals.append(getattr(obj, keyname)); continue + table = getattr(col, "table", None) + relname = alias_to_rel.get(table) + if relname and keyname: + relobj = getattr(obj, relname, None) + if relobj is not None and hasattr(relobj, keyname): + vals.append(getattr(relobj, keyname)); continue + raise ValueError("unpluckable") + return vals + + try: + first_key = pluck(items[0]) if items else None + last_key = pluck(items[-1]) if items else None + except Exception: + first_key = last_key = None + + total = None + if include_total: + base = session.query(getattr(root_alias, "id")) + base = self._apply_not_deleted(base, root_alias, params) + for _b, rel_attr, target_alias in plan.join_paths: + if not bool(getattr(getattr(rel_attr, "property", None), "uselist", False)): + base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True) + if plan.filters: + base = base.filter(*plan.filters) + total = session.query(func.count()).select_from( + base.order_by(None).distinct().subquery() + ).scalar() or 0 if log.isEnabledFor(logging.DEBUG): log.debug("QUERY: %s", str(query)) + window_limit_for_body = 0 if limit is None and (plan.limit == 0) else (limit or 50) + return SeekWindow( + items=items, + limit=window_limit_for_body, + first_key=first_key, + last_key=last_key, + order=order_spec, + total=total, + ) + + def get(self, id: int, params=None) -> T | None: + query, root_alias = self.get_query() + query = self._apply_not_deleted(query, root_alias, params) + + plan = self._plan(params, root_alias) + query = self._apply_projection_load_only(query, root_alias, plan) + query = self._apply_firsthop_strategies(query, root_alias, plan) + query = self._apply_soft_delete_criteria_for_children(query, plan, params) + if plan.filters: + query = query.filter(*plan.filters) + query = query.filter(getattr(root_alias, "id") == id) + query = self._apply_proj_opts(query, plan) + + obj = query.first() + proj = self._projection_meta(plan) + if obj: + self._tag_projection(obj, proj) + + if log.isEnabledFor(logging.DEBUG): + log.debug("QUERY: %s", str(query)) + return obj or None + + def list(self, params=None) -> list[T]: + query, root_alias = self.get_query() + plan = self._plan(params, root_alias) + query = self._apply_not_deleted(query, root_alias, params) + query = self._apply_projection_load_only(query, root_alias, plan) + query = self._apply_firsthop_strategies(query, root_alias, plan) + query = self._apply_soft_delete_criteria_for_children(query, plan, params) + if plan.filters: + query = query.filter(*plan.filters) + + order_by = plan.order_by + paginating = (plan.limit is not None) or (plan.offset not in (None, 0)) + if paginating and not order_by and self.backend.requires_order_by_for_offset: + order_by = self._default_order_by(root_alias) + order_by = _dedupe_order_by(order_by) + if order_by: + query = query.order_by(*order_by) + + default_cap = getattr(current_app.config, "CRUDKIT_DEFAULT_LIST_LIMIT", 200) + if plan.offset: + query = query.offset(plan.offset) + if plan.limit and plan.limit > 0: + query = query.limit(plan.limit) + elif plan.limit is None and default_cap: + query = query.limit(default_cap) + + query = self._apply_proj_opts(query, plan) + rows = query.all() + + proj = self._projection_meta(plan) + self._tag_projection(rows, proj) + + if log.isEnabledFor(logging.DEBUG): + log.debug("QUERY: %s", str(query)) return rows + # ---- write ops + def create(self, data: dict, actor=None, *, commit: bool = True) -> T: session = self.session obj = self.model(**data) session.add(obj) - session.flush() - self._log_version("create", obj, actor, commit=commit) - if commit: session.commit() return obj @@ -669,59 +624,34 @@ class CRUDService(Generic[T]): raise ValueError(f"{self.model.__name__} with ID {id} not found.") before = obj.as_dict() - - # Normalize and restrict payload to real columns norm = normalize_payload(data, self.model) incoming = filter_to_columns(norm, self.model) - - # Build a synthetic "desired" state for top-level columns desired = {**before, **incoming} - # Compute intended change set (before vs intended) - proposed = deep_diff( - before, desired, - ignore_keys={"id", "created_at", "updated_at"}, - list_mode="index", - ) + proposed = deep_diff(before, desired, ignore_keys={"id", "created_at", "updated_at"}, list_mode="index") patch = diff_to_patch(proposed) - - # Nothing to do if not patch: return obj - # Apply only what actually changes for k, v in patch.items(): setattr(obj, k, v) - # Optional: skip commit if ORM says no real change (paranoid check) - # Note: is_modified can lie if attrs are expired; use history for certainty. dirty = any(inspect(obj).attrs[k].history.has_changes() for k in patch.keys()) if not dirty: return obj - # Commit atomically if commit: session.commit() - # AFTER snapshot for audit after = obj.as_dict() - - # Actual diff (captures triggers/defaults, still ignoring noisy keys) - actual = deep_diff( - before, after, - ignore_keys={"id", "created_at", "updated_at"}, - list_mode="index", - ) - - # If truly nothing changed post-commit (rare), skip version spam + actual = deep_diff(before, after, ignore_keys={"id", "created_at", "updated_at"}, list_mode="index") if not (actual["added"] or actual["removed"] or actual["changed"]): return obj - # Log both what we *intended* and what *actually* happened self._log_version("update", obj, actor, metadata={"diff": actual, "patch": patch}, commit=commit) return obj - def delete(self, id: int, hard: bool = False, actor = None, *, commit: bool = True): + def delete(self, id: int, hard: bool = False, actor=None, *, commit: bool = True): session = self.session obj = session.get(self.model, id) if not obj: @@ -729,22 +659,21 @@ class CRUDService(Generic[T]): if hard or not self.supports_soft_delete: session.delete(obj) else: - soft = cast(_SoftDeletable, obj) - soft.is_deleted = True + cast(_SoftDeletable, obj).is_deleted = True if commit: session.commit() self._log_version("delete", obj, actor, commit=commit) return obj + # ---- audit + def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict | None = None, *, commit: bool = True): session = self.session try: - snapshot = {} try: snapshot = obj.as_dict() except Exception: snapshot = {"error": "serialize failed"} - version = Version( model_name=self.model.__name__, object_id=obj.id,