41 lines
1.5 KiB
Python
41 lines
1.5 KiB
Python
from __future__ import annotations
|
|
from typing import Any, Dict
|
|
|
|
from flask import current_app
|
|
|
|
from sqlalchemy import create_engine, text, event
|
|
from sqlalchemy.orm import sessionmaker, scoped_session
|
|
|
|
from crudkit.core.base import Base
|
|
|
|
_engine = None
|
|
SessionLocal = None
|
|
|
|
def init_db(database_url: str, engine_kwargs: Dict[str, Any], session_kwargs: Dict[str, Any]) -> None:
|
|
global _engine, SessionLocal
|
|
print(database_url)
|
|
_engine = create_engine(database_url, **engine_kwargs)
|
|
SessionLocal = scoped_session(sessionmaker(bind=_engine, **session_kwargs))
|
|
|
|
if database_url.startswith("sqlite:///"):
|
|
with _engine.connect() as conn:
|
|
conn.exec_driver_sql(f"PRAGMA journal_mode = WAL;")
|
|
conn.exec_driver_sql(f"PRAGMA foreign_keys = ON;")
|
|
conn.exec_driver_sql(f"PRAGMA synchronous = NORMAL;")
|
|
elif database_url.startswith("mssql+pyodbc://"):
|
|
@event.listens_for(_engine, "before_cursor_execute")
|
|
def _enable_fastexecutemany(conn, cursor, statement, parameters, context, executemany):
|
|
if executemany and hasattr(cursor, "fast_executemany"):
|
|
cursor.fast_executemany = True
|
|
|
|
def get_runtime():
|
|
return current_app.extensions["crudkit"]["runtime"]
|
|
|
|
def get_session():
|
|
ext = current_app.extensions["crudkit"]
|
|
Session = ext.get("Session") or get_runtime().session_factory
|
|
return Session()
|
|
|
|
def create_all_tables():
|
|
from . import models as _models
|
|
_models.Base.metadata.create_all(bind=get_runtime().engine)
|