Redesign1 #1
5 changed files with 168 additions and 15 deletions
|
|
@ -1,7 +1,8 @@
|
||||||
|
from .backend import BackendInfo, make_backend_info
|
||||||
from .config import Config, DevConfig, TestConfig, ProdConfig, get_config, build_database_url
|
from .config import Config, DevConfig, TestConfig, ProdConfig, get_config, build_database_url
|
||||||
from .engines import CRUDKitRuntime, build_engine, build_sessionmaker
|
from .engines import CRUDKitRuntime, build_engine, build_sessionmaker
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Config", "DevConfig", "TestConfig", "ProdConfig", "get_config", "build_database_url",
|
"Config", "DevConfig", "TestConfig", "ProdConfig", "get_config", "build_database_url",
|
||||||
"CRUDKitRuntime", "build_engine", "build_sessionmaker"
|
"CRUDKitRuntime", "build_engine", "build_sessionmaker", "BackendInfo", "make_backend_info"
|
||||||
]
|
]
|
||||||
|
|
|
||||||
122
crudkit/backend.py
Normal file
122
crudkit/backend.py
Normal file
|
|
@ -0,0 +1,122 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Tuple, Optional, Iterable
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from sqlalchemy import text, func
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy.sql.elements import ClauseElement
|
||||||
|
from sqlalchemy.sql import Select
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class BackendInfo:
|
||||||
|
name: str
|
||||||
|
version: Tuple[int, ...]
|
||||||
|
paramstyle: str
|
||||||
|
is_sqlite: bool
|
||||||
|
is_postgres: bool
|
||||||
|
is_mysql: bool
|
||||||
|
is_mssql: bool
|
||||||
|
|
||||||
|
supports_returning: bool
|
||||||
|
supports_ilike: bool
|
||||||
|
requires_order_by_for_offset: bool
|
||||||
|
max_bind_params: Optional[int]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_engine(cls, engine: Engine) -> "BackendInfo":
|
||||||
|
d = engine.dialect
|
||||||
|
name = d.name
|
||||||
|
version = tuple(getattr(d, "server_version_info", ()) or ())
|
||||||
|
is_pg = name in {"postgresql", "postgres"}
|
||||||
|
is_my = name == "mysql"
|
||||||
|
is_sq = name == "sqlite"
|
||||||
|
is_ms = name == "mssql"
|
||||||
|
|
||||||
|
supports_ilike = is_pg or is_my
|
||||||
|
supports_returning = is_pg or (is_sq and version >= (3, 35))
|
||||||
|
requires_order_by_for_offset = is_ms
|
||||||
|
|
||||||
|
max_bind_params = 999 if is_sq else None
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
name=name,
|
||||||
|
version=version,
|
||||||
|
paramstyle=d.paramstyle,
|
||||||
|
is_sqlite=is_sq,
|
||||||
|
is_postgres=is_pg,
|
||||||
|
is_mysql=is_my,
|
||||||
|
is_mssql=is_ms,
|
||||||
|
supports_returning=supports_returning,
|
||||||
|
supports_ilike=supports_ilike,
|
||||||
|
requires_order_by_for_offset=requires_order_by_for_offset,
|
||||||
|
max_bind_params=max_bind_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
def make_backend_info(engine: Engine) -> BackendInfo:
|
||||||
|
return BackendInfo.from_engine(engine)
|
||||||
|
|
||||||
|
def ci_like(column, value: str, backend: BackendInfo) -> ClauseElement:
|
||||||
|
"""
|
||||||
|
Portable save-insensitive LIKE.
|
||||||
|
Uses ILIKE where available, else lower() dance.
|
||||||
|
"""
|
||||||
|
pattern = f"%{value}%"
|
||||||
|
if backend.supports_ilike:
|
||||||
|
return column.ilike(pattern)
|
||||||
|
return func.lower(column).like(func.lower(text(":pattern"))).params(pattern=pattern)
|
||||||
|
|
||||||
|
def apply_pagination(sel: Select, backend: BackendInfo, *, page: int, per_page: int, default_order_by=None) -> Select:
|
||||||
|
"""
|
||||||
|
Portable pagination. MSSQL requires ORDER BY when using OFFSET
|
||||||
|
"""
|
||||||
|
page = max(1, int(page))
|
||||||
|
per_page = max(1, int(per_page))
|
||||||
|
offset = (page - 1) * per_page
|
||||||
|
|
||||||
|
if backend.requires_order_by_for_offset and not sel._order_by_clauses:
|
||||||
|
if default_order_by is None:
|
||||||
|
sel = sel.order_by(text("1"))
|
||||||
|
else:
|
||||||
|
sel = sel.order_by(default_order_by)
|
||||||
|
|
||||||
|
return sel.limit(per_page).offset(offset)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def maybe_identify_insert(session: Session, table, backend: BackendInfo):
|
||||||
|
"""
|
||||||
|
For MSSQL tables with IDENTIFY PK when you need to insert explicit IDs.
|
||||||
|
No-op elsewhere.
|
||||||
|
"""
|
||||||
|
if not backend.is_mssql:
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
full_name = f"{table.schema}.{table.name}" if table.schema else table.name
|
||||||
|
session.execute(text(f"SET IDENTIFY_INSERT {full_name} ON"))
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
session.execute(text(f"SET IDENTITY_INSERT {full_name} OFF"))
|
||||||
|
|
||||||
|
def chunked_in(column, values: Iterable, backend: BackendInfo, chunk_size: Optional[int] = None) -> ClauseElement:
|
||||||
|
"""
|
||||||
|
Build a safe large IN() filter respecting bund param limits.
|
||||||
|
Returns a disjunction of chunked IN clauses if needed.
|
||||||
|
"""
|
||||||
|
vals = list(values)
|
||||||
|
if not vals:
|
||||||
|
return text("1=0")
|
||||||
|
|
||||||
|
limit = chunk_size or backend.max_bind_params or len(vals)
|
||||||
|
if len(vals) <= limit:
|
||||||
|
return column.in_(vals)
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
for i in range(0, len(vals), limit):
|
||||||
|
parts.append(column.in_(vals[i:i + limit]))
|
||||||
|
|
||||||
|
expr = parts[0]
|
||||||
|
for p in parts[1:]:
|
||||||
|
expr = expr | p
|
||||||
|
return expr
|
||||||
|
|
@ -60,7 +60,7 @@ def build_database_url(
|
||||||
return url
|
return url
|
||||||
|
|
||||||
backend = (backend or "").lower().strip()
|
backend = (backend or "").lower().strip()
|
||||||
optional = options or {}
|
options = options or {}
|
||||||
|
|
||||||
if backend == "sqlite":
|
if backend == "sqlite":
|
||||||
db_path = database or "app.db"
|
db_path = database or "app.db"
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
from typing import Type, TypeVar, Generic
|
from typing import Type, TypeVar, Generic, Optional
|
||||||
from sqlalchemy.orm import Session, with_polymorphic
|
from sqlalchemy.orm import Session, with_polymorphic
|
||||||
|
from sqlalchemy import inspect, text
|
||||||
from crudkit.core.base import Version
|
from crudkit.core.base import Version
|
||||||
from crudkit.core.spec import CRUDSpec
|
from crudkit.core.spec import CRUDSpec
|
||||||
|
from crudkit.backend import BackendInfo, make_backend_info
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
@ -9,11 +11,20 @@ def _is_truthy(val):
|
||||||
return str(val).lower() in ('1', 'true', 'yes', 'on')
|
return str(val).lower() in ('1', 'true', 'yes', 'on')
|
||||||
|
|
||||||
class CRUDService(Generic[T]):
|
class CRUDService(Generic[T]):
|
||||||
def __init__(self, model: Type[T], session: Session, polymorphic: bool = False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Type[T],
|
||||||
|
session: Session,
|
||||||
|
polymorphic: bool = False,
|
||||||
|
*,
|
||||||
|
backend: Optional[BackendInfo] = None
|
||||||
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.session = session
|
self.session = session
|
||||||
self.polymorphic = polymorphic
|
self.polymorphic = polymorphic
|
||||||
self.supports_soft_delete = hasattr(model, 'is_deleted')
|
self.supports_soft_delete = hasattr(model, 'is_deleted')
|
||||||
|
# Cache backend info once. If not provided, derive from session bind.
|
||||||
|
self.backend = backend or make_backend_info(self.session.get_bind())
|
||||||
|
|
||||||
def get_query(self):
|
def get_query(self):
|
||||||
if self.polymorphic:
|
if self.polymorphic:
|
||||||
|
|
@ -23,14 +34,22 @@ class CRUDService(Generic[T]):
|
||||||
base_only = with_polymorphic(self.model, [], flat=True)
|
base_only = with_polymorphic(self.model, [], flat=True)
|
||||||
return self.session.query(base_only), base_only
|
return self.session.query(base_only), base_only
|
||||||
|
|
||||||
|
# Helper: default ORDER BY for MSSQL when paginating without explicit order
|
||||||
|
def _default_order_by(self, root_alias):
|
||||||
|
mapper = inspect(self.model)
|
||||||
|
cols = []
|
||||||
|
for col in mapper.primary_key:
|
||||||
|
try:
|
||||||
|
cols.append(getattr(root_alias, col.key))
|
||||||
|
except AttributeError:
|
||||||
|
cols.append(col)
|
||||||
|
return cols or [text("1")]
|
||||||
|
|
||||||
def get(self, id: int, include_deleted: bool = False) -> T | None:
|
def get(self, id: int, include_deleted: bool = False) -> T | None:
|
||||||
query, root_alias = self.get_query()
|
query, root_alias = self.get_query()
|
||||||
|
|
||||||
if self.supports_soft_delete and not include_deleted:
|
if self.supports_soft_delete and not include_deleted:
|
||||||
query = query.filter(getattr(root_alias, "is_deleted") == False)
|
query = query.filter(getattr(root_alias, "is_deleted") == False)
|
||||||
|
|
||||||
query = query.filter(getattr(root_alias, "id") == id)
|
query = query.filter(getattr(root_alias, "id") == id)
|
||||||
|
|
||||||
obj = query.first()
|
obj = query.first()
|
||||||
return obj or None
|
return obj or None
|
||||||
|
|
||||||
|
|
@ -60,9 +79,20 @@ class CRUDService(Generic[T]):
|
||||||
|
|
||||||
if filters:
|
if filters:
|
||||||
query = query.filter(*filters)
|
query = query.filter(*filters)
|
||||||
|
|
||||||
|
# MSSQL: requires ORDER BY when using OFFSET (and SQLA will use OFFSET for limit+offset).
|
||||||
|
paginating = (limit is not None) or (offset is not None and offset != 0)
|
||||||
|
if paginating and not order_by and self.backend.requires_order_by_for_offset:
|
||||||
|
order_by = self._default_order_by(root_alias)
|
||||||
|
|
||||||
if order_by:
|
if order_by:
|
||||||
query = query.order_by(*order_by)
|
query = query.order_by(*order_by)
|
||||||
query = query.offset(offset).limit(limit)
|
|
||||||
|
# Only apply offset/limit when not None.
|
||||||
|
if offset is not None and offset != 0:
|
||||||
|
query = query.offset(offset)
|
||||||
|
if limit is not None:
|
||||||
|
query = query.limit(limit)
|
||||||
|
|
||||||
return query.all()
|
return query.all()
|
||||||
|
|
||||||
|
|
@ -70,7 +100,6 @@ class CRUDService(Generic[T]):
|
||||||
obj = self.model(**data)
|
obj = self.model(**data)
|
||||||
self.session.add(obj)
|
self.session.add(obj)
|
||||||
self.session.commit()
|
self.session.commit()
|
||||||
|
|
||||||
self._log_version("create", obj, actor)
|
self._log_version("create", obj, actor)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
@ -78,13 +107,11 @@ class CRUDService(Generic[T]):
|
||||||
obj = self.get(id)
|
obj = self.get(id)
|
||||||
if not obj:
|
if not obj:
|
||||||
raise ValueError(f"{self.model.__name__} with ID {id} not found.")
|
raise ValueError(f"{self.model.__name__} with ID {id} not found.")
|
||||||
|
|
||||||
valid_fields = {c.name for c in self.model.__table__.columns}
|
valid_fields = {c.name for c in self.model.__table__.columns}
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
if k in valid_fields:
|
if k in valid_fields:
|
||||||
setattr(obj, k, v)
|
setattr(obj, k, v)
|
||||||
self.session.commit()
|
self.session.commit()
|
||||||
|
|
||||||
self._log_version("update", obj, actor)
|
self._log_version("update", obj, actor)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
@ -92,14 +119,11 @@ class CRUDService(Generic[T]):
|
||||||
obj = self.session.get(self.model, id)
|
obj = self.session.get(self.model, id)
|
||||||
if not obj:
|
if not obj:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if hard or not self.supports_soft_delete:
|
if hard or not self.supports_soft_delete:
|
||||||
self.session.delete(obj)
|
self.session.delete(obj)
|
||||||
else:
|
else:
|
||||||
obj.is_deleted = True
|
obj.is_deleted = True
|
||||||
|
|
||||||
self.session.commit()
|
self.session.commit()
|
||||||
|
|
||||||
self._log_version("delete", obj, actor)
|
self._log_version("delete", obj, actor)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
@ -108,7 +132,6 @@ class CRUDService(Generic[T]):
|
||||||
data = obj.as_dict()
|
data = obj.as_dict()
|
||||||
except Exception:
|
except Exception:
|
||||||
data = {"error": "Failed to serialize object."}
|
data = {"error": "Failed to serialize object."}
|
||||||
|
|
||||||
version = Version(
|
version = Version(
|
||||||
model_name=self.model.__name__,
|
model_name=self.model.__name__,
|
||||||
object_id=obj.id,
|
object_id=obj.id,
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||||
from typing import Type, Optional
|
from typing import Type, Optional
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from .backend import make_backend_info, BackendInfo
|
||||||
from .config import Config, get_config
|
from .config import Config, get_config
|
||||||
from ._sqlite import apply_sqlite_pragmas
|
from ._sqlite import apply_sqlite_pragmas
|
||||||
|
|
||||||
|
|
@ -41,3 +42,9 @@ class CRUDKitRuntime:
|
||||||
if self._config and self._engine:
|
if self._config and self._engine:
|
||||||
self._session_factory = build_sessionmaker(self._config, self._engine)
|
self._session_factory = build_sessionmaker(self._config, self._engine)
|
||||||
return self._session_factory
|
return self._session_factory
|
||||||
|
|
||||||
|
@property
|
||||||
|
def backend(self) -> BackendInfo:
|
||||||
|
if not hasattr(self, "_backend_info") or self._backend_info is None:
|
||||||
|
self._backend_info = make_backend_info(self.engine)
|
||||||
|
return self._backend_info
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue