Compare commits

..

No commits in common. "2d837210c12636c68da1bfa2eb9fb306cbfcb604" and "091db0b443cca4127e7bf1dc0b858dce3a3f457f" have entirely different histories.

52 changed files with 1168 additions and 2800 deletions

View file

@ -1,17 +1,8 @@
from .backend import BackendInfo, make_backend_info
from .config import Config, DevConfig, TestConfig, ProdConfig, get_config, build_database_url
from .engines import CRUDKitRuntime, build_engine, build_sessionmaker
from .integration import CRUDKit
from .mixins import CrudMixin
from .dsl import QuerySpec
from .eager import default_eager_policy
from .service import CrudService
from .serialize import serialize
from .blueprint import make_blueprint
__all__ = [
"Config", "DevConfig", "TestConfig", "ProdConfig", "get_config", "build_database_url",
"CRUDKitRuntime", "build_engine", "build_sessionmaker", "BackendInfo", "make_backend_info"
]
runtime = CRUDKitRuntime()
crud: CRUDKit | None = None
def init_crud(app):
global crud
crud = CRUDKit(app, runtime)
return crud
__all__ = ["CrudMixin", "QuerySpec", "default_eager_policy", "CrudService", "serialize", "make_blueprint"]

View file

@ -1,14 +0,0 @@
from __future__ import annotations
from sqlalchemy import event
from sqlalchemy.engine import Engine
def apply_sqlite_pragmas(engine: Engine, pragmas: dict[str, str]) -> None:
if not str(engine.url).startswith("sqlite://"):
return
@event.listens_for(engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
for key, value in pragmas.items():
cursor.execute(f"PRAGMA {key}={value}")
cursor.close()

View file

@ -1,21 +0,0 @@
import base64, json
from typing import Any
def encode_cursor(values: list[Any] | None, desc_flags: list[bool], backward: bool) -> str | None:
if not values:
return None
payload = {"v": values, "d": desc_flags, "b": backward}
return base64.urlsafe_b64encode(json.dumps(payload).encode()).decode()
def decode_cursor(token: str | None) -> tuple[list[Any] | None, bool] | tuple[None, bool]:
if not token:
return None, False
try:
obj = json.loads(base64.urlsafe_b64decode(token.encode()).decode())
vals = obj.get("v")
backward = bool(obj.get("b", False))
if isinstance(vals, list):
return vals, backward
except Exception:
pass
return None, False

View file

@ -1,90 +0,0 @@
from flask import Blueprint, jsonify, request
from crudkit.api._cursor import encode_cursor, decode_cursor
from crudkit.core.service import _is_truthy
def generate_crud_blueprint(model, service):
bp = Blueprint(model.__name__.lower(), __name__)
@bp.get('/')
def list_items():
args = request.args.to_dict(flat=True)
# legacy detection
legacy_offset = "offset" in args or "page" in args
# sane limit default
try:
limit = int(args.get("limit", 50))
except Exception:
limit = 50
args["limit"] = limit
if legacy_offset:
# Old behavior: honor limit/offset, same CRUDSpec goodies
items = service.list(args)
return jsonify([obj.as_dict() for obj in items])
# New behavior: keyset seek with cursors
key, backward = decode_cursor(args.get("cursor"))
window = service.seek_window(
args,
key=key,
backward=backward,
include_total=_is_truthy(args.get("include_total", "1")),
)
desc_flags = list(window.order.desc)
body = {
"items": [obj.as_dict() for obj in window.items],
"limit": window.limit,
"next_cursor": encode_cursor(window.last_key, desc_flags, backward=False),
"prev_cursor": encode_cursor(window.first_key, desc_flags, backward=True),
"total": window.total,
}
resp = jsonify(body)
# Optional Link header
links = []
if body["next_cursor"]:
links.append(f'<{request.base_url}?cursor={body["next_cursor"]}&limit={window.limit}>; rel="next"')
if body["prev_cursor"]:
links.append(f'<{request.base_url}?cursor={body["prev_cursor"]}&limit={window.limit}>; rel="prev"')
if links:
resp.headers["Link"] = ", ".join(links)
return resp
@bp.get('/<int:id>')
def get_item(id):
item = service.get(id, request.args)
try:
return jsonify(item.as_dict())
except Exception as e:
return jsonify({"status": "error", "error": str(e)})
@bp.post('/')
def create_item():
obj = service.create(request.json)
try:
return jsonify(obj.as_dict())
except Exception as e:
return jsonify({"status": "error", "error": str(e)})
@bp.patch('/<int:id>')
def update_item(id):
obj = service.update(id, request.json)
try:
return jsonify(obj.as_dict())
except Exception as e:
return jsonify({"status": "error", "error": str(e)})
@bp.delete('/<int:id>')
def delete_item(id):
service.delete(id)
try:
return jsonify({"status": "success"}), 204
except Exception as e:
return jsonify({"status": "error", "error": str(e)})
return bp

View file

@ -1,122 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Tuple, Optional, Iterable
from contextlib import contextmanager
from sqlalchemy import text, func
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql import Select
@dataclass(frozen=True)
class BackendInfo:
name: str
version: Tuple[int, ...]
paramstyle: str
is_sqlite: bool
is_postgres: bool
is_mysql: bool
is_mssql: bool
supports_returning: bool
supports_ilike: bool
requires_order_by_for_offset: bool
max_bind_params: Optional[int]
@classmethod
def from_engine(cls, engine: Engine) -> "BackendInfo":
d = engine.dialect
name = d.name
version = tuple(getattr(d, "server_version_info", ()) or ())
is_pg = name in {"postgresql", "postgres"}
is_my = name == "mysql"
is_sq = name == "sqlite"
is_ms = name == "mssql"
supports_ilike = is_pg or is_my
supports_returning = is_pg or (is_sq and version >= (3, 35))
requires_order_by_for_offset = is_ms
max_bind_params = 999 if is_sq else None
return cls(
name=name,
version=version,
paramstyle=d.paramstyle,
is_sqlite=is_sq,
is_postgres=is_pg,
is_mysql=is_my,
is_mssql=is_ms,
supports_returning=supports_returning,
supports_ilike=supports_ilike,
requires_order_by_for_offset=requires_order_by_for_offset,
max_bind_params=max_bind_params,
)
def make_backend_info(engine: Engine) -> BackendInfo:
return BackendInfo.from_engine(engine)
def ci_like(column, value: str, backend: BackendInfo) -> ClauseElement:
"""
Portable save-insensitive LIKE.
Uses ILIKE where available, else lower() dance.
"""
pattern = f"%{value}%"
if backend.supports_ilike:
return column.ilike(pattern)
return func.lower(column).like(func.lower(text(":pattern"))).params(pattern=pattern)
def apply_pagination(sel: Select, backend: BackendInfo, *, page: int, per_page: int, default_order_by=None) -> Select:
"""
Portable pagination. MSSQL requires ORDER BY when using OFFSET
"""
page = max(1, int(page))
per_page = max(1, int(per_page))
offset = (page - 1) * per_page
if backend.requires_order_by_for_offset and not sel._order_by_clauses:
if default_order_by is None:
sel = sel.order_by(text("1"))
else:
sel = sel.order_by(default_order_by)
return sel.limit(per_page).offset(offset)
@contextmanager
def maybe_identify_insert(session: Session, table, backend: BackendInfo):
"""
For MSSQL tables with IDENTIFY PK when you need to insert explicit IDs.
No-op elsewhere.
"""
if not backend.is_mssql:
yield
return
full_name = f"{table.schema}.{table.name}" if table.schema else table.name
session.execute(text(f"SET IDENTIFY_INSERT {full_name} ON"))
try:
yield
finally:
session.execute(text(f"SET IDENTITY_INSERT {full_name} OFF"))
def chunked_in(column, values: Iterable, backend: BackendInfo, chunk_size: Optional[int] = None) -> ClauseElement:
"""
Build a safe large IN() filter respecting bund param limits.
Returns a disjunction of chunked IN clauses if needed.
"""
vals = list(values)
if not vals:
return text("1=0")
limit = chunk_size or backend.max_bind_params or len(vals)
if len(vals) <= limit:
return column.in_(vals)
parts = []
for i in range(0, len(vals), limit):
parts.append(column.in_(vals[i:i + limit]))
expr = parts[0]
for p in parts[1:]:
expr = expr | p
return expr

81
crudkit/blueprint.py Normal file
View file

@ -0,0 +1,81 @@
from flask import Blueprint, request, jsonify, abort
from sqlalchemy.orm import scoped_session
from .dsl import QuerySpec
from .service import CrudService
from .eager import default_eager_policy
from .serialize import serialize
def make_blueprint(db_session_factory, registry):
bp = Blueprint("crud", __name__)
def session(): return scoped_session(db_session_factory)()
@bp.get("/<model>/list")
def list_items(model):
Model = registry.get(model) or abort(404)
spec = QuerySpec(
filters=_parse_filters(request.args),
order_by=request.args.getlist("sort"),
page=request.args.get("page", type=int),
per_page=request.args.get("per_page", type=int),
expand=request.args.getlist("expand"),
fields=request.args.get("fields", type=lambda s: [x.strip() for x in s.split(",")] if s else None),
)
s = session(); svc = CrudService(s, default_eager_policy)
rows, total = svc.list(Model, spec)
data = [serialize(r, fields=spec.fields, expand=spec.expand) for r in rows]
return jsonify({"data": data, "total": total})
@bp.post("/<model>")
def create_item(model):
Model = registry.get(model) or abort(404)
payload = request.get_json() or {}
s = session(); svc = CrudService(s, default_eager_policy)
obj = svc.create(Model, payload)
s.commit()
return jsonify(serialize(obj)), 201
@bp.get("/<model>/<int:id>")
def read_item(model, id):
Model = registry.get(model) or abort(404)
spec = QuerySpec(expand=request.args.getlist("expand"),
fields=request.args.get("fields", type=lambda s: s.split(",")))
s = session(); svc = CrudService(s, default_eager_policy)
obj = svc.get(Model, id, spec) or abort(404)
return jsonify(serialize(obj, fields=spec.fields, expand=spec.expand))
@bp.patch("/<model>/<int:id>")
def update_item(model, id):
Model = registry.get(model) or abort(404)
s = session(); svc = CrudService(s, default_eager_policy)
obj = svc.get(Model, id, QuerySpec()) or abort(404)
payload = request.get_json() or {}
svc.update(obj, payload)
s.commit()
return jsonify(serialize(obj))
@bp.delete("/<model>/<int:id>")
def delete_item(model, id):
Model = registry.get(model) or abort(404)
s = session(); svc = CrudService(s, default_eager_policy)
obj = svc.get(Model, id, QuerySpec()) or abort(404)
svc.soft_delete(obj)
s.commit()
return jsonify({"status": "deleted"})
@bp.post("/<model>/<int:id>/undelete")
def undelete_item(model, id):
Model = registry.get(model) or abort(404)
s = session(); svc = CrudService(s, default_eager_policy)
obj = svc.get(Model, id, QuerySpec()) or abort(404)
svc.undelete(obj)
s.commit()
return jsonify({"status": "restored"})
return bp
def _parse_filters(args):
out = {}
for k, v in args.items():
if k in {"page", "per_page", "sort", "expand", "fields"}:
continue
out[k] = v
return out

View file

@ -1,243 +0,0 @@
from __future__ import annotations
import os
import os
from typing import Dict, Any, Optional, Type
from urllib.parse import quote_plus
from pathlib import Path
try:
from dotenv import load_dotenv
except Exception:
load_dotenv = None
def _load_dotenv_if_present() -> None:
"""
Load .env once if present. Priority rules:
1) CRUDKIT_DOTENV points to a file
2) Project root's .env (two dirs up from this file)
3) Current working directory .env
Env already present in the process takes precedence.
"""
if load_dotenv is None:
return
path_hint = os.getenv("CRUDKIT_DOTENV")
if path_hint:
p = Path(path_hint).resolve()
if p.exists():
load_dotenv(dotenv_path=p, override=True)
os.environ["CRUDKIT_DOTENV_LOADED"] = str(p)
return
repo_env = Path(__file__).resolve().parents[1] / ".env"
if repo_env.exists():
load_dotenv(dotenv_path=repo_env, override=True)
os.environ["CRUDKIT_DOTENV_LOADED"] = str(repo_env)
return
cwd_env = Path.cwd() / ".env"
if cwd_env.exists():
load_dotenv(dotenv_path=cwd_env, override=True)
os.environ["CRUDKIT_DOTENV_LOADED"] = str(cwd_env)
def _getenv(name: str, default: Optional[str] = None) -> Optional[str]:
"""Treat empty strings as missing. Hekos when OS env has DB_BACKEND=''."""
val = os.getenv(name)
if val is None or val.strip() == "":
return default
return val
def build_database_url(
*,
backend: Optional[str] = None,
url: Optional[str] = None,
user: Optional[str] = None,
password: Optional[str] = None,
host: Optional[str] = None,
port: Optional[str] = None,
database: Optional[str] = None,
driver: Optional[str] = None,
dsn: Optional[str] = None,
trusted: Optional[bool] = None,
options: Optional[Dict[str, str]] = None,
) -> str:
"""
Build a SQLAlchemy URL string. If "url" is provided, it wins.
Supported: sqlite, postgresql, mysql, mssql (pyodbc)
"""
if url:
return url
backend = (backend or "").lower().strip()
options = options or {}
if backend == "sqlite":
db_path = database or "app.db"
if db_path == ":memory:":
return "sqlite:///:memory:"
return f"sqlite:///{db_path}"
if backend in {"postgres", "postgresql"}:
driver = driver or "psycopg"
user = user or ""
password = password or ""
creds = f"{quote_plus(user)}:{quote_plus(password)}@" if user or password else ""
host = host or "localhost"
port = port or "5432"
database = database or "app"
qs = ""
if options:
qs = "?" + "&".join(f"{k}={quote_plus(v)}" for k, v in options.items())
return f"postgresql+{driver}://{creds}{host}:{port}/{database}{qs}"
if backend == "mysql":
driver = driver or "pymysql"
user = user or ""
password = password or ""
creds = f"{quote_plus(user)}:{quote_plus(password)}@" if user or password else ""
host = host or "localhost"
port = port or "3306"
database = database or "app"
qs = ""
if options:
qs = "?" + "&".join(f"{k}={quote_plus(v)}" for k, v in options.items())
return f"mysql+{driver}://{creds}{host}:{port}/{database}{qs}"
if backend in {"mssql", "sqlserver", "sqlsrv"}:
if dsn:
qs = ""
if options:
qs = "?" + "&".join(f"{k}={quote_plus(v)}" for k, v in options.items())
return f"mssql+pyodbc://@{quote_plus(dsn)}{qs}"
driver = driver or "ODBC Driver 18 for SQL Server"
host = host or "localhost"
port = port or "1433"
database = database or "app"
if trusted:
base_opts = {
"driver": driver,
"Trusted_Connection": "yes",
"Encrypt": "yes",
"TrustServerCertificate": "yes",
"MARS_Connection": "yes",
}
base_opts.update(options)
qs = "?" + "&".join(f"{k}={quote_plus(v)}" for k, v in base_opts.items())
return f"mssql+pyodbc://{host}:{port}/{database}{qs}"
user = user or ""
password = password or ""
creds = f"{quote_plus(user)}:{quote_plus(password)}@" if user or password else ""
base_opts = {
"driver": driver,
"Encrypt": "yes",
"TrustServerCertificate": "yes",
"MARS_Connection": "yes",
}
base_opts.update(options)
qs = "?" + "&".join(f"{k}={quote_plus(v)}" for k, v in base_opts.items())
return f"mssql+pyodbc://{creds}{host}:{port}/{database}{qs}"
raise ValueError(f"Unsupported backend: {backend!r}")
class Config:
"""
CRUDKit config: environment-first with sane defaults.
Designed to be subclassed by apps, but fine as-is.
"""
_dotenv_loaded = False
DEBUG = False
TESTING = False
SECRET_KEY = _getenv("SECRET_KEY", "dev-not-secret")
if not _dotenv_loaded:
_load_dotenv_if_present()
_dotenv_loaded = True
DATABASE_URL = build_database_url(
url=_getenv("DATABASE_URL"),
backend=_getenv("DB_BACKEND"),
user=_getenv("DB_USER"),
password=_getenv("DB_PASS"),
host=_getenv("DB_HOST"),
port=_getenv("DB_PORT"),
database=_getenv("DB_NAME"),
driver=_getenv("DB_DRIVER"),
dsn=_getenv("DB_DSN"),
trusted=bool(int(_getenv("DB_TRUSTED", "0"))),
options=None,
)
SQLALCHEMY_ECHO = bool(int(os.getenv("DB_ECHO", "0")))
POOL_SIZE = int(os.getenv("DB_POOL_SIZE", "5"))
MAX_OVERFLOW = int(os.getenv("DB_MAX_OVERFLOW", "10"))
POOL_TIMEOUT = int(os.getenv("DB_POOL_TIMEOUT", "30"))
POOL_RECYCLE = int(os.getenv("DB_POOL_RECYCLE", "1000"))
POOL_PRE_PING = True
SQLITE_PRAGMAS = {
"journal_mode": os.getenv("SQLITE_JOURNAL_MODE", "WAL"),
"foreign_keys": os.getenv("SQLITE_FOREIGN_KEYS", "ON"),
"synchronous": os.getenv("SQLITE_SYNCHRONOUS", "NORMAL"),
}
@classmethod
def engine_kwargs(cls) -> Dict[str, Any]:
url = cls.DATABASE_URL
kwargs: Dict[str, Any] = {
"echo": cls.SQLALCHEMY_ECHO,
"pool_pre_ping": cls.POOL_PRE_PING,
"future": True,
}
if url.startswith("sqlite://"):
kwargs["connect_args"] = {"check_same_thread": False}
kwargs.update(
{
"pool_size": cls.POOL_SIZE,
"max_overflow": cls.MAX_OVERFLOW,
"pool_timeout": cls.POOL_TIMEOUT,
"pool_recycle": cls.POOL_RECYCLE,
}
)
return kwargs
@classmethod
def session_kwargs(cls) -> Dict[str, Any]:
return {
"autoflush": False,
"autocommit": False,
"expire_on_commit": False,
"future": True,
}
class DevConfig(Config):
DEBUG = True
SQLALCHEMY_ECHO = bool(int(os.getenv("DB_ECHO", "1")))
class TestConfig(Config):
TESTING = True
DATABASE_URL = build_database_url(backend="sqlite", database=":memory:")
SQLALCHEMY_ECHO = False
class ProdConfig(Config):
DEBUG = False
SQLALCHEMY_ECHO = bool(int(os.getenv("DB_ECHO", "0")))
def get_config(name: str | None) -> Type[Config]:
"""
Resolve config by name. None -> environment variable CRUDKIT_ENV or 'dev'.
"""
env = (name or os.getenv("CRUDKIT_ENV") or "dev").lower()
if env in {"prod", "production"}:
return ProdConfig
if env in {"test", "testing", "ci"}:
return TestConfig
return DevConfig

View file

@ -1,56 +0,0 @@
from sqlalchemy import Column, Integer, DateTime, Boolean, String, JSON, func
from sqlalchemy.orm import declarative_mixin, declarative_base
Base = declarative_base()
@declarative_mixin
class CRUDMixin:
id = Column(Integer, primary_key=True)
created_at = Column(DateTime, default=func.now(), nullable=False)
updated_at = Column(DateTime, default=func.now(), nullable=False, onupdate=func.now())
def as_dict(self, fields: list[str] | None = None):
"""
Serialize the instance.
- If 'fields' (possibly dotted) is provided, emit exactly those keys.
- Else, if '__crudkit_projection__' is set on the instance, emit those keys.
- Else, fall back to all mapped columns on this class hierarchy.
Always includes 'id' when present unless explicitly excluded.
"""
if fields is None:
fields = getattr(self, "__crudkit_projection__", None)
if fields:
out = {}
if "id" not in fields and hasattr(self, "id"):
out["id"] = getattr(self, "id")
for f in fields:
cur = self
for part in f.split("."):
if cur is None:
break
cur = getattr(cur, part, None)
out[f] = cur
return out
result = {}
for cls in self.__class__.__mro__:
if hasattr(cls, "__table__"):
for column in cls.__table__.columns:
name = column.name
result[name] = getattr(self, name)
return result
class Version(Base):
__tablename__ = "versions"
id = Column(Integer, primary_key=True)
model_name = Column(String, nullable=False)
object_id = Column(Integer, nullable=False)
change_type = Column(String, nullable=False)
data = Column(JSON, nullable=True)
timestamp = Column(DateTime, default=func.now())
actor = Column(String, nullable=True)
meta = Column('metadata', JSON, nullable=True)

View file

@ -1,556 +0,0 @@
from typing import Any, Callable, Dict, Iterable, List, Tuple, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
from sqlalchemy import and_, func, inspect, or_, text
from sqlalchemy.engine import Engine, Connection
from sqlalchemy.orm import Load, Session, raiseload, selectinload, with_polymorphic, Mapper, RelationshipProperty, class_mapper
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.util import AliasedClass
from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import UnaryExpression
from crudkit.core.base import Version
from crudkit.core.spec import CRUDSpec
from crudkit.core.types import OrderSpec, SeekWindow
from crudkit.backend import BackendInfo, make_backend_info
def _loader_options_for_fields(root_alias, model_cls, fields: list[str]) -> list[Load]:
"""
For bare MANYTOONE names in fields (e.g. "location"), selectinload the relationship
and only fetch the related PK. This is enough for preselecting <select> inputs
without projecting the FK column on the root model.
"""
opts: list[Load] = []
if not fields:
return opts
mapper = class_mapper(model_cls)
for name in fields:
prop = mapper.relationships.get(name)
if not isinstance(prop, RelationshipProperty):
continue
if prop.direction.name != "MANYTOONE":
continue
rel_attr = getattr(root_alias, name)
target_cls = prop.mapper.class_
# load_only PK if present; else just selectinload
id_attr = getattr(target_cls, "id", None)
if id_attr is not None:
opts.append(selectinload(rel_attr).load_only(id_attr))
else:
opts.append(selectinload(rel_attr))
return opts
@runtime_checkable
class _HasID(Protocol):
id: int
@runtime_checkable
class _HasTable(Protocol):
__table__: Any
@runtime_checkable
class _HasADict(Protocol):
def as_dict(self) -> dict: ...
@runtime_checkable
class _SoftDeletable(Protocol):
is_deleted: bool
class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol):
"""Minimal surface that our CRUD service relies on. Soft-delete is optional."""
pass
T = TypeVar("T", bound=_CRUDModelProto)
def _is_truthy(val):
return str(val).lower() in ('1', 'true', 'yes', 'on')
class CRUDService(Generic[T]):
def __init__(
self,
model: Type[T],
session_factory: Callable[[], Session],
polymorphic: bool = False,
*,
backend: Optional[BackendInfo] = None
):
self.model = model
self._session_factory = session_factory
self.polymorphic = polymorphic
self.supports_soft_delete = hasattr(model, 'is_deleted')
# Cache backend info once. If not provided, derive from session bind.
bind = self.session.get_bind()
eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind)
self.backend = backend or make_backend_info(eng)
@property
def session(self) -> Session:
return self._session_factory()
def get_query(self):
if self.polymorphic:
poly = with_polymorphic(self.model, "*")
return self.session.query(poly), poly
return self.session.query(self.model), self.model
def _resolve_required_includes(self, root_alias: Any, rel_field_names: Dict[Tuple[str, ...], List[str]]) -> List[Any]:
"""
For each dotted path like ("location"), -> ["label"], look up the target
model's __crudkit_field_requires__ for the terminal field and produce
selectinload options prefixed with the relationship path, e.g.:
Room.__crudkit_field_requires__['label'] = ['room_function']
=> selectinload(root.location).selectinload(Room.room_function)
"""
opts: List[Any] = []
root_mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
for path, names in (rel_field_names or {}).items():
if not path:
continue
current_alias = root_alias
current_mapper = root_mapper
rel_props: List[RelationshipProperty] = []
valid = True
for step in path:
rel = current_mapper.relationships.get(step)
if rel is None:
valid = False
break
rel_props.append(rel)
current_mapper = cast(Mapper[Any], inspect(rel.entity.entity))
if not valid:
continue
target_cls = current_mapper.class_
requires = getattr(target_cls, "__crudkit_field_requires__", None)
if not isinstance(requires, dict):
continue
for field_name in names:
needed: Iterable[str] = requires.get(field_name, [])
for rel_need in needed:
loader = selectinload(getattr(root_alias, rel_props[0].key))
for rp in rel_props[1:]:
loader = loader.selectinload(getattr(getattr(root_alias, rp.parent.class_.__name__.lower(), None) or rp.parent.class_, rp.key))
loader = loader.selectinload(getattr(target_cls, rel_need))
opts.append(loader)
return opts
def _extract_order_spec(self, root_alias, given_order_by):
"""
SQLAlchemy 2.x only:
Normalize order_by into (cols, desc_flags). Supports plain columns and
col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses.
"""
from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import UnaryExpression
given = self._stable_order_by(root_alias, given_order_by)
cols, desc_flags = [], []
for ob in given:
# Unwrap column if this is a UnaryExpression produced by .asc()/.desc()
elem = getattr(ob, "element", None)
col = elem if elem is not None else ob # don't use "or" with SA expressions
# Detect direction in SA 2.x
is_desc = False
dir_attr = getattr(ob, "_direction", None)
if dir_attr is not None:
is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC")
elif isinstance(ob, UnaryExpression):
op = getattr(ob, "operator", None)
is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC")
cols.append(col)
desc_flags.append(bool(is_desc))
from crudkit.core.types import OrderSpec
return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags))
def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool):
"""
Build lexicographic predicate for keyset seek.
For backward traversal, import comparisons.
"""
if not key_vals:
return None
conds = []
for i, col in enumerate(spec.cols):
ties = [spec.cols[j] == key_vals[j] for j in range(i)]
is_desc = spec.desc[i]
if not backward:
op = col < key_vals[i] if is_desc else col > key_vals[i]
else:
op = col > key_vals[i] if is_desc else col < key_vals[i]
conds.append(and_(*ties, op))
return or_(*conds)
def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]:
out = []
for c in spec.cols:
key = getattr(c, "key", None) or getattr(c, "name", None)
out.append(getattr(obj, key))
return out
def seek_window(
self,
params: dict | None = None,
*,
key: list[Any] | None = None,
backward: bool = False,
include_total: bool = True,
) -> "SeekWindow[T]":
"""
Transport-agnostic keyset pagination that preserves all the goodies from `list()`:
- filters, includes, joins, field projection, eager loading, soft-delete
- deterministic ordering (user sort + PK tiebreakers)
- forward/backward seek via `key` and `backward`
Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total.
"""
session = self.session
query, root_alias = self.get_query()
spec = CRUDSpec(self.model, params or {}, root_alias)
filters = spec.parse_filters()
order_by = spec.parse_sort()
root_fields, rel_field_names, root_field_names = spec.parse_fields()
for path, names in (rel_field_names or {}).items():
if "label" in names:
rel_name = path[0]
rel_attr = getattr(root_alias, rel_name, None)
if rel_attr is not None:
query = query.options(selectinload(rel_attr))
# Soft delete filter
if self.supports_soft_delete and not _is_truthy(params.get("include_deleted")):
query = query.filter(getattr(root_alias, "is_deleted") == False)
# Parse filters first
if filters:
query = query.filter(*filters)
# Includes + joins (so relationship fields like brand.name, location.label work)
spec.parse_includes()
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias)
query = query.join(target, rel_attr.of_type(target), isouter=True)
# Fields/projection: load_only for root columns, eager loads for relationships
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols))
for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
query = query.options(eager)
for opt in self._resolve_required_includes(root_alias, rel_field_names):
query = query.options(opt)
# Order + limit
order_spec = self._extract_order_spec(root_alias, order_by) # SA2-only helper
limit, _ = spec.parse_pagination()
if not limit or limit <= 0:
limit = 50 # sensible default
# Keyset predicate
if key:
pred = self._key_predicate(order_spec, key, backward)
if pred is not None:
query = query.filter(pred)
# Apply ordering. For backward, invert SQL order then reverse in-memory for display.
if not backward:
clauses = []
for col, is_desc in zip(order_spec.cols, order_spec.desc):
clauses.append(col.desc() if is_desc else col.asc())
query = query.order_by(*clauses).limit(limit)
items = query.all()
else:
inv_clauses = []
for col, is_desc in zip(order_spec.cols, order_spec.desc):
inv_clauses.append(col.asc() if is_desc else col.desc())
query = query.order_by(*inv_clauses).limit(limit)
items = list(reversed(query.all()))
# Tag projection so your renderer knows what fields were requested
proj = []
if root_field_names:
proj.extend(root_field_names)
if root_fields:
proj.extend(c.key for c in root_fields)
for path, names in (rel_field_names or {}).items():
prefix = ".".join(path)
for n in names:
proj.append(f"{prefix}.{n}")
if proj and "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
if proj:
for obj in items:
try:
setattr(obj, "__crudkit_projection__", tuple(proj))
except Exception:
pass
# Boundary keys for cursor encoding in the API layer
first_key = self._pluck_key(items[0], order_spec) if items else None
last_key = self._pluck_key(items[-1], order_spec) if items else None
# Optional total thats safe under JOINs (COUNT DISTINCT ids)
total = None
if include_total:
base = self.session.query(getattr(root_alias, "id"))
if self.supports_soft_delete and not _is_truthy(params.get("include_deleted")):
base = base.filter(getattr(root_alias, "is_deleted") == False)
if filters:
base = base.filter(*filters)
# replicate the same joins used above
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias)
base = base.join(target, rel_attr.of_type(target), isouter=True)
total = self.session.query(func.count()).select_from(base.order_by(None).distinct().subquery()).scalar() or 0
from crudkit.core.types import SeekWindow # avoid circulars at module top
return SeekWindow(
items=items,
limit=limit,
first_key=first_key,
last_key=last_key,
order=order_spec,
total=total,
)
# Helper: default ORDER BY for MSSQL when paginating without explicit order
def _default_order_by(self, root_alias):
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
cols = []
for col in mapper.primary_key:
try:
cols.append(getattr(root_alias, col.key))
except AttributeError:
cols.append(col)
return cols or [text("1")]
def _stable_order_by(self, root_alias, given_order_by):
"""
Ensure deterministic ordering by appending PK columns as tiebreakers.
If no order is provided, fall back to default primary-key order.
"""
order_by = list(given_order_by or [])
if not order_by:
return self._default_order_by(root_alias)
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
pk_cols = []
for col in mapper.primary_key:
try:
pk_cols.append(getattr(root_alias, col.key))
except ArithmeticError:
pk_cols.append(col)
return [*order_by, *pk_cols]
def get(self, id: int, params=None) -> T | None:
query, root_alias = self.get_query()
include_deleted = False
root_fields = []
root_field_names = {}
rel_field_names = {}
spec = CRUDSpec(self.model, params or {}, root_alias)
if params:
if self.supports_soft_delete:
include_deleted = _is_truthy(params.get('include_deleted'))
if self.supports_soft_delete and not include_deleted:
query = query.filter(getattr(root_alias, "is_deleted") == False)
query = query.filter(getattr(root_alias, "id") == id)
spec.parse_includes()
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias)
query = query.join(target, rel_attr.of_type(target), isouter=True)
if params:
root_fields, rel_field_names, root_field_names = spec.parse_fields()
fields = (params or {}).get("fields") if isinstance(params, dict) else None
if fields:
for opt in _loader_options_for_fields(root_alias, self.model, fields):
query = query.options(opt)
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols))
for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
query = query.options(eager)
if params:
fields = params.get("fields") or []
for opt in _loader_options_for_fields(root_alias, self.model, fields):
query = query.options(opt)
obj = query.first()
proj = []
if root_field_names:
proj.extend(root_field_names)
if root_fields:
proj.extend(c.key for c in root_fields)
for path, names in (rel_field_names or {}).items():
prefix = ".".join(path)
for n in names:
proj.append(f"{prefix}.{n}")
if proj and "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
if proj:
try:
setattr(obj, "__crudkit_projection__", tuple(proj))
except Exception:
pass
return obj or None
def list(self, params=None) -> list[T]:
query, root_alias = self.get_query()
root_fields = []
root_field_names = {}
rel_field_names = {}
if params:
if self.supports_soft_delete:
include_deleted = _is_truthy(params.get('include_deleted'))
if not include_deleted:
query = query.filter(getattr(root_alias, "is_deleted") == False)
spec = CRUDSpec(self.model, params or {}, root_alias)
filters = spec.parse_filters()
order_by = spec.parse_sort()
limit, offset = spec.parse_pagination()
spec.parse_includes()
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias)
query = query.join(target, rel_attr.of_type(target), isouter=True)
if params:
root_fields, rel_field_names, root_field_names = spec.parse_fields()
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols))
for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
query = query.options(eager)
if params:
fields = params.get("fields") or []
for opt in _loader_options_for_fields(root_alias, self.model, fields):
query = query.options(opt)
if filters:
query = query.filter(*filters)
# MSSQL: requires ORDER BY when using OFFSET (and SQLA will use OFFSET for limit+offset).
paginating = (limit is not None) or (offset is not None and offset != 0)
if paginating and not order_by and self.backend.requires_order_by_for_offset:
order_by = self._default_order_by(root_alias)
if order_by:
query = query.order_by(*order_by)
# Only apply offset/limit when not None.
if offset is not None and offset != 0:
query = query.offset(offset)
if limit is not None and limit > 0:
query = query.limit(limit)
rows = query.all()
proj = []
if root_field_names:
proj.extend(root_field_names)
if root_fields:
proj.extend(c.key for c in root_fields)
for path, names in (rel_field_names or {}).items():
prefix = ".".join(path)
for n in names:
proj.append(f"{prefix}.{n}")
if proj and "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
if proj:
for obj in rows:
try:
setattr(obj, "__crudkit_projection__", tuple(proj))
except Exception:
pass
return rows
def create(self, data: dict, actor=None) -> T:
obj = self.model(**data)
self.session.add(obj)
self.session.commit()
self._log_version("create", obj, actor)
return obj
def update(self, id: int, data: dict, actor=None) -> T:
obj = self.get(id)
if not obj:
raise ValueError(f"{self.model.__name__} with ID {id} not found.")
valid_fields = {c.name for c in self.model.__table__.columns}
for k, v in data.items():
if k in valid_fields:
setattr(obj, k, v)
self.session.commit()
self._log_version("update", obj, actor)
return obj
def delete(self, id: int, hard: bool = False, actor = False):
obj = self.session.get(self.model, id)
if not obj:
return None
if hard or not self.supports_soft_delete:
self.session.delete(obj)
else:
soft = cast(_SoftDeletable, obj)
soft.is_deleted = True
self.session.commit()
self._log_version("delete", obj, actor)
return obj
def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict = {}):
try:
data = obj.as_dict()
except Exception:
data = {"error": "Failed to serialize object."}
version = Version(
model_name=self.model.__name__,
object_id=obj.id,
change_type=change_type,
data=data,
actor=str(actor) if actor else None,
meta=metadata
)
self.session.add(version)
self.session.commit()

View file

@ -1,189 +0,0 @@
from typing import List, Tuple, Set, Dict, Optional
from sqlalchemy import asc, desc
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import aliased, selectinload
from sqlalchemy.orm.attributes import InstrumentedAttribute
OPERATORS = {
'eq': lambda col, val: col == val,
'lt': lambda col, val: col < val,
'lte': lambda col, val: col <= val,
'gt': lambda col, val: col > val,
'gte': lambda col, val: col >= val,
'ne': lambda col, val: col != val,
'icontains': lambda col, val: col.ilike(f"%{val}%"),
}
class CRUDSpec:
def __init__(self, model, params, root_alias):
self.model = model
self.params = params
self.root_alias = root_alias
self.eager_paths: Set[Tuple[str, ...]] = set()
self.join_paths: List[Tuple[object, InstrumentedAttribute, object]] = []
self.alias_map: Dict[Tuple[str, ...], object] = {}
self._root_fields: List[InstrumentedAttribute] = []
self._rel_field_names: Dict[Tuple[str, ...], object] = {}
self.include_paths: Set[Tuple[str, ...]] = set()
def _resolve_column(self, path: str):
current_alias = self.root_alias
parts = path.split('.')
join_path: list[str] = []
for i, attr in enumerate(parts):
try:
attr_obj = getattr(current_alias, attr)
except AttributeError:
return None, None
prop = getattr(attr_obj, "property", None)
if prop is not None and hasattr(prop, "direction"):
join_path.append(attr)
path_key = tuple(join_path)
alias = self.alias_map.get(path_key)
if not alias:
alias = aliased(prop.mapper.class_)
self.alias_map[path_key] = alias
self.join_paths.append((current_alias, attr_obj, alias))
current_alias = alias
continue
if isinstance(attr_obj, InstrumentedAttribute) or getattr(attr_obj, "comparator", None) is not None or hasattr(attr_obj, "clauses"):
return attr_obj, tuple(join_path) if join_path else None
return None, None
def parse_includes(self):
raw = self.params.get('include')
if not raw:
return
tokens = [p.strip() for p in str(raw).split(',') if p.strip()]
for token in tokens:
_, join_path = self._resolve_column(token)
if join_path:
self.eager_paths.add(join_path)
else:
col, maybe = self._resolve_column(token + '.id')
if maybe:
self.eager_paths.add(maybe)
def parse_filters(self):
filters = []
for key, value in self.params.items():
if key in ('sort', 'limit', 'offset'):
continue
if '__' in key:
path_op = key.rsplit('__', 1)
if len(path_op) != 2:
continue
path, op = path_op
else:
path, op = key, 'eq'
col, join_path = self._resolve_column(path)
if col and op in OPERATORS:
filters.append(OPERATORS[op](col, value))
if join_path:
self.eager_paths.add(join_path)
return filters
def parse_sort(self):
sort_args = self.params.get('sort', '')
result = []
for part in sort_args.split(','):
part = part.strip()
if not part:
continue
if part.startswith('-'):
field = part[1:]
order = desc
else:
field = part
order = asc
col, join_path = self._resolve_column(field)
if col:
result.append(order(col))
if join_path:
self.eager_paths.add(join_path)
return result
def parse_pagination(self):
limit = int(self.params.get('limit', 100))
offset = int(self.params.get('offset', 0))
return limit, offset
def parse_fields(self):
"""
Parse ?fields=colA,colB,rel1.colC,rel1.rel2.colD
- Root fields become InstrumentedAttributes bound to root_alias.
- Related fields store attribute NAMES; we'll resolve them on the target class when building loader options.
Returns (root_fields, rel_field_names).
"""
raw = self.params.get('fields')
if not raw:
return [], {}, {}
if isinstance(raw, list):
tokens = []
for chunk in raw:
tokens.extend(p.strip() for p in str(chunk).split(',') if p.strip())
else:
tokens = [p.strip() for p in str(raw).split(',') if p.strip()]
root_fields: List[InstrumentedAttribute] = []
root_field_names: list[str] = []
rel_field_names: Dict[Tuple[str, ...], List[str]] = {}
for token in tokens:
col, join_path = self._resolve_column(token)
if not col:
continue
if join_path:
rel_field_names.setdefault(join_path, []).append(col.key)
self.eager_paths.add(join_path)
else:
root_fields.append(col)
root_field_names.append(getattr(col, "key", token))
seen = set()
root_fields = [c for c in root_fields if not (c.key in seen or seen.add(c.key))]
for k, names in rel_field_names.items():
seen2 = set()
rel_field_names[k] = [n for n in names if not (n in seen2 or seen2.add(n))]
self._root_fields = root_fields
self._rel_field_names = rel_field_names
return root_fields, rel_field_names, root_field_names
def get_eager_loads(self, root_alias, *, fields_map=None):
loads = []
for path in self.eager_paths:
current = root_alias
loader = None
for idx, name in enumerate(path):
rel_attr = getattr(current, name)
loader = selectinload(rel_attr) if loader is None else loader.selectinload(rel_attr)
# step to target class for the next hop
target_cls = rel_attr.property.mapper.class_
current = target_cls
# if final hop and we have a fields map, narrow columns
if fields_map and idx == len(path) - 1 and path in fields_map:
cols = []
for n in fields_map[path]:
attr = getattr(target_cls, n, None)
# Only include real column attributes; skip hybrids/expressions
if isinstance(attr, InstrumentedAttribute):
cols.append(attr)
# Only apply load_only if we have at least one real column
if cols:
loader = loader.load_only(*cols)
if loader is not None:
loads.append(loader)
return loads
def get_join_paths(self):
return self.join_paths

View file

@ -1,16 +0,0 @@
from dataclasses import dataclass
from typing import Any, Sequence
@dataclass(frozen=True)
class OrderSpec:
cols: Sequence[Any]
desc: Sequence[bool]
@dataclass
class SeekWindow:
items: list[Any]
limit: int
first_key: list[Any] | None
last_key: list[Any] | None
order: OrderSpec
total: int | None = None

147
crudkit/dsl.py Normal file
View file

@ -0,0 +1,147 @@
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional
from sqlalchemy import asc, desc, select, false
from sqlalchemy.inspection import inspect
@dataclass
class QuerySpec:
filters: Dict[str, Any] = field(default_factory=dict)
order_by: List[str] = field(default_factory=list)
page: Optional[int] = None
per_page: Optional[int] = None
expand: List[str] = field(default_factory=list)
fields: Optional[List[str]] = None
FILTER_OPS = {
"__eq": lambda c, v: c == v,
"__ne": lambda c, v: c != v,
"__lt": lambda c, v: c < v,
"__lte": lambda c, v: c <= v,
"__gt": lambda c, v: c > v,
"__gte": lambda c, v: c >= v,
"__ilike": lambda c, v: c.ilike(v),
"__in": lambda c, v: c.in_(v),
"__isnull": lambda c, v: (c.is_(None) if v else c.is_not(None))
}
def _split_filter_key(raw_key: str):
for op in sorted(FILTER_OPS.keys(), key=len, reverse=True):
if raw_key.endswith(op):
return raw_key[: -len(op)], op
return raw_key, None
def _ensure_wildcards(op_key, value):
if op_key == "__ilike" and isinstance(value, str) and "%" not in value and "_" not in value:
return f"%{value}%"
return value
def _related_predicate(Model, path_parts, op_key, value):
"""
Build EXISTS subqueries for dotted filters:
- scalar rels -> attr.has(inner_predicate)
- collection -> attr.any(inner_predicate)
"""
head, *rest = path_parts
# class-bound relationship attribute (InstrumentedAttribute)
attr = getattr(Model, head, None)
if attr is None:
return None
# relationship metadata if you need uselist + target model
rel = inspect(Model).relationships.get(head)
if rel is None:
return None
Target = rel.mapper.class_
if not rest:
# filtering directly on a relationship without a leaf column isn't supported
return None
if len(rest) == 1:
# final hop is a column on the related model
leaf = rest[0]
col = getattr(Target, leaf, None)
if col is None:
return None
pred = FILTER_OPS[op_key](col, value) if op_key else (col == value)
else:
# recurse deeper: owner.room.area.name__ilike=...
pred = _related_predicate(Target, rest, op_key, value)
if pred is None:
return None
# wrap at this hop using the *attribute*, not the RelationshipProperty
return attr.any(pred) if rel.uselist else attr.has(pred)
def split_sort_tokens(tokens):
simple, dotted = [], []
for tok in (tokens or []):
if not tok:
continue
key = tok.lstrip("-")
if ":" in key:
key = key.split(":", 1)[0]
(dotted if "." in key else simple).append(tok)
return simple, dotted
def build_query(Model, spec: QuerySpec, eager_policy=None):
stmt = select(Model)
# filter out soft-deleted rows
deleted_attr = getattr(Model, "deleted", None)
if deleted_attr is not None:
stmt = stmt.where(deleted_attr == false())
else:
is_deleted_attr = getattr(Model, "is_deleted", None)
if is_deleted_attr is not None:
stmt = stmt.where(is_deleted_attr == false())
# filters
for raw_key, val in spec.filters.items():
path, op_key = _split_filter_key(raw_key)
val = _ensure_wildcards(op_key, val)
if "." in path:
pred = _related_predicate(Model, path.split("."), op_key, val)
if pred is not None:
stmt = stmt.where(pred)
continue
col = getattr(Model, path, None)
if col is None:
continue
stmt = stmt.where(FILTER_OPS[op_key](col, val) if op_key else (col == val))
simple_sorts, _ = split_sort_tokens(spec.order_by)
for token in simple_sorts:
direction = "asc"
key = token
if token.startswith("-"):
direction = "desc"
key = token[1:]
if ":" in key:
key, d = key.rsplit(":", 1)
direction = "desc" if d.lower().startswith("d") else "asc"
if "." in key:
continue
col = getattr(Model, key, None)
if col is None:
continue
stmt = stmt.order_by(desc(col) if direction == "desc" else asc(col))
if not spec.order_by and spec.page and spec.per_page:
pk_cols = inspect(Model).primary_key
if pk_cols:
stmt = stmt.order_by(*(asc(c) for c in pk_cols))
# eager loading
if eager_policy:
opts = eager_policy(Model, spec.expand)
if opts:
stmt = stmt.options(*opts)
return stmt

75
crudkit/eager.py Normal file
View file

@ -0,0 +1,75 @@
from __future__ import annotations
from typing import Iterable, List, Sequence, Set
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import Load, joinedload, selectinload, RelationshipProperty
class EagerConfig:
def __init__(self, strict: bool = False, max_depth: int = 4):
self.strict = strict
self.max_depth = max_depth
def _rel(cls, name: str) -> RelationshipProperty | None:
return inspect(cls).relationships.get(name)
def _is_expandable(rel: RelationshipProperty) -> bool:
# Skip dynamic or viewonly collections; they dont support eagerload
return rel.lazy != "dynamic"
def default_eager_policy(Model, expand: Sequence[str], cfg: EagerConfig | None = None) -> List[Load]:
"""
Heuristic:
- many-to-one / one-to-one: joinedload
- one-to-many / many-to-many: selectinload
Accepts dotted paths like "author.publisher".
"""
if not expand:
return []
cfg = cfg or EagerConfig()
# normalize, dedupe, and prefer longer paths over their prefixes
raw: Set[str] = {p.strip() for p in expand if p and p.strip()}
# drop prefixes if a longer path exists (author, author.publisher -> keep only author.publisher)
pruned: Set[str] = set(raw)
for p in raw:
parts = p.split(".")
for i in range(1, len(parts)):
pruned.discard(".".join(parts[:i]))
opts: List[Load] = []
seen: Set[tuple] = set()
for path in sorted(pruned):
parts = path.split(".")
if len(parts) > cfg.max_depth:
if cfg.strict:
raise ValueError(f"expand path too deep: {path} (max {cfg.max_depth})")
continue
current_model = Model
# build the chain incrementally
loader: Load | None = None
ok = True
for i, name in enumerate(parts):
rel = _rel(current_model, name)
if not rel or not _is_expandable(rel):
ok = False
break
attr = getattr(current_model, name)
if loader is None:
loader = selectinload(attr) if rel.uselist else joinedload(attr)
else:
loader = loader.selectinload(attr) if rel.uselist else loader.joinedload(attr)
current_model = rel.mapper.class_
if not ok:
if cfg.strict:
raise ValueError(f"unknown or non-expandable relationship in expand path: {path}")
continue
key = (tuple(parts),)
if loader is not None and key not in seen:
opts.append(loader)
seen.add(key)
return opts

View file

@ -1,50 +0,0 @@
from __future__ import annotations
from typing import Type, Optional
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from .backend import make_backend_info, BackendInfo
from .config import Config, get_config
from ._sqlite import apply_sqlite_pragmas
def build_engine(config_cls: Type[Config] | None = None):
config_cls = config_cls or get_config(None)
engine = create_engine(config_cls.DATABASE_URL, **config_cls.engine_kwargs())
apply_sqlite_pragmas(engine, config_cls.SQLITE_PRAGMAS)
return engine
def build_sessionmaker(config_cls: Type[Config] | None = None, engine=None):
config_cls = config_cls or get_config(None)
engine = engine or build_engine(config_cls)
return sessionmaker(bind=engine, **config_cls.session_kwargs())
class CRUDKitRuntime:
"""
Lightweight container so CRUDKit can be given either:
- prebuild engine/sessionmaker, or
- a Config to build them lazily
"""
def __init__(self, *, engine=None, session_factory=None, config: Optional[Type[Config]] = None):
if engine is None and session_factory is None and config is None:
config = get_config(None)
self._config = config
self._engine = engine or (build_engine(config) if config else None)
self._session_factory = session_factory or (build_sessionmaker(config, self._engine) if config else None)
@property
def engine(self):
if self._engine is None and self._config:
self._engine = build_engine(self._config)
return self._engine
@property
def session_factory(self):
if self._session_factory is None:
if self._config and self._engine:
self._session_factory = build_sessionmaker(self._config, self._engine)
return self._session_factory
@property
def backend(self) -> BackendInfo:
if not hasattr(self, "_backend_info") or self._backend_info is None:
self._backend_info = make_backend_info(self.engine)
return self._backend_info

3
crudkit/html/__init__.py Normal file
View file

@ -0,0 +1,3 @@
from .ui_fragments import make_fragments_blueprint
__all__ = ["make_fragments_blueprint"]

View file

@ -0,0 +1,140 @@
{% macro options(items, value_attr="id", label_path="name", getp=None) -%}
{%- for obj in items -%}
<option value="{{ getp(obj, value_attr) }}">{{ getp(obj, label_path) }}</option>
{%- endfor -%}
{% endmacro %}
{% macro lis(items, label_path="name", sublabel_path=None, getp=None) -%}
{%- for obj in items -%}
<li data-id="{{ obj.id }}">
<div class="li-main">{{ getp(obj, label_path) }}</div>
{%- if sublabel_path %}
<div class="li-sub">{{ getp(obj, sublabel_path) }}</div>
{%- endif %}
</li>
{%- else -%}
<li class="empty"><em>No results.</em></li>
{%- endfor -%}
{% endmacro %}
{% macro rows(items, fields, getp=None) -%}
{%- for obj in items -%}
<tr id="row-{{ obj.id }}">
{%- for f in fields -%}
<td data-field="{{ f }}">{{ getp(obj, f) }}</td>
{%- endfor -%}
</tr>
{%- else -%}
<tr>
<td colspan="{{ fields|length }}"><em>No results.</em></td>
</tr>
{%- endfor -%}
{%- endmacro %}
{# helper: centralize the query string once #}
{% macro _q(model, page, per_page, sort, filters, fields_csv) -%}
/ui/{{ model }}/frag/rows
?page={{ page }}&per_page={{ per_page }}
{%- if sort %}&sort={{ sort }}{% endif -%}
{%- if fields_csv %}&fields_csv={{ fields_csv|urlencode }}{% endif -%}
{%- for k, v in (filters or {}).items() %}&{{ k }}={{ v|urlencode }}{% endfor -%}
{%- endmacro %}
{% macro pager(model, page, pages, per_page, sort, filters, fields_csv) -%}
{% set p = page|int %}
{% set pg = pages|int %}
{% set prev = 1 if p <= 1 else p - 1 %} {% set nxt=pg if p>= pg else p + 1 %}
<nav class="pager-nav" aria-label="Pagination">
{% if p > 1 %}
<button type="button" class="page-btn" data-page="1"
hx-get="{{ _q(model, 1, per_page, sort, filters, fields_csv) }}" hx-target="#rows" hx-swap="innerHTML"
aria-label="First page">First</button>
<button type="button" class="page-btn" data-page="{{ prev }}"
hx-get="{{ _q(model, prev, per_page, sort, filters, fields_csv) }}" hx-target="#rows" hx-swap="innerHTML"
rel="prev">Prev</button>
{% else %}
<button type="button" class="page-btn" disabled>First</button>
<button type="button" class="page-btn" disabled>Prev</button>
{% endif %}
<span aria-live="polite">Page {{ p }} / {{ pg }}</span>
{% if p < pg %} <button type="button" class="page-btn" data-page="{{ p + 1 }}"
hx-get="{{ _q(model, p + 1, per_page, sort, filters, fields_csv) }}" hx-target="#rows" hx-swap="innerHTML"
rel="next">Next</button>
<button type="button" class="page-btn" data-page="{{ pg }}"
hx-get="{{ _q(model, pg, per_page, sort, filters, fields_csv) }}" hx-target="#rows" hx-swap="innerHTML"
aria-label="Last page">Last</button>
{% else %}
<button type="button" class="page-btn" disabled>Next</button>
<button type="button" class="page-btn" disabled>Last</button>
{% endif %}
</nav>
{# one tiny listener to keep #pager-state in sync for every button #}
<script>
(function () {
const nav = document.currentScript.previousElementSibling;
nav.addEventListener('click', function (ev) {
const btn = ev.target.closest('.page-btn');
if (!btn || btn.disabled) return;
const page = btn.getAttribute('data-page');
if (!page) return;
const inp = document.querySelector('#pager-state input[name=page]');
if (inp) inp.value = page;
}, { capture: true });
})();
</script>
{%- endmacro %}
{% macro form(schema, action, method="POST", obj_id=None, hx=False, csrf_token=None) -%}
<form action="{{ action }}" method="post" {%- if hx %} hx-{{ "patch" if obj_id else "post" }}="{{ action }}"
hx-target="closest dialog, #modal-body, body" hx-swap="innerHTML" hx-disabled-elt="button[type=submit]" {%-
endif -%}>
{%- if csrf_token %}<input type="hidden" name="csrf_token" value="{{ csrf_token() }}">{% endif -%}
{%- if obj_id %}<input type="hidden" name="id" value="{{ obj_id }}">{% endif -%}
<input type="hidden" name="fields_csv" value="{{ request.args.get('fields_csv','id,name') }}">
{%- for f in schema -%}
<div class="field" data-name="{{ f.name }}">
{% set fid = 'f-' ~ f.name ~ '-' ~ (obj_id or 'new') %}
<label for="{{ fid }}">{{ f.label or f.name|replace('_',' ')|title }}</label>
{%- if f.type == "textarea" -%}
<textarea id="{{ fid }}" name="{{ f.name }}" {%- if f.required %} required{% endif %}{% if f.maxlength %}
maxlength="{{ f.maxlength }}" {% endif %}>{{ f.value or "" }}</textarea>
{%- elif f.type == "select" -%}
<select id="{{ fid }}" name="{{ f.name }}" {% if f.required %}required{% endif %}>
<option value="">{{ f.placeholder or ("Choose " ~ (f.label or f.name|replace('_',' ')|title)) }}
</option>
{% if f.multiple %}
{% set selected = (f.value or [])|list %}
{% for val, lbl in f.choices %}
<option value="{{ val }}" {{ 'selected' if val in selected else '' }}>{{ lbl }}</option>
{% endfor %}
{% else %}
{% for val, lbl in f.choices %}
<option value="{{ val }}" {{ 'selected' if (f.value|string)==(val|string) else '' }}>{{ lbl }}</option>
{% endfor %}
{% endif %}
</select>
{%- elif f.type == "checkbox" -%}
<input type="hidden" name="{{ f.name }}" value="0">
<input id="{{ fid }}" type="checkbox" name="{{ f.name }}" value="1" {{ "checked" if f.value else "" }}>
{%- else -%}
<input id="{{ fid }}" type="{{ f.type }}" name="{{ f.name }}"
value="{{ f.value if f.value is not none else '' }}" {%- if f.required %} required{% endif %} {%- if
f.maxlength %} maxlength="{{ f.maxlength }}" {% endif %}>
{%- endif -%}
{%- if f.help %}<div class="help">{{ f.help }}</div>{% endif -%}
</div>
{%- endfor -%}
<div class="actions">
<button type="submit">Save</button>
</div>
</form>
{%- endmacro %}

View file

@ -0,0 +1,3 @@
{% import "crudkit/_macros.html" as ui %}
{% set action = url_for('frags.save', model=model) %}
{{ ui.form(schema, action, method="POST", obj_id=obj.id if obj else None, hx=true) }}

View file

@ -0,0 +1,2 @@
{% import "crudkit/_macros.html" as ui %}
{{ ui.lis(items, label_path=label_path, sublabel_path=sublabel_path, getp=getp) }}

View file

@ -0,0 +1,3 @@
{# Renders only <option>...</option> rows #}
{% import "crudkit/_macros.html" as ui %}
{{ ui.options(items, value_attr=value_attr, label_path=label_path, getp=getp) }}

View file

@ -0,0 +1,2 @@
{% import 'crudkit/_macros.html' as ui %}
{{ ui.pager(model, page, pages, per_page, sort, filters, fields_csv) }}

View file

@ -0,0 +1,2 @@
{% import "crudkit/_macros.html" as ui %}
{{ ui.rows([obj], fields, getp=getp) }}

View file

@ -0,0 +1,2 @@
{% import "crudkit/_macros.html" as ui %}
{{ ui.rows(items, fields, getp=getp) }}

137
crudkit/html/type_map.py Normal file
View file

@ -0,0 +1,137 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy import select
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import Mapper, RelationshipProperty
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.sqltypes import (
String, Text, Unicode, UnicodeText,
Integer, BigInteger, SmallInteger, Float, Numeric, Boolean,
Date, DateTime, Time, JSON, Enum
)
CANDIDATE_LABELS = ("name", "title", "label", "display_name")
def _guess_label_attr(model_cls) -> str:
for cand in CANDIDATE_LABELS:
if hasattr(model_cls, cand):
return cand
return "id"
def _pretty(label: str) -> str:
return label.replace("_", " ").title()
def _column_input_type(col: Column) -> str:
t = col.type
if isinstance(t, (String, Unicode)):
return "text"
if isinstance(t, (Text, UnicodeText, JSON)):
return "textarea"
if isinstance(t, (Integer, SmallInteger, BigInteger)):
return "number"
if isinstance(t, (Float, Numeric)):
return "number"
if isinstance(t, Boolean):
return "checkbox"
if isinstance(t, Date):
return "date"
if isinstance(t, DateTime):
return "datetime-local"
if isinstance(t, Time):
return "time"
if isinstance(t, Enum):
return "select"
return "text"
def _enum_choices(col: Column) -> Optional[List[Tuple[str, str]]]:
t = col.type
if isinstance(t, Enum):
if t.enum_class:
return [(e.name, e.value) for e in t.enum_class]
if t.enums:
return [(v, v) for v in t.enums]
return None
def build_form_schema(model_cls, session, obj=None, *, include=None, exclude=None, fk_limit=200):
mapper: Mapper = inspect(model_cls)
include = set(include or [])
exclude = set(exclude or {"id", "created_at", "updated_at", "deleted", "version"})
fields = []
fields: List[Dict[str, Any]] = []
fk_map = {}
for rel in mapper.relationships:
for lc in rel.local_columns:
fk_map[lc.key] = rel
for attr in mapper.column_attrs:
col = attr.columns[0]
name = col.key
if include and name not in include:
continue
if name in exclude:
continue
field = {
"name": name,
"type": _column_input_type(col),
"required": not col.nullable,
"value": getattr(obj, name, None) if obj is not None else None,
"placeholder": "",
"help": "",
# default label from column name
"label": _pretty(name),
}
enum_choices = _enum_choices(col)
if enum_choices:
field["type"] = "select"
field["choices"] = enum_choices
if name in fk_map:
rel = fk_map[name]
target = rel.mapper.class_
label_attr = _guess_label_attr(target)
rows = session.execute(select(target).limit(fk_limit)).scalars().all()
field["type"] = "select"
field["choices"] = [(getattr(r, "id"), getattr(r, label_attr)) for r in rows]
field["rel"] = {"target": target.__name__, "label_attr": label_attr}
field["label"] = _pretty(rel.key)
if getattr(col.type, "length", None):
field["maxlength"] = col.type.length
fields.append(field)
for rel in mapper.relationships:
if not rel.uselist or rel.secondary is None:
continue # only true many-to-many
if include and f"{rel.key}_ids" not in include:
continue
target = rel.mapper.class_
label_attr = _guess_label_attr(target)
choices = session.execute(select(target).limit(fk_limit)).scalars().all()
current = []
if obj is not None:
current = [getattr(x, "id") for x in getattr(obj, rel.key, []) or []]
fields.append({
"name": f"{rel.key}_ids", # e.g. "tags_ids"
"label": rel.key.replace("_"," ").title(),
"type": "select",
"multiple": True,
"required": False,
"choices": [(getattr(r,"id"), getattr(r,label_attr)) for r in choices],
"value": current, # list of selected IDs
"placeholder": f"Choose {rel.key.replace('_',' ').title()}",
"help": "",
})
if include:
order = list(include)
fields.sort(key=lambda f: order.index(f["name"]) if f["name"] in include else 10**9)
return fields

View file

@ -0,0 +1,269 @@
from __future__ import annotations
from typing import Any, Dict, List, Tuple
from math import ceil
from flask import Blueprint, request, render_template, abort, make_response
from sqlalchemy import select
from sqlalchemy.orm import scoped_session
from sqlalchemy.inspection import inspect
from sqlalchemy.sql.elements import UnaryExpression
from sqlalchemy.sql.sqltypes import Integer, Boolean, Date, DateTime, Float, Numeric
from ..dsl import QuerySpec
from ..service import CrudService
from ..eager import default_eager_policy
from .type_map import build_form_schema
Session = None
def make_fragments_blueprint(db_session_factory, registry: Dict[str, Any], *, name="frags"):
"""
HTML fragments for HTMX/Alpine. No base pages. Pure partials:
GET /<model>/frag/options -> <option>...</option>
GET /<model>/frag/lis -> <li>...</li>
GET /<model>/frag/rows -> <tr>...</tr> + pager markup if wanted
GET /<model>/frag/form -> <form>...</form> (auto-generated)
"""
global Session
if Session is None:
Session = scoped_session(db_session_factory)
bp = Blueprint(name, __name__, template_folder="templates")
def session():
return Session
@bp.teardown_app_request
def remove_session(exc=None):
Session.remove()
def _parse_filters(args):
reserved = {"page", "per_page", "sort", "expand", "fields", "value", "label", "label_tpl", "fields_csv", "li_label", "li_sublabel"}
out = {}
for k, v in args.items():
if k not in reserved and v != "":
out[k] = v
return out
def _paths_from_csv(csv: str) -> List[str]:
return [p.strip() for p in csv.split(",") if p.strip()]
def _collect_expand_from_paths(paths: List[str]) -> List[str]:
rels = set()
for p in paths:
bits = p.split(".")
if len(bits) > 1:
rels.add(bits[0])
return list(rels)
def _getp(obj, path: str):
cur = obj
for part in path.split("."):
cur = getattr(cur, part, None) if cur is not None else None
return cur
def _extract_m2m_lists(Model, req_form) -> dict[str, list[int]]:
"""Return {'tags': [1,2]} for any <rel>_ids fields; caller removes keys from main form."""
mapper = inspect(Model)
out = {}
for rel in mapper.relationships:
if not rel.uselist or rel.secondary is None:
continue
key = f"{rel.key}_ids"
ids = req_form.getlist(key)
if ids is None:
continue
out[rel.key] = [int(i) for i in ids if i]
return out
@bp.get("/<model>/frag/options")
def options(model):
Model = registry.get(model) or abort(404)
value_attr = request.args.get("value", default="id")
label_path = request.args.get("label", default="name")
filters = _parse_filters(request.args)
expand = _collect_expand_from_paths([label_path])
spec = QuerySpec(filters=filters, order_by=[], page=None, per_page=None, expand=expand)
s = session(); svc = CrudService(s, default_eager_policy)
items, _ = svc.list(Model, spec)
return render_template("crudkit/options.html", items=items, value_attr=value_attr, label_path=label_path, getp=_getp)
@bp.get("/<model>/frag/lis")
def lis(model):
Model = registry.get(model) or abort(404)
label_path = request.args.get("li_label", default="name")
sublabel_path = request.args.get("li_sublabel")
filters = _parse_filters(request.args)
sort = request.args.get("sort")
page = request.args.get("page", type=int)
per_page = request.args.get("per_page", type=int)
expand = _collect_expand_from_paths([p for p in (label_path, sublabel_path) if p])
spec = QuerySpec(filters=filters, order_by=[sort] if sort else [], page=page, per_page=per_page, expand=expand)
s = session(); svc = CrudService(s, default_eager_policy)
rows, total = svc.list(Model, spec)
pages = (ceil(total / per_page) if page and per_page else 1)
return render_template("crudkit/lis.html", items=rows, label_path=label_path, sublabel_path=sublabel_path, page=page or 1, per_page=per_page or 1, total=total, model=model, sort=sort, filters=filters, getp=_getp)
@bp.get("/<model>/frag/rows")
def rows(model):
Model = registry.get(model) or abort(404)
fields_csv = request.args.get("fields_csv") or "id,name"
fields = _paths_from_csv(fields_csv)
filters = _parse_filters(request.args)
sort = request.args.get("sort")
page = request.args.get("page", type=int) or 1
per_page = request.args.get("per_page", type=int) or 20
expand = _collect_expand_from_paths(fields + ([sort.split(":")[0]] if sort else []))
spec = QuerySpec(filters=filters, order_by=[sort] if sort else [], page=page, per_page=per_page, expand=expand)
s = session(); svc = CrudService(s, default_eager_policy)
rows, _ = svc.list(Model, spec)
html = render_template("crudkit/rows.html", items=rows, fields=fields, getp=_getp, model=model)
return html
@bp.get("/<model>/frag/pager")
def pager(model):
Model = registry.get(model) or abort(404)
page = request.args.get("page", type=int) or 1
print(page)
per_page = request.args.get("per_page", type=int) or 20
filters = _parse_filters(request.args)
sort = request.args.get("sort")
fields_csv = request.args.get("fields_csv") or "id,name"
fields = _paths_from_csv(fields_csv)
expand = _collect_expand_from_paths(fields + ([sort.split(":")[0]] if sort else []))
spec = QuerySpec(filters=filters, order_by=[sort] if sort else [], page=page, per_page=per_page, expand=expand)
s = session(); svc = CrudService(s, default_eager_policy)
_, total = svc.list(Model, spec)
pages = max(1, ceil(total / per_page))
html = render_template("crudkit/pager.html", model=model, page=page, pages=pages,
per_page=per_page, sort=sort, filters=filters, fields_csv=fields_csv)
return html
@bp.get("/<model>/frag/form")
def form(model):
Model = registry.get(model) or abort(404)
id = request.args.get("id", type=int)
include_csv = request.args.get("include")
include = [s.strip() for s in include_csv.split(",")] if include_csv else None
s = session(); svc = CrudService(s, default_eager_policy)
obj = svc.get(Model, id) if id else None
schema = build_form_schema(Model, s, obj=obj, include=include)
hx = request.args.get("hx", type=int) == 1
return render_template("crudkit/form.html", model=model, obj=obj, schema=schema, hx=hx)
def coerce_form_types(Model, data: dict) -> dict:
"""Turn HTML string inputs into the Python types your columns expect."""
mapper = inspect(Model)
for attr in mapper.column_attrs:
col = attr.columns[0]
name = col.key
if name not in data:
continue
v = data[name]
if v == "":
data[name] = None
continue
t = col.type
try:
if isinstance(t, Boolean):
data[name] = v in ("1", "true", "on", "yes", True)
elif isinstance(t, Integer):
data[name] = int(v)
elif isinstance(t, (Float, Numeric)):
data[name] = float(v)
elif isinstance(t, DateTime):
from datetime import datetime
data[name] = datetime.fromisoformat(v)
elif isinstance(t, Date):
from datetime import date
data[name] = date.fromisoformat(v)
except Exception:
# Leave as string; your validator can complain later.
pass
return data
@bp.post("/<model>/frag/save")
def save(model):
Model = registry.get(model) or abort(404)
s = session(); svc = CrudService(s, default_eager_policy)
# grab the raw form and fields to re-render
raw = request.form
form = raw.to_dict(flat=True)
fields_csv = form.pop("fields_csv", "id,name")
# many-to-many lists first
m2m = _extract_m2m_lists(Model, raw)
for rel_name in list(m2m.keys()):
form.pop(f"{rel_name}_ids", None)
# coerce primitives for regular columns
form = coerce_form_types(Model, form)
id_val = form.pop("id", None)
if id_val:
obj = svc.get(Model, int(id_val)) or abort(404)
svc.update(obj, form)
else:
obj = svc.create(Model, form)
# apply many-to-many selections
mapper = inspect(Model)
for rel_name, id_list in m2m.items():
rel = mapper.relationships[rel_name]
target = rel.mapper.class_
selected = []
if id_list:
selected = s.execute(select(target).where(target.id.in_(id_list))).scalars().all()
coll = getattr(obj, rel_name)
coll.clear()
coll.extend(selected)
s.commit()
rows_html = render_template(
"crudkit/row.html",
obj=obj,
fields=[p.strip() for p in fields_csv.split(",") if p.strip()],
getp=_getp,
)
resp = make_response(rows_html)
if id_val:
resp.headers["HX-Trigger"] = '{"toast":{"level":"success","message":"Updated"}}'
resp.headers["HX-Retarget"] = f"#row-{obj.id}"
resp.headers["HX-Reswap"] = "outerHTML"
else:
resp.headers["HX-Trigger"] = '{"toast":{"level":"success","message":"Created"}}'
resp.headers["HX-Retarget"] = "#rows"
resp.headers["HX-Reswap"] = "beforeend"
return resp
@bp.get("/_debug/<model>/schema")
def debug_model(model):
Model = registry[model]
from sqlalchemy.inspection import inspect
m = inspect(Model)
return {
"columns": [c.key for c in m.columns],
"relationships": [
{
"key": r.key,
"target": r.mapper.class_.__name__,
"uselist": r.uselist,
"local_cols": [c.key for c in r.local_columns],
} for r in m.relationships
],
}
return bp

View file

@ -1,25 +0,0 @@
from __future__ import annotations
from typing import Type
from flask import Flask
from crudkit.engines import CRUDKitRuntime
from .registry import CRUDRegistry
class CRUDKit:
def __init__(self, app: Flask, runtime: CRUDKitRuntime):
self.app = app
self.runtime = runtime
self.registry = CRUDRegistry(runtime)
def register(self, model: Type, **kwargs):
return self.registry.register_class(self.app, model, **kwargs)
def register_many(self, models: list[Type], **kwargs):
return self.registry.register_many(self.app, models, **kwargs)
def get_model(self, key: str):
return self.registry.get_model(key)
def get_service(self, model: Type):
return self.registry.get_service(model)

View file

@ -1,24 +0,0 @@
from __future__ import annotations
from contextlib import contextmanager
from fastapi import Depends
from sqlalchemy.orm import Session
from ..engines import CRUDKitRuntime
_runtime = CRUDKitRuntime()
@contextmanager
def _session_scope():
SessionLocal = _runtime.session_factory
session: Session = SessionLocal()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
def get_db():
with _session_scope() as s:
yield s

View file

@ -1,20 +0,0 @@
from __future__ import annotations
from flask import Flask
from sqlalchemy.orm import scoped_session
from ..engines import CRUDKitRuntime
from ..config import Config
def init_app(app: Flask, *, runtime: CRUDKitRuntime | None = None, config: type[Config] | None == None):
"""
Initializes CRUDKit for a Flask app. Provies `app.extensions['crudkit']`
with a runtime (engine + session_factory). Caller manages session lifecycle.
"""
runtime = runtime or CRUDKitRuntime(config=config)
app.extensions.setdefault("crudkit", {})
app.extensions["crudkit"]["runtime"] = runtime
Session = runtime.session_factory
if Session is not None:
app.extensions["crudkit"]["Session"] = scoped_session(Session)
return runtime

23
crudkit/mixins.py Normal file
View file

@ -0,0 +1,23 @@
import datetime as dt
from sqlalchemy import Column, Integer, DateTime, Boolean
from sqlalchemy.orm import declared_attr
from sqlalchemy.ext.hybrid import hybrid_property
class CrudMixin:
id = Column(Integer, primary_key=True)
created_at = Column(DateTime, default=dt.datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=dt.datetime.utcnow, onupdate=dt.datetime.utcnow, nullable=False)
deleted = Column("deleted", Boolean, default=False, nullable=False)
version = Column(Integer, default=1, nullable=False)
@hybrid_property
def is_deleted(self):
return self.deleted
def mark_deleted(self):
self.deleted = True
self.version += 1
@declared_attr
def __mapper_args__(cls):
return {"version_id_col": cls.version}

View file

@ -1,120 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Optional, Type, TypeVar, cast
from flask import Flask
from sqlalchemy.orm import Session
from crudkit.core.service import CRUDService
from crudkit.api.flask_api import generate_crud_blueprint
from crudkit.engines import CRUDKitRuntime
T = TypeVar("T")
@dataclass
class Registered:
model: Type[Any]
service: CRUDService[Any]
blueprint_name: str
url_prefix: str
class CRUDRegistry:
"""
Binds:
- name -> model class
- model class -> CRUDService (using CRUDKitRuntime.session_factory)
- model class -> Flask blueprint (via generate_crud_blueprint)
"""
def __init__(self, runtime: CRUDKitRuntime):
self._rt = runtime
self._models_by_key: Dict[str, Type[Any]] = {}
self._services_by_model: Dict[Type[Any], CRUDService[Any]] = {}
self._bps_by_model: Dict[Type[Any], Registered] = {}
@staticmethod
def _key(model_or_name: Type[Any] | str) -> str:
return model_or_name.lower() if isinstance(model_or_name, str) else model_or_name.__name__.lower()
def get_model(self, key: str) -> Optional[Type[Any]]:
return self._models_by_key.get(key.lower())
def get_service(self, model: Type[T]) -> Optional[CRUDService[T]]:
return cast(Optional[CRUDService[T]], self._services_by_model.get(model))
def is_registered(self, model: Type[Any]) -> bool:
return model in self._services_by_model
def register_class(
self,
app: Flask,
model: Type[Any],
*,
url_prefix: Optional[str] = None,
blueprint_name: Optional[str] = None,
polymorphic: bool = False,
service_kwargs: Optional[dict] = None
) -> Registered:
"""
Register a model:
- store name -> class
- create a CRUDService bound to a Session from runtime.session_factory
- attach backend into from runtime.backend
- mount Flask blueprint at /api/<modelname> by default
Idempotent for each model.
"""
key = self._key(model)
self._models_by_key.setdefault(key, model)
svc = self._services_by_model.get(model)
if svc is None:
SessionMaker = self._rt.session_factory
if SessionMaker is None:
raise RuntimeError("CRUDKitRuntime.session_factory is not initialized.")
svc = CRUDService(
model,
session_factory=SessionMaker,
polymorphic=polymorphic,
backend=self._rt.backend,
**(service_kwargs or {}),
)
self._services_by_model[model] = svc
reg = self._bps_by_model.get(model)
if reg:
return reg
prefix = url_prefix or f"/api/{key}"
bp_name = blueprint_name or f"crudkit.{key}"
bp = generate_crud_blueprint(model, svc)
bp.name = bp_name
app.register_blueprint(bp, url_prefix=prefix)
reg = Registered(model=model, service=svc, blueprint_name=bp_name, url_prefix=prefix)
self._bps_by_model[model] = reg
return reg
def register_many(
self,
app: Flask,
models: list[Type[Any]],
*,
base_prefix: str = "/api",
polymorphic: bool = False,
service_kwargs: Optional[dict] = None,
) -> list[Registered]:
out: list[Registered] = []
for m in models:
key = self._key(m)
out.append(
self.register_class(
app,
m,
url_prefix=f"{base_prefix}/{key}",
polymorphic=polymorphic,
service_kwargs=service_kwargs,
)
)
return out

22
crudkit/serialize.py Normal file
View file

@ -0,0 +1,22 @@
def serialize(obj, *, fields=None, expand=None):
expand = set(expand or [])
fields = set(fields or [])
out = {}
# base columns
for col in obj.__table__.columns:
name = col.key
if fields and name not in fields:
continue
out[name] = getattr(obj, name)
# expansions
for rel in obj.__mapper__.relationships:
if rel.key not in expand:
continue
val = getattr(obj, rel.key)
if val is None:
out[rel.key] = None
elif rel.uselist:
out[rel.key] = [serialize(child) for child in val]
else:
out[rel.key] = serialize(val)
return out

169
crudkit/service.py Normal file
View file

@ -0,0 +1,169 @@
import sqlalchemy as sa
from sqlalchemy import func, asc
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, aliased
from sqlalchemy.inspection import inspect
from sqlalchemy.sql.elements import UnaryExpression
from .dsl import QuerySpec, build_query, split_sort_tokens
from .eager import default_eager_policy
def _dedup_order_by(ordering):
seen = set()
result = []
for ob in ordering:
col = ob.element if isinstance(ob, UnaryExpression) else ob
key = f"{col}-{getattr(ob, 'modifier', '')}-{getattr(ob, 'operator', '')}"
if key in seen:
continue
seen.add(key)
result.append(ob)
return result
def _parse_sort_token(token: str):
token = token.strip()
direction = "asc"
if token.startswith('-'):
direction = "desc"
token = token[1:]
if ":" in token:
key, dirpart = token.rsplit(":", 1)
direction = "desc" if dirpart.lower().startswith("d") else "asc"
return key, direction
return token, direction
def _apply_dotted_ordering(stmt, Model, sort_tokens):
"""
stmt: a select(Model) statement
sort_tokens: list[str] like ["owner.identifier", "-brand.name"]
Returns: (stmt, alias_cache)
"""
mapper = inspect(Model)
alias_cache = {} # maps a path like "owner" or "brand" to its alias
for tok in sort_tokens:
path, direction = _parse_sort_token(tok)
parts = [p for p in path.split(".") if p]
if not parts:
continue
entity = Model
current_mapper = mapper
alias_path = []
# Walk relationships for all but the last part
for rel_name in parts[:-1]:
rel = current_mapper.relationships.get(rel_name)
if rel is None:
# invalid sort key; skip quietly or raise
# raise ValueError(f"Unknown relationship {current_mapper.class_.__name__}.{rel_name}")
entity = None
break
alias_path.append(rel_name)
key = ".".join(alias_path)
if key in alias_cache:
entity_alias = alias_cache[key]
else:
# build an alias and join
entity_alias = aliased(rel.mapper.class_)
stmt = stmt.outerjoin(entity_alias, getattr(entity, rel.key))
alias_cache[key] = entity_alias
entity = entity_alias
current_mapper = inspect(rel.mapper.class_)
if entity is None:
continue
col_name = parts[-1]
# Validate final column
if col_name not in current_mapper.columns:
# raise ValueError(f"Unknown column {current_mapper.class_.__name__}.{col_name}")
continue
col = getattr(entity, col_name) if entity is not Model else getattr(Model, col_name)
stmt = stmt.order_by(col.desc() if direction == "desc" else col.asc())
return stmt
class CrudService:
def __init__(self, session: Session, eager_policy=default_eager_policy):
self.s = session
self.eager_policy = eager_policy
def create(self, Model, data, *, before=None, after=None):
if before: data = before(data) or data
obj = Model(**data)
self.s.add(obj)
self.s.flush()
if after: after(obj)
return obj
def get(self, Model, id, spec: QuerySpec | None = None):
spec = spec or QuerySpec()
stmt = build_query(Model, spec, self.eager_policy).where(Model.id == id)
return self.s.execute(stmt).scalars().first()
def list(self, Model, spec: QuerySpec):
stmt = build_query(Model, spec, self.eager_policy)
simple_sorts, dotted_sorts = split_sort_tokens(spec.order_by)
if dotted_sorts:
stmt = _apply_dotted_ordering(stmt, Model, dotted_sorts)
# count query
pk = getattr(Model, "id") # adjust if not 'id'
count_base = stmt.with_only_columns(sa.distinct(pk)).order_by(None)
total = self.s.execute(
sa.select(sa.func.count()).select_from(count_base.subquery())
).scalar_one()
if spec.page and spec.per_page:
offset = (spec.page - 1) * spec.per_page
stmt = stmt.limit(spec.per_page).offset(offset)
# ---- ORDER BY handling ----
mapper = inspect(Model)
pk_cols = mapper.primary_key
# Gather all clauses added so far
ordering = list(stmt._order_by_clauses)
# Append pk tie-breakers if not already present
existing_cols = {
str(ob.element if isinstance(ob, UnaryExpression) else ob)
for ob in ordering
}
for c in pk_cols:
if str(c) not in existing_cols:
ordering.append(asc(c))
# Dedup *before* applying
ordering = _dedup_order_by(ordering)
# Now wipe old order_bys and set once
stmt = stmt.order_by(None).order_by(*ordering)
rows = self.s.execute(stmt).scalars().all()
return rows, total
def update(self, obj, data, *, before=None, after=None):
if obj.is_deleted: raise ValueError("Cannot update a deleted record")
if before: data = before(obj, data) or data
for k, v in data.items(): setattr(obj, k, v)
obj.version += 1
if after: after(obj)
return obj
def soft_delete(self, obj, *, cascade=False, guard=None):
if guard and not guard(obj): raise ValueError("Delete blocked by guard")
# optionsl FK hygiene checks go here
obj.mark_deleted()
return obj
def undelete(self, obj):
obj.deleted = False
obj.version += 1
return obj

View file

View file

@ -1,839 +0,0 @@
import os
import re
from collections import OrderedDict
from flask import current_app, url_for
from jinja2 import Environment, FileSystemLoader, ChoiceLoader
from sqlalchemy import inspect
from sqlalchemy.orm import Load, RelationshipProperty, class_mapper, load_only, selectinload
from sqlalchemy.orm.base import NO_VALUE
from typing import Any, Dict, List, Optional, Tuple
_ALLOWED_ATTRS = {
"class", "placeholder", "autocomplete", "inputmode", "pattern",
"min", "max", "step", "maxlength", "minlength",
"required", "readonly", "disabled",
"multiple", "size",
"id", "name", "value",
}
def get_env():
app = current_app
default_path = os.path.join(os.path.dirname(__file__), 'templates')
fallback_loader = FileSystemLoader(default_path)
return app.jinja_env.overlay(
loader=ChoiceLoader([app.jinja_loader, fallback_loader])
)
def _normalize_rows_layout(layout: Optional[List[dict]]) -> Dict[str, dict]:
"""
Create node dicts for each row and link parent->children.
Node shape:
{
'name': str,
'legend': Optional[str],
'attrs': dict, # sanitized
'order': int,
'parent': Optional[str],
'children': list, # list of node names (we'll expand later)
'fields': list, # filled later
}
Always ensures a 'main' node exists.
"""
nodes: Dict[str, dict] = {}
def make_node(name: str) -> dict:
node = nodes.get(name)
if node is None:
node = nodes[name] = {
"name": name,
"legend": None,
"attrs": {},
"order": 0,
"parent": None,
"children": [],
"fields": [],
}
return node
# seed nodes from layout
if isinstance(layout, list):
for item in layout:
name = item.get("name")
if not isinstance(name, str) or not name:
continue
node = make_node(name)
node["legend"] = item.get("legend")
node["attrs"] = _sanitize_attrs(item.get("attrs") or {})
try:
node["order"] = int(item.get("order") or 0)
except Exception:
node["order"] = 0
parent = item.get("parent")
node["parent"] = parent if isinstance(parent, str) and parent else None
# ensure main exists and is early-ordered
main = make_node("main")
if "order" not in main or main["order"] == 0:
main["order"] = -10
# default any unknown parents to main (except main itself)
for n in list(nodes.values()):
if n["name"] == "main":
n["parent"] = None
continue
p = n["parent"]
if p is None or p not in nodes or p == n["name"]:
n["parent"] = "main"
# detect cycles defensively; break by reparenting to main
visiting = set()
visited = set()
def visit(name: str):
if name in visited:
return
if name in visiting:
# cycle; break this node to main
nodes[name]["parent"] = "main"
return
visiting.add(name)
parent = nodes[name]["parent"]
if parent is not None:
visit(parent)
visiting.remove(name)
visited.add(name)
for nm in list(nodes.keys()):
visit(nm)
# compute children lists
for n in nodes.values():
n["children"] = []
for n in nodes.values():
p = n["parent"]
if p is not None:
nodes[p]["children"].append(n["name"])
# sort children by (order, name) for deterministic rendering
for n in nodes.values():
n["children"].sort(key=lambda nm: (nodes[nm]["order"], nodes[nm]["name"]))
return nodes
def _assign_fields_to_rows(fields: List[dict], rows: Dict[str, dict]) -> List[dict]:
"""
Put fields into their target row buckets by name (default 'main'),
then return a list of root nodes expanded with nested dicts ready for templates.
"""
# assign fields
for f in fields:
row_name = f.get("row") or "main"
node = rows.get(row_name) or rows["main"]
node["fields"].append(f)
# expand tree into nested structures
def expand(name: str) -> dict:
n = rows[name]
return {
"name": n["name"],
"legend": n["legend"],
"attrs": n["attrs"],
"order": n["order"],
"fields": n["fields"],
"children": [expand(ch) for ch in n["children"]],
}
# roots are nodes with parent == None
roots = [expand(nm) for nm, n in rows.items() if n["parent"] is None]
roots.sort(key=lambda r: (r["order"], r["name"]))
return roots
def _sanitize_attrs(attrs: Any) -> dict[str, Any]:
"""
Whitelist attributes; allow data-* and aria-*; render True as boolean attr.
Drop False/None and anything not whitelisted.
"""
if not isinstance(attrs, dict):
return {}
out: dict[str, Any] = {}
for k, v in attrs.items():
if not isinstance(k, str):
continue
elif isinstance(v, str):
if len(v) > 512:
v = v[:512]
if k.startswith("data-") or k.startswith("aria-") or k in _ALLOWED_ATTRS:
if isinstance(v, bool):
if v:
out[k] = True
elif v is not None:
out[k] = str(v)
return out
class _SafeObj:
"""Attribute access that returns '' for missing/None instead of exploding."""
__slots__ = ("_obj",)
def __init__(self, obj): self._obj = obj
def __str__(self): return "" if self._obj is None else str(self._obj)
def __getattr__(self, name):
if self._obj is None:
return ""
val = getattr(self._obj, name, None)
if val is None:
return ""
return _SafeObj(val)
def _coerce_fk_value(values: dict | None, instance: Any, base: str):
"""
Resolve current selection for relationship 'base':
1) values['<base>_id']
2) values['<base>']['id'] or values['<base>'] if it's an int or numeric string
3) instance.<base> (if already loaded) -> use its .id [safe for detached]
4) instance.<base>_id (if already loaded and not expired)
Never trigger a lazy load.
"""
# 1) explicit *_id from values
if isinstance(values, dict):
key = f"{base}_id"
if key in values:
return values.get(key)
rel = values.get(base)
# 2a) nested dict with id
if isinstance(rel, dict):
vid = rel.get("id") or rel.get(key)
if vid is not None:
return vid
# 2b) scalar id
if isinstance(rel, int):
return rel
if isinstance(rel, str):
s = rel.strip()
if s.isdigit():
return s # template compares as strings, so this is fine
# 3) use loaded relationship object (safe even if instance is detached)
if instance is not None:
try:
state = inspect(instance)
rel_attr = state.attrs.get(base)
if rel_attr is not None and rel_attr.loaded_value is not NO_VALUE:
rel_obj = rel_attr.loaded_value
if rel_obj is not None:
rid = getattr(rel_obj, "id", None)
if rid is not None:
return rid
# 4) use loaded fk column if present and not expired
id_attr = state.attrs.get(f"{base}_id")
if id_attr is not None and id_attr.loaded_value is not NO_VALUE:
return id_attr.loaded_value
except Exception:
pass
return None
def _is_many_to_one(mapper, name: str) -> Optional[RelationshipProperty]:
try:
prop = mapper.relationships[name]
except Exception:
return None
if isinstance(prop, RelationshipProperty) and prop.direction.name == 'MANYTOONE':
return prop
return None
def _rel_for_id_name(mapper, name: str) -> tuple[Optional[str], Optional[RelationshipProperty]]:
if name.endswith("_id"):
base = name[:-3]
prop = _is_many_to_one(mapper, base)
return (base, prop) if prop else (None, None)
else:
prop = _is_many_to_one(mapper, name)
return (name, prop) if prop else (None, None)
def _fk_options(session, related_model, label_spec):
simple_cols, rel_paths = _extract_label_requirements(label_spec)
q = session.query(related_model)
col_attrs = []
if hasattr(related_model, "id"):
col_attrs.append(getattr(related_model, "id"))
for name in simple_cols:
if hasattr(related_model, name):
col_attrs.append(getattr(related_model, name))
if col_attrs:
q = q.options(load_only(*col_attrs))
for rel_name, col_name in rel_paths:
rel_prop = getattr(related_model, rel_name, None)
if rel_prop is None:
continue
try:
target_cls = related_model.__mapper__.relationships[rel_name].mapper.class_
col_attr = getattr(target_cls, col_name, None)
if col_attr is None:
q = q.options(selectinload(rel_prop))
else:
q = q.options(selectinload(rel_prop).load_only(col_attr))
except Exception:
q = q.options(selectinload(rel_prop))
if simple_cols:
first = simple_cols[0]
if hasattr(related_model, first):
q = q.order_by(getattr(related_model, first))
rows = q.all()
return [
{
'value': getattr(opt, 'id'),
'label': _label_from_obj(opt, label_spec),
}
for opt in rows
]
def _normalize_field_spec(spec, mapper, session, label_specs_model_default):
"""
Turn a user field spec into a concrete field dict the template understands.
"""
name = spec['name']
base_rel_name, rel_prop = _rel_for_id_name(mapper, name)
field = {
"name": name if not base_rel_name else f"{base_rel_name}_id",
"label": spec.get("label", name),
"type": spec.get("type"),
"options": spec.get("options"),
"attrs": spec.get("attrs"),
"label_attrs": spec.get("label_attrs"),
"wrap": spec.get("wrap"),
"row": spec.get("row"),
"help": spec.get("help"),
"template": spec.get("template"),
"template_name": spec.get("template_name"),
"template_ctx": spec.get("template_ctx"),
}
if rel_prop:
if field["type"] is None:
field["type"] = "select"
if field["type"] == "select" and field.get("options") is None and session is not None:
related_model = rel_prop.mapper.class_
label_spec = (
spec.get("label_spec")
or label_specs_model_default.get(base_rel_name)
or getattr(related_model, "__crud_label__", None)
or "id"
)
field["options"] = _fk_options(session, related_model, label_spec)
return field
col = mapper.columns.get(name)
if field["type"] is None:
if col is not None and hasattr(col.type, "python_type"):
py = None
try:
py = col.type.python_type
except Exception:
pass
if py is bool:
field["type"] = "checkbox"
else:
field["type"] = "text"
else:
field["type"] = "text"
return field
def _extract_label_requirements(spec: Any) -> tuple[list[str], list[tuple[str, str]]]:
"""
From a label spec, return:
- simple_cols: ["name", "code"]
- rel_paths: [("room_function", "description"), ("owner", "last_name")]
"""
simple_cols: list[str] = []
rel_paths: list[tuple[str, str]] = []
def ingest(token: str) -> None:
token = str(token).strip()
if not token:
return
if "." in token:
rel, col = token.split(".", 1)
if rel and col:
rel_paths.append((rel, col))
else:
simple_cols.append(token)
if spec is None or callable(spec):
return simple_cols, rel_paths
if isinstance(spec, (list, tuple)):
for a in spec:
ingest(a)
return simple_cols, rel_paths
if isinstance(spec, str):
# format string like "{first} {last}" or "{room_function.description} · {name}"
if "{" in spec and "}" in spec:
names = re.findall(r"{\s*([^}:\s]+)", spec)
for n in names:
ingest(n)
else:
ingest(spec)
return simple_cols, rel_paths
return simple_cols, rel_paths
def _attrs_from_label_spec(spec: Any) -> list[str]:
"""
Return a list of attribute names needed from the related model to compute the label.
Only simple attribute names are returned; dotted paths return just the first segment.
"""
if spec is None:
return []
if callable(spec):
return []
if isinstance(spec, (list, tuple)):
return [str(a).split(".", 1)[0] for a in spec]
if isinstance(spec, str):
if "{" in spec and "}" in spec:
names = re.findall(r"{\s*([^}:\s]+)", spec)
return [n.split(".", 1)[0] for n in names]
return [spec.split(".", 1)[0]]
return []
def _label_from_obj(obj: Any, spec: Any) -> str:
if spec is None:
for attr in ("label", "name", "title", "description"):
if hasattr(obj, attr):
val = getattr(obj, attr)
if not callable(val) and val is not None:
return str(val)
if hasattr(obj, "id"):
return str(getattr(obj, "id"))
return object.__repr__(obj)
if isinstance(spec, (list, tuple)):
parts = []
for a in spec:
cur = obj
for part in str(a).split("."):
cur = getattr(cur, part, None)
if cur is None:
break
parts.append("" if cur is None else str(cur))
return " ".join(p for p in parts if p)
if isinstance(spec, str) and "{" in spec and "}" in spec:
fields = re.findall(r"{\s*([^}:\s]+)", spec)
data: dict[str, Any] = {}
for f in fields:
root = f.split(".", 1)[0]
if root not in data:
val = getattr(obj, root, None)
data[root] = _SafeObj(val)
try:
return spec.format(**data)
except Exception:
return str(obj)
cur = obj
for part in str(spec).split("."):
cur = getattr(cur, part, None)
if cur is None:
return ""
return str(cur)
def _val_from_row_or_obj(row: Dict[str, Any], obj: Any, dotted: str) -> Any:
"""Best-effort deep get: try the projected row first, then the ORM object."""
val = _deep_get(row, dotted)
if val is None:
val = _deep_get_from_obj(obj, dotted)
return val
def _matches_simple_condition(row: Dict[str, Any], obj: Any, cond: Dict[str, Any]) -> bool:
"""
Supports:
{"field": "foo.bar", "eq": 10}
{"field": "foo", "ne": None}
{"field": "count", "gt": 0} (also lt, gte, lte)
{"field": "name", "in": ["a","b"]}
{"field": "thing", "is": None, | True | False}
{"any": [ ...subconds... ]} # OR
{"all": [ ...subconds... ]} # AND
{"not": { ...subcond... }} # NOT
"""
if "any" in cond:
return any(_matches_simple_condition(row, obj, c) for c in cond["any"])
if "all" in cond:
return all(_matches_simple_condition(row, obj, c) for c in cond["all"])
if "not" in cond:
return not _matches_simple_condition(row, obj, cond["not"])
field = cond.get("field")
val = _val_from_row_or_obj(row, obj, field) if field else None
if "is" in cond:
target = cond["is"]
if target is None:
return val is None
if isinstance(target, bool):
return bool(val) is target
return val is target
if "eq" in cond:
return val == cond["eq"]
if "ne" in cond:
return val != cond["ne"]
if "gt" in cond:
try: return val > cond["gt"]
except Exception: return False
if "lt" in cond:
try: return val < cond["lt"]
except Exception: return False
if "gte" in cond:
try: return val >= cond["gte"]
except Exception: return False
if "lte" in cond:
try: return val <= cond["lte"]
except Exception: return False
if "in" in cond:
try: return val in cond["in"]
except Exception: return False
return False
def _row_class_for(row: Dict[str, Any], obj: Any, rules: Optional[List[Dict[str, Any]]]) -> Optional[str]:
"""
rules is a list of:
{"when": <condition-dict>, "class": "table-warning fw-semibold"}
Multiple matching rules stack classes. Later wins on duplicates by normal CSS rules.
"""
if not rules:
return None
classes = []
for rule in rules:
when = rule.get("when") or {}
if _matches_simple_condition(row, obj, when):
cls = rule.get("class")
if cls:
classes.append(cls)
return " ".join(dict.fromkeys(classes)) or None
def _is_rel_loaded(obj, rel_name: str) -> bool:
try:
state = inspect(obj)
attr = state.attrs[rel_name]
return attr.loaded_value is not NO_VALUE
except Exception:
return False
def _deep_get_from_obj(obj, dotted: str):
cur = obj
parts = dotted.split(".")
for i, part in enumerate(parts):
if i < len(parts) - 1 and not _is_rel_loaded(cur, part):
return None
cur = getattr(cur, part, None)
if cur is None:
return None
return cur
def _deep_get(row: Dict[str, Any], dotted: str) -> Any:
if dotted in row:
return row[dotted]
cur = row
for part in dotted.split('.'):
if isinstance(cur, dict) and part in cur:
cur = cur[part]
else:
return None
return cur
def _format_value(val: Any, fmt: Optional[str]) -> Any:
if fmt is None:
return val
try:
if fmt == "yesno":
return "Yes" if bool(val) else "No"
if fmt == "date":
return val.strftime("%Y-%m-%d") if hasattr(val, "strftime") else val
if fmt == "datetime":
return val.strftime("%Y-%m-%d %H:%M") if hasattr(val, "strftime") else val
if fmt == "time":
return val.strftime("%H:%M") if hasattr(val, "strftime") else val
except Exception:
return val
return val
def _class_for(val: Any, classes: Optional[Dict[str, str]]) -> Optional[str]:
if not classes:
return None
key = "none" if val is None else str(val).lower()
return classes.get(key, classes.get("default"))
def _build_href(spec: Dict[str, Any], row: Dict[str, Any], obj) -> Optional[str]:
if not spec:
return None
params = {}
for k, v in (spec.get("params") or {}).items():
if isinstance(v, str) and v.startswith("{") and v.endswith("}"):
key = v[1:-1]
val = _deep_get(row, key)
if val is None:
val = _deep_get_from_obj(obj, key)
params[k] = val
else:
params[k] = v
if any(v is None for v in params.values()):
return None
try:
return url_for('crudkit.' + spec["endpoint"], **params)
except Exception as e:
return None
def _humanize(field: str) -> str:
return field.replace(".", " > ").replace("_", " ").title()
def _normalize_columns(columns: Optional[List[Dict[str, Any]]], default_fields: List[str]) -> List[Dict[str, Any]]:
if not columns:
return [{"field": f, "label": _humanize(f)} for f in default_fields]
norm = []
for col in columns:
c = dict(col)
c.setdefault("label", _humanize(c["field"]))
norm.append(c)
return norm
def _normalize_opts(opts: Dict[str, Any]) -> Dict[str, Any]:
"""
Accept either:
render_table(..., object_class='user', row_classe[...])
or:
render_table(..., opts={'object_class': 'user', 'row_classes': [...]})
Returns a flat dict with top-level keys for convenience, while preserving
all original keys for the template.
"""
if not isinstance(opts, dict):
return {}
flat = dict(opts)
nested = flat.get("opts")
if isinstance(nested, dict):
for k, v in nested.items():
flat.setdefault(k, v)
return flat
def get_crudkit_template(env, name):
try:
return env.get_template(f'crudkit/{name}')
except Exception:
return env.get_template(name)
def render_field(field, value):
env = get_env()
# 1) custom template field
field_type = field.get('type', 'text')
if field_type == 'template':
tname = field.get('template') or field.get('template_name')
if not tname:
return "" # nothing to render
t = get_crudkit_template(env, tname)
# merge ctx with some sensible defaults
ctx = dict(field.get('template_ctx') or {})
# make sure templates always see these
ctx.setdefault('field', field)
ctx.setdefault('value', value)
return t.render(**ctx)
# 2) normal controls
template = get_crudkit_template(env, 'field.html')
return template.render(
field_name=field['name'],
field_label=field.get('label', field['name']),
value=value,
field_type=field_type,
options=field.get('options', None),
attrs=_sanitize_attrs(field.get('attrs') or {}),
label_attrs=_sanitize_attrs(field.get('label_attrs') or {}),
help=field.get('help'),
)
def render_table(objects: List[Any], columns: Optional[List[Dict[str, Any]]] = None, **opts):
env = get_env()
template = get_crudkit_template(env, 'table.html')
if not objects:
return template.render(fields=[], rows=[])
flat_opts = _normalize_opts(opts)
proj = getattr(objects[0], "__crudkit_projection__", None)
row_dicts = [obj.as_dict(proj) for obj in objects]
default_fields = [k for k in row_dicts[0].keys() if k != "id"]
cols = _normalize_columns(columns, default_fields)
row_rules = (flat_opts.get("row_classes") or [])
disp_rows = []
for obj, rd in zip(objects, row_dicts):
cells = []
for col in cols:
field = col["field"]
raw = _deep_get(rd, field)
text = _format_value(raw, col.get("format"))
href = _build_href(col.get("link"), rd, obj) if col.get("link") else None
cls = _class_for(raw, col.get("classes"))
cells.append({"text": text, "href": href, "class": cls})
row_cls = _row_class_for(rd, obj, row_rules)
disp_rows.append({"id": rd.get("id"), "class": row_cls, "cells": cells})
return template.render(columns=cols, rows=disp_rows, kwargs=flat_opts)
def render_form(
model_cls,
values,
session=None,
*,
fields_spec: Optional[list[dict]] = None,
label_specs: Optional[Dict[str, Any]] = None,
exclude: Optional[set[str]] = None,
overrides: Optional[Dict[str, Dict[str, Any]]] = None,
instance: Any = None,
layout: Optional[list[dict]] = None,
submit_attrs: Optional[dict[str, Any]] = None,
submit_label: Optional[str] = None,
):
"""
fields_spec: list of dicts describing fields in order. Each dict supports:
- name: "first_name" | "location" | "location_id" (required)
- label: override_label
- type: "text" | "textarea" | "checkbox" | "select" | "hidden" | ...
- label_spec: for relationship selects, e.g. "{name} - {room_function.description}"
- options: prebuilt list of {"value","label"}; skips querying if provided
- attrs: dict of arbitrary HTML attributes, e.g. {"required": True, "placeholder": "Jane"}
- help: small help text under the field
label_specs: legacy per-relationship label spec fallback ({"location": "..."}).
exclude: set of field names to hide.
overrides: legacy quick overrides keyed by field name (label/type/etc.)
instance: the ORM object backing the form; used to populate *_id values
layout: A list of dicts describing layouts for fields.
submit_attrs: A dict of attributes to apply to the submit button.
"""
env = get_env()
template = get_crudkit_template(env, "form.html")
exclude = exclude or set()
overrides = overrides or {}
label_specs = label_specs or {}
mapper = class_mapper(model_cls)
fields: list[dict] = []
values_map = dict(values or {}) # we'll augment this with *_id selections
if fields_spec:
# Spec-driven path
for spec in fields_spec:
if spec["name"] in exclude:
continue
field = _normalize_field_spec(
{**spec, **overrides.get(spec["name"], {})},
mapper, session, label_specs
)
fields.append(field)
# After building fields, inject current values for any M2O selects
for f in fields:
name = f.get("name")
if isinstance(name, str) and name.endswith("_id"):
base = name[:-3]
rel_prop = mapper.relationships.get(base)
if isinstance(rel_prop, RelationshipProperty) and rel_prop.direction.name == "MANYTOONE":
values_map[name] = _coerce_fk_value(values, instance, base)
else:
# Auto-generate path (your original behavior)
fk_fields = set()
# Relationships first
for prop in mapper.iterate_properties:
if isinstance(prop, RelationshipProperty) and prop.direction.name == 'MANYTOONE':
base = prop.key
if base in exclude or f"{base}_id" in exclude:
continue
if session is None:
continue
related_model = prop.mapper.class_
rel_label_spec = (
label_specs.get(base)
or getattr(related_model, "__crud_label__", None)
or "id"
)
options = _fk_options(session, related_model, rel_label_spec)
base_field = {
"name": f"{base}_id",
"label": base,
"type": "select",
"options": options,
}
field = {**base_field, **overrides.get(f"{base}_id", {})}
fields.append(field)
fk_fields.add(f"{base}_id")
# NEW: set the current selection for this dropdown
values_map[f"{base}_id"] = _coerce_fk_value(values, instance, base)
# Then plain columns
for col in model_cls.__table__.columns:
if col.name in fk_fields or col.name in exclude:
continue
if col.name in ('id', 'created_at', 'updated_at'):
continue
if col.default or col.server_default or col.onupdate:
continue
base_field = {
"name": col.name,
"label": col.name,
"type": "checkbox" if getattr(col.type, "python_type", None) is bool else "text",
}
field = {**base_field, **overrides.get(col.name, {})}
if field.get("wrap"):
field["wrap"] = _sanitize_attrs(field["wrap"])
fields.append(field)
if submit_attrs:
submit_attrs = _sanitize_attrs(submit_attrs)
common_ctx = {"values": values_map, "instance": instance, "model_cls": model_cls, "session": session}
for f in fields:
if f.get("type") == "template":
base = dict(common_ctx)
base.update(f.get("template_ctx") or {})
f["template_ctx"] = base
# Build rows (supports nested layout with parents)
rows_map = _normalize_rows_layout(layout)
rows_tree = _assign_fields_to_rows(fields, rows_map)
return template.render(
rows=rows_tree,
fields=fields, # keep for backward compatibility
values=values_map,
render_field=render_field,
submit_attrs=submit_attrs,
submit_label=submit_label
)

View file

@ -1,59 +0,0 @@
{# show label unless hidden/custom #}
{% if field_type != 'hidden' and field_label %}
<label for="{{ field_name }}"
{% if label_attrs %}{% for k,v in label_attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}>
{{ field_label }}
</label>
{% endif %}
{% if field_type == 'select' %}
<select name="{{ field_name }}" id="{{ field_name }}"
{% if attrs %}{% for k,v in attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}
{%- if not options %} disabled{% endif %}>
{% if options %}
<option value="">-- Select --</option>
{% for opt in options %}
<option value="{{ opt.value }}" {% if opt.value|string == value|string %}selected{% endif %}>
{{ opt.label }}
</option>
{% endfor %}
{% else %}
<option value="">-- No selection available --</option>
{% endif %}
</select>
{% elif field_type == 'textarea' %}
<textarea name="{{ field_name }}" id="{{ field_name }}"
{% if attrs %}{% for k,v in attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}>{{ value }}</textarea>
{% elif field_type == 'checkbox' %}
<input type="checkbox" name="{{ field_name }}" id="{{ field_name }}" value="1"
{% if value %}checked{% endif %}
{% if attrs %}{% for k,v in attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}>
{% elif field_type == 'hidden' %}
<input type="hidden" name="{{ field_name }}" id="{{ field_name }}" value="{{ value }}">
{% elif field_type == 'display' %}
<div {% if attrs %}{% for k,v in attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}>{{ value }}</div>
{% else %}
<input type="text" name="{{ field_name }}" id="{{ field_name }}" value="{{ value }}"
{% if attrs %}{% for k,v in attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}>
{% endif %}
{% if help %}
<div class="form-text">{{ help }}</div>
{% endif %}

View file

@ -1,40 +0,0 @@
<form method="POST">
{% macro render_row(row) %}
<!-- {{ row.name }} -->
{% if row.fields or row.children or row.legend %}
{% if row.legend %}<legend>{{ row.legend }}</legend>{% endif %}
<fieldset
{% if row.attrs %}{% for k,v in row.attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}>
{% for field in row.fields %}
<div
{% if field.wrap %}{% for k,v in field.wrap.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}>
{{ render_field(field, values.get(field.name, '')) | safe }}
</div>
{% endfor %}
{% for child in row.children %}
{{ render_row(child) }}
{% endfor %}
</fieldset>
{% endif %}
{% endmacro %}
{% if rows %}
{% for row in rows %}
{{ render_row(row) }}
{% endfor %}
{% else %}
{% for field in fields %}
{{ render_field(field, values.get(field.name, '')) | safe }}
{% endfor %}
{% endif %}
<button type="submit"
{% if submit_attrs %}{% for k,v in submit_attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}
>{{ submit_label if label else 'Save' }}</button>
</form>

View file

@ -1,26 +0,0 @@
<table>
<thead>
<tr>
{% for col in columns %}
<th>{{ col.label }}</th>
{% endfor %}
</tr>
</thead>
<tbody>
{% if rows %}
{% for row in rows %}
<tr class="{{ row.class or '' }}">
{% for cell in row.cells %}
{% if cell.href %}
<td class="{{ cell.class or '' }}"><a href="{{ cell.href }}">{{ cell.text if cell.text is not none else '-' }}</a></td>
{% else %}
<td class="{{ cell.class or '' }}">{{ cell.text if cell.text is not none else '-' }}</td>
{% endif %}
{% endfor %}
</tr>
{% endfor %}
{% else %}
<tr><td colspan="{{ columns|length }}">No data.</td></tr>
{% endif %}
</tbody>
</table>

27
example_app/app.py Normal file
View file

@ -0,0 +1,27 @@
from flask import Flask, render_template
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from .models import Base, Author, Book
from crudkit.blueprint import make_blueprint as make_json_blueprint
from crudkit.html import make_fragments_blueprint
engine = create_engine("sqlite:///example.db", echo=True, future=True)
SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)
def session_factory():
return SessionLocal()
registry = {"author": Author, "book": Book}
def create_app():
app = Flask(__name__)
Base.metadata.create_all(engine)
app.register_blueprint(make_json_blueprint(session_factory, registry), url_prefix="/api")
app.register_blueprint(make_fragments_blueprint(session_factory, registry), url_prefix="/ui")
@app.get("/demo")
def demo():
return render_template("demo.html")
return app
if __name__ == "__main__":
create_app().run(debug=True)

18
example_app/models.py Normal file
View file

@ -0,0 +1,18 @@
from typing import List
from sqlalchemy import String, ForeignKey
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from crudkit import CrudMixin
class Base(DeclarativeBase):
pass
class Author(CrudMixin, Base):
__tablename__ = "author"
name: Mapped[str] = mapped_column(String(200), nullable=False)
books: Mapped[List["Book"]] = relationship(back_populates="author", cascade="all, delete-orphan")
class Book(CrudMixin, Base):
__tablename__ = "book"
title: Mapped[str] = mapped_column(String(200), nullable=False)
author_id: Mapped[int] = mapped_column(ForeignKey("author.id"), nullable=False)
author: Mapped[Author] = relationship(back_populates="books")

19
example_app/seed.py Normal file
View file

@ -0,0 +1,19 @@
from .app import SessionLocal, engine
from .models import Base, Author, Book
def run():
Base.metadata.create_all(engine)
s = SessionLocal()
a1 = Author(name="Ursula K. Le Guin")
a2 = Author(name="Octavia E. Butler")
s.add_all([
a1, a2,
Book(title="The Left Hand of Darkness", author=a1),
Book(title="A Wizard of Earthsea", author=a1),
Book(title="Parable of the Sower", author=a2),
])
s.commit()
s.close()
if __name__ == "__main__":
run()

View file

@ -0,0 +1,17 @@
<!-- templates/demo.html -->
<!doctype html><meta charset="utf-8">
<script src="https://unpkg.com/htmx.org@2.0.0"></script>
<body>
<table class="table-auto w-full border">
<thead><tr><th class="px-3 py-2">ID</th><th class="px-3 py-2">Title</th><th class="px-3 py-2">Author</th><th></th></tr></thead>
<tbody id="rows"
hx-get="/ui/book/frag/rows?fields_csv=id,title,author.name&page=1&per_page=20"
hx-trigger="load" hx-target="this" hx-swap="innerHTML"></tbody>
</table>
<button hx-get="/ui/book/frag/form?hx=1&fields_csv=id,title,author.name"
hx-target="#modal-body" hx-swap="innerHTML"
onclick="document.getElementById('modal').showModal()">New Book</button>
<dialog id="modal"><div id="modal-body"></div></dialog>
</body>

View file

@ -1,48 +0,0 @@
from flask import Flask
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, scoped_session
from crudkit import ProdConfig
from crudkit.api.flask_api import generate_crud_blueprint
from crudkit.core.service import CRUDService
from crudkit.integrations.flask import init_app
from muck.models.dbref import Base, Dbref
from muck.models.exit import Exit
from muck.models.player import Player
from muck.models.program import Program
from muck.models.room import Room
from muck.models.thing import Thing
from muck.init import bootstrap_world
DATABASE_URL = "sqlite:///muck.db"
engine = create_engine(DATABASE_URL, echo=True)
SessionLocal = scoped_session(sessionmaker(bind=engine))
Base.metadata.create_all(engine)
session = SessionLocal()
bootstrap_world(session)
app = Flask(__name__)
dbref_service = CRUDService(Dbref, session, polymorphic=True)
exit_service = CRUDService(Exit, session)
player_service = CRUDService(Player, session)
program_service = CRUDService(Program, session)
room_service = CRUDService(Room, session)
thing_service = CRUDService(Thing, session)
app.register_blueprint(generate_crud_blueprint(Dbref, dbref_service), url_prefix="/api/dbref")
app.register_blueprint(generate_crud_blueprint(Exit, exit_service), url_prefix="/api/exits")
app.register_blueprint(generate_crud_blueprint(Player, player_service), url_prefix="/api/players")
app.register_blueprint(generate_crud_blueprint(Program, program_service), url_prefix="/api/programs")
app.register_blueprint(generate_crud_blueprint(Room, room_service), url_prefix="/api/rooms")
app.register_blueprint(generate_crud_blueprint(Thing, thing_service), url_prefix="/api/things")
if __name__ == "__main__":
init_app(app, config=ProdConfig)
# app.run(debug=True, port=5050)

View file

@ -1,39 +0,0 @@
from muck.models.room import Room
from muck.models.player import Player
def bootstrap_world(session):
if session.query(Room).first() or session.query(Player).first():
print("World already initialized.")
return
print("Bootstrapping world...")
room_zero = Room(
id=0,
name="Room Zero",
props={"_": {"de": "You are in Room Zero. It is very dark in here."}}
)
the_one = Player(
id=1,
name="One",
password="potrzebie",
props={"_": {"de": "You see The One."}}
)
the_one.location = room_zero
the_one.home = room_zero
the_one.creator = the_one
the_one.owner = the_one
the_one.modifier = the_one
the_one.last_user = the_one
room_zero.owner = the_one
room_zero.creator = the_one
room_zero.modifier = the_one
room_zero.last_user = the_one
session.add_all([room_zero, the_one])
session.commit()
print("World initialized.")

View file

@ -1,5 +0,0 @@
from muck.models.dbref import Dbref
from muck.models.exit import Exit
from sqlalchemy.orm import relationship
Dbref.exits = relationship("Exit", back_populates="source", foreign_keys=[Exit.location_id])

View file

@ -1,81 +0,0 @@
from sqlalchemy import Column, Integer, String, ForeignKey, Boolean, DateTime, JSON, Enum as SQLEnum, func
from sqlalchemy.orm import relationship, foreign, remote
from crudkit.core.base import CRUDMixin, Base
from enum import Enum
class ObjectType(str, Enum):
ROOM = "room"
THING = "thing"
EXIT = "exit"
PLAYER = "player"
PROGRAM = "program"
TYPE_SUFFIXES = {
ObjectType.ROOM: "R",
ObjectType.EXIT: "E",
ObjectType.PLAYER: "P",
ObjectType.PROGRAM: "F",
ObjectType.THING: "",
}
class Dbref(Base, CRUDMixin):
__tablename__ = "dbref"
type = Column(SQLEnum(ObjectType, name="object_type_enum"), nullable=False)
name = Column(String, nullable=False)
props = Column(JSON, nullable=False, default={})
is_deleted = Column(Boolean, nullable=False, default=False)
last_used = Column(DateTime, nullable=False, default=func.now())
use_count = Column(Integer, nullable=False, default=0)
location_id = Column(Integer, ForeignKey("dbref.id"), nullable=True, default=0)
location = relationship("Dbref", foreign_keys=[location_id], back_populates="contents", primaryjoin=lambda: foreign(Dbref.location_id) == remote(Dbref.id), remote_side=lambda: Dbref.id)
contents = relationship("Dbref", foreign_keys=[location_id], back_populates="location")
owner_id = Column(Integer, ForeignKey("dbref.id"), nullable=True)
owner = relationship("Player", remote_side=[CRUDMixin.id], foreign_keys=[owner_id], primaryjoin=lambda: Dbref.owner_id == remote(Dbref.id), post_update=True)
creator_id = Column(Integer, ForeignKey("dbref.id"), nullable=True)
creator = relationship("Player", remote_side=[CRUDMixin.id], foreign_keys=[creator_id], primaryjoin=lambda: Dbref.creator_id == remote(Dbref.id), post_update=True)
modifier_id = Column(Integer, ForeignKey("dbref.id"), nullable=True)
modifier = relationship("Player", remote_side=[CRUDMixin.id], foreign_keys=[modifier_id], primaryjoin=lambda: Dbref.modifier_id == remote(Dbref.id), post_update=True)
last_user_id = Column(Integer, ForeignKey("dbref.id"), nullable=True)
last_user = relationship("Player", remote_side=[CRUDMixin.id], foreign_keys=[last_user_id], primaryjoin=lambda: Dbref.last_user_id == remote(Dbref.id), post_update=True)
__mapper_args__ = {
"polymorphic_on": type,
"polymorphic_identity": "dbref"
}
def __str__(self):
suffix = TYPE_SUFFIXES.get(self.type, "")
return f"#{self.id}{suffix}"
def __repr__(self):
suffix = TYPE_SUFFIXES.get(self.type, "")
return f"<Dbref #{self.id}{suffix}>"
def is_type(self, *types: ObjectType) -> bool:
return self.type in types
def display_type(self) -> str:
return self.type.value.upper()
@property
def is_room(self): return self.is_type(ObjectType.ROOM)
@property
def is_thing(self): return self.is_type(ObjectType.THING)
@property
def is_exit(self): return self.is_type(ObjectType.EXIT)
@property
def is_player(self): return self.is_type(ObjectType.PLAYER)
@property
def is_program(self): return self.is_type(ObjectType.PROGRAM)

View file

@ -1,21 +0,0 @@
from sqlalchemy import Column, Integer, ForeignKey
from sqlalchemy.orm import relationship, foreign, remote
from crudkit.core.base import CRUDMixin
from muck.models.dbref import Dbref, ObjectType
class Exit(Dbref):
__tablename__ = "exits"
id = Column(Integer, ForeignKey("dbref.id"), primary_key=True)
destination_id = Column(Integer, ForeignKey("dbref.id"), nullable=True)
destination = relationship("Dbref", remote_side=[CRUDMixin.id], foreign_keys=[destination_id], primaryjoin=lambda: foreign(Exit.destination_id) == remote(Dbref.id))
source = relationship("Dbref", back_populates="exits", foreign_keys=[Dbref.location_id], remote_side=[Dbref.id])
__mapper_args__ = {
"polymorphic_identity": ObjectType.EXIT,
"inherit_condition": id == Dbref.id
}

View file

@ -1,27 +0,0 @@
from sqlalchemy import Column, Integer, Boolean, String, ForeignKey
from sqlalchemy.orm import relationship, foreign, remote
from crudkit.core.base import CRUDMixin
from muck.models.dbref import Dbref, ObjectType
class Player(Dbref):
__tablename__ = "players"
id = Column(Integer, ForeignKey("dbref.id"), primary_key=True)
pennies = Column(Integer, nullable=False, default=0)
insert_mode = Column(Boolean, nullable=False, default=False)
block = Column(Integer, nullable=True)
password = Column(String, nullable=False)
home_id = Column(Integer, ForeignKey("dbref.id"), nullable=True)
home = relationship("Dbref", remote_side=[CRUDMixin.id], foreign_keys=[home_id], primaryjoin=lambda: foreign(Player.home_id) == remote(Dbref.id))
current_program_id = Column(Integer, ForeignKey("dbref.id"), nullable=True)
current_program = relationship("Dbref", remote_side=[CRUDMixin.id], foreign_keys=[current_program_id], primaryjoin=lambda: foreign(Player.current_program_id) == remote(Dbref.id))
__mapper_args__ = {
"polymorphic_identity": ObjectType.PLAYER,
"inherit_condition": id == Dbref.id
}

View file

@ -1,13 +0,0 @@
from sqlalchemy import Column, Integer, ForeignKey
from muck.models.dbref import Dbref, ObjectType
class Program(Dbref):
__tablename__ = "programs"
id = Column(Integer, ForeignKey("dbref.id"), primary_key=True)
__mapper_args__ = {
"polymorphic_identity": ObjectType.PROGRAM,
"inherit_condition": id == Dbref.id
}

View file

@ -1,19 +0,0 @@
from sqlalchemy import Column, Integer, ForeignKey
from sqlalchemy.orm import relationship, foreign, remote
from crudkit.core.base import CRUDMixin
from muck.models.dbref import Dbref, ObjectType
class Room(Dbref):
__tablename__ = "rooms"
id = Column(Integer, ForeignKey("dbref.id"), primary_key=True)
dropto_id = Column(Integer, ForeignKey("dbref.id"), nullable=True)
dropto = relationship("Dbref", remote_side=[CRUDMixin.id], foreign_keys=[dropto_id], primaryjoin=lambda: foreign(Room.dropto_id) == remote(Dbref.id))
__mapper_args__ = {
"polymorphic_identity": ObjectType.ROOM,
"inherit_condition": id == Dbref.id
}

View file

@ -1,21 +0,0 @@
from sqlalchemy import Column, Integer, ForeignKey
from sqlalchemy.orm import relationship, foreign, remote
from crudkit.core.base import CRUDMixin
from muck.models.dbref import Dbref, ObjectType
class Thing(Dbref):
__tablename__ = "things"
id = Column(Integer, ForeignKey("dbref.id"), primary_key=True)
value = Column(Integer, nullable=False, default=0)
home_id = Column(Integer, ForeignKey("dbref.id"), nullable=True)
home = relationship("Dbref", remote_side=[CRUDMixin.id], foreign_keys=[home_id], primaryjoin=lambda: foreign(Thing.home_id) == remote(Dbref.id))
__mapper_args__ = {
"polymorphic_identity": ObjectType.THING,
"inherit_condition": id == Dbref.id
}