Downstream changes and decoupled engine instance problem. Fwee.
This commit is contained in:
parent
8be6f917c7
commit
d34654834b
4 changed files with 668 additions and 374 deletions
|
|
@ -1,21 +1,135 @@
|
||||||
import base64, json
|
# crudkit/api/_cursor.py
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
def encode_cursor(values: list[Any] | None, desc_flags: list[bool], backward: bool) -> str | None:
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import dataclasses
|
||||||
|
import datetime as _dt
|
||||||
|
import decimal as _dec
|
||||||
|
import hmac
|
||||||
|
import json
|
||||||
|
import typing as _t
|
||||||
|
from hashlib import sha256
|
||||||
|
|
||||||
|
Any = _t.Any
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class Cursor:
|
||||||
|
values: list[Any]
|
||||||
|
desc_flags: list[bool]
|
||||||
|
backward: bool
|
||||||
|
version: int = 1
|
||||||
|
|
||||||
|
|
||||||
|
def _json_default(o: Any) -> Any:
|
||||||
|
# Keep it boring and predictable.
|
||||||
|
if isinstance(o, (_dt.datetime, _dt.date, _dt.time)):
|
||||||
|
# ISO is good enough; assume UTC-aware datetimes are already normalized upstream.
|
||||||
|
return o.isoformat()
|
||||||
|
if isinstance(o, _dec.Decimal):
|
||||||
|
return str(o)
|
||||||
|
# UUIDs, Enums, etc.
|
||||||
|
if hasattr(o, "__str__"):
|
||||||
|
return str(o)
|
||||||
|
raise TypeError(f"Unsupported cursor value type: {type(o)!r}")
|
||||||
|
|
||||||
|
|
||||||
|
def _b64url_nopad_encode(b: bytes) -> str:
|
||||||
|
return base64.urlsafe_b64encode(b).rstrip(b"=").decode("ascii")
|
||||||
|
|
||||||
|
|
||||||
|
def _b64url_nopad_decode(s: str) -> bytes:
|
||||||
|
# Restore padding for strict decode
|
||||||
|
pad = "=" * (-len(s) % 4)
|
||||||
|
return base64.urlsafe_b64decode((s + pad).encode("ascii"))
|
||||||
|
|
||||||
|
|
||||||
|
def encode_cursor(
|
||||||
|
values: list[Any] | None,
|
||||||
|
desc_flags: list[bool],
|
||||||
|
backward: bool,
|
||||||
|
*,
|
||||||
|
secret: bytes | None = None,
|
||||||
|
) -> str | None:
|
||||||
|
"""
|
||||||
|
Create an opaque, optionally signed cursor token.
|
||||||
|
|
||||||
|
- values: keyset values for the last/first row
|
||||||
|
- desc_flags: per-key descending flags (same length/order as sort keys)
|
||||||
|
- backward: whether the window direction is backward
|
||||||
|
- secret: if set, an HMAC-SHA256 signature is appended (JWT-ish 'payload.sig')
|
||||||
|
"""
|
||||||
if not values:
|
if not values:
|
||||||
return None
|
return None
|
||||||
payload = {"v": values, "d": desc_flags, "b": backward}
|
|
||||||
return base64.urlsafe_b64encode(json.dumps(payload).encode()).decode()
|
|
||||||
|
|
||||||
def decode_cursor(token: str | None) -> tuple[list[Any] | None, bool] | tuple[None, bool]:
|
payload = {
|
||||||
|
"ver": 1,
|
||||||
|
"v": values,
|
||||||
|
"d": desc_flags,
|
||||||
|
"b": bool(backward),
|
||||||
|
}
|
||||||
|
|
||||||
|
body = json.dumps(payload, default=_json_default, separators=(",", ":"), sort_keys=True).encode("utf-8")
|
||||||
|
token = _b64url_nopad_encode(body)
|
||||||
|
|
||||||
|
if secret:
|
||||||
|
sig = hmac.new(secret, body, sha256).digest()
|
||||||
|
token = f"{token}.{_b64url_nopad_encode(sig)}"
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
def decode_cursor(
|
||||||
|
token: str | None,
|
||||||
|
*,
|
||||||
|
secret: bytes | None = None,
|
||||||
|
) -> tuple[list[Any] | None, list[bool] | None, bool]:
|
||||||
|
"""
|
||||||
|
Parse a cursor token. Returns (values, desc_flags, backward).
|
||||||
|
|
||||||
|
- Accepts legacy tokens lacking 'd' and returns desc_flags=None in that case.
|
||||||
|
- If secret is provided, verifies HMAC and rejects tampered tokens.
|
||||||
|
- On any parse failure, returns (None, None, False).
|
||||||
|
"""
|
||||||
if not token:
|
if not token:
|
||||||
return None, False
|
return None, None, False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
obj = json.loads(base64.urlsafe_b64decode(token.encode()).decode())
|
# Split payload.sig if signed
|
||||||
|
if "." in token:
|
||||||
|
body_b64, sig_b64 = token.split(".", 1)
|
||||||
|
body = _b64url_nopad_decode(body_b64)
|
||||||
|
if secret is None:
|
||||||
|
# Caller didn’t ask for verification; still parse but don’t trust.
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
expected = hmac.new(secret, body, sha256).digest()
|
||||||
|
actual = _b64url_nopad_decode(sig_b64)
|
||||||
|
if not hmac.compare_digest(expected, actual):
|
||||||
|
return None, None, False
|
||||||
|
else:
|
||||||
|
body = _b64url_nopad_decode(token)
|
||||||
|
|
||||||
|
obj = json.loads(body.decode("utf-8"))
|
||||||
|
|
||||||
|
# Versioning. If we ever change fields, branch here.
|
||||||
|
ver = int(obj.get("ver", 0))
|
||||||
|
if ver not in (0, 1):
|
||||||
|
return None, None, False
|
||||||
|
|
||||||
vals = obj.get("v")
|
vals = obj.get("v")
|
||||||
backward = bool(obj.get("b", False))
|
backward = bool(obj.get("b", False))
|
||||||
if isinstance(vals, list):
|
|
||||||
return vals, backward
|
# desc_flags may be absent in legacy payloads (ver 0)
|
||||||
|
desc = obj.get("d")
|
||||||
|
if not isinstance(vals, list):
|
||||||
|
return None, None, False
|
||||||
|
if desc is not None and not (isinstance(desc, list) and all(isinstance(x, bool) for x in desc)):
|
||||||
|
# Ignore weird 'd' types rather than crashing
|
||||||
|
desc = None
|
||||||
|
|
||||||
|
return vals, desc, backward
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
# Be tolerant on decode: treat as no-cursor.
|
||||||
return None, False
|
return None, None, False
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,40 @@
|
||||||
from flask import Blueprint, jsonify, request
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from flask import Blueprint, jsonify, request, abort
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
from crudkit.api._cursor import encode_cursor, decode_cursor
|
from crudkit.api._cursor import encode_cursor, decode_cursor
|
||||||
from crudkit.core.service import _is_truthy
|
from crudkit.core.service import _is_truthy
|
||||||
|
|
||||||
|
|
||||||
|
def _bool_param(d: dict[str, str], key: str, default: bool) -> bool:
|
||||||
|
return _is_truthy(d.get(key, "1" if default else "0"))
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_int(value: str | None, default: int) -> int:
|
||||||
|
try:
|
||||||
|
return int(value) if value is not None else default
|
||||||
|
except Exception:
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _link_with_params(base_url: str, **params) -> str:
|
||||||
|
# Filter out None, encode safely
|
||||||
|
q = {k: v for k, v in params.items() if v is not None}
|
||||||
|
return f"{base_url}?{urlencode(q)}"
|
||||||
|
|
||||||
|
|
||||||
def generate_crud_blueprint(model, service):
|
def generate_crud_blueprint(model, service):
|
||||||
bp = Blueprint(model.__name__.lower(), __name__)
|
bp = Blueprint(model.__name__.lower(), __name__)
|
||||||
|
|
||||||
@bp.get('/')
|
@bp.get("/")
|
||||||
def list_items():
|
def list_items():
|
||||||
|
# Work from a copy so we don't mutate request.args
|
||||||
args = request.args.to_dict(flat=True)
|
args = request.args.to_dict(flat=True)
|
||||||
|
|
||||||
# legacy detection
|
|
||||||
legacy_offset = "offset" in args or "page" in args
|
legacy_offset = "offset" in args or "page" in args
|
||||||
|
|
||||||
# sane limit default
|
limit = _safe_int(args.get("limit"), 50)
|
||||||
try:
|
|
||||||
limit = int(args.get("limit", 50))
|
|
||||||
except Exception:
|
|
||||||
limit = 50
|
|
||||||
args["limit"] = limit
|
args["limit"] = limit
|
||||||
|
|
||||||
if legacy_offset:
|
if legacy_offset:
|
||||||
|
|
@ -25,17 +42,23 @@ def generate_crud_blueprint(model, service):
|
||||||
items = service.list(args)
|
items = service.list(args)
|
||||||
return jsonify([obj.as_dict() for obj in items])
|
return jsonify([obj.as_dict() for obj in items])
|
||||||
|
|
||||||
# New behavior: keyset seek with cursors
|
# New behavior: keyset pagination with cursors
|
||||||
key, backward = decode_cursor(args.get("cursor"))
|
cursor_token = args.get("cursor")
|
||||||
|
key, desc_from_cursor, backward = decode_cursor(cursor_token)
|
||||||
|
|
||||||
window = service.seek_window(
|
window = service.seek_window(
|
||||||
args,
|
args,
|
||||||
key=key,
|
key=key,
|
||||||
backward=backward,
|
backward=backward,
|
||||||
include_total=_is_truthy(args.get("include_total", "1")),
|
include_total=_bool_param(args, "include_total", True),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Prefer the order actually used by the window; fall back to desc_from_cursor if needed.
|
||||||
|
try:
|
||||||
desc_flags = list(window.order.desc)
|
desc_flags = list(window.order.desc)
|
||||||
|
except Exception:
|
||||||
|
desc_flags = desc_from_cursor or []
|
||||||
|
|
||||||
body = {
|
body = {
|
||||||
"items": [obj.as_dict() for obj in window.items],
|
"items": [obj.as_dict() for obj in window.items],
|
||||||
"limit": window.limit,
|
"limit": window.limit,
|
||||||
|
|
@ -45,46 +68,60 @@ def generate_crud_blueprint(model, service):
|
||||||
}
|
}
|
||||||
|
|
||||||
resp = jsonify(body)
|
resp = jsonify(body)
|
||||||
# Optional Link header
|
|
||||||
links = []
|
# Preserve user’s other query params like include_total, filters, sorts, etc.
|
||||||
|
base_url = request.base_url
|
||||||
|
base_params = {k: v for k, v in args.items() if k not in {"cursor"}}
|
||||||
|
link_parts = []
|
||||||
if body["next_cursor"]:
|
if body["next_cursor"]:
|
||||||
links.append(f'<{request.base_url}?cursor={body["next_cursor"]}&limit={window.limit}>; rel="next"')
|
link_parts.append(
|
||||||
|
f'<{_link_with_params(base_url, **base_params, cursor=body["next_cursor"])}>; rel="next"'
|
||||||
|
)
|
||||||
if body["prev_cursor"]:
|
if body["prev_cursor"]:
|
||||||
links.append(f'<{request.base_url}?cursor={body["prev_cursor"]}&limit={window.limit}>; rel="prev"')
|
link_parts.append(
|
||||||
if links:
|
f'<{_link_with_params(base_url, **base_params, cursor=body["prev_cursor"])}>; rel="prev"'
|
||||||
resp.headers["Link"] = ", ".join(links)
|
)
|
||||||
|
if link_parts:
|
||||||
|
resp.headers["Link"] = ", ".join(link_parts)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
@bp.get('/<int:id>')
|
@bp.get("/<int:id>")
|
||||||
def get_item(id):
|
def get_item(id):
|
||||||
item = service.get(id, request.args)
|
|
||||||
try:
|
try:
|
||||||
|
item = service.get(id, request.args)
|
||||||
|
if item is None:
|
||||||
|
abort(404)
|
||||||
return jsonify(item.as_dict())
|
return jsonify(item.as_dict())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return jsonify({"status": "error", "error": str(e)})
|
# Could be validation, auth, or just you forgetting an index again
|
||||||
|
return jsonify({"status": "error", "error": str(e)}), 400
|
||||||
|
|
||||||
@bp.post('/')
|
@bp.post("/")
|
||||||
def create_item():
|
def create_item():
|
||||||
obj = service.create(request.json)
|
payload = request.get_json(silent=True) or {}
|
||||||
try:
|
try:
|
||||||
return jsonify(obj.as_dict())
|
obj = service.create(payload)
|
||||||
|
return jsonify(obj.as_dict()), 201
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return jsonify({"status": "error", "error": str(e)})
|
return jsonify({"status": "error", "error": str(e)}), 400
|
||||||
|
|
||||||
@bp.patch('/<int:id>')
|
@bp.patch("/<int:id>")
|
||||||
def update_item(id):
|
def update_item(id):
|
||||||
obj = service.update(id, request.json)
|
payload = request.get_json(silent=True) or {}
|
||||||
try:
|
try:
|
||||||
|
obj = service.update(id, payload)
|
||||||
return jsonify(obj.as_dict())
|
return jsonify(obj.as_dict())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return jsonify({"status": "error", "error": str(e)})
|
# 404 if not found, 400 if validation. Your service can throw specific exceptions if you ever feel like being professional.
|
||||||
|
return jsonify({"status": "error", "error": str(e)}), 400
|
||||||
|
|
||||||
@bp.delete('/<int:id>')
|
@bp.delete("/<int:id>")
|
||||||
def delete_item(id):
|
def delete_item(id):
|
||||||
service.delete(id)
|
|
||||||
try:
|
try:
|
||||||
return jsonify({"status": "success"}), 204
|
service.delete(id)
|
||||||
|
# 204 means "no content" so don't send any.
|
||||||
|
return ("", 204)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return jsonify({"status": "error", "error": str(e)})
|
return jsonify({"status": "error", "error": str(e)}), 400
|
||||||
|
|
||||||
return bp
|
return bp
|
||||||
|
|
|
||||||
|
|
@ -1,92 +1,22 @@
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Tuple, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
|
||||||
from sqlalchemy import and_, func, inspect, or_, text
|
from sqlalchemy import and_, func, inspect, or_, text
|
||||||
from sqlalchemy.engine import Engine, Connection
|
from sqlalchemy.engine import Engine, Connection
|
||||||
from sqlalchemy.orm import Load, Session, raiseload, selectinload, with_polymorphic, Mapper, RelationshipProperty, class_mapper, ColumnProperty
|
from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager, selectinload
|
||||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||||
from sqlalchemy.orm.base import NO_VALUE
|
|
||||||
from sqlalchemy.orm.util import AliasedClass
|
|
||||||
from sqlalchemy.sql import operators
|
from sqlalchemy.sql import operators
|
||||||
from sqlalchemy.sql.elements import UnaryExpression
|
from sqlalchemy.sql.elements import UnaryExpression, ColumnElement
|
||||||
|
|
||||||
from crudkit.core.base import Version
|
from crudkit.core.base import Version
|
||||||
from crudkit.core.spec import CRUDSpec
|
from crudkit.core.spec import CRUDSpec
|
||||||
from crudkit.core.types import OrderSpec, SeekWindow
|
from crudkit.core.types import OrderSpec, SeekWindow
|
||||||
from crudkit.backend import BackendInfo, make_backend_info
|
from crudkit.backend import BackendInfo, make_backend_info
|
||||||
|
from crudkit.projection import compile_projection
|
||||||
|
|
||||||
def _expand_requires(model_cls, fields):
|
import logging
|
||||||
out, seen = [], set()
|
log = logging.getLogger("crudkit.service")
|
||||||
def add(f):
|
|
||||||
if f not in seen:
|
|
||||||
seen.add(f); out.append(f)
|
|
||||||
|
|
||||||
for f in fields:
|
|
||||||
add(f)
|
|
||||||
parts = f.split(".")
|
|
||||||
cur_cls = model_cls
|
|
||||||
prefix = []
|
|
||||||
|
|
||||||
for p in parts[:-1]:
|
|
||||||
rel = getattr(cur_cls.__mapper__.relationships, 'get', lambda _: None)(p)
|
|
||||||
if not rel:
|
|
||||||
cur_cls = None
|
|
||||||
break
|
|
||||||
cur_cls = rel.mapper.class_
|
|
||||||
prefix.append(p)
|
|
||||||
|
|
||||||
if cur_cls is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
leaf = parts[-1]
|
|
||||||
deps = (getattr(cur_cls, "__crudkit_field_requires__", {}) or {}).get(leaf)
|
|
||||||
if not deps:
|
|
||||||
continue
|
|
||||||
|
|
||||||
pre = ".".join(prefix)
|
|
||||||
for dep in deps:
|
|
||||||
add(f"{pre + '.' if pre else ''}{dep}")
|
|
||||||
return out
|
|
||||||
|
|
||||||
def _is_rel(model_cls, name: str) -> bool:
|
|
||||||
try:
|
|
||||||
prop = model_cls.__mapper__.relationships.get(name)
|
|
||||||
return isinstance(prop, RelationshipProperty)
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _is_instrumented_column(attr) -> bool:
|
|
||||||
try:
|
|
||||||
return hasattr(attr, "property") and isinstance(attr.property, ColumnProperty)
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _loader_options_for_fields(root_alias, model_cls, fields: list[str]) -> list[Load]:
|
|
||||||
"""
|
|
||||||
For bare MANYTOONE names in fields (e.g. "location"), selectinload the relationship
|
|
||||||
and only fetch the related PK. This is enough for preselecting <select> inputs
|
|
||||||
without projecting the FK column on the root model.
|
|
||||||
"""
|
|
||||||
opts: list[Load] = []
|
|
||||||
if not fields:
|
|
||||||
return opts
|
|
||||||
|
|
||||||
mapper = class_mapper(model_cls)
|
|
||||||
for name in fields:
|
|
||||||
prop = mapper.relationships.get(name)
|
|
||||||
if not isinstance(prop, RelationshipProperty):
|
|
||||||
continue
|
|
||||||
if prop.direction.name != "MANYTOONE":
|
|
||||||
continue
|
|
||||||
|
|
||||||
rel_attr = getattr(root_alias, name)
|
|
||||||
target_cls = prop.mapper.class_
|
|
||||||
# load_only PK if present; else just selectinload
|
|
||||||
id_attr = getattr(target_cls, "id", None)
|
|
||||||
if id_attr is not None:
|
|
||||||
opts.append(selectinload(rel_attr).load_only(id_attr))
|
|
||||||
else:
|
|
||||||
opts.append(selectinload(rel_attr))
|
|
||||||
|
|
||||||
return opts
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class _HasID(Protocol):
|
class _HasID(Protocol):
|
||||||
|
|
@ -110,9 +40,85 @@ class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol):
|
||||||
|
|
||||||
T = TypeVar("T", bound=_CRUDModelProto)
|
T = TypeVar("T", bound=_CRUDModelProto)
|
||||||
|
|
||||||
|
def _hops_from_sort(params: dict | None) -> set[str]:
|
||||||
|
"""Extract first-hop relationship names from a sort spec like 'owner.first_name,-brand.name'."""
|
||||||
|
if not params:
|
||||||
|
return set()
|
||||||
|
raw = params.get("sort")
|
||||||
|
tokens: list[str] = []
|
||||||
|
if isinstance(raw, str):
|
||||||
|
tokens = [t.strip() for t in raw.split(",") if t.strip()]
|
||||||
|
elif isinstance(raw, (list, tuple)):
|
||||||
|
for item in raw:
|
||||||
|
if isinstance(item, str):
|
||||||
|
tokens.extend([t.strip() for t in item.split(",") if t.strip()])
|
||||||
|
hops: set[str] = set()
|
||||||
|
for tok in tokens:
|
||||||
|
tok = tok.lstrip("+-")
|
||||||
|
if "." in tok:
|
||||||
|
hops.add(tok.split(".", 1)[0])
|
||||||
|
return hops
|
||||||
|
|
||||||
|
def _belongs_to_alias(col: Any, alias: Any) -> bool:
|
||||||
|
# Try to detect if a column/expression ultimately comes from this alias.
|
||||||
|
# Works for most ORM columns; complex expressions may need more.
|
||||||
|
t = getattr(col, "table", None)
|
||||||
|
selectable = getattr(alias, "selectable", None)
|
||||||
|
return t is not None and selectable is not None and t is selectable
|
||||||
|
|
||||||
|
def _paths_needed_for_sql(order_by: Iterable[Any], filters: Iterable[Any], join_paths: tuple) -> set[str]:
|
||||||
|
hops: set[str] = set()
|
||||||
|
paths: set[tuple[str, ...]] = set()
|
||||||
|
# Sort columns
|
||||||
|
for ob in order_by or []:
|
||||||
|
col = getattr(ob, "element", ob) # unwrap UnaryExpression
|
||||||
|
for _path, rel_attr, target_alias in join_paths:
|
||||||
|
if _belongs_to_alias(col, target_alias):
|
||||||
|
hops.add(rel_attr.key)
|
||||||
|
# Filter columns (best-effort)
|
||||||
|
# Walk simple binary expressions
|
||||||
|
def _extract_cols(expr: Any) -> Iterable[Any]:
|
||||||
|
if isinstance(expr, ColumnElement):
|
||||||
|
yield expr
|
||||||
|
for ch in getattr(expr, "get_children", lambda: [])():
|
||||||
|
yield from _extract_cols(ch)
|
||||||
|
elif hasattr(expr, "clauses"):
|
||||||
|
for ch in expr.clauses:
|
||||||
|
yield from _extract_cols(ch)
|
||||||
|
|
||||||
|
for flt in filters or []:
|
||||||
|
for col in _extract_cols(flt):
|
||||||
|
for _path, rel_attr, target_alias in join_paths:
|
||||||
|
if _belongs_to_alias(col, target_alias):
|
||||||
|
hops.add(rel_attr.key)
|
||||||
|
return hops
|
||||||
|
|
||||||
|
def _paths_from_fields(req_fields: list[str]) -> set[str]:
|
||||||
|
out: set[str] = set()
|
||||||
|
for f in req_fields:
|
||||||
|
if "." in f:
|
||||||
|
parent = f.split(".", 1)[0]
|
||||||
|
if parent:
|
||||||
|
out.add(parent)
|
||||||
|
return out
|
||||||
|
|
||||||
def _is_truthy(val):
|
def _is_truthy(val):
|
||||||
return str(val).lower() in ('1', 'true', 'yes', 'on')
|
return str(val).lower() in ('1', 'true', 'yes', 'on')
|
||||||
|
|
||||||
|
def _normalize_fields_param(params: dict | None) -> list[str]:
|
||||||
|
if not params:
|
||||||
|
return []
|
||||||
|
raw = params.get("fields")
|
||||||
|
if isinstance(raw, (list, tuple)):
|
||||||
|
out: list[str] = []
|
||||||
|
for item in raw:
|
||||||
|
if isinstance(item, str):
|
||||||
|
out.extend([p for p in (s.strip() for s in item.split(",")) if p])
|
||||||
|
return out
|
||||||
|
if isinstance(raw, str):
|
||||||
|
return [p for p in (s.strip() for s in raw.split(",")) if p]
|
||||||
|
return []
|
||||||
|
|
||||||
class CRUDService(Generic[T]):
|
class CRUDService(Generic[T]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -126,8 +132,19 @@ class CRUDService(Generic[T]):
|
||||||
self._session_factory = session_factory
|
self._session_factory = session_factory
|
||||||
self.polymorphic = polymorphic
|
self.polymorphic = polymorphic
|
||||||
self.supports_soft_delete = hasattr(model, 'is_deleted')
|
self.supports_soft_delete = hasattr(model, 'is_deleted')
|
||||||
# Cache backend info once. If not provided, derive from session bind.
|
|
||||||
bind = self.session.get_bind()
|
# Derive engine WITHOUT leaking a session/connection
|
||||||
|
bind = getattr(session_factory, "bind", None)
|
||||||
|
if bind is None:
|
||||||
|
tmp_sess = session_factory()
|
||||||
|
try:
|
||||||
|
bind = tmp_sess.get_bind()
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
tmp_sess.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind)
|
eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind)
|
||||||
self.backend = backend or make_backend_info(eng)
|
self.backend = backend or make_backend_info(eng)
|
||||||
|
|
||||||
|
|
@ -141,58 +158,18 @@ class CRUDService(Generic[T]):
|
||||||
return self.session.query(poly), poly
|
return self.session.query(poly), poly
|
||||||
return self.session.query(self.model), self.model
|
return self.session.query(self.model), self.model
|
||||||
|
|
||||||
def _resolve_required_includes(self, root_alias: Any, rel_field_names: Dict[Tuple[str, ...], List[str]]) -> List[Any]:
|
def _debug_bind(self, where: str):
|
||||||
"""
|
try:
|
||||||
For each dotted path like ("location"), -> ["label"], look up the target
|
bind = self.session.get_bind()
|
||||||
model's __crudkit_field_requires__ for the terminal field and produce
|
eng = getattr(bind, "engine", bind)
|
||||||
selectinload options prefixed with the relationship path, e.g.:
|
print(f"SERVICE BIND [{where}]: engine_id={id(eng)} url={getattr(eng, 'url', '?')} session={type(self.session).__name__}")
|
||||||
Room.__crudkit_field_requires__['label'] = ['room_function']
|
except Exception as e:
|
||||||
=> selectinload(root.location).selectinload(Room.room_function)
|
print(f"SERVICE BIND [{where}]: failed to introspect bind: {e}")
|
||||||
"""
|
|
||||||
opts: List[Any] = []
|
|
||||||
root_mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
|
|
||||||
|
|
||||||
for path, names in (rel_field_names or {}).items():
|
def _apply_not_deleted(self, query, root_alias, params) -> Any:
|
||||||
if not path:
|
if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")):
|
||||||
continue
|
return query.filter(getattr(root_alias, "is_deleted") == False)
|
||||||
|
return query
|
||||||
current_mapper = root_mapper
|
|
||||||
rel_props: List[RelationshipProperty] = []
|
|
||||||
|
|
||||||
valid = True
|
|
||||||
for step in path:
|
|
||||||
rel = current_mapper.relationships.get(step)
|
|
||||||
if not isinstance(rel, RelationshipProperty):
|
|
||||||
valid = False
|
|
||||||
break
|
|
||||||
rel_props.append(rel)
|
|
||||||
current_mapper = cast(Mapper[Any], inspect(rel.entity.entity))
|
|
||||||
if not valid or not rel_props:
|
|
||||||
continue
|
|
||||||
|
|
||||||
first = rel_props[0]
|
|
||||||
base_loader = selectinload(getattr(root_alias, first.key))
|
|
||||||
for i in range(1, len(rel_props)):
|
|
||||||
prev_target_cls = rel_props[i - 1].mapper.class_
|
|
||||||
hop_attr = getattr(prev_target_cls, rel_props[i].key)
|
|
||||||
base_loader = base_loader.selectinload(hop_attr)
|
|
||||||
|
|
||||||
target_cls = rel_props[-1].mapper.class_
|
|
||||||
|
|
||||||
requires = getattr(target_cls, "__crudkit_field_requires__", None)
|
|
||||||
if not isinstance(requires, dict):
|
|
||||||
continue
|
|
||||||
|
|
||||||
for field_name in names:
|
|
||||||
needed: Iterable[str] = requires.get(field_name, []) or []
|
|
||||||
for rel_need in needed:
|
|
||||||
rel_prop2 = target_cls.__mapper__.relationships.get(rel_need)
|
|
||||||
if not isinstance(rel_prop2, RelationshipProperty):
|
|
||||||
continue
|
|
||||||
dep_attr = getattr(target_cls, rel_prop2.key)
|
|
||||||
opts.append(base_loader.selectinload(dep_attr))
|
|
||||||
|
|
||||||
return opts
|
|
||||||
|
|
||||||
def _extract_order_spec(self, root_alias, given_order_by):
|
def _extract_order_spec(self, root_alias, given_order_by):
|
||||||
"""
|
"""
|
||||||
|
|
@ -200,8 +177,6 @@ class CRUDService(Generic[T]):
|
||||||
Normalize order_by into (cols, desc_flags). Supports plain columns and
|
Normalize order_by into (cols, desc_flags). Supports plain columns and
|
||||||
col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses.
|
col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses.
|
||||||
"""
|
"""
|
||||||
from sqlalchemy.sql import operators
|
|
||||||
from sqlalchemy.sql.elements import UnaryExpression
|
|
||||||
|
|
||||||
given = self._stable_order_by(root_alias, given_order_by)
|
given = self._stable_order_by(root_alias, given_order_by)
|
||||||
|
|
||||||
|
|
@ -209,7 +184,7 @@ class CRUDService(Generic[T]):
|
||||||
for ob in given:
|
for ob in given:
|
||||||
# Unwrap column if this is a UnaryExpression produced by .asc()/.desc()
|
# Unwrap column if this is a UnaryExpression produced by .asc()/.desc()
|
||||||
elem = getattr(ob, "element", None)
|
elem = getattr(ob, "element", None)
|
||||||
col = elem if elem is not None else ob # don't use "or" with SA expressions
|
col = elem if elem is not None else ob
|
||||||
|
|
||||||
# Detect direction in SA 2.x
|
# Detect direction in SA 2.x
|
||||||
is_desc = False
|
is_desc = False
|
||||||
|
|
@ -223,31 +198,33 @@ class CRUDService(Generic[T]):
|
||||||
cols.append(col)
|
cols.append(col)
|
||||||
desc_flags.append(bool(is_desc))
|
desc_flags.append(bool(is_desc))
|
||||||
|
|
||||||
from crudkit.core.types import OrderSpec
|
|
||||||
return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags))
|
return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags))
|
||||||
|
|
||||||
def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool):
|
def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool):
|
||||||
"""
|
|
||||||
Build lexicographic predicate for keyset seek.
|
|
||||||
For backward traversal, import comparisons.
|
|
||||||
"""
|
|
||||||
if not key_vals:
|
if not key_vals:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
conds = []
|
conds = []
|
||||||
for i, col in enumerate(spec.cols):
|
for i, col in enumerate(spec.cols):
|
||||||
|
# If NULLs possible, normalize for comparison. Example using coalesce to a sentinel:
|
||||||
|
# sent_col = func.coalesce(col, literal("-∞"))
|
||||||
|
sent_col = col
|
||||||
ties = [spec.cols[j] == key_vals[j] for j in range(i)]
|
ties = [spec.cols[j] == key_vals[j] for j in range(i)]
|
||||||
is_desc = spec.desc[i]
|
is_desc = spec.desc[i]
|
||||||
if not backward:
|
if not backward:
|
||||||
op = col < key_vals[i] if is_desc else col > key_vals[i]
|
op = (sent_col < key_vals[i]) if is_desc else (sent_col > key_vals[i])
|
||||||
else:
|
else:
|
||||||
op = col > key_vals[i] if is_desc else col < key_vals[i]
|
op = (sent_col > key_vals[i]) if is_desc else (sent_col < key_vals[i])
|
||||||
conds.append(and_(*ties, op))
|
conds.append(and_(*ties, op))
|
||||||
return or_(*conds)
|
return or_(*conds)
|
||||||
|
|
||||||
def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]:
|
def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]:
|
||||||
out = []
|
out = []
|
||||||
for c in spec.cols:
|
for c in spec.cols:
|
||||||
|
# Only simple mapped columns supported for key pluck
|
||||||
key = getattr(c, "key", None) or getattr(c, "name", None)
|
key = getattr(c, "key", None) or getattr(c, "name", None)
|
||||||
|
if key is None or not hasattr(obj, key):
|
||||||
|
raise ValueError("Order includes non-mapped or related column; cannot pluck cursor key from row object.")
|
||||||
out.append(getattr(obj, key))
|
out.append(getattr(obj, key))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
@ -266,62 +243,86 @@ class CRUDService(Generic[T]):
|
||||||
- forward/backward seek via `key` and `backward`
|
- forward/backward seek via `key` and `backward`
|
||||||
Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total.
|
Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total.
|
||||||
"""
|
"""
|
||||||
fields = list(params.get("fields", []))
|
self._debug_bind("seek_window")
|
||||||
if fields:
|
session = self.session
|
||||||
fields = _expand_requires(self.model, fields)
|
|
||||||
params = {**params, "fields": fields}
|
|
||||||
query, root_alias = self.get_query()
|
query, root_alias = self.get_query()
|
||||||
|
|
||||||
|
# Normalize requested fields and compile projection (may skip later to avoid conflicts)
|
||||||
|
fields = _normalize_fields_param(params)
|
||||||
|
expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], [])
|
||||||
|
|
||||||
spec = CRUDSpec(self.model, params or {}, root_alias)
|
spec = CRUDSpec(self.model, params or {}, root_alias)
|
||||||
|
|
||||||
filters = spec.parse_filters()
|
filters = spec.parse_filters()
|
||||||
order_by = spec.parse_sort()
|
order_by = spec.parse_sort()
|
||||||
|
|
||||||
|
# Field parsing for root load_only fallback
|
||||||
root_fields, rel_field_names, root_field_names = spec.parse_fields()
|
root_fields, rel_field_names, root_field_names = spec.parse_fields()
|
||||||
|
|
||||||
seen_rel_roots = set()
|
|
||||||
for path, names in (rel_field_names or {}).items():
|
|
||||||
if not path:
|
|
||||||
continue
|
|
||||||
rel_name = path[0]
|
|
||||||
if rel_name in seen_rel_roots:
|
|
||||||
continue
|
|
||||||
if _is_rel(self.model, rel_name):
|
|
||||||
rel_attr = getattr(root_alias, rel_name, None)
|
|
||||||
if rel_attr is not None:
|
|
||||||
query = query.options(selectinload(rel_attr))
|
|
||||||
seen_rel_roots.add(rel_name)
|
|
||||||
|
|
||||||
# Soft delete filter
|
# Soft delete filter
|
||||||
if self.supports_soft_delete and not _is_truthy(params.get("include_deleted")):
|
query = self._apply_not_deleted(query, root_alias, params)
|
||||||
query = query.filter(getattr(root_alias, "is_deleted") == False)
|
|
||||||
|
|
||||||
# Parse filters first
|
# Apply filters first
|
||||||
if filters:
|
if filters:
|
||||||
query = query.filter(*filters)
|
query = query.filter(*filters)
|
||||||
|
|
||||||
# Includes + joins (so relationship fields like brand.name, location.label work)
|
# Includes + join paths (dotted fields etc.)
|
||||||
spec.parse_includes()
|
spec.parse_includes()
|
||||||
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
|
join_paths = tuple(spec.get_join_paths()) # iterable of (path, relationship_attr, target_alias)
|
||||||
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
|
||||||
target = cast(Any, target_alias)
|
|
||||||
query = query.join(target, rel_attr.of_type(target), isouter=True)
|
|
||||||
|
|
||||||
# Fields/projection: load_only for root columns, eager loads for relationships
|
# Relationship names required by ORDER BY / WHERE
|
||||||
|
sql_hops: set[str] = _paths_needed_for_sql(order_by, filters, join_paths)
|
||||||
|
# Also include relationships mentioned directly in the sort spec
|
||||||
|
sql_hops |= _hops_from_sort(params)
|
||||||
|
|
||||||
|
# First-hop relationship names implied by dotted projection fields
|
||||||
|
proj_hops: set[str] = _paths_from_fields(fields)
|
||||||
|
|
||||||
|
# Root column projection
|
||||||
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
|
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
|
||||||
if only_cols:
|
if only_cols:
|
||||||
query = query.options(Load(root_alias).load_only(*only_cols))
|
query = query.options(Load(root_alias).load_only(*only_cols))
|
||||||
# for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
|
|
||||||
# query = query.options(eager)
|
|
||||||
|
|
||||||
for opt in self._resolve_required_includes(root_alias, rel_field_names):
|
# Relationship handling per path (avoid loader strategy conflicts)
|
||||||
query = query.options(opt)
|
used_contains_eager = False
|
||||||
|
joined_names: set[str] = set()
|
||||||
|
|
||||||
|
for _path, relationship_attr, target_alias in join_paths:
|
||||||
|
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
||||||
|
name = relationship_attr.key
|
||||||
|
if name in sql_hops:
|
||||||
|
# Needed for WHERE/ORDER BY: join + hydrate from that join
|
||||||
|
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
|
||||||
|
query = query.options(contains_eager(rel_attr, alias=target_alias))
|
||||||
|
used_contains_eager = True
|
||||||
|
joined_names.add(name)
|
||||||
|
elif name in proj_hops:
|
||||||
|
# Display-only: bulk-load efficiently, no join
|
||||||
|
query = query.options(selectinload(rel_attr))
|
||||||
|
joined_names.add(name)
|
||||||
|
|
||||||
|
# Force-join any SQL-needed relationships that weren't in join_paths
|
||||||
|
missing_sql = sql_hops - joined_names
|
||||||
|
for name in missing_sql:
|
||||||
|
rel_attr = cast(InstrumentedAttribute, getattr(root_alias, name))
|
||||||
|
query = query.join(rel_attr, isouter=True)
|
||||||
|
query = query.options(contains_eager(rel_attr))
|
||||||
|
used_contains_eager = True
|
||||||
|
joined_names.add(name)
|
||||||
|
|
||||||
|
# Apply projection loader options only if they won't conflict with contains_eager
|
||||||
|
if proj_opts and not used_contains_eager:
|
||||||
|
query = query.options(*proj_opts)
|
||||||
|
|
||||||
# Order + limit
|
# Order + limit
|
||||||
order_spec = self._extract_order_spec(root_alias, order_by) # SA2-only helper
|
order_spec = self._extract_order_spec(root_alias, order_by) # SA 2.x helper
|
||||||
limit, _ = spec.parse_pagination()
|
limit, _ = spec.parse_pagination()
|
||||||
if not limit or limit <= 0:
|
if limit is None:
|
||||||
limit = 50 # sensible default
|
effective_limit = 50
|
||||||
|
elif limit == 0:
|
||||||
|
effective_limit = None # unlimited
|
||||||
|
else:
|
||||||
|
effective_limit = limit
|
||||||
|
|
||||||
# Keyset predicate
|
# Keyset predicate
|
||||||
if key:
|
if key:
|
||||||
|
|
@ -331,30 +332,36 @@ class CRUDService(Generic[T]):
|
||||||
|
|
||||||
# Apply ordering. For backward, invert SQL order then reverse in-memory for display.
|
# Apply ordering. For backward, invert SQL order then reverse in-memory for display.
|
||||||
if not backward:
|
if not backward:
|
||||||
clauses = []
|
clauses = [(c.desc() if is_desc else c.asc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)]
|
||||||
for col, is_desc in zip(order_spec.cols, order_spec.desc):
|
query = query.order_by(*clauses)
|
||||||
clauses.append(col.desc() if is_desc else col.asc())
|
if effective_limit is not None:
|
||||||
query = query.order_by(*clauses).limit(limit)
|
query = query.limit(effective_limit)
|
||||||
items = query.all()
|
items = query.all()
|
||||||
else:
|
else:
|
||||||
inv_clauses = []
|
inv_clauses = [(c.asc() if is_desc else c.desc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)]
|
||||||
for col, is_desc in zip(order_spec.cols, order_spec.desc):
|
query = query.order_by(*inv_clauses)
|
||||||
inv_clauses.append(col.asc() if is_desc else col.desc())
|
if effective_limit is not None:
|
||||||
query = query.order_by(*inv_clauses).limit(limit)
|
query = query.limit(effective_limit)
|
||||||
items = list(reversed(query.all()))
|
items = list(reversed(query.all()))
|
||||||
|
|
||||||
# Tag projection so your renderer knows what fields were requested
|
# Tag projection so your renderer knows what fields were requested
|
||||||
|
if fields:
|
||||||
|
proj = list(dict.fromkeys(fields)) # dedupe, preserve order
|
||||||
|
if "id" not in proj and hasattr(self.model, "id"):
|
||||||
|
proj.insert(0, "id")
|
||||||
|
else:
|
||||||
proj = []
|
proj = []
|
||||||
if root_field_names:
|
if root_field_names:
|
||||||
proj.extend(root_field_names)
|
proj.extend(root_field_names)
|
||||||
if root_fields:
|
if root_fields:
|
||||||
proj.extend(c.key for c in root_fields)
|
proj.extend(c.key for c in root_fields if hasattr(c, "key"))
|
||||||
for path, names in (rel_field_names or {}).items():
|
for path, names in (rel_field_names or {}).items():
|
||||||
prefix = ".".join(path)
|
prefix = ".".join(path)
|
||||||
for n in names:
|
for n in names:
|
||||||
proj.append(f"{prefix}.{n}")
|
proj.append(f"{prefix}.{n}")
|
||||||
if proj and "id" not in proj and hasattr(self.model, "id"):
|
if proj and "id" not in proj and hasattr(self.model, "id"):
|
||||||
proj.insert(0, "id")
|
proj.insert(0, "id")
|
||||||
|
|
||||||
if proj:
|
if proj:
|
||||||
for obj in items:
|
for obj in items:
|
||||||
try:
|
try:
|
||||||
|
|
@ -363,28 +370,73 @@ class CRUDService(Generic[T]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Boundary keys for cursor encoding in the API layer
|
# Boundary keys for cursor encoding in the API layer
|
||||||
first_key = self._pluck_key(items[0], order_spec) if items else None
|
# When ORDER BY includes related columns (e.g., owner.first_name),
|
||||||
last_key = self._pluck_key(items[-1], order_spec) if items else None
|
# pluck values from the related object we hydrated with contains_eager/selectinload.
|
||||||
|
def _pluck_key_from_obj(obj: Any) -> list[Any]:
|
||||||
|
vals: list[Any] = []
|
||||||
|
# Build a quick map: selectable -> relationship name
|
||||||
|
alias_to_rel: dict[Any, str] = {}
|
||||||
|
for _p, relationship_attr, target_alias in join_paths:
|
||||||
|
sel = getattr(target_alias, "selectable", None)
|
||||||
|
if sel is not None:
|
||||||
|
alias_to_rel[sel] = relationship_attr.key
|
||||||
|
|
||||||
|
for col in order_spec.cols:
|
||||||
|
key = getattr(col, "key", None) or getattr(col, "name", None)
|
||||||
|
# Try root attribute first
|
||||||
|
if key and hasattr(obj, key):
|
||||||
|
vals.append(getattr(obj, key))
|
||||||
|
continue
|
||||||
|
# Try relationship hop by matching the column's table/selectable
|
||||||
|
table = getattr(col, "table", None)
|
||||||
|
relname = alias_to_rel.get(table)
|
||||||
|
if relname and key:
|
||||||
|
relobj = getattr(obj, relname, None)
|
||||||
|
if relobj is not None and hasattr(relobj, key):
|
||||||
|
vals.append(getattr(relobj, key))
|
||||||
|
continue
|
||||||
|
# Give up: unsupported expression for cursor purposes
|
||||||
|
raise ValueError("unpluckable")
|
||||||
|
return vals
|
||||||
|
|
||||||
|
try:
|
||||||
|
first_key = _pluck_key_from_obj(items[0]) if items else None
|
||||||
|
last_key = _pluck_key_from_obj(items[-1]) if items else None
|
||||||
|
except Exception:
|
||||||
|
# If we can't derive cursor keys (e.g., ORDER BY expression/aggregate),
|
||||||
|
# disable cursors for this response rather than exploding.
|
||||||
|
first_key = None
|
||||||
|
last_key = None
|
||||||
|
|
||||||
# Optional total that’s safe under JOINs (COUNT DISTINCT ids)
|
# Optional total that’s safe under JOINs (COUNT DISTINCT ids)
|
||||||
total = None
|
total = None
|
||||||
if include_total:
|
if include_total:
|
||||||
base = self.session.query(getattr(root_alias, "id"))
|
base = session.query(getattr(root_alias, "id"))
|
||||||
if self.supports_soft_delete and not _is_truthy(params.get("include_deleted")):
|
base = self._apply_not_deleted(base, root_alias, params)
|
||||||
base = base.filter(getattr(root_alias, "is_deleted") == False)
|
|
||||||
if filters:
|
if filters:
|
||||||
base = base.filter(*filters)
|
base = base.filter(*filters)
|
||||||
# replicate the same joins used above
|
# Mirror join structure for any SQL-needed relationships
|
||||||
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
|
for _path, relationship_attr, target_alias in join_paths:
|
||||||
|
if relationship_attr.key in sql_hops:
|
||||||
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
||||||
target = cast(Any, target_alias)
|
base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
|
||||||
base = base.join(target, rel_attr.of_type(target), isouter=True)
|
# Also mirror any forced joins
|
||||||
total = self.session.query(func.count()).select_from(base.order_by(None).distinct().subquery()).scalar() or 0
|
for name in (sql_hops - {ra.key for _p, ra, _a in join_paths}):
|
||||||
|
rel_attr = cast(InstrumentedAttribute, getattr(root_alias, name))
|
||||||
|
base = base.join(rel_attr, isouter=True)
|
||||||
|
|
||||||
|
total = session.query(func.count()).select_from(
|
||||||
|
base.order_by(None).distinct().subquery()
|
||||||
|
).scalar() or 0
|
||||||
|
|
||||||
|
window_limit_for_body = 0 if effective_limit is None and limit == 0 else (effective_limit or 50)
|
||||||
|
|
||||||
|
if log.isEnabledFor(logging.DEBUG):
|
||||||
|
log.debug("QUERY: %s", str(query))
|
||||||
|
|
||||||
from crudkit.core.types import SeekWindow # avoid circulars at module top
|
|
||||||
return SeekWindow(
|
return SeekWindow(
|
||||||
items=items,
|
items=items,
|
||||||
limit=limit,
|
limit=window_limit_for_body,
|
||||||
first_key=first_key,
|
first_key=first_key,
|
||||||
last_key=last_key,
|
last_key=last_key,
|
||||||
order=order_spec,
|
order=order_spec,
|
||||||
|
|
@ -416,148 +468,173 @@ class CRUDService(Generic[T]):
|
||||||
for col in mapper.primary_key:
|
for col in mapper.primary_key:
|
||||||
try:
|
try:
|
||||||
pk_cols.append(getattr(root_alias, col.key))
|
pk_cols.append(getattr(root_alias, col.key))
|
||||||
except ArithmeticError:
|
except AttributeError:
|
||||||
pk_cols.append(col)
|
pk_cols.append(col)
|
||||||
|
|
||||||
return [*order_by, *pk_cols]
|
return [*order_by, *pk_cols]
|
||||||
|
|
||||||
def get(self, id: int, params=None) -> T | None:
|
def get(self, id: int, params=None) -> T | None:
|
||||||
|
"""Fetch a single row by id with conflict-free eager loading and clean projection."""
|
||||||
|
self._debug_bind("get")
|
||||||
query, root_alias = self.get_query()
|
query, root_alias = self.get_query()
|
||||||
|
|
||||||
include_deleted = False
|
# Defaults so we can build a projection even if params is None
|
||||||
root_fields = []
|
root_fields: list[Any] = []
|
||||||
root_field_names = {}
|
root_field_names: dict[str, str] = {}
|
||||||
rel_field_names = {}
|
rel_field_names: dict[tuple[str, ...], list[str]] = {}
|
||||||
|
req_fields: list[str] = _normalize_fields_param(params)
|
||||||
|
|
||||||
|
# Soft-delete guard
|
||||||
|
query = self._apply_not_deleted(query, root_alias, params)
|
||||||
|
|
||||||
spec = CRUDSpec(self.model, params or {}, root_alias)
|
spec = CRUDSpec(self.model, params or {}, root_alias)
|
||||||
if params:
|
|
||||||
if self.supports_soft_delete:
|
# Optional extra filters (in addition to id); keep parity with list()
|
||||||
include_deleted = _is_truthy(params.get('include_deleted'))
|
filters = spec.parse_filters()
|
||||||
if self.supports_soft_delete and not include_deleted:
|
if filters:
|
||||||
query = query.filter(getattr(root_alias, "is_deleted") == False)
|
query = query.filter(*filters)
|
||||||
|
|
||||||
|
# Always filter by id
|
||||||
query = query.filter(getattr(root_alias, "id") == id)
|
query = query.filter(getattr(root_alias, "id") == id)
|
||||||
|
|
||||||
|
# Includes + join paths we may need
|
||||||
spec.parse_includes()
|
spec.parse_includes()
|
||||||
|
join_paths = tuple(spec.get_join_paths())
|
||||||
|
|
||||||
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
|
# Field parsing to enable root load_only
|
||||||
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
|
||||||
target = cast(Any, target_alias)
|
|
||||||
query = query.join(target, rel_attr.of_type(target), isouter=True)
|
|
||||||
|
|
||||||
if params:
|
if params:
|
||||||
root_fields, rel_field_names, root_field_names = spec.parse_fields()
|
root_fields, rel_field_names, root_field_names = spec.parse_fields()
|
||||||
|
|
||||||
if rel_field_names:
|
# Decide which relationship paths are needed for SQL vs display-only
|
||||||
seen_rel_roots = set()
|
# For get(), there is no ORDER BY; only filters might force SQL use.
|
||||||
for path, names in rel_field_names.items():
|
sql_hops = _paths_needed_for_sql(order_by=None, filters=filters, join_paths=join_paths)
|
||||||
if not path:
|
proj_hops = _paths_from_fields(req_fields)
|
||||||
continue
|
|
||||||
rel_name = path[0]
|
|
||||||
if rel_name in seen_rel_roots:
|
|
||||||
continue
|
|
||||||
if _is_rel(self.model, rel_name):
|
|
||||||
rel_attr = getattr(root_alias, rel_name, None)
|
|
||||||
if rel_attr is not None:
|
|
||||||
query = query.options(selectinload(rel_attr))
|
|
||||||
seen_rel_roots.add(rel_name)
|
|
||||||
|
|
||||||
fields = (params or {}).get("fields") if isinstance(params, dict) else None
|
|
||||||
if fields:
|
|
||||||
for opt in _loader_options_for_fields(root_alias, self.model, fields):
|
|
||||||
query = query.options(opt)
|
|
||||||
|
|
||||||
|
# Root column projection
|
||||||
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
|
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
|
||||||
if only_cols:
|
if only_cols:
|
||||||
query = query.options(Load(root_alias).load_only(*only_cols))
|
query = query.options(Load(root_alias).load_only(*only_cols))
|
||||||
|
|
||||||
# for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
|
# Relationship handling per path: avoid loader strategy conflicts
|
||||||
# query = query.options(eager)
|
used_contains_eager = False
|
||||||
|
for _path, relationship_attr, target_alias in join_paths:
|
||||||
|
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
||||||
|
name = relationship_attr.key
|
||||||
|
if name in sql_hops:
|
||||||
|
# Needed in WHERE: join + hydrate from the join
|
||||||
|
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
|
||||||
|
query = query.options(contains_eager(rel_attr, alias=target_alias))
|
||||||
|
used_contains_eager = True
|
||||||
|
elif name in proj_hops:
|
||||||
|
# Display-only: bulk-load efficiently
|
||||||
|
query = query.options(selectinload(rel_attr))
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
if params:
|
# Projection loader options compiled from requested fields.
|
||||||
fields = params.get("fields") or []
|
# Skip if we used contains_eager to avoid strategy conflicts.
|
||||||
for opt in _loader_options_for_fields(root_alias, self.model, fields):
|
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
|
||||||
query = query.options(opt)
|
if proj_opts and not used_contains_eager:
|
||||||
|
query = query.options(*proj_opts)
|
||||||
|
|
||||||
obj = query.first()
|
obj = query.first()
|
||||||
|
|
||||||
|
# Emit exactly what the client requested (plus id), or a reasonable fallback
|
||||||
|
if req_fields:
|
||||||
|
proj = list(dict.fromkeys(req_fields)) # dedupe, preserve order
|
||||||
|
if "id" not in proj and hasattr(self.model, "id"):
|
||||||
|
proj.insert(0, "id")
|
||||||
|
else:
|
||||||
proj = []
|
proj = []
|
||||||
if root_field_names:
|
if root_field_names:
|
||||||
proj.extend(root_field_names)
|
proj.extend(root_field_names)
|
||||||
if root_fields:
|
if root_fields:
|
||||||
proj.extend(c.key for c in root_fields)
|
proj.extend(c.key for c in root_fields if hasattr(c, "key"))
|
||||||
for path, names in (rel_field_names or {}).items():
|
for path, names in (rel_field_names or {}).items():
|
||||||
prefix = ".".join(path)
|
prefix = ".".join(path)
|
||||||
for n in names:
|
for n in names:
|
||||||
proj.append(f"{prefix}.{n}")
|
proj.append(f"{prefix}.{n}")
|
||||||
|
|
||||||
if proj and "id" not in proj and hasattr(self.model, "id"):
|
if proj and "id" not in proj and hasattr(self.model, "id"):
|
||||||
proj.insert(0, "id")
|
proj.insert(0, "id")
|
||||||
|
|
||||||
if proj:
|
if proj and obj is not None:
|
||||||
try:
|
try:
|
||||||
setattr(obj, "__crudkit_projection__", tuple(proj))
|
setattr(obj, "__crudkit_projection__", tuple(proj))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if log.isEnabledFor(logging.DEBUG):
|
||||||
|
log.debug("QUERY: %s", str(query))
|
||||||
|
|
||||||
return obj or None
|
return obj or None
|
||||||
|
|
||||||
def list(self, params=None) -> list[T]:
|
def list(self, params=None) -> list[T]:
|
||||||
|
"""Offset/limit listing with smart relationship loading and clean projection."""
|
||||||
|
self._debug_bind("list")
|
||||||
query, root_alias = self.get_query()
|
query, root_alias = self.get_query()
|
||||||
|
|
||||||
root_fields = []
|
# Defaults so we can reference them later even if params is None
|
||||||
root_field_names = {}
|
root_fields: list[Any] = []
|
||||||
rel_field_names = {}
|
root_field_names: dict[str, str] = {}
|
||||||
|
rel_field_names: dict[tuple[str, ...], list[str]] = {}
|
||||||
|
req_fields: list[str] = _normalize_fields_param(params)
|
||||||
|
|
||||||
if params:
|
if params:
|
||||||
if self.supports_soft_delete:
|
query = self._apply_not_deleted(query, root_alias, params)
|
||||||
include_deleted = _is_truthy(params.get('include_deleted'))
|
|
||||||
if not include_deleted:
|
|
||||||
query = query.filter(getattr(root_alias, "is_deleted") == False)
|
|
||||||
|
|
||||||
spec = CRUDSpec(self.model, params or {}, root_alias)
|
spec = CRUDSpec(self.model, params or {}, root_alias)
|
||||||
filters = spec.parse_filters()
|
filters = spec.parse_filters()
|
||||||
order_by = spec.parse_sort()
|
order_by = spec.parse_sort()
|
||||||
limit, offset = spec.parse_pagination()
|
limit, offset = spec.parse_pagination()
|
||||||
|
|
||||||
|
# Includes + join paths we might need
|
||||||
spec.parse_includes()
|
spec.parse_includes()
|
||||||
|
join_paths = tuple(spec.get_join_paths())
|
||||||
|
|
||||||
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
|
# Field parsing for load_only on root columns
|
||||||
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
|
||||||
target = cast(Any, target_alias)
|
|
||||||
query = query.join(target, rel_attr.of_type(target), isouter=True)
|
|
||||||
|
|
||||||
if params:
|
|
||||||
root_fields, rel_field_names, root_field_names = spec.parse_fields()
|
root_fields, rel_field_names, root_field_names = spec.parse_fields()
|
||||||
|
|
||||||
if rel_field_names:
|
|
||||||
seen_rel_roots = set()
|
|
||||||
for path, names in rel_field_names.items():
|
|
||||||
if not path:
|
|
||||||
continue
|
|
||||||
rel_name = path[0]
|
|
||||||
if rel_name in seen_rel_roots:
|
|
||||||
continue
|
|
||||||
if _is_rel(self.model, rel_name):
|
|
||||||
rel_attr = getattr(root_alias, rel_name, None)
|
|
||||||
if rel_attr is not None:
|
|
||||||
query = query.options(selectinload(rel_attr))
|
|
||||||
seen_rel_roots.add(rel_name)
|
|
||||||
|
|
||||||
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
|
|
||||||
if only_cols:
|
|
||||||
query = query.options(Load(root_alias).load_only(*only_cols))
|
|
||||||
|
|
||||||
# for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
|
|
||||||
# query = query.options(eager)
|
|
||||||
|
|
||||||
if params:
|
|
||||||
fields = params.get("fields") or []
|
|
||||||
for opt in _loader_options_for_fields(root_alias, self.model, fields):
|
|
||||||
query = query.options(opt)
|
|
||||||
|
|
||||||
if filters:
|
if filters:
|
||||||
query = query.filter(*filters)
|
query = query.filter(*filters)
|
||||||
|
|
||||||
# MSSQL: requires ORDER BY when using OFFSET (and SQLA will use OFFSET for limit+offset).
|
# Determine which relationship paths are needed for SQL vs display-only
|
||||||
|
sql_hops = _paths_needed_for_sql(order_by, filters, join_paths)
|
||||||
|
sql_hops |= _hops_from_sort(params) # ensure sort-driven joins exist
|
||||||
|
proj_hops = _paths_from_fields(req_fields)
|
||||||
|
|
||||||
|
# Root column projection
|
||||||
|
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
|
||||||
|
if only_cols:
|
||||||
|
query = query.options(Load(root_alias).load_only(*only_cols))
|
||||||
|
|
||||||
|
# Relationship handling per path
|
||||||
|
used_contains_eager = False
|
||||||
|
joined_names: set[str] = set()
|
||||||
|
|
||||||
|
for _path, relationship_attr, target_alias in join_paths:
|
||||||
|
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
||||||
|
name = relationship_attr.key
|
||||||
|
if name in sql_hops:
|
||||||
|
# Needed for WHERE/ORDER BY: join + hydrate from the join
|
||||||
|
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
|
||||||
|
query = query.options(contains_eager(rel_attr, alias=target_alias))
|
||||||
|
used_contains_eager = True
|
||||||
|
joined_names.add(name)
|
||||||
|
elif name in proj_hops:
|
||||||
|
# Display-only: no join, bulk-load efficiently
|
||||||
|
query = query.options(selectinload(rel_attr))
|
||||||
|
joined_names.add(name)
|
||||||
|
|
||||||
|
# Force-join any SQL-needed relationships that weren't in join_paths
|
||||||
|
missing_sql = sql_hops - joined_names
|
||||||
|
for name in missing_sql:
|
||||||
|
rel_attr = cast(InstrumentedAttribute, getattr(root_alias, name))
|
||||||
|
query = query.join(rel_attr, isouter=True)
|
||||||
|
query = query.options(contains_eager(rel_attr))
|
||||||
|
used_contains_eager = True
|
||||||
|
joined_names.add(name)
|
||||||
|
|
||||||
|
# MSSQL requires ORDER BY when OFFSET is used (SQLA uses OFFSET for limit/offset)
|
||||||
paginating = (limit is not None) or (offset is not None and offset != 0)
|
paginating = (limit is not None) or (offset is not None and offset != 0)
|
||||||
if paginating and not order_by and self.backend.requires_order_by_for_offset:
|
if paginating and not order_by and self.backend.requires_order_by_for_offset:
|
||||||
order_by = self._default_order_by(root_alias)
|
order_by = self._default_order_by(root_alias)
|
||||||
|
|
@ -565,24 +642,41 @@ class CRUDService(Generic[T]):
|
||||||
if order_by:
|
if order_by:
|
||||||
query = query.order_by(*order_by)
|
query = query.order_by(*order_by)
|
||||||
|
|
||||||
# Only apply offset/limit when not None.
|
# Only apply offset/limit when not None and not zero
|
||||||
if offset is not None and offset != 0:
|
if offset is not None and offset != 0:
|
||||||
query = query.offset(offset)
|
query = query.offset(offset)
|
||||||
if limit is not None and limit > 0:
|
if limit is not None and limit > 0:
|
||||||
query = query.limit(limit)
|
query = query.limit(limit)
|
||||||
|
|
||||||
|
# Projection loader options compiled from requested fields.
|
||||||
|
# Skip if we used contains_eager to avoid loader-strategy conflicts.
|
||||||
|
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
|
||||||
|
if proj_opts and not used_contains_eager:
|
||||||
|
query = query.options(*proj_opts)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# No params means no filters/sorts/limits; still honor projection loaders if any
|
||||||
|
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
|
||||||
|
if proj_opts:
|
||||||
|
query = query.options(*proj_opts)
|
||||||
|
|
||||||
rows = query.all()
|
rows = query.all()
|
||||||
|
|
||||||
|
# Emit exactly what the client requested (plus id), or a reasonable fallback
|
||||||
|
if req_fields:
|
||||||
|
proj = list(dict.fromkeys(req_fields)) # dedupe while preserving order
|
||||||
|
if "id" not in proj and hasattr(self.model, "id"):
|
||||||
|
proj.insert(0, "id")
|
||||||
|
else:
|
||||||
proj = []
|
proj = []
|
||||||
if root_field_names:
|
if root_field_names:
|
||||||
proj.extend(root_field_names)
|
proj.extend(root_field_names)
|
||||||
if root_fields:
|
if root_fields:
|
||||||
proj.extend(c.key for c in root_fields)
|
proj.extend(c.key for c in root_fields if hasattr(c, "key"))
|
||||||
for path, names in (rel_field_names or {}).items():
|
for path, names in (rel_field_names or {}).items():
|
||||||
prefix = ".".join(path)
|
prefix = ".".join(path)
|
||||||
for n in names:
|
for n in names:
|
||||||
proj.append(f"{prefix}.{n}")
|
proj.append(f"{prefix}.{n}")
|
||||||
|
|
||||||
if proj and "id" not in proj and hasattr(self.model, "id"):
|
if proj and "id" not in proj and hasattr(self.model, "id"):
|
||||||
proj.insert(0, "id")
|
proj.insert(0, "id")
|
||||||
|
|
||||||
|
|
@ -593,41 +687,52 @@ class CRUDService(Generic[T]):
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if log.isEnabledFor(logging.DEBUG):
|
||||||
|
log.debug("QUERY: %s", str(query))
|
||||||
|
|
||||||
return rows
|
return rows
|
||||||
|
|
||||||
|
|
||||||
def create(self, data: dict, actor=None) -> T:
|
def create(self, data: dict, actor=None) -> T:
|
||||||
|
session = self.session
|
||||||
obj = self.model(**data)
|
obj = self.model(**data)
|
||||||
self.session.add(obj)
|
session.add(obj)
|
||||||
self.session.commit()
|
session.commit()
|
||||||
self._log_version("create", obj, actor)
|
self._log_version("create", obj, actor)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
def update(self, id: int, data: dict, actor=None) -> T:
|
def update(self, id: int, data: dict, actor=None) -> T:
|
||||||
obj = self.get(id)
|
session = self.session
|
||||||
|
obj = session.get(self.model, id)
|
||||||
if not obj:
|
if not obj:
|
||||||
raise ValueError(f"{self.model.__name__} with ID {id} not found.")
|
raise ValueError(f"{self.model.__name__} with ID {id} not found.")
|
||||||
valid_fields = {c.name for c in self.model.__table__.columns}
|
valid_fields = {c.name for c in self.model.__table__.columns}
|
||||||
|
unknown = set(data) - valid_fields
|
||||||
|
if unknown:
|
||||||
|
raise ValueError(f"Unknown fields: {', '.join(sorted(unknown))}")
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
if k in valid_fields:
|
if k in valid_fields:
|
||||||
setattr(obj, k, v)
|
setattr(obj, k, v)
|
||||||
self.session.commit()
|
session.commit()
|
||||||
self._log_version("update", obj, actor)
|
self._log_version("update", obj, actor)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
def delete(self, id: int, hard: bool = False, actor = False):
|
def delete(self, id: int, hard: bool = False, actor = None):
|
||||||
obj = self.session.get(self.model, id)
|
session = self.session
|
||||||
|
obj = session.get(self.model, id)
|
||||||
if not obj:
|
if not obj:
|
||||||
return None
|
return None
|
||||||
if hard or not self.supports_soft_delete:
|
if hard or not self.supports_soft_delete:
|
||||||
self.session.delete(obj)
|
session.delete(obj)
|
||||||
else:
|
else:
|
||||||
soft = cast(_SoftDeletable, obj)
|
soft = cast(_SoftDeletable, obj)
|
||||||
soft.is_deleted = True
|
soft.is_deleted = True
|
||||||
self.session.commit()
|
session.commit()
|
||||||
self._log_version("delete", obj, actor)
|
self._log_version("delete", obj, actor)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict = {}):
|
def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict | None = None):
|
||||||
|
session = self.session
|
||||||
try:
|
try:
|
||||||
data = obj.as_dict()
|
data = obj.as_dict()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -640,5 +745,5 @@ class CRUDService(Generic[T]):
|
||||||
actor=str(actor) if actor else None,
|
actor=str(actor) if actor else None,
|
||||||
meta=metadata
|
meta=metadata
|
||||||
)
|
)
|
||||||
self.session.add(version)
|
session.add(version)
|
||||||
self.session.commit()
|
session.commit()
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,58 @@
|
||||||
|
# crudkit/integrations/flask.py
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from flask import Flask
|
from flask import Flask
|
||||||
from sqlalchemy.orm import scoped_session
|
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||||
from ..engines import CRUDKitRuntime
|
from ..engines import CRUDKitRuntime
|
||||||
from ..config import Config
|
from ..config import Config
|
||||||
|
|
||||||
def init_app(app: Flask, *, runtime: CRUDKitRuntime | None = None, config: type[Config] | None == None):
|
def init_app(app: Flask, *, runtime: CRUDKitRuntime | None = None, config: type[Config] | None = None):
|
||||||
"""
|
"""
|
||||||
Initializes CRUDKit for a Flask app. Provies `app.extensions['crudkit']`
|
Initializes CRUDKit for a Flask app. Provides `app.extensions['crudkit']`
|
||||||
with a runtime (engine + session_factory). Caller manages session lifecycle.
|
with a runtime (engine + session_factory). Caller manages session lifecycle.
|
||||||
"""
|
"""
|
||||||
runtime = runtime or CRUDKitRuntime(config=config)
|
runtime = runtime or CRUDKitRuntime(config=config)
|
||||||
app.extensions.setdefault("crudkit", {})
|
app.extensions.setdefault("crudkit", {})
|
||||||
app.extensions["crudkit"]["runtime"] = runtime
|
app.extensions["crudkit"]["runtime"] = runtime
|
||||||
|
|
||||||
Session = runtime.session_factory
|
# Build ONE sessionmaker bound to the ONE true engine object
|
||||||
if Session is not None:
|
# so engine id == sessionmaker.bind id, always.
|
||||||
app.extensions["crudkit"]["Session"] = scoped_session(Session)
|
engine = runtime.engine
|
||||||
|
SessionFactory = runtime.session_factory or sessionmaker(bind=engine, **runtime._config.session_kwargs())
|
||||||
|
app.extensions["crudkit"]["SessionFactory"] = SessionFactory
|
||||||
|
app.extensions["crudkit"]["Session"] = scoped_session(SessionFactory)
|
||||||
|
|
||||||
|
# Attach pool listeners to the *same* engine the SessionFactory is bound to.
|
||||||
|
# Don’t guess. Don’t hope. Inspect.
|
||||||
|
try:
|
||||||
|
bound_engine = getattr(SessionFactory, "bind", None) or getattr(SessionFactory, "kw", {}).get("bind") or engine
|
||||||
|
pool = bound_engine.pool
|
||||||
|
|
||||||
|
from sqlalchemy import event
|
||||||
|
|
||||||
|
@event.listens_for(pool, "checkout")
|
||||||
|
def _on_checkout(dbapi_conn, conn_record, conn_proxy):
|
||||||
|
sz = pool.size()
|
||||||
|
chk = pool.checkedout()
|
||||||
|
try:
|
||||||
|
conns_in_pool = pool.checkedin()
|
||||||
|
except Exception:
|
||||||
|
conns_in_pool = "?"
|
||||||
|
print(f"POOL CHECKOUT: Pool size: {sz} Connections in pool: {conns_in_pool} "
|
||||||
|
f"Current Overflow: {pool.overflow()} Current Checked out connections: {chk} "
|
||||||
|
f"engine id= {id(bound_engine)}")
|
||||||
|
|
||||||
|
@event.listens_for(pool, "checkin")
|
||||||
|
def _on_checkin(dbapi_conn, conn_record):
|
||||||
|
sz = pool.size()
|
||||||
|
chk = pool.checkedout()
|
||||||
|
try:
|
||||||
|
conns_in_pool = pool.checkedin()
|
||||||
|
except Exception:
|
||||||
|
conns_in_pool = "?"
|
||||||
|
print(f"POOL CHECKIN: Pool size: {sz} Connections in pool: {conns_in_pool} "
|
||||||
|
f"Current Overflow: {pool.overflow()} Current Checked out connections: {chk} "
|
||||||
|
f"engine id= {id(bound_engine)}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[crudkit.init_app] Failed to attach pool listeners: {e}")
|
||||||
|
|
||||||
return runtime
|
return runtime
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue