Compare commits

..

No commits in common. "ce7d092be44ff1e9f391a235d328d461cf4555bc" and "15ae0caf2796fcf6c3fad183dace00772b973796" have entirely different histories.

2 changed files with 66 additions and 142 deletions

View file

@ -2,45 +2,29 @@
from __future__ import annotations
from flask import Blueprint, jsonify, request, abort, current_app, url_for
from hashlib import md5
from flask import Blueprint, jsonify, request, abort
from urllib.parse import urlencode
from werkzeug.exceptions import HTTPException
from crudkit.api._cursor import encode_cursor, decode_cursor
from crudkit.core.service import _is_truthy
MAX_JSON = 1_000_000
def _etag_for(obj) -> str:
v = getattr(obj, "updated_at", None) or obj.id
return md5(str(v).encode()).hexdigest()
def _json_payload() -> dict:
if request.content_length and request.content_length > MAX_JSON:
abort(413)
if not request.is_json:
abort(415)
payload = request.get_json(silent=False)
if not isinstance(payload, dict):
abort(400)
return payload
def _args_flat() -> dict[str, str]:
return request.args.to_dict(flat=True) # type: ignore[arg-type]
def _json_error(e: Exception, status: int = 400):
if isinstance(e, HTTPException):
status = e.code or status
msg = e.description
else:
msg = str(e)
if current_app.debug:
return jsonify({"status": "error", "error": msg, "type": e.__class__.__name__}), status
return jsonify({"status": "error", "error": msg}), status
def _bool_param(d: dict[str, str], key: str, default: bool) -> bool:
return _is_truthy(d.get(key, "1" if default else "0"))
def _safe_int(value: str | None, default: int) -> int:
try:
return int(value) if value is not None else default
except Exception:
return default
def _link_with_params(base_url: str, **params) -> str:
q = {k: v for k, v in params.items() if v is not None}
return f"{base_url}?{urlencode(q)}"
def generate_crud_blueprint(model, service, *, base_prefix: str | None = None, rest: bool = True, rpc: bool = True):
"""
REST:
@ -66,75 +50,65 @@ def generate_crud_blueprint(model, service, *, base_prefix: str | None = None, r
bp = Blueprint(plural, __name__, url_prefix=f"/api/{plural}")
@bp.errorhandler(Exception)
def _handle_any(e: Exception):
return _json_error(e)
@bp.errorhandler(404)
def _not_found(_e):
return jsonify({"status": "error", "error": "not found"}), 404
# ---------- REST ----------
if rest:
@bp.get("/")
def rest_list():
args = _args_flat()
args = request.args.to_dict(flat=True)
# support cursor pagination transparently; fall back to limit/offset
try:
items = service.list(args)
return jsonify([o.as_dict() for o in items])
except Exception as e:
return _json_error(e)
return jsonify({"status": "error", "error": str(e)}), 400
@bp.get("/<int:obj_id>")
def rest_get(obj_id: int):
try:
item = service.get(obj_id, request.args)
if item is None:
abort(404)
etag = _etag_for(item)
if request.if_none_match and (etag in request.if_none_match):
return "", 304
resp = jsonify(item.as_dict())
resp.set_etag(etag)
return resp
return jsonify(item.as_dict())
except Exception as e:
return jsonify({"status": "error", "error": str(e)}), 400
@bp.post("/")
def rest_create():
payload = _json_payload()
payload = request.get_json(silent=True) or {}
try:
obj = service.create(payload)
resp = jsonify(obj.as_dict())
resp.status_code = 201
resp.headers["Location"] = url_for(f"{plural}.rest_get", obj_id=obj.id, _external=False)
resp.headers["Location"] = f"{request.base_url.rstrip('/')}/{obj.id}"
return resp
except Exception as e:
return _json_error(e)
return jsonify({"status": "error", "error": str(e)}), 400
@bp.patch("/<int:obj_id>")
def rest_update(obj_id: int):
payload = _json_payload()
payload = request.get_json(silent=True) or {}
try:
obj = service.update(obj_id, payload)
return jsonify(obj.as_dict())
except Exception as e:
return _json_error(e)
return jsonify({"status": "error", "error": str(e)}), 400
@bp.delete("/<int:obj_id>")
def rest_delete(obj_id: int):
hard = _bool_param(_args_flat(), "hard", False) # type: ignore[arg-type]
hard = (request.args.get("hard") in ("1", "true", "yes"))
try:
obj = service.delete(obj_id, hard=hard)
if obj is None:
abort(404)
return ("", 204)
except Exception as e:
return _json_error(e)
return jsonify({"status": "error", "error": str(e)}), 400
# ---------- RPC (your existing routes) ----------
if rpc:
# your original functions verbatim, shortened here for sanity
@bp.get("/get")
def rpc_get():
print("⚠️ WARNING: Deprecated RPC call used: /get")
id_ = int(request.args.get("id", 0))
if not id_:
return jsonify({"status": "error", "error": "missing required param: id"}), 400
@ -144,52 +118,48 @@ def generate_crud_blueprint(model, service, *, base_prefix: str | None = None, r
abort(404)
return jsonify(item.as_dict())
except Exception as e:
return _json_error(e)
return jsonify({"status": "error", "error": str(e)}), 400
@bp.get("/list")
def rpc_list():
print("⚠️ WARNING: Deprecated RPC call used: /list")
args = _args_flat()
args = request.args.to_dict(flat=True)
try:
items = service.list(args)
return jsonify([obj.as_dict() for obj in items])
except Exception as e:
return _json_error(e)
return jsonify({"status": "error", "error": str(e)}), 400
@bp.post("/create")
def rpc_create():
print("⚠️ WARNING: Deprecated RPC call used: /create")
payload = _json_payload()
payload = request.get_json(silent=True) or {}
try:
obj = service.create(payload)
return jsonify(obj.as_dict()), 201
except Exception as e:
return _json_error(e)
return jsonify({"status": "error", "error": str(e)}), 400
@bp.patch("/update")
def rpc_update():
print("⚠️ WARNING: Deprecated RPC call used: /update")
id_ = int(request.args.get("id", 0))
if not id_:
return jsonify({"status": "error", "error": "missing required param: id"}), 400
payload = _json_payload()
payload = request.get_json(silent=True) or {}
try:
obj = service.update(id_, payload)
return jsonify(obj.as_dict())
except Exception as e:
return _json_error(e)
return jsonify({"status": "error", "error": str(e)}), 400
@bp.delete("/delete")
def rpc_delete():
print("⚠️ WARNING: Deprecated RPC call used: /delete")
id_ = int(request.args.get("id", 0))
if not id_:
return jsonify({"status": "error", "error": "missing required param: id"}), 400
hard = _bool_param(_args_flat(), "hard", False) # type: ignore[arg-type]
hard = (request.args.get("hard") in ("1", "true", "yes"))
try:
obj = service.delete(id_, hard=hard)
return ("", 204) if obj is not None else abort(404)
except Exception as e:
return _json_error(e)
return jsonify({"status": "error", "error": str(e)}), 400
return bp

View file

@ -1,66 +1,18 @@
from functools import lru_cache
from typing import Any, Dict, Iterable, List, Optional, Tuple, Set, cast
from typing import Any, Dict, Iterable, List, Tuple, Set
from sqlalchemy import Column, Integer, DateTime, Boolean, String, JSON, func, inspect
from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE, RelationshipProperty, Mapper
from sqlalchemy.orm.state import InstanceState
from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE, RelationshipProperty
Base = declarative_base()
@lru_cache(maxsize=512)
def _column_names_for_model(cls: type) -> tuple[str, ...]:
try:
mapper = inspect(cls)
return tuple(prop.key for prop in mapper.column_attrs)
except Exception:
names: list[str] = []
for c in cls.__mro__:
if hasattr(c, "__table__"):
names.extend(col.name for col in c.__table__.columns)
return tuple(dict.fromkeys(names))
def _sa_state(obj: Any) -> Optional[InstanceState[Any]]:
"""Safely get SQLAlchemy InstanceState (or None)."""
try:
st = inspect(obj)
return cast(Optional[InstanceState[Any]], st)
except Exception:
return None
def _sa_mapper(obj: Any) -> Optional[Mapper]:
"""Safely get Mapper for a maooed instance (or None)."""
try:
st = inspect(obj)
mapper = getattr(st, "mapper", None)
return cast(Optional[Mapper], mapper)
except Exception:
return None
def _safe_get_loaded_attr(obj, name):
st = _sa_state(obj)
if st is None:
return None
try:
st_dict = getattr(st, "dict", {})
if name in st_dict:
return st_dict[name]
attrs = getattr(st, "attrs", None)
attr = None
if attrs is not None:
try:
attr = attrs[name]
except Exception:
try:
get = getattr(attrs, "get", None)
if callable(get):
attr = get(name)
except Exception:
attr = None
st = inspect(obj)
attr = st.attrs.get(name)
if attr is not None:
val = attr.loaded_value
return None if val is NO_VALUE else val
if name in st.dict:
return st.dict.get(name)
return None
except Exception:
return None
@ -81,7 +33,10 @@ def _is_collection_rel(prop: RelationshipProperty) -> bool:
def _serialize_simple_obj(obj) -> Dict[str, Any]:
"""Columns only (no relationships)."""
out: Dict[str, Any] = {}
for name in _column_names_for_model(type(obj)):
for cls in obj.__class__.__mro__:
if hasattr(cls, "__table__"):
for col in cls.__table__.columns:
name = col.name
try:
out[name] = getattr(obj, name)
except Exception:
@ -249,16 +204,12 @@ class CRUDMixin:
# Determine which relationships to consider
try:
mapper = _sa_mapper(self)
embed_set = set(str(x).split(".", 1)[0] for x in (embed or []))
if mapper is None:
return data
st = _sa_state(self)
if st is None:
return data
st = inspect(self)
mapper = st.mapper
embed_set = set(str(x).split(".", 1)[0] for x in (embed or [])) # top-level names
for name, prop in mapper.relationships.items():
# Only touch relationships that are already loaded; never lazy-load here.
rel_loaded = getattr(st, "attrs", {}).get(name)
rel_loaded = st.attrs.get(name)
if rel_loaded is None or rel_loaded.loaded_value is NO_VALUE:
continue
@ -315,10 +266,13 @@ class CRUDMixin:
val = None
# If it's a scalar ORM object (relationship), serialize its columns
mapper = _sa_mapper(val)
if mapper is not None:
try:
st = inspect(val) # will raise if not an ORM object
if getattr(st, "mapper", None) is not None:
out[name] = _serialize_simple_obj(val)
continue
except Exception:
pass
# If it's a collection and no subfields were requested, emit a light list
if isinstance(val, (list, tuple)):