From e829de97924f6ed10ae1e50e6812f7012c642363 Mon Sep 17 00:00:00 2001 From: Yaro Kasear Date: Mon, 20 Oct 2025 11:03:03 -0500 Subject: [PATCH 1/3] Optimization and refactoring pass. --- crudkit/api/flask_api.py | 99 ++++++++++++++++++++++++++++------------ crudkit/core/base.py | 85 +++++++++++++++++++++++----------- 2 files changed, 130 insertions(+), 54 deletions(-) diff --git a/crudkit/api/flask_api.py b/crudkit/api/flask_api.py index 4e310d3..62e78a4 100644 --- a/crudkit/api/flask_api.py +++ b/crudkit/api/flask_api.py @@ -2,12 +2,41 @@ from __future__ import annotations -from flask import Blueprint, jsonify, request, abort +from flask import Blueprint, jsonify, request, abort, current_app, url_for +from hashlib import md5 from urllib.parse import urlencode +from werkzeug.exceptions import HTTPException -from crudkit.api._cursor import encode_cursor, decode_cursor from crudkit.core.service import _is_truthy +MAX_JSON = 1_000_000 + +def _etag_for(obj) -> str: + v = getattr(obj, "updated_at", None) or obj.id + return md5(str(v).encode()).hexdigest() + +def _json_payload() -> dict: + if request.content_length and request.content_length > MAX_JSON: + abort(413) + if not request.is_json: + abort(415) + payload = request.get_json(silent=False) + if not isinstance(payload, dict): + abort(400) + return payload + +def _args_flat() -> dict[str, str]: + return request.args.to_dict(flat=True) # type: ignore[arg-type] + +def _json_error(e: Exception, status: int = 400): + if isinstance(e, HTTPException): + status = e.code or status + msg = e.description + else: + msg = str(e) + if current_app.debug: + return jsonify({"status": "error", "error": msg, "type": e.__class__.__name__}), status + return jsonify({"status": "error", "error": msg}), status def _bool_param(d: dict[str, str], key: str, default: bool) -> bool: return _is_truthy(d.get(key, "1" if default else "0")) @@ -50,65 +79,75 @@ def generate_crud_blueprint(model, service, *, base_prefix: str | None = None, r bp = Blueprint(plural, __name__, url_prefix=f"/api/{plural}") + @bp.errorhandler(Exception) + def _handle_any(e: Exception): + return _json_error(e) + + @bp.errorhandler(404) + def _not_found(_e): + return jsonify({"status": "error", "error": "not found"}), 404 + # ---------- REST ---------- if rest: @bp.get("/") def rest_list(): - args = request.args.to_dict(flat=True) + args = _args_flat() # support cursor pagination transparently; fall back to limit/offset try: items = service.list(args) return jsonify([o.as_dict() for o in items]) except Exception as e: - return jsonify({"status": "error", "error": str(e)}), 400 + return _json_error(e) @bp.get("/") def rest_get(obj_id: int): - try: - item = service.get(obj_id, request.args) - if item is None: - abort(404) - return jsonify(item.as_dict()) - except Exception as e: - return jsonify({"status": "error", "error": str(e)}), 400 + item = service.get(obj_id, request.args) + if item is None: + abort(404) + etag = _etag_for(item) + if request.if_none_match and (etag in request.if_none_match): + return "", 304 + resp = jsonify(item.as_dict()) + resp.set_etag(etag) + return resp @bp.post("/") def rest_create(): - payload = request.get_json(silent=True) or {} + payload = _json_payload() try: obj = service.create(payload) resp = jsonify(obj.as_dict()) resp.status_code = 201 - resp.headers["Location"] = f"{request.base_url.rstrip('/')}/{obj.id}" + resp.headers["Location"] = url_for(f"{plural}.rest_get", obj_id=obj.id, _external=False) return resp except Exception as e: - return jsonify({"status": "error", "error": str(e)}), 400 + return _json_error(e) @bp.patch("/") def rest_update(obj_id: int): - payload = request.get_json(silent=True) or {} + payload = _json_payload() try: obj = service.update(obj_id, payload) return jsonify(obj.as_dict()) except Exception as e: - return jsonify({"status": "error", "error": str(e)}), 400 + return _json_error(e) @bp.delete("/") def rest_delete(obj_id: int): - hard = (request.args.get("hard") in ("1", "true", "yes")) + hard = _bool_param(_args_flat(), "hard", False) # type: ignore[arg-type] try: obj = service.delete(obj_id, hard=hard) if obj is None: abort(404) return ("", 204) except Exception as e: - return jsonify({"status": "error", "error": str(e)}), 400 + return _json_error(e) # ---------- RPC (your existing routes) ---------- if rpc: - # your original functions verbatim, shortened here for sanity @bp.get("/get") def rpc_get(): + print("⚠️ WARNING: Deprecated RPC call used: /get") id_ = int(request.args.get("id", 0)) if not id_: return jsonify({"status": "error", "error": "missing required param: id"}), 400 @@ -118,48 +157,52 @@ def generate_crud_blueprint(model, service, *, base_prefix: str | None = None, r abort(404) return jsonify(item.as_dict()) except Exception as e: - return jsonify({"status": "error", "error": str(e)}), 400 + return _json_error(e) @bp.get("/list") def rpc_list(): - args = request.args.to_dict(flat=True) + print("⚠️ WARNING: Deprecated RPC call used: /list") + args = _args_flat() try: items = service.list(args) return jsonify([obj.as_dict() for obj in items]) except Exception as e: - return jsonify({"status": "error", "error": str(e)}), 400 + return _json_error(e) @bp.post("/create") def rpc_create(): - payload = request.get_json(silent=True) or {} + print("⚠️ WARNING: Deprecated RPC call used: /create") + payload = _json_payload() try: obj = service.create(payload) return jsonify(obj.as_dict()), 201 except Exception as e: - return jsonify({"status": "error", "error": str(e)}), 400 + return _json_error(e) @bp.patch("/update") def rpc_update(): + print("⚠️ WARNING: Deprecated RPC call used: /update") id_ = int(request.args.get("id", 0)) if not id_: return jsonify({"status": "error", "error": "missing required param: id"}), 400 - payload = request.get_json(silent=True) or {} + payload = _json_payload() try: obj = service.update(id_, payload) return jsonify(obj.as_dict()) except Exception as e: - return jsonify({"status": "error", "error": str(e)}), 400 + return _json_error(e) @bp.delete("/delete") def rpc_delete(): + print("⚠️ WARNING: Deprecated RPC call used: /delete") id_ = int(request.args.get("id", 0)) if not id_: return jsonify({"status": "error", "error": "missing required param: id"}), 400 - hard = (request.args.get("hard") in ("1", "true", "yes")) + hard = _bool_param(_args_flat(), "hard", False) # type: ignore[arg-type] try: obj = service.delete(id_, hard=hard) return ("", 204) if obj is not None else abort(404) except Exception as e: - return jsonify({"status": "error", "error": str(e)}), 400 + return _json_error(e) return bp diff --git a/crudkit/core/base.py b/crudkit/core/base.py index c42b90e..e612f13 100644 --- a/crudkit/core/base.py +++ b/crudkit/core/base.py @@ -1,18 +1,53 @@ -from typing import Any, Dict, Iterable, List, Tuple, Set +from functools import lru_cache +from typing import Any, Dict, Iterable, List, Optional, Tuple, Set, cast from sqlalchemy import Column, Integer, DateTime, Boolean, String, JSON, func, inspect -from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE, RelationshipProperty +from sqlalchemy.orm import declarative_mixin, declarative_base, NO_VALUE, RelationshipProperty, Mapper +from sqlalchemy.orm.state import InstanceState Base = declarative_base() -def _safe_get_loaded_attr(obj, name): +@lru_cache(maxsize=512) +def _column_names_for_model(cls: type) -> tuple[str, ...]: + try: + mapper = inspect(cls) + return tuple(prop.key for prop in mapper.column_attrs) + except Exception: + names: list[str] = [] + for c in cls.__mro__: + if hasattr(c, "__table__"): + names.extend(col.name for col in c.__table__.columns) + return tuple(dict.fromkeys(names)) + +def _sa_state(obj: Any) -> Optional[InstanceState[Any]]: + """Safely get SQLAlchemy InstanceState (or None).""" try: st = inspect(obj) - attr = st.attrs.get(name) - if attr is not None: + return cast(Optional[InstanceState[Any]], st) + except Exception: + return None + +def _sa_mapper(obj: Any) -> Optional[Mapper]: + """Safely get Mapper for a maooed instance (or None).""" + try: + st = inspect(obj) + mapper = getattr(st, "mapper", None) + return cast(Optional[Mapper], mapper) + except Exception: + return None + +def _safe_get_loaded_attr(obj, name): + st = _sa_state(obj) + if st is None: + return None + try: + attrs = getattr(st, "attrs", {}).get(name) + if attrs is not None and name in attrs: + attr = attrs[name] val = attr.loaded_value return None if val is NO_VALUE else val - if name in st.dict: - return st.dict.get(name) + st_dict = getattr(st, "dict", {}) + if name in st_dict: + return st_dict.get(name) return None except Exception: return None @@ -33,14 +68,11 @@ def _is_collection_rel(prop: RelationshipProperty) -> bool: def _serialize_simple_obj(obj) -> Dict[str, Any]: """Columns only (no relationships).""" out: Dict[str, Any] = {} - for cls in obj.__class__.__mro__: - if hasattr(cls, "__table__"): - for col in cls.__table__.columns: - name = col.name - try: - out[name] = getattr(obj, name) - except Exception: - out[name] = None + for name in _column_names_for_model(type(obj)): + try: + out[name] = getattr(obj, name) + except Exception: + out[name] = None return out def _serialize_loaded_rel(obj, name, *, depth: int, seen: Set[Tuple[type, Any]], embed: Set[str]) -> Any: @@ -204,12 +236,16 @@ class CRUDMixin: # Determine which relationships to consider try: - st = inspect(self) - mapper = st.mapper - embed_set = set(str(x).split(".", 1)[0] for x in (embed or [])) # top-level names + mapper = _sa_mapper(self) + embed_set = set(str(x).split(".", 1)[0] for x in (embed or [])) + if mapper is None: + return data + st = _sa_state(self) + if st is None: + return data for name, prop in mapper.relationships.items(): # Only touch relationships that are already loaded; never lazy-load here. - rel_loaded = st.attrs.get(name) + rel_loaded = getattr(st, "attrs", {}).get(name) if rel_loaded is None or rel_loaded.loaded_value is NO_VALUE: continue @@ -266,13 +302,10 @@ class CRUDMixin: val = None # If it's a scalar ORM object (relationship), serialize its columns - try: - st = inspect(val) # will raise if not an ORM object - if getattr(st, "mapper", None) is not None: - out[name] = _serialize_simple_obj(val) - continue - except Exception: - pass + mapper = _sa_mapper(val) + if mapper is not None: + out[name] = _serialize_simple_obj(val) + continue # If it's a collection and no subfields were requested, emit a light list if isinstance(val, (list, tuple)): From 01a0031cf45a28487299183ac5e7f3a4a34bb966 Mon Sep 17 00:00:00 2001 From: Yaro Kasear Date: Mon, 20 Oct 2025 13:53:27 -0500 Subject: [PATCH 2/3] Fix a regression added by some refactor. --- crudkit/core/base.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/crudkit/core/base.py b/crudkit/core/base.py index e612f13..d73a04f 100644 --- a/crudkit/core/base.py +++ b/crudkit/core/base.py @@ -40,14 +40,27 @@ def _safe_get_loaded_attr(obj, name): if st is None: return None try: - attrs = getattr(st, "attrs", {}).get(name) - if attrs is not None and name in attrs: - attr = attrs[name] - val = attr.loaded_value - return None if val is NO_VALUE else val st_dict = getattr(st, "dict", {}) if name in st_dict: - return st_dict.get(name) + return st_dict[name] + + attrs = getattr(st, "attrs", None) + attr = None + if attrs is not None: + try: + attr = attrs[name] + except Exception: + try: + get = getattr(attrs, "get", None) + if callable(get): + attr = get(name) + except Exception: + attr = None + + if attr is not None: + val = attr.loaded_value + return None if val is NO_VALUE else val + return None except Exception: return None From ce7d092be44ff1e9f391a235d328d461cf4555bc Mon Sep 17 00:00:00 2001 From: Yaro Kasear Date: Mon, 20 Oct 2025 13:54:30 -0500 Subject: [PATCH 3/3] Removed unused functions in the Flask API file. --- crudkit/api/flask_api.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/crudkit/api/flask_api.py b/crudkit/api/flask_api.py index 62e78a4..39e7c49 100644 --- a/crudkit/api/flask_api.py +++ b/crudkit/api/flask_api.py @@ -41,19 +41,6 @@ def _json_error(e: Exception, status: int = 400): def _bool_param(d: dict[str, str], key: str, default: bool) -> bool: return _is_truthy(d.get(key, "1" if default else "0")) - -def _safe_int(value: str | None, default: int) -> int: - try: - return int(value) if value is not None else default - except Exception: - return default - - -def _link_with_params(base_url: str, **params) -> str: - q = {k: v for k, v in params.items() if v is not None} - return f"{base_url}?{urlencode(q)}" - - def generate_crud_blueprint(model, service, *, base_prefix: str | None = None, rest: bool = True, rpc: bool = True): """ REST: