Compare commits

..

No commits in common. "2a9fb389d7f67e5f7adff6369f924f2c8a5f6c3e" and "515eb27fe07de9711390df3b679d01403451ebf4" have entirely different histories.

4 changed files with 370 additions and 550 deletions

View file

@ -1,135 +1,21 @@
# crudkit/api/_cursor.py import base64, json
from typing import Any
from __future__ import annotations def encode_cursor(values: list[Any] | None, desc_flags: list[bool], backward: bool) -> str | None:
import base64
import dataclasses
import datetime as _dt
import decimal as _dec
import hmac
import json
import typing as _t
from hashlib import sha256
Any = _t.Any
@dataclasses.dataclass(frozen=True)
class Cursor:
values: list[Any]
desc_flags: list[bool]
backward: bool
version: int = 1
def _json_default(o: Any) -> Any:
# Keep it boring and predictable.
if isinstance(o, (_dt.datetime, _dt.date, _dt.time)):
# ISO is good enough; assume UTC-aware datetimes are already normalized upstream.
return o.isoformat()
if isinstance(o, _dec.Decimal):
return str(o)
# UUIDs, Enums, etc.
if hasattr(o, "__str__"):
return str(o)
raise TypeError(f"Unsupported cursor value type: {type(o)!r}")
def _b64url_nopad_encode(b: bytes) -> str:
return base64.urlsafe_b64encode(b).rstrip(b"=").decode("ascii")
def _b64url_nopad_decode(s: str) -> bytes:
# Restore padding for strict decode
pad = "=" * (-len(s) % 4)
return base64.urlsafe_b64decode((s + pad).encode("ascii"))
def encode_cursor(
values: list[Any] | None,
desc_flags: list[bool],
backward: bool,
*,
secret: bytes | None = None,
) -> str | None:
"""
Create an opaque, optionally signed cursor token.
- values: keyset values for the last/first row
- desc_flags: per-key descending flags (same length/order as sort keys)
- backward: whether the window direction is backward
- secret: if set, an HMAC-SHA256 signature is appended (JWT-ish 'payload.sig')
"""
if not values: if not values:
return None return None
payload = {"v": values, "d": desc_flags, "b": backward}
return base64.urlsafe_b64encode(json.dumps(payload).encode()).decode()
payload = { def decode_cursor(token: str | None) -> tuple[list[Any] | None, bool] | tuple[None, bool]:
"ver": 1,
"v": values,
"d": desc_flags,
"b": bool(backward),
}
body = json.dumps(payload, default=_json_default, separators=(",", ":"), sort_keys=True).encode("utf-8")
token = _b64url_nopad_encode(body)
if secret:
sig = hmac.new(secret, body, sha256).digest()
token = f"{token}.{_b64url_nopad_encode(sig)}"
return token
def decode_cursor(
token: str | None,
*,
secret: bytes | None = None,
) -> tuple[list[Any] | None, list[bool] | None, bool]:
"""
Parse a cursor token. Returns (values, desc_flags, backward).
- Accepts legacy tokens lacking 'd' and returns desc_flags=None in that case.
- If secret is provided, verifies HMAC and rejects tampered tokens.
- On any parse failure, returns (None, None, False).
"""
if not token: if not token:
return None, None, False return None, False
try: try:
# Split payload.sig if signed obj = json.loads(base64.urlsafe_b64decode(token.encode()).decode())
if "." in token:
body_b64, sig_b64 = token.split(".", 1)
body = _b64url_nopad_decode(body_b64)
if secret is None:
# Caller didnt ask for verification; still parse but dont trust.
pass
else:
expected = hmac.new(secret, body, sha256).digest()
actual = _b64url_nopad_decode(sig_b64)
if not hmac.compare_digest(expected, actual):
return None, None, False
else:
body = _b64url_nopad_decode(token)
obj = json.loads(body.decode("utf-8"))
# Versioning. If we ever change fields, branch here.
ver = int(obj.get("ver", 0))
if ver not in (0, 1):
return None, None, False
vals = obj.get("v") vals = obj.get("v")
backward = bool(obj.get("b", False)) backward = bool(obj.get("b", False))
if isinstance(vals, list):
# desc_flags may be absent in legacy payloads (ver 0) return vals, backward
desc = obj.get("d")
if not isinstance(vals, list):
return None, None, False
if desc is not None and not (isinstance(desc, list) and all(isinstance(x, bool) for x in desc)):
# Ignore weird 'd' types rather than crashing
desc = None
return vals, desc, backward
except Exception: except Exception:
# Be tolerant on decode: treat as no-cursor. pass
return None, None, False return None, False

View file

@ -1,40 +1,23 @@
from __future__ import annotations from flask import Blueprint, jsonify, request
from flask import Blueprint, jsonify, request, abort
from urllib.parse import urlencode
from crudkit.api._cursor import encode_cursor, decode_cursor from crudkit.api._cursor import encode_cursor, decode_cursor
from crudkit.core.service import _is_truthy from crudkit.core.service import _is_truthy
def _bool_param(d: dict[str, str], key: str, default: bool) -> bool:
return _is_truthy(d.get(key, "1" if default else "0"))
def _safe_int(value: str | None, default: int) -> int:
try:
return int(value) if value is not None else default
except Exception:
return default
def _link_with_params(base_url: str, **params) -> str:
# Filter out None, encode safely
q = {k: v for k, v in params.items() if v is not None}
return f"{base_url}?{urlencode(q)}"
def generate_crud_blueprint(model, service): def generate_crud_blueprint(model, service):
bp = Blueprint(model.__name__.lower(), __name__) bp = Blueprint(model.__name__.lower(), __name__)
@bp.get("/") @bp.get('/')
def list_items(): def list_items():
# Work from a copy so we don't mutate request.args
args = request.args.to_dict(flat=True) args = request.args.to_dict(flat=True)
# legacy detection
legacy_offset = "offset" in args or "page" in args legacy_offset = "offset" in args or "page" in args
limit = _safe_int(args.get("limit"), 50) # sane limit default
try:
limit = int(args.get("limit", 50))
except Exception:
limit = 50
args["limit"] = limit args["limit"] = limit
if legacy_offset: if legacy_offset:
@ -42,23 +25,17 @@ def generate_crud_blueprint(model, service):
items = service.list(args) items = service.list(args)
return jsonify([obj.as_dict() for obj in items]) return jsonify([obj.as_dict() for obj in items])
# New behavior: keyset pagination with cursors # New behavior: keyset seek with cursors
cursor_token = args.get("cursor") key, backward = decode_cursor(args.get("cursor"))
key, desc_from_cursor, backward = decode_cursor(cursor_token)
window = service.seek_window( window = service.seek_window(
args, args,
key=key, key=key,
backward=backward, backward=backward,
include_total=_bool_param(args, "include_total", True), include_total=_is_truthy(args.get("include_total", "1")),
) )
# Prefer the order actually used by the window; fall back to desc_from_cursor if needed. desc_flags = list(window.order.desc)
try:
desc_flags = list(window.order.desc)
except Exception:
desc_flags = desc_from_cursor or []
body = { body = {
"items": [obj.as_dict() for obj in window.items], "items": [obj.as_dict() for obj in window.items],
"limit": window.limit, "limit": window.limit,
@ -68,60 +45,46 @@ def generate_crud_blueprint(model, service):
} }
resp = jsonify(body) resp = jsonify(body)
# Optional Link header
# Preserve users other query params like include_total, filters, sorts, etc. links = []
base_url = request.base_url
base_params = {k: v for k, v in args.items() if k not in {"cursor"}}
link_parts = []
if body["next_cursor"]: if body["next_cursor"]:
link_parts.append( links.append(f'<{request.base_url}?cursor={body["next_cursor"]}&limit={window.limit}>; rel="next"')
f'<{_link_with_params(base_url, **base_params, cursor=body["next_cursor"])}>; rel="next"'
)
if body["prev_cursor"]: if body["prev_cursor"]:
link_parts.append( links.append(f'<{request.base_url}?cursor={body["prev_cursor"]}&limit={window.limit}>; rel="prev"')
f'<{_link_with_params(base_url, **base_params, cursor=body["prev_cursor"])}>; rel="prev"' if links:
) resp.headers["Link"] = ", ".join(links)
if link_parts:
resp.headers["Link"] = ", ".join(link_parts)
return resp return resp
@bp.get("/<int:id>") @bp.get('/<int:id>')
def get_item(id): def get_item(id):
item = service.get(id, request.args)
try: try:
item = service.get(id, request.args)
if item is None:
abort(404)
return jsonify(item.as_dict()) return jsonify(item.as_dict())
except Exception as e: except Exception as e:
# Could be validation, auth, or just you forgetting an index again return jsonify({"status": "error", "error": str(e)})
return jsonify({"status": "error", "error": str(e)}), 400
@bp.post("/") @bp.post('/')
def create_item(): def create_item():
payload = request.get_json(silent=True) or {} obj = service.create(request.json)
try: try:
obj = service.create(payload)
return jsonify(obj.as_dict()), 201
except Exception as e:
return jsonify({"status": "error", "error": str(e)}), 400
@bp.patch("/<int:id>")
def update_item(id):
payload = request.get_json(silent=True) or {}
try:
obj = service.update(id, payload)
return jsonify(obj.as_dict()) return jsonify(obj.as_dict())
except Exception as e: except Exception as e:
# 404 if not found, 400 if validation. Your service can throw specific exceptions if you ever feel like being professional. return jsonify({"status": "error", "error": str(e)})
return jsonify({"status": "error", "error": str(e)}), 400
@bp.delete("/<int:id>") @bp.patch('/<int:id>')
def delete_item(id): def update_item(id):
obj = service.update(id, request.json)
try: try:
service.delete(id) return jsonify(obj.as_dict())
# 204 means "no content" so don't send any.
return ("", 204)
except Exception as e: except Exception as e:
return jsonify({"status": "error", "error": str(e)}), 400 return jsonify({"status": "error", "error": str(e)})
@bp.delete('/<int:id>')
def delete_item(id):
service.delete(id)
try:
return jsonify({"status": "success"}), 204
except Exception as e:
return jsonify({"status": "error", "error": str(e)})
return bp return bp

View file

@ -1,22 +1,92 @@
from __future__ import annotations from typing import Any, Callable, Dict, Iterable, List, Tuple, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
from sqlalchemy import and_, func, inspect, or_, text
from collections.abc import Iterable
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
from sqlalchemy import and_, func, inspect, or_, text, UnaryExpression
from sqlalchemy.engine import Engine, Connection from sqlalchemy.engine import Engine, Connection
from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager, selectinload from sqlalchemy.orm import Load, Session, raiseload, selectinload, with_polymorphic, Mapper, RelationshipProperty, class_mapper, ColumnProperty
from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.base import NO_VALUE
from sqlalchemy.orm.util import AliasedClass
from sqlalchemy.sql import operators from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import UnaryExpression
from crudkit.core.base import Version from crudkit.core.base import Version
from crudkit.core.spec import CRUDSpec from crudkit.core.spec import CRUDSpec
from crudkit.core.types import OrderSpec, SeekWindow from crudkit.core.types import OrderSpec, SeekWindow
from crudkit.backend import BackendInfo, make_backend_info from crudkit.backend import BackendInfo, make_backend_info
from crudkit.projection import compile_projection
import logging def _expand_requires(model_cls, fields):
log = logging.getLogger("crudkit.service") out, seen = [], set()
def add(f):
if f not in seen:
seen.add(f); out.append(f)
for f in fields:
add(f)
parts = f.split(".")
cur_cls = model_cls
prefix = []
for p in parts[:-1]:
rel = getattr(cur_cls.__mapper__.relationships, 'get', lambda _: None)(p)
if not rel:
cur_cls = None
break
cur_cls = rel.mapper.class_
prefix.append(p)
if cur_cls is None:
continue
leaf = parts[-1]
deps = (getattr(cur_cls, "__crudkit_field_requires__", {}) or {}).get(leaf)
if not deps:
continue
pre = ".".join(prefix)
for dep in deps:
add(f"{pre + '.' if pre else ''}{dep}")
return out
def _is_rel(model_cls, name: str) -> bool:
try:
prop = model_cls.__mapper__.relationships.get(name)
return isinstance(prop, RelationshipProperty)
except Exception:
return False
def _is_instrumented_column(attr) -> bool:
try:
return hasattr(attr, "property") and isinstance(attr.property, ColumnProperty)
except Exception:
return False
def _loader_options_for_fields(root_alias, model_cls, fields: list[str]) -> list[Load]:
"""
For bare MANYTOONE names in fields (e.g. "location"), selectinload the relationship
and only fetch the related PK. This is enough for preselecting <select> inputs
without projecting the FK column on the root model.
"""
opts: list[Load] = []
if not fields:
return opts
mapper = class_mapper(model_cls)
for name in fields:
prop = mapper.relationships.get(name)
if not isinstance(prop, RelationshipProperty):
continue
if prop.direction.name != "MANYTOONE":
continue
rel_attr = getattr(root_alias, name)
target_cls = prop.mapper.class_
# load_only PK if present; else just selectinload
id_attr = getattr(target_cls, "id", None)
if id_attr is not None:
opts.append(selectinload(rel_attr).load_only(id_attr))
else:
opts.append(selectinload(rel_attr))
return opts
@runtime_checkable @runtime_checkable
class _HasID(Protocol): class _HasID(Protocol):
@ -40,65 +110,9 @@ class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol):
T = TypeVar("T", bound=_CRUDModelProto) T = TypeVar("T", bound=_CRUDModelProto)
def _belongs_to_alias(col: Any, alias: Any) -> bool:
# Try to detect if a column/expression ultimately comes from this alias.
# Works for most ORM columns; complex expressions may need more.
t = getattr(col, "table", None)
selectable = getattr(alias, "selectable", None)
return t is not None and selectable is not None and t is selectable
def _paths_needed_for_sql(order_by: Iterable[Any], filters: Iterable[Any], join_paths: tuple) -> set[tuple[str, ...]]:
paths: set[tuple[str, ...]] = set()
# Sort columns
for ob in order_by or []:
col = getattr(ob, "element", ob) # unwrap UnaryExpression
for path, _rel_attr, target_alias in join_paths:
if _belongs_to_alias(col, target_alias):
paths.add(tuple(path))
# Filter columns (best-effort)
# Walk simple binary expressions
def _extract_cols(expr: Any) -> Iterable[Any]:
if isinstance(expr, ColumnElement):
yield expr
for ch in getattr(expr, "get_children", lambda: [])():
yield from _extract_cols(ch)
elif hasattr(expr, "clauses"):
for ch in expr.clauses:
yield from _extract_cols(ch)
for flt in filters or []:
for col in _extract_cols(flt):
for path, _rel_attr, target_alias in join_paths:
if _belongs_to_alias(col, target_alias):
paths.add(tuple[path])
return paths
def _paths_from_fields(req_fields: list[str]) -> set[tuple[str, ...]]:
out: set[tuple[str, ...]] = set()
for f in req_fields:
if "." in f:
parts = tuple(f.split(".")[:-1])
if parts:
out.add(parts)
return out
def _is_truthy(val): def _is_truthy(val):
return str(val).lower() in ('1', 'true', 'yes', 'on') return str(val).lower() in ('1', 'true', 'yes', 'on')
def _normalize_fields_param(params: dict | None) -> list[str]:
if not params:
return []
raw = params.get("fields")
if isinstance(raw, (list, tuple)):
out: list[str] = []
for item in raw:
if isinstance(item, str):
out.extend([p for p in (s.strip() for s in item.split(",")) if p])
return out
if isinstance(raw, str):
return [p for p in (s.strip() for s in raw.split(",")) if p]
return []
class CRUDService(Generic[T]): class CRUDService(Generic[T]):
def __init__( def __init__(
self, self,
@ -113,7 +127,7 @@ class CRUDService(Generic[T]):
self.polymorphic = polymorphic self.polymorphic = polymorphic
self.supports_soft_delete = hasattr(model, 'is_deleted') self.supports_soft_delete = hasattr(model, 'is_deleted')
# Cache backend info once. If not provided, derive from session bind. # Cache backend info once. If not provided, derive from session bind.
bind = session_factory().get_bind() bind = self.session.get_bind()
eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind) eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind)
self.backend = backend or make_backend_info(eng) self.backend = backend or make_backend_info(eng)
@ -127,10 +141,58 @@ class CRUDService(Generic[T]):
return self.session.query(poly), poly return self.session.query(poly), poly
return self.session.query(self.model), self.model return self.session.query(self.model), self.model
def _apply_not_deleted(self, query, root_alias, params) -> Any: def _resolve_required_includes(self, root_alias: Any, rel_field_names: Dict[Tuple[str, ...], List[str]]) -> List[Any]:
if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")): """
return query.filter(getattr(root_alias, "is_deleted") == False) For each dotted path like ("location"), -> ["label"], look up the target
return query model's __crudkit_field_requires__ for the terminal field and produce
selectinload options prefixed with the relationship path, e.g.:
Room.__crudkit_field_requires__['label'] = ['room_function']
=> selectinload(root.location).selectinload(Room.room_function)
"""
opts: List[Any] = []
root_mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
for path, names in (rel_field_names or {}).items():
if not path:
continue
current_mapper = root_mapper
rel_props: List[RelationshipProperty] = []
valid = True
for step in path:
rel = current_mapper.relationships.get(step)
if not isinstance(rel, RelationshipProperty):
valid = False
break
rel_props.append(rel)
current_mapper = cast(Mapper[Any], inspect(rel.entity.entity))
if not valid or not rel_props:
continue
first = rel_props[0]
base_loader = selectinload(getattr(root_alias, first.key))
for i in range(1, len(rel_props)):
prev_target_cls = rel_props[i - 1].mapper.class_
hop_attr = getattr(prev_target_cls, rel_props[i].key)
base_loader = base_loader.selectinload(hop_attr)
target_cls = rel_props[-1].mapper.class_
requires = getattr(target_cls, "__crudkit_field_requires__", None)
if not isinstance(requires, dict):
continue
for field_name in names:
needed: Iterable[str] = requires.get(field_name, []) or []
for rel_need in needed:
rel_prop2 = target_cls.__mapper__.relationships.get(rel_need)
if not isinstance(rel_prop2, RelationshipProperty):
continue
dep_attr = getattr(target_cls, rel_prop2.key)
opts.append(base_loader.selectinload(dep_attr))
return opts
def _extract_order_spec(self, root_alias, given_order_by): def _extract_order_spec(self, root_alias, given_order_by):
""" """
@ -138,6 +200,8 @@ class CRUDService(Generic[T]):
Normalize order_by into (cols, desc_flags). Supports plain columns and Normalize order_by into (cols, desc_flags). Supports plain columns and
col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses. col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses.
""" """
from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import UnaryExpression
given = self._stable_order_by(root_alias, given_order_by) given = self._stable_order_by(root_alias, given_order_by)
@ -145,7 +209,7 @@ class CRUDService(Generic[T]):
for ob in given: for ob in given:
# Unwrap column if this is a UnaryExpression produced by .asc()/.desc() # Unwrap column if this is a UnaryExpression produced by .asc()/.desc()
elem = getattr(ob, "element", None) elem = getattr(ob, "element", None)
col = elem if elem is not None else ob col = elem if elem is not None else ob # don't use "or" with SA expressions
# Detect direction in SA 2.x # Detect direction in SA 2.x
is_desc = False is_desc = False
@ -159,33 +223,31 @@ class CRUDService(Generic[T]):
cols.append(col) cols.append(col)
desc_flags.append(bool(is_desc)) desc_flags.append(bool(is_desc))
from crudkit.core.types import OrderSpec
return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags)) return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags))
def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool): def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool):
"""
Build lexicographic predicate for keyset seek.
For backward traversal, import comparisons.
"""
if not key_vals: if not key_vals:
return None return None
conds = [] conds = []
for i, col in enumerate(spec.cols): for i, col in enumerate(spec.cols):
# If NULLs possible, normalize for comparison. Example using coalesce to a sentinel:
# sent_col = func.coalesce(col, literal("-∞"))
sent_col = col
ties = [spec.cols[j] == key_vals[j] for j in range(i)] ties = [spec.cols[j] == key_vals[j] for j in range(i)]
is_desc = spec.desc[i] is_desc = spec.desc[i]
if not backward: if not backward:
op = (sent_col < key_vals[i]) if is_desc else (sent_col > key_vals[i]) op = col < key_vals[i] if is_desc else col > key_vals[i]
else: else:
op = (sent_col > key_vals[i]) if is_desc else (sent_col < key_vals[i]) op = col > key_vals[i] if is_desc else col < key_vals[i]
conds.append(and_(*ties, op)) conds.append(and_(*ties, op))
return or_(*conds) return or_(*conds)
def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]: def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]:
out = [] out = []
for c in spec.cols: for c in spec.cols:
# Only simple mapped columns supported for key pluck
key = getattr(c, "key", None) or getattr(c, "name", None) key = getattr(c, "key", None) or getattr(c, "name", None)
if key is None or not hasattr(obj, key):
raise ValueError("Order includes non-mapped or related column; cannot pluck cursor key from row object.")
out.append(getattr(obj, key)) out.append(getattr(obj, key))
return out return out
@ -204,103 +266,62 @@ class CRUDService(Generic[T]):
- forward/backward seek via `key` and `backward` - forward/backward seek via `key` and `backward`
Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total. Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total.
""" """
session = self.session fields = list(params.get("fields", []))
if fields:
fields = _expand_requires(self.model, fields)
params = {**params, "fields": fields}
query, root_alias = self.get_query() query, root_alias = self.get_query()
# Normalize requested fields and compile projection (may skip later to avoid conflicts)
fields = _normalize_fields_param(params)
expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], [])
spec = CRUDSpec(self.model, params or {}, root_alias) spec = CRUDSpec(self.model, params or {}, root_alias)
filters = spec.parse_filters() filters = spec.parse_filters()
order_by = spec.parse_sort() order_by = spec.parse_sort()
# Field parsing for root load_only fallback
root_fields, rel_field_names, root_field_names = spec.parse_fields() root_fields, rel_field_names, root_field_names = spec.parse_fields()
# Soft delete filter seen_rel_roots = set()
query = self._apply_not_deleted(query, root_alias, params) for path, names in (rel_field_names or {}).items():
if not path:
continue
rel_name = path[0]
if rel_name in seen_rel_roots:
continue
if _is_rel(self.model, rel_name):
rel_attr = getattr(root_alias, rel_name, None)
if rel_attr is not None:
query = query.options(selectinload(rel_attr))
seen_rel_roots.add(rel_name)
# Apply filters first # Soft delete filter
if self.supports_soft_delete and not _is_truthy(params.get("include_deleted")):
query = query.filter(getattr(root_alias, "is_deleted") == False)
# Parse filters first
if filters: if filters:
query = query.filter(*filters) query = query.filter(*filters)
# Includes + join paths (dotted fields etc.) # Includes + joins (so relationship fields like brand.name, location.label work)
spec.parse_includes() spec.parse_includes()
join_paths = tuple(spec.get_join_paths()) # iterable of (path, relationship_attr, target_alias) for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias)
query = query.join(target, rel_attr.of_type(target), isouter=True)
# Decide which relationship *names* are required for SQL (filters/sort) vs display-only # Fields/projection: load_only for root columns, eager loads for relationships
def _belongs_to_alias(col: Any, alias: Any) -> bool:
t = getattr(col, "table", None)
selectable = getattr(alias, "selectable", None)
return t is not None and selectable is not None and t is selectable
# 1) which relationship aliases are referenced by sort/filter
sql_hops: set[str] = set()
for path, relationship_attr, target_alias in join_paths:
# If any ORDER BY column comes from this alias, mark it
for ob in (order_by or []):
col = getattr(ob, "element", ob) # unwrap UnaryExpression
if _belongs_to_alias(col, target_alias):
sql_hops.add(relationship_attr.key)
break
# If any filter expr touches this alias, mark it (best effort)
if relationship_attr.key not in sql_hops:
def _walk_cols(expr: Any):
# Primitive walker for ColumnElement trees
from sqlalchemy.sql.elements import ColumnElement
if isinstance(expr, ColumnElement):
yield expr
for ch in getattr(expr, "get_children", lambda: [])():
yield from _walk_cols(ch)
elif hasattr(expr, "clauses"):
for ch in expr.clauses:
yield from _walk_cols(ch)
for flt in (filters or []):
if any(_belongs_to_alias(c, target_alias) for c in _walk_cols(flt)):
sql_hops.add(relationship_attr.key)
break
# 2) first-hop relationship names implied by dotted projection fields
proj_hops: set[str] = {f.split(".", 1)[0] for f in fields if "." in f}
# Root column projection
from sqlalchemy.orm import Load # local import to match your style
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols: if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols)) query = query.options(Load(root_alias).load_only(*only_cols))
# for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
# query = query.options(eager)
# Relationship handling per path (avoid loader strategy conflicts) for opt in self._resolve_required_includes(root_alias, rel_field_names):
used_contains_eager = False query = query.options(opt)
for path, relationship_attr, target_alias in join_paths:
rel_attr = cast(InstrumentedAttribute, relationship_attr)
name = relationship_attr.key
if name in sql_hops:
# Needed for WHERE/ORDER BY: join + hydrate from that join
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True
elif name in proj_hops:
# Display-only: bulk-load efficiently, no join
query = query.options(selectinload(rel_attr))
else:
# Not needed
pass
# Apply projection loader options only if they won't conflict with contains_eager
if proj_opts and not used_contains_eager:
query = query.options(*proj_opts)
# Order + limit # Order + limit
order_spec = self._extract_order_spec(root_alias, order_by) # SA 2.x helper order_spec = self._extract_order_spec(root_alias, order_by) # SA2-only helper
limit, _ = spec.parse_pagination() limit, _ = spec.parse_pagination()
if limit is None: if not limit or limit <= 0:
effective_limit = 50 limit = 50 # sensible default
elif limit == 0:
effective_limit = None # unlimited
else:
effective_limit = limit
# Keyset predicate # Keyset predicate
if key: if key:
@ -310,36 +331,30 @@ class CRUDService(Generic[T]):
# Apply ordering. For backward, invert SQL order then reverse in-memory for display. # Apply ordering. For backward, invert SQL order then reverse in-memory for display.
if not backward: if not backward:
clauses = [(c.desc() if is_desc else c.asc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)] clauses = []
query = query.order_by(*clauses) for col, is_desc in zip(order_spec.cols, order_spec.desc):
if effective_limit is not None: clauses.append(col.desc() if is_desc else col.asc())
query = query.limit(effective_limit) query = query.order_by(*clauses).limit(limit)
items = query.all() items = query.all()
else: else:
inv_clauses = [(c.asc() if is_desc else c.desc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)] inv_clauses = []
query = query.order_by(*inv_clauses) for col, is_desc in zip(order_spec.cols, order_spec.desc):
if effective_limit is not None: inv_clauses.append(col.asc() if is_desc else col.desc())
query = query.limit(effective_limit) query = query.order_by(*inv_clauses).limit(limit)
items = list(reversed(query.all())) items = list(reversed(query.all()))
# Tag projection so your renderer knows what fields were requested # Tag projection so your renderer knows what fields were requested
if fields: proj = []
proj = list(dict.fromkeys(fields)) # dedupe, preserve order if root_field_names:
if "id" not in proj and hasattr(self.model, "id"): proj.extend(root_field_names)
proj.insert(0, "id") if root_fields:
else: proj.extend(c.key for c in root_fields)
proj = [] for path, names in (rel_field_names or {}).items():
if root_field_names: prefix = ".".join(path)
proj.extend(root_field_names) for n in names:
if root_fields: proj.append(f"{prefix}.{n}")
proj.extend(c.key for c in root_fields if hasattr(c, "key")) if proj and "id" not in proj and hasattr(self.model, "id"):
for path, names in (rel_field_names or {}).items(): proj.insert(0, "id")
prefix = ".".join(path)
for n in names:
proj.append(f"{prefix}.{n}")
if proj and "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
if proj: if proj:
for obj in items: for obj in items:
try: try:
@ -349,32 +364,27 @@ class CRUDService(Generic[T]):
# Boundary keys for cursor encoding in the API layer # Boundary keys for cursor encoding in the API layer
first_key = self._pluck_key(items[0], order_spec) if items else None first_key = self._pluck_key(items[0], order_spec) if items else None
last_key = self._pluck_key(items[-1], order_spec) if items else None last_key = self._pluck_key(items[-1], order_spec) if items else None
# Optional total thats safe under JOINs (COUNT DISTINCT ids) # Optional total thats safe under JOINs (COUNT DISTINCT ids)
total = None total = None
if include_total: if include_total:
base = session.query(getattr(root_alias, "id")) base = self.session.query(getattr(root_alias, "id"))
base = self._apply_not_deleted(base, root_alias, params) if self.supports_soft_delete and not _is_truthy(params.get("include_deleted")):
base = base.filter(getattr(root_alias, "is_deleted") == False)
if filters: if filters:
base = base.filter(*filters) base = base.filter(*filters)
# Mirror join structure for any SQL-needed relationships # replicate the same joins used above
for path, relationship_attr, target_alias in join_paths: for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
if relationship_attr.key in sql_hops: rel_attr = cast(InstrumentedAttribute, relationship_attr)
rel_attr = cast(InstrumentedAttribute, relationship_attr) target = cast(Any, target_alias)
base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True) base = base.join(target, rel_attr.of_type(target), isouter=True)
total = session.query(func.count()).select_from( total = self.session.query(func.count()).select_from(base.order_by(None).distinct().subquery()).scalar() or 0
base.order_by(None).distinct().subquery()
).scalar() or 0
window_limit_for_body = 0 if effective_limit is None and limit == 0 else (effective_limit or 50)
if log.isEnabledFor(logging.DEBUG):
log.debug("QUERY: %s", str(query))
from crudkit.core.types import SeekWindow # avoid circulars at module top
return SeekWindow( return SeekWindow(
items=items, items=items,
limit=window_limit_for_body, limit=limit,
first_key=first_key, first_key=first_key,
last_key=last_key, last_key=last_key,
order=order_spec, order=order_spec,
@ -406,160 +416,148 @@ class CRUDService(Generic[T]):
for col in mapper.primary_key: for col in mapper.primary_key:
try: try:
pk_cols.append(getattr(root_alias, col.key)) pk_cols.append(getattr(root_alias, col.key))
except AttributeError: except ArithmeticError:
pk_cols.append(col) pk_cols.append(col)
return [*order_by, *pk_cols] return [*order_by, *pk_cols]
def get(self, id: int, params=None) -> T | None: def get(self, id: int, params=None) -> T | None:
"""Fetch a single row by id with conflict-free eager loading and clean projection."""
query, root_alias = self.get_query() query, root_alias = self.get_query()
# Defaults so we can build a projection even if params is None include_deleted = False
root_fields: list[Any] = [] root_fields = []
root_field_names: dict[str, str] = {} root_field_names = {}
rel_field_names: dict[tuple[str, ...], list[str]] = {} rel_field_names = {}
req_fields: list[str] = _normalize_fields_param(params)
# Soft-delete guard
query = self._apply_not_deleted(query, root_alias, params)
spec = CRUDSpec(self.model, params or {}, root_alias) spec = CRUDSpec(self.model, params or {}, root_alias)
if params:
# Optional extra filters (in addition to id); keep parity with list() if self.supports_soft_delete:
filters = spec.parse_filters() include_deleted = _is_truthy(params.get('include_deleted'))
if filters: if self.supports_soft_delete and not include_deleted:
query = query.filter(*filters) query = query.filter(getattr(root_alias, "is_deleted") == False)
# Always filter by id
query = query.filter(getattr(root_alias, "id") == id) query = query.filter(getattr(root_alias, "id") == id)
# Includes + join paths we may need
spec.parse_includes() spec.parse_includes()
join_paths = tuple(spec.get_join_paths())
# Field parsing to enable root load_only for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias)
query = query.join(target, rel_attr.of_type(target), isouter=True)
if params: if params:
root_fields, rel_field_names, root_field_names = spec.parse_fields() root_fields, rel_field_names, root_field_names = spec.parse_fields()
# Decide which relationship paths are needed for SQL vs display-only if rel_field_names:
# For get(), there is no ORDER BY; only filters might force SQL use. seen_rel_roots = set()
sql_paths = _paths_needed_for_sql(order_by=None, filters=filters, join_paths=join_paths) for path, names in rel_field_names.items():
proj_paths = _paths_from_fields(req_fields) if not path:
continue
rel_name = path[0]
if rel_name in seen_rel_roots:
continue
if _is_rel(self.model, rel_name):
rel_attr = getattr(root_alias, rel_name, None)
if rel_attr is not None:
query = query.options(selectinload(rel_attr))
seen_rel_roots.add(rel_name)
fields = (params or {}).get("fields") if isinstance(params, dict) else None
if fields:
for opt in _loader_options_for_fields(root_alias, self.model, fields):
query = query.options(opt)
# Root column projection
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols: if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols)) query = query.options(Load(root_alias).load_only(*only_cols))
# Relationship handling per path: avoid loader strategy conflicts # for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
used_contains_eager = False # query = query.options(eager)
for path, relationship_attr, target_alias in join_paths:
rel_attr = cast(InstrumentedAttribute, relationship_attr)
ptuple = tuple(path)
if ptuple in sql_paths:
# Needed in WHERE: join + hydrate from the join
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True
elif ptuple in proj_paths:
# Display-only: bulk-load efficiently
query = query.options(selectinload(rel_attr))
else:
pass
# Projection loader options compiled from requested fields. if params:
# Skip if we used contains_eager to avoid strategy conflicts. fields = params.get("fields") or []
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) for opt in _loader_options_for_fields(root_alias, self.model, fields):
if proj_opts and not used_contains_eager: query = query.options(opt)
query = query.options(*proj_opts)
obj = query.first() obj = query.first()
# Emit exactly what the client requested (plus id), or a reasonable fallback proj = []
if req_fields: if root_field_names:
proj = list(dict.fromkeys(req_fields)) # dedupe, preserve order proj.extend(root_field_names)
if "id" not in proj and hasattr(self.model, "id"): if root_fields:
proj.insert(0, "id") proj.extend(c.key for c in root_fields)
else: for path, names in (rel_field_names or {}).items():
proj = [] prefix = ".".join(path)
if root_field_names: for n in names:
proj.extend(root_field_names) proj.append(f"{prefix}.{n}")
if root_fields:
proj.extend(c.key for c in root_fields if hasattr(c, "key"))
for path, names in (rel_field_names or {}).items():
prefix = ".".join(path)
for n in names:
proj.append(f"{prefix}.{n}")
if proj and "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
if proj and obj is not None: if proj and "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
if proj:
try: try:
setattr(obj, "__crudkit_projection__", tuple(proj)) setattr(obj, "__crudkit_projection__", tuple(proj))
except Exception: except Exception:
pass pass
if log.isEnabledFor(logging.DEBUG):
log.debug("QUERY: %s", str(query))
return obj or None return obj or None
def list(self, params=None) -> list[T]: def list(self, params=None) -> list[T]:
"""Offset/limit listing with smart relationship loading and clean projection."""
query, root_alias = self.get_query() query, root_alias = self.get_query()
# Defaults so we can reference them later even if params is None root_fields = []
root_fields: list[Any] = [] root_field_names = {}
root_field_names: dict[str, str] = {} rel_field_names = {}
rel_field_names: dict[tuple[str, ...], list[str]] = {}
req_fields: list[str] = _normalize_fields_param(params)
if params: if params:
query = self._apply_not_deleted(query, root_alias, params) if self.supports_soft_delete:
include_deleted = _is_truthy(params.get('include_deleted'))
if not include_deleted:
query = query.filter(getattr(root_alias, "is_deleted") == False)
spec = CRUDSpec(self.model, params or {}, root_alias) spec = CRUDSpec(self.model, params or {}, root_alias)
filters = spec.parse_filters() filters = spec.parse_filters()
order_by = spec.parse_sort() order_by = spec.parse_sort()
limit, offset = spec.parse_pagination() limit, offset = spec.parse_pagination()
# Includes + join paths we might need
spec.parse_includes() spec.parse_includes()
join_paths = tuple(spec.get_join_paths())
# Field parsing for load_only on root columns for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
root_fields, rel_field_names, root_field_names = spec.parse_fields() rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias)
query = query.join(target, rel_attr.of_type(target), isouter=True)
if filters: if params:
query = query.filter(*filters) root_fields, rel_field_names, root_field_names = spec.parse_fields()
# Determine which relationship paths are needed for SQL vs display-only if rel_field_names:
sql_paths = _paths_needed_for_sql(order_by, filters, join_paths) seen_rel_roots = set()
proj_paths = _paths_from_fields(req_fields) for path, names in rel_field_names.items():
if not path:
continue
rel_name = path[0]
if rel_name in seen_rel_roots:
continue
if _is_rel(self.model, rel_name):
rel_attr = getattr(root_alias, rel_name, None)
if rel_attr is not None:
query = query.options(selectinload(rel_attr))
seen_rel_roots.add(rel_name)
# Root column projection
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols: if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols)) query = query.options(Load(root_alias).load_only(*only_cols))
# Relationship handling per path # for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
used_contains_eager = False # query = query.options(eager)
for path, relationship_attr, target_alias in join_paths:
rel_attr = cast(InstrumentedAttribute, relationship_attr)
ptuple = tuple(path)
if ptuple in sql_paths:
# Needed for WHERE/ORDER BY: join + hydrate from the join
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
query = query.options(contains_eager(rel_attr, alias=target_alias))
used_contains_eager = True
elif ptuple in proj_paths:
# Display-only: no join, bulk-load efficiently
query = query.options(selectinload(rel_attr))
else:
# Not needed at all; do nothing
pass
# MSSQL requires ORDER BY when OFFSET is used (SQLA uses OFFSET for limit/offset) if params:
fields = params.get("fields") or []
for opt in _loader_options_for_fields(root_alias, self.model, fields):
query = query.options(opt)
if filters:
query = query.filter(*filters)
# MSSQL: requires ORDER BY when using OFFSET (and SQLA will use OFFSET for limit+offset).
paginating = (limit is not None) or (offset is not None and offset != 0) paginating = (limit is not None) or (offset is not None and offset != 0)
if paginating and not order_by and self.backend.requires_order_by_for_offset: if paginating and not order_by and self.backend.requires_order_by_for_offset:
order_by = self._default_order_by(root_alias) order_by = self._default_order_by(root_alias)
@ -567,43 +565,26 @@ class CRUDService(Generic[T]):
if order_by: if order_by:
query = query.order_by(*order_by) query = query.order_by(*order_by)
# Only apply offset/limit when not None and not zero # Only apply offset/limit when not None.
if offset is not None and offset != 0: if offset is not None and offset != 0:
query = query.offset(offset) query = query.offset(offset)
if limit is not None and limit > 0: if limit is not None and limit > 0:
query = query.limit(limit) query = query.limit(limit)
# Projection loader options compiled from requested fields.
# Skip if we used contains_eager to avoid loader-strategy conflicts.
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
if proj_opts and not used_contains_eager:
query = query.options(*proj_opts)
else:
# No params means no filters/sorts/limits; still honor projection loaders if any
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
if proj_opts:
query = query.options(*proj_opts)
rows = query.all() rows = query.all()
# Emit exactly what the client requested (plus id), or a reasonable fallback proj = []
if req_fields: if root_field_names:
proj = list(dict.fromkeys(req_fields)) # dedupe while preserving order proj.extend(root_field_names)
if "id" not in proj and hasattr(self.model, "id"): if root_fields:
proj.insert(0, "id") proj.extend(c.key for c in root_fields)
else: for path, names in (rel_field_names or {}).items():
proj = [] prefix = ".".join(path)
if root_field_names: for n in names:
proj.extend(root_field_names) proj.append(f"{prefix}.{n}")
if root_fields:
proj.extend(c.key for c in root_fields if hasattr(c, "key")) if proj and "id" not in proj and hasattr(self.model, "id"):
for path, names in (rel_field_names or {}).items(): proj.insert(0, "id")
prefix = ".".join(path)
for n in names:
proj.append(f"{prefix}.{n}")
if proj and "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
if proj: if proj:
for obj in rows: for obj in rows:
@ -612,51 +593,41 @@ class CRUDService(Generic[T]):
except Exception: except Exception:
pass pass
if log.isEnabledFor(logging.DEBUG):
log.debug("QUERY: %s", str(query))
return rows return rows
def create(self, data: dict, actor=None) -> T: def create(self, data: dict, actor=None) -> T:
session = self.session
obj = self.model(**data) obj = self.model(**data)
session.add(obj) self.session.add(obj)
session.commit() self.session.commit()
self._log_version("create", obj, actor) self._log_version("create", obj, actor)
return obj return obj
def update(self, id: int, data: dict, actor=None) -> T: def update(self, id: int, data: dict, actor=None) -> T:
session = self.session
obj = self.get(id) obj = self.get(id)
if not obj: if not obj:
raise ValueError(f"{self.model.__name__} with ID {id} not found.") raise ValueError(f"{self.model.__name__} with ID {id} not found.")
valid_fields = {c.name for c in self.model.__table__.columns} valid_fields = {c.name for c in self.model.__table__.columns}
unknown = set(data) - valid_fields
if unknown:
raise ValueError(f"Unknown fields: {', '.join(sorted(unknown))}")
for k, v in data.items(): for k, v in data.items():
if k in valid_fields: if k in valid_fields:
setattr(obj, k, v) setattr(obj, k, v)
session.commit() self.session.commit()
self._log_version("update", obj, actor) self._log_version("update", obj, actor)
return obj return obj
def delete(self, id: int, hard: bool = False, actor = None): def delete(self, id: int, hard: bool = False, actor = False):
session = self.session obj = self.session.get(self.model, id)
obj = session.get(self.model, id)
if not obj: if not obj:
return None return None
if hard or not self.supports_soft_delete: if hard or not self.supports_soft_delete:
session.delete(obj) self.session.delete(obj)
else: else:
soft = cast(_SoftDeletable, obj) soft = cast(_SoftDeletable, obj)
soft.is_deleted = True soft.is_deleted = True
session.commit() self.session.commit()
self._log_version("delete", obj, actor) self._log_version("delete", obj, actor)
return obj return obj
def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict | None = None): def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict = {}):
session = self.session
try: try:
data = obj.as_dict() data = obj.as_dict()
except Exception: except Exception:
@ -669,5 +640,5 @@ class CRUDService(Generic[T]):
actor=str(actor) if actor else None, actor=str(actor) if actor else None,
meta=metadata meta=metadata
) )
session.add(version) self.session.add(version)
session.commit() self.session.commit()

View file

@ -91,7 +91,7 @@ def init_listing_routes(app):
] ]
limit = int(request.args.get("limit", 15)) limit = int(request.args.get("limit", 15))
cursor = request.args.get("cursor") cursor = request.args.get("cursor")
key, _desc, backward = decode_cursor(cursor) key, backward = decode_cursor(cursor)
service = crudkit.crud.get_service(cls) service = crudkit.crud.get_service(cls)
window = service.seek_window(spec | {"limit": limit}, key=key, backward=backward, include_total=True) window = service.seek_window(spec | {"limit": limit}, key=key, backward=backward, include_total=True)