Compare commits
3 commits
15ae0caf27
...
ce7d092be4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ce7d092be4 | ||
|
|
01a0031cf4 | ||
|
|
e829de9792 |
2 changed files with 142 additions and 66 deletions
|
|
@ -2,29 +2,45 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Blueprint, jsonify, request, abort
|
||||
from flask import Blueprint, jsonify, request, abort, current_app, url_for
|
||||
from hashlib import md5
|
||||
from urllib.parse import urlencode
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
from crudkit.api._cursor import encode_cursor, decode_cursor
|
||||
from crudkit.core.service import _is_truthy
|
||||
|
||||
MAX_JSON = 1_000_000
|
||||
|
||||
def _etag_for(obj) -> str:
|
||||
v = getattr(obj, "updated_at", None) or obj.id
|
||||
return md5(str(v).encode()).hexdigest()
|
||||
|
||||
def _json_payload() -> dict:
|
||||
if request.content_length and request.content_length > MAX_JSON:
|
||||
abort(413)
|
||||
if not request.is_json:
|
||||
abort(415)
|
||||
payload = request.get_json(silent=False)
|
||||
if not isinstance(payload, dict):
|
||||
abort(400)
|
||||
return payload
|
||||
|
||||
def _args_flat() -> dict[str, str]:
|
||||
return request.args.to_dict(flat=True) # type: ignore[arg-type]
|
||||
|
||||
def _json_error(e: Exception, status: int = 400):
|
||||
if isinstance(e, HTTPException):
|
||||
status = e.code or status
|
||||
msg = e.description
|
||||
else:
|
||||
msg = str(e)
|
||||
if current_app.debug:
|
||||
return jsonify({"status": "error", "error": msg, "type": e.__class__.__name__}), status
|
||||
return jsonify({"status": "error", "error": msg}), status
|
||||
|
||||
def _bool_param(d: dict[str, str], key: str, default: bool) -> bool:
|
||||
return _is_truthy(d.get(key, "1" if default else "0"))
|
||||
|
||||
|
||||
def _safe_int(value: str | None, default: int) -> int:
|
||||
try:
|
||||
return int(value) if value is not None else default
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
|
||||
def _link_with_params(base_url: str, **params) -> str:
|
||||
q = {k: v for k, v in params.items() if v is not None}
|
||||
return f"{base_url}?{urlencode(q)}"
|
||||
|
||||
|
||||
def generate_crud_blueprint(model, service, *, base_prefix: str | None = None, rest: bool = True, rpc: bool = True):
|
||||
"""
|
||||
REST:
|
||||
|
|
@ -50,65 +66,75 @@ def generate_crud_blueprint(model, service, *, base_prefix: str | None = None, r
|
|||
|
||||
bp = Blueprint(plural, __name__, url_prefix=f"/api/{plural}")
|
||||
|
||||
@bp.errorhandler(Exception)
|
||||
def _handle_any(e: Exception):
|
||||
return _json_error(e)
|
||||
|
||||
@bp.errorhandler(404)
|
||||
def _not_found(_e):
|
||||
return jsonify({"status": "error", "error": "not found"}), 404
|
||||
|
||||
# ---------- REST ----------
|
||||
if rest:
|
||||
@bp.get("/")
|
||||
def rest_list():
|
||||
args = request.args.to_dict(flat=True)
|
||||
args = _args_flat()
|
||||
# support cursor pagination transparently; fall back to limit/offset
|
||||
try:
|
||||
items = service.list(args)
|
||||
return jsonify([o.as_dict() for o in items])
|
||||
except Exception as e:
|
||||
return jsonify({"status": "error", "error": str(e)}), 400
|
||||
return _json_error(e)
|
||||
|
||||
@bp.get("/<int:obj_id>")
|
||||
def rest_get(obj_id: int):
|
||||
try:
|
||||
item = service.get(obj_id, request.args)
|
||||
if item is None:
|
||||
abort(404)
|
||||
return jsonify(item.as_dict())
|
||||
except Exception as e:
|
||||
return jsonify({"status": "error", "error": str(e)}), 400
|
||||
item = service.get(obj_id, request.args)
|
||||
if item is None:
|
||||
abort(404)
|
||||
etag = _etag_for(item)
|
||||
if request.if_none_match and (etag in request.if_none_match):
|
||||
return "", 304
|
||||
resp = jsonify(item.as_dict())
|
||||
resp.set_etag(etag)
|
||||
return resp
|
||||
|
||||
@bp.post("/")
|
||||
def rest_create():
|
||||
payload = request.get_json(silent=True) or {}
|
||||
payload = _json_payload()
|
||||
try:
|
||||
obj = service.create(payload)
|
||||
resp = jsonify(obj.as_dict())
|
||||
resp.status_code = 201
|
||||
resp.headers["Location"] = f"{request.base_url.rstrip('/')}/{obj.id}"
|
||||
resp.headers["Location"] = url_for(f"{plural}.rest_get", obj_id=obj.id, _external=False)
|
||||
return resp
|
||||
except Exception as e:
|
||||
return jsonify({"status": "error", "error": str(e)}), 400
|
||||
return _json_error(e)
|
||||
|
||||
@bp.patch("/<int:obj_id>")
|
||||
def rest_update(obj_id: int):
|
||||
payload = request.get_json(silent=True) or {}
|
||||
payload = _json_payload()
|
||||
try:
|
||||
obj = service.update(obj_id, payload)
|
||||
return jsonify(obj.as_dict())
|
||||
except Exception as e:
|
||||
return jsonify({"status": "error", "error": str(e)}), 400
|
||||
return _json_error(e)
|
||||
|
||||
@bp.delete("/<int:obj_id>")
|
||||
def rest_delete(obj_id: int):
|
||||
hard = (request.args.get("hard") in ("1", "true", "yes"))
|
||||
hard = _bool_param(_args_flat(), "hard", False) # type: ignore[arg-type]
|
||||
try:
|
||||
obj = service.delete(obj_id, hard=hard)
|
||||
if obj is None:
|
||||
abort(404)
|
||||
return ("", 204)
|
||||
except Exception as e:
|
||||
return jsonify({"status": "error", "error": str(e)}), 400
|
||||
return _json_error(e)
|
||||
|
||||
# ---------- RPC (your existing routes) ----------
|
||||
if rpc:
|
||||
# your original functions verbatim, shortened here for sanity
|
||||
@bp.get("/get")
|
||||
def rpc_get():
|
||||
print("⚠️ WARNING: Deprecated RPC call used: /get")
|
||||
id_ = int(request.args.get("id", 0))
|
||||
if not id_:
|
||||
return jsonify({"status": "error", "error": "missing required param: id"}), 400
|
||||
|
|
@ -118,48 +144,52 @@ def generate_crud_blueprint(model, service, *, base_prefix: str | None = None, r
|
|||
abort(404)
|
||||
return jsonify(item.as_dict())
|
||||
except Exception as e:
|
||||
return jsonify({"status": "error", "error": str(e)}), 400
|
||||
return _json_error(e)
|
||||
|
||||
@bp.get("/list")
|
||||
def rpc_list():
|
||||
args = request.args.to_dict(flat=True)
|
||||
print("⚠️ WARNING: Deprecated RPC call used: /list")
|
||||
args = _args_flat()
|
||||
try:
|
||||
items = service.list(args)
|
||||
return jsonify([obj.as_dict() for obj in items])
|
||||
except Exception as e:
|
||||
return jsonify({"status": "error", "error": str(e)}), 400
|
||||
return _json_error(e)
|
||||
|
||||
@bp.post("/create")
|
||||
def rpc_create():
|
||||
payload = request.get_json(silent=True) or {}
|
||||
print("⚠️ WARNING: Deprecated RPC call used: /create")
|
||||
payload = _json_payload()
|
||||
try:
|
||||
obj = service.create(payload)
|
||||
return jsonify(obj.as_dict()), 201
|
||||
except Exception as e:
|
||||
return jsonify({"status": "error", "error": str(e)}), 400
|
||||
return _json_error(e)
|
||||
|
||||
@bp.patch("/update")
|
||||
def rpc_update():
|
||||
print("⚠️ WARNING: Deprecated RPC call used: /update")
|
||||
id_ = int(request.args.get("id", 0))
|
||||
if not id_:
|
||||
return jsonify({"status": "error", "error": "missing required param: id"}), 400
|
||||
payload = request.get_json(silent=True) or {}
|
||||
payload = _json_payload()
|
||||
try:
|
||||
obj = service.update(id_, payload)
|
||||
return jsonify(obj.as_dict())
|
||||
except Exception as e:
|
||||
return jsonify({"status": "error", "error": str(e)}), 400
|
||||
return _json_error(e)
|
||||
|
||||
@bp.delete("/delete")
|
||||
def rpc_delete():
|
||||
print("⚠️ WARNING: Deprecated RPC call used: /delete")
|
||||
id_ = int(request.args.get("id", 0))
|
||||
if not id_:
|
||||
return jsonify({"status": "error", "error": "missing required param: id"}), 400
|
||||
hard = (request.args.get("hard") in ("1", "true", "yes"))
|
||||
hard = _bool_param(_args_flat(), "hard", False) # type: ignore[arg-type]
|
||||
try:
|
||||
obj = service.delete(id_, hard=hard)
|
||||
return ("", 204) if obj is not None else abort(404)
|
||||
except Exception as e:
|
||||
return jsonify({"status": "error", "error": str(e)}), 400
|
||||
return _json_error(e)
|
||||
|
||||
return bp
|
||||
|
|
|
|||
|
|
@ -1,18 +1,66 @@
|
|||
from typing import Any, Dict, Iterable, List, Tuple, Set
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Set, cast
|
||||
from sqlalchemy import Column, Integer, DateTime, Boolean, String, JSON, func, inspect
|
||||
from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE, RelationshipProperty
|
||||
from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE, RelationshipProperty, Mapper
|
||||
from sqlalchemy.orm.state import InstanceState
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
def _safe_get_loaded_attr(obj, name):
|
||||
@lru_cache(maxsize=512)
|
||||
def _column_names_for_model(cls: type) -> tuple[str, ...]:
|
||||
try:
|
||||
mapper = inspect(cls)
|
||||
return tuple(prop.key for prop in mapper.column_attrs)
|
||||
except Exception:
|
||||
names: list[str] = []
|
||||
for c in cls.__mro__:
|
||||
if hasattr(c, "__table__"):
|
||||
names.extend(col.name for col in c.__table__.columns)
|
||||
return tuple(dict.fromkeys(names))
|
||||
|
||||
def _sa_state(obj: Any) -> Optional[InstanceState[Any]]:
|
||||
"""Safely get SQLAlchemy InstanceState (or None)."""
|
||||
try:
|
||||
st = inspect(obj)
|
||||
attr = st.attrs.get(name)
|
||||
return cast(Optional[InstanceState[Any]], st)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _sa_mapper(obj: Any) -> Optional[Mapper]:
|
||||
"""Safely get Mapper for a maooed instance (or None)."""
|
||||
try:
|
||||
st = inspect(obj)
|
||||
mapper = getattr(st, "mapper", None)
|
||||
return cast(Optional[Mapper], mapper)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _safe_get_loaded_attr(obj, name):
|
||||
st = _sa_state(obj)
|
||||
if st is None:
|
||||
return None
|
||||
try:
|
||||
st_dict = getattr(st, "dict", {})
|
||||
if name in st_dict:
|
||||
return st_dict[name]
|
||||
|
||||
attrs = getattr(st, "attrs", None)
|
||||
attr = None
|
||||
if attrs is not None:
|
||||
try:
|
||||
attr = attrs[name]
|
||||
except Exception:
|
||||
try:
|
||||
get = getattr(attrs, "get", None)
|
||||
if callable(get):
|
||||
attr = get(name)
|
||||
except Exception:
|
||||
attr = None
|
||||
|
||||
if attr is not None:
|
||||
val = attr.loaded_value
|
||||
return None if val is NO_VALUE else val
|
||||
if name in st.dict:
|
||||
return st.dict.get(name)
|
||||
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
|
@ -33,14 +81,11 @@ def _is_collection_rel(prop: RelationshipProperty) -> bool:
|
|||
def _serialize_simple_obj(obj) -> Dict[str, Any]:
|
||||
"""Columns only (no relationships)."""
|
||||
out: Dict[str, Any] = {}
|
||||
for cls in obj.__class__.__mro__:
|
||||
if hasattr(cls, "__table__"):
|
||||
for col in cls.__table__.columns:
|
||||
name = col.name
|
||||
try:
|
||||
out[name] = getattr(obj, name)
|
||||
except Exception:
|
||||
out[name] = None
|
||||
for name in _column_names_for_model(type(obj)):
|
||||
try:
|
||||
out[name] = getattr(obj, name)
|
||||
except Exception:
|
||||
out[name] = None
|
||||
return out
|
||||
|
||||
def _serialize_loaded_rel(obj, name, *, depth: int, seen: Set[Tuple[type, Any]], embed: Set[str]) -> Any:
|
||||
|
|
@ -204,12 +249,16 @@ class CRUDMixin:
|
|||
|
||||
# Determine which relationships to consider
|
||||
try:
|
||||
st = inspect(self)
|
||||
mapper = st.mapper
|
||||
embed_set = set(str(x).split(".", 1)[0] for x in (embed or [])) # top-level names
|
||||
mapper = _sa_mapper(self)
|
||||
embed_set = set(str(x).split(".", 1)[0] for x in (embed or []))
|
||||
if mapper is None:
|
||||
return data
|
||||
st = _sa_state(self)
|
||||
if st is None:
|
||||
return data
|
||||
for name, prop in mapper.relationships.items():
|
||||
# Only touch relationships that are already loaded; never lazy-load here.
|
||||
rel_loaded = st.attrs.get(name)
|
||||
rel_loaded = getattr(st, "attrs", {}).get(name)
|
||||
if rel_loaded is None or rel_loaded.loaded_value is NO_VALUE:
|
||||
continue
|
||||
|
||||
|
|
@ -266,13 +315,10 @@ class CRUDMixin:
|
|||
val = None
|
||||
|
||||
# If it's a scalar ORM object (relationship), serialize its columns
|
||||
try:
|
||||
st = inspect(val) # will raise if not an ORM object
|
||||
if getattr(st, "mapper", None) is not None:
|
||||
out[name] = _serialize_simple_obj(val)
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
mapper = _sa_mapper(val)
|
||||
if mapper is not None:
|
||||
out[name] = _serialize_simple_obj(val)
|
||||
continue
|
||||
|
||||
# If it's a collection and no subfields were requested, emit a light list
|
||||
if isinstance(val, (list, tuple)):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue