diff --git a/crudkit/__init__.py b/crudkit/__init__.py new file mode 100644 index 0000000..cf333de --- /dev/null +++ b/crudkit/__init__.py @@ -0,0 +1,8 @@ +from .mixins import CrudMixin +from .dsl import QuerySpec +from .eager import default_eager_policy +from .service import CrudService +from .serialize import serialize +from .blueprint import make_blueprint + +__all__ = ["CrudMixin", "QuerySpec", "default_eager_policy", "CrudService", "serialize", "make_blueprint"] \ No newline at end of file diff --git a/crudkit/blueprint.py b/crudkit/blueprint.py new file mode 100644 index 0000000..cae482e --- /dev/null +++ b/crudkit/blueprint.py @@ -0,0 +1,81 @@ +from flask import Blueprint, request, jsonify, abort +from sqlalchemy.orm import scoped_session +from .dsl import QuerySpec +from .service import CrudService +from .eager import default_eager_policy +from .serialize import serialize + +def make_blueprint(db_session_factory, registry): + bp = Blueprint("crud", __name__) + def session(): return scoped_session(db_session_factory)() + + @bp.get("//list") + def list_items(model): + Model = registry.get(model) or abort(404) + spec = QuerySpec( + filters=_parse_filters(request.args), + order_by=request.args.getlist("sort"), + page=request.args.get("page", type=int), + per_page=request.args.get("per_page", type=int), + expand=request.args.getlist("expand"), + fields=request.args.get("fields", type=lambda s: [x.strip() for x in s.split(",")] if s else None), + ) + s = session(); svc = CrudService(s, default_eager_policy) + rows, total = svc.list(Model, spec) + data = [serialize(r, fields=spec.fields, expand=spec.expand) for r in rows] + return jsonify({"data": data, "total": total}) + + @bp.post("/") + def create_item(model): + Model = registry.get(model) or abort(404) + payload = request.get_json() or {} + s = session(); svc = CrudService(s, default_eager_policy) + obj = svc.create(Model, payload) + s.commit() + return jsonify(serialize(obj)), 201 + + @bp.get("//") + def read_item(model, id): + Model = registry.get(model) or abort(404) + spec = QuerySpec(expand=request.args.getlist("expand"), + fields=request.args.get("fields", type=lambda s: s.split(","))) + s = session(); svc = CrudService(s, default_eager_policy) + obj = svc.get(Model, id, spec) or abort(404) + return jsonify(serialize(obj, fields=spec.fields, expand=spec.expand)) + + @bp.patch("//") + def update_item(model, id): + Model = registry.get(model) or abort(404) + s = session(); svc = CrudService(s, default_eager_policy) + obj = svc.get(Model, id, QuerySpec()) or abort(404) + payload = request.get_json() or {} + svc.update(obj, payload) + s.commit() + return jsonify(serialize(obj)) + + @bp.delete("//") + def delete_item(model, id): + Model = registry.get(model) or abort(404) + s = session(); svc = CrudService(s, default_eager_policy) + obj = svc.get(Model, id, QuerySpec()) or abort(404) + svc.soft_delete(obj) + s.commit() + return jsonify({"status": "deleted"}) + + @bp.post("///undelete") + def undelete_item(model, id): + Model = registry.get(model) or abort(404) + s = session(); svc = CrudService(s, default_eager_policy) + obj = svc.get(Model, id, QuerySpec()) or abort(404) + svc.undelete(obj) + s.commit() + return jsonify({"status": "restored"}) + return bp + +def _parse_filters(args): + out = {} + for k, v in args.items(): + if k in {"page", "per_page", "sort", "expand", "fields"}: + continue + out[k] = v + return out \ No newline at end of file diff --git a/crudkit/dsl.py b/crudkit/dsl.py new file mode 100644 index 0000000..a9ee931 --- /dev/null +++ b/crudkit/dsl.py @@ -0,0 +1,117 @@ +from dataclasses import dataclass, field +from typing import List, Dict, Any, Optional +from sqlalchemy import asc, desc, select, false +from sqlalchemy.inspection import inspect + +@dataclass +class QuerySpec: + filters: Dict[str, Any] = field(default_factory=dict) + order_by: List[str] = field(default_factory=list) + page: Optional[int] = None + per_page: Optional[int] = None + expand: List[str] = field(default_factory=list) + fields: Optional[List[str]] = None + +FILTER_OPS = { + "__eq": lambda c, v: c == v, + "__ne": lambda c, v: c != v, + "__lt": lambda c, v: c < v, + "__lte": lambda c, v: c <= v, + "__gt": lambda c, v: c > v, + "__gte": lambda c, v: c >= v, + "__ilike": lambda c, v: c.ilike(v), + "__in": lambda c, v: c.in_(v), + "__isnull": lambda c, v: (c.is_(None) if v else c.is_not(None)) +} + +def _split_filter_key(raw_key: str): + for op in sorted(FILTER_OPS.keys(), key=len, reverse=True): + if raw_key.endswith(op): + return raw_key[: -len(op)], op + return raw_key, None + +def _ensure_wildcards(op_key, value): + if op_key == "__ilike" and isinstance(value, str) and "%" not in value and "_" not in value: + return f"%{value}%" + return value + +def _related_predicate(Model, path_parts, op_key, value): + """ + Build EXISTS subqueries for dotted filters: + - scalar rels -> attr.has(inner_predicate) + - collection -> attr.any(inner_predicate) + """ + head, *rest = path_parts + + # class-bound relationship attribute (InstrumentedAttribute) + attr = getattr(Model, head, None) + if attr is None: + return None + + # relationship metadata if you need uselist + target model + rel = inspect(Model).relationships.get(head) + if rel is None: + return None + Target = rel.mapper.class_ + + if not rest: + # filtering directly on a relationship without a leaf column isn't supported + return None + + if len(rest) == 1: + # final hop is a column on the related model + leaf = rest[0] + col = getattr(Target, leaf, None) + if col is None: + return None + pred = FILTER_OPS[op_key](col, value) if op_key else (col == value) + else: + # recurse deeper: owner.room.area.name__ilike=... + pred = _related_predicate(Target, rest, op_key, value) + if pred is None: + return None + + # wrap at this hop using the *attribute*, not the RelationshipProperty + return attr.any(pred) if rel.uselist else attr.has(pred) + +def build_query(Model, spec: QuerySpec, eager_policy=None): + stmt = select(Model) + + # filter out soft-deleted rows + deleted_attr = getattr(Model, "deleted", None) + if deleted_attr is not None: + stmt = stmt.where(deleted_attr == false()) + else: + is_deleted_attr = getattr(Model, "is_deleted", None) + if is_deleted_attr is not None: + stmt = stmt.where(is_deleted_attr == false()) + + # filters + for raw_key, val in spec.filters.items(): + path, op_key = _split_filter_key(raw_key) + val = _ensure_wildcards(op_key, val) + + if "." in path: + pred = _related_predicate(Model, path.split("."), op_key, val) + if pred is not None: + stmt = stmt.where(pred) + continue + + col = getattr(Model, path, None) + if col is None: + continue + stmt = stmt.where(FILTER_OPS[op_key](col, val) if op_key else (col == val)) + + # order_by + for key in spec.order_by: + desc_ = key.startswith("-") + col = getattr(Model, key[1:] if desc_ else key) + stmt = stmt.order_by(desc(col) if desc_ else asc(col)) + + # eager loading + if eager_policy: + opts = eager_policy(Model, spec.expand) + if opts: + stmt = stmt.options(*opts) + + return stmt diff --git a/crudkit/eager.py b/crudkit/eager.py new file mode 100644 index 0000000..f32efc2 --- /dev/null +++ b/crudkit/eager.py @@ -0,0 +1,42 @@ +from typing import List +from sqlalchemy.inspection import inspect +from sqlalchemy.orm import Load, joinedload, selectinload + +def default_eager_policy(Model, expand: List[str]) -> List[Load]: + """ + Heuristic: + - many-to-one / one-to-one: joinedload + - one-to-many / many-to-many: selectinload + Accepts dotted paths like "author.publisher". + """ + if not expand: + return [] + + opts: List[Load] = [] + + for path in expand: + parts = path.split(".") + current_model = Model + current_inspect = inspect(current_model) + + # first hop + rel = current_inspect.relationships.get(parts[0]) + if not rel: + continue # silently skip bad names + attr = getattr(current_model, parts[0]) + loader: Load = selectinload(attr) if rel.uselist else joinedload(attr) + current_model = rel.mapper.class_ + + # nested hops, if any + for name in parts[1:]: + current_inspect = inspect(current_model) + rel = current_inspect.relationships.get(name) + if not rel: + break + attr = getattr(current_model, name) + loader = loader.selectinload(attr) if rel.uselist else loader.joinedload(attr) + current_model = rel.mapper.class_ + + opts.append(loader) + + return opts diff --git a/crudkit/html/__init__.py b/crudkit/html/__init__.py new file mode 100644 index 0000000..a94f018 --- /dev/null +++ b/crudkit/html/__init__.py @@ -0,0 +1,3 @@ +from .ui_fragments import make_fragments_blueprint + +__all__ = ["make_fragments_blueprint"] diff --git a/crudkit/html/templates/crudkit/_macros.html b/crudkit/html/templates/crudkit/_macros.html new file mode 100644 index 0000000..f713e81 --- /dev/null +++ b/crudkit/html/templates/crudkit/_macros.html @@ -0,0 +1,93 @@ +{% macro options(items, value_attr="id", label_path="name", getp=None) -%} +{%- for obj in items -%} + +{%- endfor -%} +{% endmacro %} + +{% macro lis(items, label_path="name", sublabel_path=None, getp=None) -%} +{%- for obj in items -%} +
  • +
    {{ getp(obj, label_path) }}
    + {%- if sublabel_path %} +
    {{ getp(obj, sublabel_path) }}
    + {%- endif %} +
  • +{%- else -%} +
  • No results.
  • +{%- endfor -%} +{% endmacro %} + +{% macro rows(items, fields, getp=None) -%} +{%- for obj in items -%} + + {%- for f in fields -%} + {{ getp(obj, f) }} + {%- endfor -%} + +{%- else -%} + + No results. + +{%- endfor -%} +{%- endmacro %} + +{% macro pager(model, page, pages, per_page, sort, filters) -%} + +{%- endmacro %} + +{% macro form(schema, action, method="POST", obj_id=None, hx=False, csrf_token=None) -%} +
    + {%- if csrf_token %}{% endif -%} + {%- if obj_id %}{% endif -%} + + + {%- for f in schema -%} +
    + {% set fid = 'f-' ~ f.name ~ '-' ~ (obj_id or 'new') %} + + {%- if f.type == "textarea" -%} + + {%- elif f.type == "select" -%} + + {%- elif f.type == "checkbox" -%} + + + {%- else -%} + + {%- endif -%} + {%- if f.help %}
    {{ f.help }}
    {% endif -%} +
    + {%- endfor -%} + +
    + +
    +
    +{%- endmacro %} \ No newline at end of file diff --git a/crudkit/html/templates/crudkit/form.html b/crudkit/html/templates/crudkit/form.html new file mode 100644 index 0000000..5f3bfb3 --- /dev/null +++ b/crudkit/html/templates/crudkit/form.html @@ -0,0 +1,3 @@ +{% import "_macros.html" as ui %} +{% set action = url_for('frags.save', model=model) %} +{{ ui.form(schema, action, method="POST", obj_id=obj.id if obj else None, hx=true) }} \ No newline at end of file diff --git a/crudkit/html/templates/crudkit/lis.html b/crudkit/html/templates/crudkit/lis.html new file mode 100644 index 0000000..e9b1813 --- /dev/null +++ b/crudkit/html/templates/crudkit/lis.html @@ -0,0 +1,2 @@ +{% import "_macros.html" as ui %} +{{ ui.lis(items, label_path=label_path, sublabel_path=sublabel_path, getp=getp) }} diff --git a/crudkit/html/templates/crudkit/options.html b/crudkit/html/templates/crudkit/options.html new file mode 100644 index 0000000..34d6a2b --- /dev/null +++ b/crudkit/html/templates/crudkit/options.html @@ -0,0 +1,3 @@ +{# Renders only rows #} +{% import "_macros.html" as ui %} +{{ ui.options(items, value_attr=value_attr, label_path=label_path, getp=getp) }} diff --git a/crudkit/html/templates/crudkit/row.html b/crudkit/html/templates/crudkit/row.html new file mode 100644 index 0000000..a3ac629 --- /dev/null +++ b/crudkit/html/templates/crudkit/row.html @@ -0,0 +1,2 @@ +{% import "_macros.html" as ui %} +{{ ui.rows([obj], fields, getp=getp) }} \ No newline at end of file diff --git a/crudkit/html/templates/crudkit/rows.html b/crudkit/html/templates/crudkit/rows.html new file mode 100644 index 0000000..7fafa3a --- /dev/null +++ b/crudkit/html/templates/crudkit/rows.html @@ -0,0 +1,3 @@ +{% import "_macros.html" as ui %} +{{ ui.rows(items, fields, getp=getp) }} +{{ ui.pager(model, page, pages, per_page, sort, filters) }} diff --git a/crudkit/html/type_map.py b/crudkit/html/type_map.py new file mode 100644 index 0000000..5e582c2 --- /dev/null +++ b/crudkit/html/type_map.py @@ -0,0 +1,137 @@ +from __future__ import annotations +from typing import Any, Dict, List, Optional, Tuple +from sqlalchemy import select +from sqlalchemy.inspection import inspect +from sqlalchemy.orm import Mapper, RelationshipProperty +from sqlalchemy.sql.schema import Column +from sqlalchemy.sql.sqltypes import ( + String, Text, Unicode, UnicodeText, + Integer, BigInteger, SmallInteger, Float, Numeric, Boolean, + Date, DateTime, Time, JSON, Enum +) + +CANDIDATE_LABELS = ("name", "title", "label", "display_name") + +def _guess_label_attr(model_cls) -> str: + for cand in CANDIDATE_LABELS: + if hasattr(model_cls, cand): + return cand + return "id" + +def _pretty(label: str) -> str: + return label.replace("_", " ").title() + +def _column_input_type(col: Column) -> str: + t = col.type + if isinstance(t, (String, Unicode)): + return "text" + if isinstance(t, (Text, UnicodeText, JSON)): + return "textarea" + if isinstance(t, (Integer, SmallInteger, BigInteger)): + return "number" + if isinstance(t, (Float, Numeric)): + return "number" + if isinstance(t, Boolean): + return "checkbox" + if isinstance(t, Date): + return "date" + if isinstance(t, DateTime): + return "datetime-local" + if isinstance(t, Time): + return "time" + if isinstance(t, Enum): + return "select" + return "text" + +def _enum_choices(col: Column) -> Optional[List[Tuple[str, str]]]: + t = col.type + if isinstance(t, Enum): + if t.enum_class: + return [(e.name, e.value) for e in t.enum_class] + if t.enums: + return [(v, v) for v in t.enums] + return None + +def build_form_schema(model_cls, session, obj=None, *, include=None, exclude=None, fk_limit=200): + mapper: Mapper = inspect(model_cls) + include = set(include or []) + exclude = set(exclude or {"id", "created_at", "updated_at", "deleted", "version"}) + fields = [] + + fields: List[Dict[str, Any]] = [] + + fk_map = {} + for rel in mapper.relationships: + for lc in rel.local_columns: + fk_map[lc.key] = rel + + for attr in mapper.column_attrs: + col = attr.columns[0] + name = col.key + if include and name not in include: + continue + if name in exclude: + continue + + field = { + "name": name, + "type": _column_input_type(col), + "required": not col.nullable, + "value": getattr(obj, name, None) if obj is not None else None, + "placeholder": "", + "help": "", + # default label from column name + "label": _pretty(name), + } + + enum_choices = _enum_choices(col) + if enum_choices: + field["type"] = "select" + field["choices"] = enum_choices + + if name in fk_map: + rel = fk_map[name] + target = rel.mapper.class_ + label_attr = _guess_label_attr(target) + rows = session.execute(select(target).limit(fk_limit)).scalars().all() + field["type"] = "select" + field["choices"] = [(getattr(r, "id"), getattr(r, label_attr)) for r in rows] + field["rel"] = {"target": target.__name__, "label_attr": label_attr} + field["label"] = _pretty(rel.key) + + if getattr(col.type, "length", None): + field["maxlength"] = col.type.length + + fields.append(field) + + for rel in mapper.relationships: + if not rel.uselist or rel.secondary is None: + continue # only true many-to-many + + if include and f"{rel.key}_ids" not in include: + continue + + target = rel.mapper.class_ + label_attr = _guess_label_attr(target) + choices = session.execute(select(target).limit(fk_limit)).scalars().all() + + current = [] + if obj is not None: + current = [getattr(x, "id") for x in getattr(obj, rel.key, []) or []] + + fields.append({ + "name": f"{rel.key}_ids", # e.g. "tags_ids" + "label": rel.key.replace("_"," ").title(), + "type": "select", + "multiple": True, + "required": False, + "choices": [(getattr(r,"id"), getattr(r,label_attr)) for r in choices], + "value": current, # list of selected IDs + "placeholder": f"Choose {rel.key.replace('_',' ').title()}", + "help": "", + }) + + if include: + order = list(include) + fields.sort(key=lambda f: order.index(f["name"]) if f["name"] in include else 10**9) + return fields diff --git a/crudkit/html/ui_fragments.py b/crudkit/html/ui_fragments.py new file mode 100644 index 0000000..525c6e5 --- /dev/null +++ b/crudkit/html/ui_fragments.py @@ -0,0 +1,233 @@ +from __future__ import annotations +from typing import Any, Dict, List, Tuple +from math import ceil +from flask import Blueprint, request, render_template, abort, make_response +from sqlalchemy import select +from sqlalchemy.orm import scoped_session +from sqlalchemy.inspection import inspect +from sqlalchemy.sql.sqltypes import Integer, Boolean, Date, DateTime, Float, Numeric + +from ..dsl import QuerySpec +from ..service import CrudService +from ..eager import default_eager_policy +from .type_map import build_form_schema + +def make_fragments_blueprint(db_session_factory, registry: Dict[str, Any], *, name="frags"): + """ + HTML fragments for HTMX/Alpine. No base pages. Pure partials: + GET //frag/options -> + GET //frag/lis ->
  • ...
  • + GET //frag/rows -> ... + pager markup if wanted + GET //frag/form ->
    ...
    (auto-generated) + """ + bp = Blueprint(name, __name__, template_folder="templates/crudkit") + def session(): return scoped_session(db_session_factory)() + + def _parse_filters(args): + reserved = {"page", "per_page", "sort", "expand", "fields", "value", "label", "label_tpl", "fields_csv", "li_label", "li_sublabel"} + out = {} + for k, v in args.items(): + if k not in reserved and v != "": + out[k] = v + return out + + def _paths_from_csv(csv: str) -> List[str]: + return [p.strip() for p in csv.split(",") if p.strip()] + + def _collect_expand_from_paths(paths: List[str]) -> List[str]: + rels = set() + for p in paths: + bits = p.split(".") + if len(bits) > 1: + rels.add(bits[0]) + return list(rels) + + def _getp(obj, path: str): + cur = obj + for part in path.split("."): + cur = getattr(cur, part, None) if cur is not None else None + return cur + + def _extract_m2m_lists(Model, req_form) -> dict[str, list[int]]: + """Return {'tags': [1,2]} for any _ids fields; caller removes keys from main form.""" + mapper = inspect(Model) + out = {} + for rel in mapper.relationships: + if not rel.uselist or rel.secondary is None: + continue + key = f"{rel.key}_ids" + ids = req_form.getlist(key) + if ids is None: + continue + out[rel.key] = [int(i) for i in ids if i] + return out + + @bp.get("//frag/options") + def options(model): + Model = registry.get(model) or abort(404) + value_attr = request.args.get("value", default="id") + label_path = request.args.get("label", default="name") + filters = _parse_filters(request.args) + + expand = _collect_expand_from_paths([label_path]) + spec = QuerySpec(filters=filters, order_by=[], page=None, per_page=None, expand=expand) + s = session(); svc = CrudService(s, default_eager_policy) + items, _ = svc.list(Model, spec) + + return render_template("options.html", items=items, value_attr=value_attr, label_path=label_path, getp=_getp) + + @bp.get("//frag/lis") + def lis(model): + Model = registry.get(model) or abort(404) + label_path = request.args.get("li_label", default="name") + sublabel_path = request.args.get("li_sublabel") + filters = _parse_filters(request.args) + sort = request.args.get("sort") + page = request.args.get("page", type=int) + per_page = request.args.get("per_page", type=int) + + expand = _collect_expand_from_paths([p for p in (label_path, sublabel_path) if p]) + spec = QuerySpec(filters=filters, order_by=[sort] if sort else [], page=page, per_page=per_page, expand=expand) + s = session(); svc = CrudService(s, default_eager_policy) + rows, total = svc.list(Model, spec) + pages = (ceil(total / per_page) if page and per_page else 1) + return render_template("lis.html", items=rows, label_path=label_path, sublabel_path=sublabel_path, page=page or 1, per_page=per_page or 1, total=total, model=model, sort=sort, filters=filters, getp=_getp) + + @bp.get("//frag/rows") + def rows(model): + Model = registry.get(model) or abort(404) + fields_csv = request.args.get("fields_csv") or "id,name" + fields = _paths_from_csv(fields_csv) + filters = _parse_filters(request.args) + sort = request.args.get("sort") + page = request.args.get("page", type=int) or 1 + per_page = request.args.get("per_page", type=int) or 20 + + expand = _collect_expand_from_paths(fields) + spec = QuerySpec(filters=filters, order_by=[sort] if sort else [], page=page, per_page=per_page, expand=expand) + s = session(); svc = CrudService(s, default_eager_policy) + rows, total = svc.list(Model, spec) + pages = max(1, ceil(total / per_page)) + return render_template("rows.html", items=rows, fields=fields, page=page, pages=pages, per_page=per_page, total=total, model=model, sort=sort, filters=filters, getp=_getp) + + @bp.get("//frag/form") + def form(model): + Model = registry.get(model) or abort(404) + id = request.args.get("id", type=int) + include_csv = request.args.get("include") + include = [s.strip() for s in include_csv.split(",")] if include_csv else None + + s = session(); svc = CrudService(s, default_eager_policy) + obj = svc.get(Model, id) if id else None + + schema = build_form_schema(Model, s, obj=obj, include=include) + + hx = request.args.get("hx", type=int) == 1 + return render_template("form.html", model=model, obj=obj, schema=schema, hx=hx) + + def coerce_form_types(Model, data: dict) -> dict: + """Turn HTML string inputs into the Python types your columns expect.""" + mapper = inspect(Model) + for attr in mapper.column_attrs: + col = attr.columns[0] + name = col.key + if name not in data: + continue + v = data[name] + if v == "": + data[name] = None + continue + t = col.type + try: + if isinstance(t, Boolean): + data[name] = v in ("1", "true", "on", "yes", True) + elif isinstance(t, Integer): + data[name] = int(v) + elif isinstance(t, (Float, Numeric)): + data[name] = float(v) + elif isinstance(t, DateTime): + from datetime import datetime + data[name] = datetime.fromisoformat(v) + elif isinstance(t, Date): + from datetime import date + data[name] = date.fromisoformat(v) + except Exception: + # Leave as string; your validator can complain later. + pass + return data + + @bp.post("//frag/save") + def save(model): + Model = registry.get(model) or abort(404) + s = session(); svc = CrudService(s, default_eager_policy) + + # grab the raw form and fields to re-render + raw = request.form + form = raw.to_dict(flat=True) + fields_csv = form.pop("fields_csv", "id,name") + + # many-to-many lists first + m2m = _extract_m2m_lists(Model, raw) + for rel_name in list(m2m.keys()): + form.pop(f"{rel_name}_ids", None) + + # coerce primitives for regular columns + form = coerce_form_types(Model, form) + + id_val = form.pop("id", None) + + if id_val: + obj = svc.get(Model, int(id_val)) or abort(404) + svc.update(obj, form) + else: + obj = svc.create(Model, form) + + # apply many-to-many selections + mapper = inspect(Model) + for rel_name, id_list in m2m.items(): + rel = mapper.relationships[rel_name] + target = rel.mapper.class_ + selected = [] + if id_list: + selected = s.execute(select(target).where(target.id.in_(id_list))).scalars().all() + coll = getattr(obj, rel_name) + coll.clear() + coll.extend(selected) + + s.commit() + + rows_html = render_template( + "crudkit/row.html", + obj=obj, + fields=[p.strip() for p in fields_csv.split(",") if p.strip()], + getp=_getp, + ) + resp = make_response(rows_html) + if id_val: + resp.headers["HX-Trigger"] = '{"toast":{"level":"success","message":"Updated"}}' + resp.headers["HX-Retarget"] = f"#row-{obj.id}" + resp.headers["HX-Reswap"] = "outerHTML" + else: + resp.headers["HX-Trigger"] = '{"toast":{"level":"success","message":"Created"}}' + resp.headers["HX-Retarget"] = "#rows" + resp.headers["HX-Reswap"] = "beforeend" + return resp + + @bp.get("/_debug//schema") + def debug_model(model): + Model = registry[model] + from sqlalchemy.inspection import inspect + m = inspect(Model) + return { + "columns": [c.key for c in m.columns], + "relationships": [ + { + "key": r.key, + "target": r.mapper.class_.__name__, + "uselist": r.uselist, + "local_cols": [c.key for c in r.local_columns], + } for r in m.relationships + ], + } + return bp + diff --git a/crudkit/mixins.py b/crudkit/mixins.py new file mode 100644 index 0000000..38a0e71 --- /dev/null +++ b/crudkit/mixins.py @@ -0,0 +1,23 @@ +import datetime as dt +from sqlalchemy import Column, Integer, DateTime, Boolean +from sqlalchemy.orm import declared_attr +from sqlalchemy.ext.hybrid import hybrid_property + +class CrudMixin: + id = Column(Integer, primary_key=True) + created_at = Column(DateTime, default=dt.datetime.utcnow, nullable=False) + updated_at = Column(DateTime, default=dt.datetime.utcnow, onupdate=dt.datetime.utcnow, nullable=False) + deleted = Column("deleted", Boolean, default=False, nullable=False) + version = Column(Integer, default=1, nullable=False) + + @hybrid_property + def is_deleted(self): + return self.deleted + + def mark_deleted(self): + self.deleted = True + self.version += 1 + + @declared_attr + def __mapper_args__(cls): + return {"version_id_col": cls.version} diff --git a/crudkit/serialize.py b/crudkit/serialize.py new file mode 100644 index 0000000..3ba6116 --- /dev/null +++ b/crudkit/serialize.py @@ -0,0 +1,22 @@ +def serialize(obj, *, fields=None, expand=None): + expand = set(expand or []) + fields = set(fields or []) + out = {} + # base columns + for col in obj.__table__.columns: + name = col.key + if fields and name not in fields: + continue + out[name] = getattr(obj, name) + # expansions + for rel in obj.__mapper__.relationships: + if rel.key not in expand: + continue + val = getattr(obj, rel.key) + if val is None: + out[rel.key] = None + elif rel.uselist: + out[rel.key] = [serialize(child) for child in val] + else: + out[rel.key] = serialize(val) + return out \ No newline at end of file diff --git a/crudkit/service.py b/crudkit/service.py new file mode 100644 index 0000000..1950d06 --- /dev/null +++ b/crudkit/service.py @@ -0,0 +1,52 @@ +from sqlalchemy import func +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from .dsl import QuerySpec, build_query +from .eager import default_eager_policy + +class CrudService: + def __init__(self, session: Session, eager_policy=default_eager_policy): + self.s = session + self.eager_policy = eager_policy + + def create(self, Model, data, *, before=None, after=None): + if before: data = before(data) or data + obj = Model(**data) + self.s.add(obj) + self.s.flush() + if after: after(obj) + return obj + + def get(self, Model, id, spec: QuerySpec | None = None): + spec = spec or QuerySpec() + stmt = build_query(Model, spec, self.eager_policy).where(Model.id == id) + return self.s.execute(stmt).scalars().first() + + def list(self, Model, spec: QuerySpec): + stmt = build_query(Model, spec, self.eager_policy) + count_stmt = stmt.with_only_columns(func.count()).order_by(None) + total = self.s.execute(count_stmt).scalar_one() + if spec.page and spec.per_page: + stmt = stmt.limit(spec.per_page).offset((spec.page - 1) * spec.per_page) + rows = self.s.execute(stmt).scalars().all() + return rows, total + + def update(self, obj, data, *, before=None, after=None): + if obj.is_deleted: raise ValueError("Cannot update a deleted record") + if before: data = before(obj, data) or data + for k, v in data.items(): setattr(obj, k, v) + obj.version += 1 + if after: after(obj) + return obj + + def soft_delete(self, obj, *, cascade=False, guard=None): + if guard and not guard(obj): raise ValueError("Delete blocked by guard") + # optionsl FK hygiene checks go here + obj.mark_deleted() + return obj + + def undelete(self, obj): + obj.deleted = False + obj.version += 1 + return obj diff --git a/inventory/__init__.py b/inventory/__init__.py index cb125ab..acd5102 100644 --- a/inventory/__init__.py +++ b/inventory/__init__.py @@ -1,6 +1,7 @@ from flask import Flask, current_app from flask_sqlalchemy import SQLAlchemy from sqlalchemy.engine.url import make_url +from sqlalchemy.orm import sessionmaker import logging import os @@ -23,9 +24,8 @@ def is_in_memory_sqlite(): def create_app(): from config import Config app = Flask(__name__) - app.secret_key = os.getenv('SECRET_KEY', 'dev-secret-key-unsafe') # You know what to do for prod + app.secret_key = os.getenv('SECRET_KEY', 'dev-secret-key-unsafe') app.config.from_object(Config) - db.init_app(app) with app.app_context(): @@ -33,16 +33,25 @@ def create_app(): if is_in_memory_sqlite(): db.create_all() - from .routes import main - from .routes.images import image_bp - from .ui.blueprint import bp as ui_bp - app.register_blueprint(main) - app.register_blueprint(image_bp) - app.register_blueprint(ui_bp) + # ✅ db.engine is only safe to touch inside an app context + SessionLocal = sessionmaker(bind=db.engine, expire_on_commit=False) - from .routes.helpers import generate_breadcrumbs - @app.context_processor - def inject_breadcrumbs(): - return {'breadcrumbs': generate_breadcrumbs()} + from .models import registry + from .routes import main + from .routes.images import image_bp + from .ui.blueprint import bp as ui_bp + from crudkit.blueprint import make_blueprint as make_json_bp + from crudkit.html import make_fragments_blueprint as make_html_bp + + app.register_blueprint(main) + app.register_blueprint(image_bp) + app.register_blueprint(ui_bp) + app.register_blueprint(make_json_bp(SessionLocal, registry), url_prefix="/api") + app.register_blueprint(make_html_bp(SessionLocal, registry), url_prefix="/ui") + + from .routes.helpers import generate_breadcrumbs + @app.context_processor + def inject_breadcrumbs(): + return {'breadcrumbs': generate_breadcrumbs()} return app diff --git a/inventory/models/__init__.py b/inventory/models/__init__.py index cc57a87..83d2e78 100644 --- a/inventory/models/__init__.py +++ b/inventory/models/__init__.py @@ -44,11 +44,25 @@ Room.ui_eagerload = ( selectinload(Room.users) ) + +registry = { + "area": Area, + "brand": Brand, + "image": Image, + "inventory": Inventory, + "item": Item, + "room_function": RoomFunction, + "room": Room, + "user": User, + "work_log": WorkLog, + "work_note": WorkNote +} + __all__ = [ "db", "Image", "ImageAttachable", "RoomFunction", "Room", "Area", "Brand", "Item", "Inventory", "WorkLog", "WorkNote", "worklog_images", - "User", + "User", "registry" ] diff --git a/inventory/models/areas.py b/inventory/models/areas.py index 45ad4f5..3a986c5 100644 --- a/inventory/models/areas.py +++ b/inventory/models/areas.py @@ -2,12 +2,13 @@ from typing import List, Optional, TYPE_CHECKING if TYPE_CHECKING: from .rooms import Room +from crudkit import CrudMixin from sqlalchemy import Identity, Integer, Unicode from sqlalchemy.orm import Mapped, mapped_column, relationship from . import db -class Area(db.Model): +class Area(db.Model, CrudMixin): __tablename__ = 'area' id: Mapped[int] = mapped_column(Integer, Identity(start=1, increment=1), primary_key=True) diff --git a/inventory/models/brands.py b/inventory/models/brands.py index d2adf17..e4ba1ac 100644 --- a/inventory/models/brands.py +++ b/inventory/models/brands.py @@ -2,12 +2,13 @@ from typing import List, TYPE_CHECKING if TYPE_CHECKING: from .inventory import Inventory +from crudkit import CrudMixin from sqlalchemy import Identity, Integer, Unicode from sqlalchemy.orm import Mapped, mapped_column, relationship from . import db -class Brand(db.Model): +class Brand(db.Model, CrudMixin): __tablename__ = 'brand' id: Mapped[int] = mapped_column(Integer, Identity(start=1, increment=1), primary_key=True) diff --git a/inventory/models/image.py b/inventory/models/image.py index ee61d1b..5714064 100644 --- a/inventory/models/image.py +++ b/inventory/models/image.py @@ -12,7 +12,9 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship from . import db from .image_links import worklog_images -class Image(db.Model): +from crudkit import CrudMixin + +class Image(db.Model, CrudMixin): __tablename__ = 'images' id: Mapped[int] = mapped_column(Integer, primary_key=True) diff --git a/inventory/models/inventory.py b/inventory/models/inventory.py index 25c138b..5ea2154 100644 --- a/inventory/models/inventory.py +++ b/inventory/models/inventory.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from .image import Image from .users import User +from crudkit import CrudMixin from sqlalchemy import Boolean, ForeignKey, Identity, Index, Integer, Unicode, DateTime, text from sqlalchemy.orm import Mapped, mapped_column, relationship import datetime @@ -18,7 +19,7 @@ from .brands import Brand from .image import ImageAttachable from .users import User -class Inventory(db.Model, ImageAttachable): +class Inventory(db.Model, ImageAttachable, CrudMixin): __tablename__ = 'inventory' __table_args__ = ( Index('Inventory$Barcode', 'barcode'), diff --git a/inventory/models/items.py b/inventory/models/items.py index 75c06ef..b22baa1 100644 --- a/inventory/models/items.py +++ b/inventory/models/items.py @@ -2,12 +2,13 @@ from typing import List, Optional, TYPE_CHECKING if TYPE_CHECKING: from .inventory import Inventory +from crudkit import CrudMixin from sqlalchemy import Identity, Integer, Unicode from sqlalchemy.orm import Mapped, mapped_column, relationship from . import db -class Item(db.Model): +class Item(db.Model, CrudMixin): __tablename__ = 'item' id: Mapped[int] = mapped_column(Integer, Identity(start=1, increment=1), primary_key=True) diff --git a/inventory/models/room_functions.py b/inventory/models/room_functions.py index 72c1a1f..c218f80 100644 --- a/inventory/models/room_functions.py +++ b/inventory/models/room_functions.py @@ -2,12 +2,13 @@ from typing import List, Optional, TYPE_CHECKING if TYPE_CHECKING: from .rooms import Room +from crudkit import CrudMixin from sqlalchemy import Identity, Integer, Unicode from sqlalchemy.orm import Mapped, mapped_column, relationship from . import db -class RoomFunction(db.Model): +class RoomFunction(db.Model, CrudMixin): __tablename__ = 'room_function' id: Mapped[int] = mapped_column(Integer, Identity(start=1, increment=1), primary_key=True) diff --git a/inventory/models/rooms.py b/inventory/models/rooms.py index 3360272..0f8df32 100644 --- a/inventory/models/rooms.py +++ b/inventory/models/rooms.py @@ -5,12 +5,13 @@ if TYPE_CHECKING: from .inventory import Inventory from .users import User +from crudkit import CrudMixin from sqlalchemy import ForeignKey, Identity, Integer, Unicode from sqlalchemy.orm import Mapped, mapped_column, relationship from . import db -class Room(db.Model): +class Room(db.Model, CrudMixin): __tablename__ = 'rooms' id: Mapped[int] = mapped_column(Integer, Identity(start=1, increment=1), primary_key=True) diff --git a/inventory/models/users.py b/inventory/models/users.py index 36a23e2..325d9df 100644 --- a/inventory/models/users.py +++ b/inventory/models/users.py @@ -5,13 +5,14 @@ if TYPE_CHECKING: from .work_log import WorkLog from .image import Image +from crudkit import CrudMixin from sqlalchemy import Boolean, ForeignKey, Identity, Integer, Unicode, text from sqlalchemy.orm import Mapped, mapped_column, relationship from . import db from .image import ImageAttachable -class User(db.Model, ImageAttachable): +class User(db.Model, ImageAttachable, CrudMixin): __tablename__ = 'users' id: Mapped[int] = mapped_column(Integer, Identity(start=1, increment=1), primary_key=True) diff --git a/inventory/models/work_log.py b/inventory/models/work_log.py index 796cbd9..c365cb3 100644 --- a/inventory/models/work_log.py +++ b/inventory/models/work_log.py @@ -5,6 +5,7 @@ if TYPE_CHECKING: from .users import User from .work_note import WorkNote +from crudkit import CrudMixin from sqlalchemy import Boolean, Identity, Integer, ForeignKey, Unicode, DateTime, text from sqlalchemy.orm import Mapped, mapped_column, relationship import datetime @@ -14,7 +15,7 @@ from .image import ImageAttachable from .image_links import worklog_images from .work_note import WorkNote -class WorkLog(db.Model, ImageAttachable): +class WorkLog(db.Model, ImageAttachable, CrudMixin): __tablename__ = 'work_log' id: Mapped[int] = mapped_column(Integer, Identity(start=1, increment=1), primary_key=True) diff --git a/inventory/models/work_note.py b/inventory/models/work_note.py index ce5a642..d954569 100644 --- a/inventory/models/work_note.py +++ b/inventory/models/work_note.py @@ -1,11 +1,12 @@ import datetime +from crudkit import CrudMixin from sqlalchemy import ForeignKey, DateTime, UnicodeText, func from sqlalchemy.orm import Mapped, mapped_column, relationship from . import db -class WorkNote(db.Model): +class WorkNote(db.Model, CrudMixin): __tablename__ = 'work_note' id: Mapped[int] = mapped_column(primary_key=True)