Updated CRUDkit
This commit is contained in:
parent
f1fa1f2407
commit
571583bcf4
6 changed files with 186 additions and 24 deletions
|
|
@ -1,7 +1,8 @@
|
|||
from .backend import BackendInfo, make_backend_info
|
||||
from .config import Config, DevConfig, TestConfig, ProdConfig, get_config, build_database_url
|
||||
from .engines import CRUDKitRuntime, build_engine, build_sessionmaker
|
||||
|
||||
__all__ = [
|
||||
"Config", "DevConfig", "TestConfig", "ProdConfig", "get_config", "build_database_url",
|
||||
"CRUDKitRuntime", "build_engine", "build_sessionmaker"
|
||||
"CRUDKitRuntime", "build_engine", "build_sessionmaker", "BackendInfo", "make_backend_info"
|
||||
]
|
||||
|
|
|
|||
122
crudkit/backend.py
Normal file
122
crudkit/backend.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Optional, Iterable
|
||||
from contextlib import contextmanager
|
||||
from sqlalchemy import text, func
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.elements import ClauseElement
|
||||
from sqlalchemy.sql import Select
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BackendInfo:
|
||||
name: str
|
||||
version: Tuple[int, ...]
|
||||
paramstyle: str
|
||||
is_sqlite: bool
|
||||
is_postgres: bool
|
||||
is_mysql: bool
|
||||
is_mssql: bool
|
||||
|
||||
supports_returning: bool
|
||||
supports_ilike: bool
|
||||
requires_order_by_for_offset: bool
|
||||
max_bind_params: Optional[int]
|
||||
|
||||
@classmethod
|
||||
def from_engine(cls, engine: Engine) -> "BackendInfo":
|
||||
d = engine.dialect
|
||||
name = d.name
|
||||
version = tuple(getattr(d, "server_version_info", ()) or ())
|
||||
is_pg = name in {"postgresql", "postgres"}
|
||||
is_my = name == "mysql"
|
||||
is_sq = name == "sqlite"
|
||||
is_ms = name == "mssql"
|
||||
|
||||
supports_ilike = is_pg or is_my
|
||||
supports_returning = is_pg or (is_sq and version >= (3, 35))
|
||||
requires_order_by_for_offset = is_ms
|
||||
|
||||
max_bind_params = 999 if is_sq else None
|
||||
|
||||
return cls(
|
||||
name=name,
|
||||
version=version,
|
||||
paramstyle=d.paramstyle,
|
||||
is_sqlite=is_sq,
|
||||
is_postgres=is_pg,
|
||||
is_mysql=is_my,
|
||||
is_mssql=is_ms,
|
||||
supports_returning=supports_returning,
|
||||
supports_ilike=supports_ilike,
|
||||
requires_order_by_for_offset=requires_order_by_for_offset,
|
||||
max_bind_params=max_bind_params,
|
||||
)
|
||||
|
||||
def make_backend_info(engine: Engine) -> BackendInfo:
|
||||
return BackendInfo.from_engine(engine)
|
||||
|
||||
def ci_like(column, value: str, backend: BackendInfo) -> ClauseElement:
|
||||
"""
|
||||
Portable save-insensitive LIKE.
|
||||
Uses ILIKE where available, else lower() dance.
|
||||
"""
|
||||
pattern = f"%{value}%"
|
||||
if backend.supports_ilike:
|
||||
return column.ilike(pattern)
|
||||
return func.lower(column).like(func.lower(text(":pattern"))).params(pattern=pattern)
|
||||
|
||||
def apply_pagination(sel: Select, backend: BackendInfo, *, page: int, per_page: int, default_order_by=None) -> Select:
|
||||
"""
|
||||
Portable pagination. MSSQL requires ORDER BY when using OFFSET
|
||||
"""
|
||||
page = max(1, int(page))
|
||||
per_page = max(1, int(per_page))
|
||||
offset = (page - 1) * per_page
|
||||
|
||||
if backend.requires_order_by_for_offset and not sel._order_by_clauses:
|
||||
if default_order_by is None:
|
||||
sel = sel.order_by(text("1"))
|
||||
else:
|
||||
sel = sel.order_by(default_order_by)
|
||||
|
||||
return sel.limit(per_page).offset(offset)
|
||||
|
||||
@contextmanager
|
||||
def maybe_identify_insert(session: Session, table, backend: BackendInfo):
|
||||
"""
|
||||
For MSSQL tables with IDENTIFY PK when you need to insert explicit IDs.
|
||||
No-op elsewhere.
|
||||
"""
|
||||
if not backend.is_mssql:
|
||||
yield
|
||||
return
|
||||
|
||||
full_name = f"{table.schema}.{table.name}" if table.schema else table.name
|
||||
session.execute(text(f"SET IDENTIFY_INSERT {full_name} ON"))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
session.execute(text(f"SET IDENTITY_INSERT {full_name} OFF"))
|
||||
|
||||
def chunked_in(column, values: Iterable, backend: BackendInfo, chunk_size: Optional[int] = None) -> ClauseElement:
|
||||
"""
|
||||
Build a safe large IN() filter respecting bund param limits.
|
||||
Returns a disjunction of chunked IN clauses if needed.
|
||||
"""
|
||||
vals = list(values)
|
||||
if not vals:
|
||||
return text("1=0")
|
||||
|
||||
limit = chunk_size or backend.max_bind_params or len(vals)
|
||||
if len(vals) <= limit:
|
||||
return column.in_(vals)
|
||||
|
||||
parts = []
|
||||
for i in range(0, len(vals), limit):
|
||||
parts.append(column.in_(vals[i:i + limit]))
|
||||
|
||||
expr = parts[0]
|
||||
for p in parts[1:]:
|
||||
expr = expr | p
|
||||
return expr
|
||||
|
|
@ -60,7 +60,7 @@ def build_database_url(
|
|||
return url
|
||||
|
||||
backend = (backend or "").lower().strip()
|
||||
optional = options or {}
|
||||
options = options or {}
|
||||
|
||||
if backend == "sqlite":
|
||||
db_path = database or "app.db"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
from typing import Type, TypeVar, Generic
|
||||
from typing import Type, TypeVar, Generic, Optional
|
||||
from sqlalchemy.orm import Session, with_polymorphic
|
||||
from sqlalchemy import inspect, text
|
||||
from crudkit.core.base import Version
|
||||
from crudkit.core.spec import CRUDSpec
|
||||
from crudkit.backend import BackendInfo, make_backend_info
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
|
@ -9,11 +11,20 @@ def _is_truthy(val):
|
|||
return str(val).lower() in ('1', 'true', 'yes', 'on')
|
||||
|
||||
class CRUDService(Generic[T]):
|
||||
def __init__(self, model: Type[T], session: Session, polymorphic: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
model: Type[T],
|
||||
session: Session,
|
||||
polymorphic: bool = False,
|
||||
*,
|
||||
backend: Optional[BackendInfo] = None
|
||||
):
|
||||
self.model = model
|
||||
self.session = session
|
||||
self.polymorphic = polymorphic
|
||||
self.supports_soft_delete = hasattr(model, 'is_deleted')
|
||||
# Cache backend info once. If not provided, derive from session bind.
|
||||
self.backend = backend or make_backend_info(self.session.get_bind())
|
||||
|
||||
def get_query(self):
|
||||
if self.polymorphic:
|
||||
|
|
@ -23,14 +34,22 @@ class CRUDService(Generic[T]):
|
|||
base_only = with_polymorphic(self.model, [], flat=True)
|
||||
return self.session.query(base_only), base_only
|
||||
|
||||
# Helper: default ORDER BY for MSSQL when paginating without explicit order
|
||||
def _default_order_by(self, root_alias):
|
||||
mapper = inspect(self.model)
|
||||
cols = []
|
||||
for col in mapper.primary_key:
|
||||
try:
|
||||
cols.append(getattr(root_alias, col.key))
|
||||
except AttributeError:
|
||||
cols.append(col)
|
||||
return cols or [text("1")]
|
||||
|
||||
def get(self, id: int, include_deleted: bool = False) -> T | None:
|
||||
query, root_alias = self.get_query()
|
||||
|
||||
if self.supports_soft_delete and not include_deleted:
|
||||
query = query.filter(getattr(root_alias, "is_deleted") == False)
|
||||
|
||||
query = query.filter(getattr(root_alias, "id") == id)
|
||||
|
||||
obj = query.first()
|
||||
return obj or None
|
||||
|
||||
|
|
@ -60,9 +79,20 @@ class CRUDService(Generic[T]):
|
|||
|
||||
if filters:
|
||||
query = query.filter(*filters)
|
||||
|
||||
# MSSQL: requires ORDER BY when using OFFSET (and SQLA will use OFFSET for limit+offset).
|
||||
paginating = (limit is not None) or (offset is not None and offset != 0)
|
||||
if paginating and not order_by and self.backend.requires_order_by_for_offset:
|
||||
order_by = self._default_order_by(root_alias)
|
||||
|
||||
if order_by:
|
||||
query = query.order_by(*order_by)
|
||||
query = query.offset(offset).limit(limit)
|
||||
|
||||
# Only apply offset/limit when not None.
|
||||
if offset is not None and offset != 0:
|
||||
query = query.offset(offset)
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
return query.all()
|
||||
|
||||
|
|
@ -70,7 +100,6 @@ class CRUDService(Generic[T]):
|
|||
obj = self.model(**data)
|
||||
self.session.add(obj)
|
||||
self.session.commit()
|
||||
|
||||
self._log_version("create", obj, actor)
|
||||
return obj
|
||||
|
||||
|
|
@ -78,13 +107,11 @@ class CRUDService(Generic[T]):
|
|||
obj = self.get(id)
|
||||
if not obj:
|
||||
raise ValueError(f"{self.model.__name__} with ID {id} not found.")
|
||||
|
||||
valid_fields = {c.name for c in self.model.__table__.columns}
|
||||
for k, v in data.items():
|
||||
if k in valid_fields:
|
||||
setattr(obj, k, v)
|
||||
self.session.commit()
|
||||
|
||||
self._log_version("update", obj, actor)
|
||||
return obj
|
||||
|
||||
|
|
@ -92,14 +119,11 @@ class CRUDService(Generic[T]):
|
|||
obj = self.session.get(self.model, id)
|
||||
if not obj:
|
||||
return None
|
||||
|
||||
if hard or not self.supports_soft_delete:
|
||||
self.session.delete(obj)
|
||||
else:
|
||||
obj.is_deleted = True
|
||||
|
||||
self.session.commit()
|
||||
|
||||
self._log_version("delete", obj, actor)
|
||||
return obj
|
||||
|
||||
|
|
@ -108,7 +132,6 @@ class CRUDService(Generic[T]):
|
|||
data = obj.as_dict()
|
||||
except Exception:
|
||||
data = {"error": "Failed to serialize object."}
|
||||
|
||||
version = Version(
|
||||
model_name=self.model.__name__,
|
||||
object_id=obj.id,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||
from typing import Type, Optional
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from .backend import make_backend_info, BackendInfo
|
||||
from .config import Config, get_config
|
||||
from ._sqlite import apply_sqlite_pragmas
|
||||
|
||||
|
|
@ -41,3 +42,9 @@ class CRUDKitRuntime:
|
|||
if self._config and self._engine:
|
||||
self._session_factory = build_sessionmaker(self._config, self._engine)
|
||||
return self._session_factory
|
||||
|
||||
@property
|
||||
def backend(self) -> BackendInfo:
|
||||
if not hasattr(self, "_backend_info") or self._backend_info is None:
|
||||
self._backend_info = make_backend_info(self.engine)
|
||||
return self._backend_info
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue