Backend behavior and a minor fix for config implemented.

This commit is contained in:
Yaro Kasear 2025-09-08 15:06:49 -05:00
parent a871b9c5fe
commit f9458a429e
5 changed files with 168 additions and 15 deletions

View file

@ -1,7 +1,8 @@
from .backend import BackendInfo, make_backend_info
from .config import Config, DevConfig, TestConfig, ProdConfig, get_config, build_database_url
from .engines import CRUDKitRuntime, build_engine, build_sessionmaker
__all__ = [
"Config", "DevConfig", "TestConfig", "ProdConfig", "get_config", "build_database_url",
"CRUDKitRuntime", "build_engine", "build_sessionmaker"
"CRUDKitRuntime", "build_engine", "build_sessionmaker", "BackendInfo", "make_backend_info"
]

122
crudkit/backend.py Normal file
View file

@ -0,0 +1,122 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Tuple, Optional, Iterable
from contextlib import contextmanager
from sqlalchemy import text, func
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql import Select
@dataclass(frozen=True)
class BackendInfo:
name: str
version: Tuple[int, ...]
paramstyle: str
is_sqlite: bool
is_postgres: bool
is_mysql: bool
is_mssql: bool
supports_returning: bool
supports_ilike: bool
requires_order_by_for_offset: bool
max_bind_params: Optional[int]
@classmethod
def from_engine(cls, engine: Engine) -> "BackendInfo":
d = engine.dialect
name = d.name
version = tuple(getattr(d, "server_version_info", ()) or ())
is_pg = name in {"postgresql", "postgres"}
is_my = name == "mysql"
is_sq = name == "sqlite"
is_ms = name == "mssql"
supports_ilike = is_pg or is_my
supports_returning = is_pg or (is_sq and version >= (3, 35))
requires_order_by_for_offset = is_ms
max_bind_params = 999 if is_sq else None
return cls(
name=name,
version=version,
paramstyle=d.paramstyle,
is_sqlite=is_sq,
is_postgres=is_pg,
is_mysql=is_my,
is_mssql=is_ms,
supports_returning=supports_returning,
supports_ilike=supports_ilike,
requires_order_by_for_offset=requires_order_by_for_offset,
max_bind_params=max_bind_params,
)
def make_backend_info(engine: Engine) -> BackendInfo:
return BackendInfo.from_engine(engine)
def ci_like(column, value: str, backend: BackendInfo) -> ClauseElement:
"""
Portable save-insensitive LIKE.
Uses ILIKE where available, else lower() dance.
"""
pattern = f"%{value}%"
if backend.supports_ilike:
return column.ilike(pattern)
return func.lower(column).like(func.lower(text(":pattern"))).params(pattern=pattern)
def apply_pagination(sel: Select, backend: BackendInfo, *, page: int, per_page: int, default_order_by=None) -> Select:
"""
Portable pagination. MSSQL requires ORDER BY when using OFFSET
"""
page = max(1, int(page))
per_page = max(1, int(per_page))
offset = (page - 1) * per_page
if backend.requires_order_by_for_offset and not sel._order_by_clauses:
if default_order_by is None:
sel = sel.order_by(text("1"))
else:
sel = sel.order_by(default_order_by)
return sel.limit(per_page).offset(offset)
@contextmanager
def maybe_identify_insert(session: Session, table, backend: BackendInfo):
"""
For MSSQL tables with IDENTIFY PK when you need to insert explicit IDs.
No-op elsewhere.
"""
if not backend.is_mssql:
yield
return
full_name = f"{table.schema}.{table.name}" if table.schema else table.name
session.execute(text(f"SET IDENTIFY_INSERT {full_name} ON"))
try:
yield
finally:
session.execute(text(f"SET IDENTITY_INSERT {full_name} OFF"))
def chunked_in(column, values: Iterable, backend: BackendInfo, chunk_size: Optional[int] = None) -> ClauseElement:
"""
Build a safe large IN() filter respecting bund param limits.
Returns a disjunction of chunked IN clauses if needed.
"""
vals = list(values)
if not vals:
return text("1=0")
limit = chunk_size or backend.max_bind_params or len(vals)
if len(vals) <= limit:
return column.in_(vals)
parts = []
for i in range(0, len(vals), limit):
parts.append(column.in_(vals[i:i + limit]))
expr = parts[0]
for p in parts[1:]:
expr = expr | p
return expr

View file

@ -60,7 +60,7 @@ def build_database_url(
return url
backend = (backend or "").lower().strip()
optional = options or {}
options = options or {}
if backend == "sqlite":
db_path = database or "app.db"

View file

@ -1,7 +1,9 @@
from typing import Type, TypeVar, Generic
from typing import Type, TypeVar, Generic, Optional
from sqlalchemy.orm import Session, with_polymorphic
from sqlalchemy import inspect, text
from crudkit.core.base import Version
from crudkit.core.spec import CRUDSpec
from crudkit.backend import BackendInfo, make_backend_info
T = TypeVar("T")
@ -9,11 +11,20 @@ def _is_truthy(val):
return str(val).lower() in ('1', 'true', 'yes', 'on')
class CRUDService(Generic[T]):
def __init__(self, model: Type[T], session: Session, polymorphic: bool = False):
def __init__(
self,
model: Type[T],
session: Session,
polymorphic: bool = False,
*,
backend: Optional[BackendInfo] = None
):
self.model = model
self.session = session
self.polymorphic = polymorphic
self.supports_soft_delete = hasattr(model, 'is_deleted')
# Cache backend info once. If not provided, derive from session bind.
self.backend = backend or make_backend_info(self.session.get_bind())
def get_query(self):
if self.polymorphic:
@ -23,14 +34,22 @@ class CRUDService(Generic[T]):
base_only = with_polymorphic(self.model, [], flat=True)
return self.session.query(base_only), base_only
# Helper: default ORDER BY for MSSQL when paginating without explicit order
def _default_order_by(self, root_alias):
mapper = inspect(self.model)
cols = []
for col in mapper.primary_key:
try:
cols.append(getattr(root_alias, col.key))
except AttributeError:
cols.append(col)
return cols or [text("1")]
def get(self, id: int, include_deleted: bool = False) -> T | None:
query, root_alias = self.get_query()
if self.supports_soft_delete and not include_deleted:
query = query.filter(getattr(root_alias, "is_deleted") == False)
query = query.filter(getattr(root_alias, "id") == id)
obj = query.first()
return obj or None
@ -60,9 +79,20 @@ class CRUDService(Generic[T]):
if filters:
query = query.filter(*filters)
# MSSQL: requires ORDER BY when using OFFSET (and SQLA will use OFFSET for limit+offset).
paginating = (limit is not None) or (offset is not None and offset != 0)
if paginating and not order_by and self.backend.requires_order_by_for_offset:
order_by = self._default_order_by(root_alias)
if order_by:
query = query.order_by(*order_by)
query = query.offset(offset).limit(limit)
# Only apply offset/limit when not None.
if offset is not None and offset != 0:
query = query.offset(offset)
if limit is not None:
query = query.limit(limit)
return query.all()
@ -70,7 +100,6 @@ class CRUDService(Generic[T]):
obj = self.model(**data)
self.session.add(obj)
self.session.commit()
self._log_version("create", obj, actor)
return obj
@ -78,13 +107,11 @@ class CRUDService(Generic[T]):
obj = self.get(id)
if not obj:
raise ValueError(f"{self.model.__name__} with ID {id} not found.")
valid_fields = {c.name for c in self.model.__table__.columns}
for k, v in data.items():
if k in valid_fields:
setattr(obj, k, v)
self.session.commit()
self._log_version("update", obj, actor)
return obj
@ -92,14 +119,11 @@ class CRUDService(Generic[T]):
obj = self.session.get(self.model, id)
if not obj:
return None
if hard or not self.supports_soft_delete:
self.session.delete(obj)
else:
obj.is_deleted = True
self.session.commit()
self._log_version("delete", obj, actor)
return obj
@ -108,7 +132,6 @@ class CRUDService(Generic[T]):
data = obj.as_dict()
except Exception:
data = {"error": "Failed to serialize object."}
version = Version(
model_name=self.model.__name__,
object_id=obj.id,

View file

@ -2,6 +2,7 @@ from __future__ import annotations
from typing import Type, Optional
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from .backend import make_backend_info, BackendInfo
from .config import Config, get_config
from ._sqlite import apply_sqlite_pragmas
@ -41,3 +42,9 @@ class CRUDKitRuntime:
if self._config and self._engine:
self._session_factory = build_sessionmaker(self._config, self._engine)
return self._session_factory
@property
def backend(self) -> BackendInfo:
if not hasattr(self, "_backend_info") or self._backend_info is None:
self._backend_info = make_backend_info(self.engine)
return self._backend_info