Updated CRUDkit
This commit is contained in:
parent
f1fa1f2407
commit
571583bcf4
6 changed files with 186 additions and 24 deletions
|
|
@ -1,7 +1,9 @@
|
|||
from typing import Type, TypeVar, Generic
|
||||
from typing import Type, TypeVar, Generic, Optional
|
||||
from sqlalchemy.orm import Session, with_polymorphic
|
||||
from sqlalchemy import inspect, text
|
||||
from crudkit.core.base import Version
|
||||
from crudkit.core.spec import CRUDSpec
|
||||
from crudkit.backend import BackendInfo, make_backend_info
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
|
@ -9,11 +11,20 @@ def _is_truthy(val):
|
|||
return str(val).lower() in ('1', 'true', 'yes', 'on')
|
||||
|
||||
class CRUDService(Generic[T]):
|
||||
def __init__(self, model: Type[T], session: Session, polymorphic: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
model: Type[T],
|
||||
session: Session,
|
||||
polymorphic: bool = False,
|
||||
*,
|
||||
backend: Optional[BackendInfo] = None
|
||||
):
|
||||
self.model = model
|
||||
self.session = session
|
||||
self.polymorphic = polymorphic
|
||||
self.supports_soft_delete = hasattr(model, 'is_deleted')
|
||||
# Cache backend info once. If not provided, derive from session bind.
|
||||
self.backend = backend or make_backend_info(self.session.get_bind())
|
||||
|
||||
def get_query(self):
|
||||
if self.polymorphic:
|
||||
|
|
@ -23,14 +34,22 @@ class CRUDService(Generic[T]):
|
|||
base_only = with_polymorphic(self.model, [], flat=True)
|
||||
return self.session.query(base_only), base_only
|
||||
|
||||
# Helper: default ORDER BY for MSSQL when paginating without explicit order
|
||||
def _default_order_by(self, root_alias):
|
||||
mapper = inspect(self.model)
|
||||
cols = []
|
||||
for col in mapper.primary_key:
|
||||
try:
|
||||
cols.append(getattr(root_alias, col.key))
|
||||
except AttributeError:
|
||||
cols.append(col)
|
||||
return cols or [text("1")]
|
||||
|
||||
def get(self, id: int, include_deleted: bool = False) -> T | None:
|
||||
query, root_alias = self.get_query()
|
||||
|
||||
if self.supports_soft_delete and not include_deleted:
|
||||
query = query.filter(getattr(root_alias, "is_deleted") == False)
|
||||
|
||||
query = query.filter(getattr(root_alias, "id") == id)
|
||||
|
||||
obj = query.first()
|
||||
return obj or None
|
||||
|
||||
|
|
@ -60,9 +79,20 @@ class CRUDService(Generic[T]):
|
|||
|
||||
if filters:
|
||||
query = query.filter(*filters)
|
||||
|
||||
# MSSQL: requires ORDER BY when using OFFSET (and SQLA will use OFFSET for limit+offset).
|
||||
paginating = (limit is not None) or (offset is not None and offset != 0)
|
||||
if paginating and not order_by and self.backend.requires_order_by_for_offset:
|
||||
order_by = self._default_order_by(root_alias)
|
||||
|
||||
if order_by:
|
||||
query = query.order_by(*order_by)
|
||||
query = query.offset(offset).limit(limit)
|
||||
|
||||
# Only apply offset/limit when not None.
|
||||
if offset is not None and offset != 0:
|
||||
query = query.offset(offset)
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
return query.all()
|
||||
|
||||
|
|
@ -70,7 +100,6 @@ class CRUDService(Generic[T]):
|
|||
obj = self.model(**data)
|
||||
self.session.add(obj)
|
||||
self.session.commit()
|
||||
|
||||
self._log_version("create", obj, actor)
|
||||
return obj
|
||||
|
||||
|
|
@ -78,13 +107,11 @@ class CRUDService(Generic[T]):
|
|||
obj = self.get(id)
|
||||
if not obj:
|
||||
raise ValueError(f"{self.model.__name__} with ID {id} not found.")
|
||||
|
||||
valid_fields = {c.name for c in self.model.__table__.columns}
|
||||
for k, v in data.items():
|
||||
if k in valid_fields:
|
||||
setattr(obj, k, v)
|
||||
self.session.commit()
|
||||
|
||||
self._log_version("update", obj, actor)
|
||||
return obj
|
||||
|
||||
|
|
@ -92,14 +119,11 @@ class CRUDService(Generic[T]):
|
|||
obj = self.session.get(self.model, id)
|
||||
if not obj:
|
||||
return None
|
||||
|
||||
if hard or not self.supports_soft_delete:
|
||||
self.session.delete(obj)
|
||||
else:
|
||||
obj.is_deleted = True
|
||||
|
||||
self.session.commit()
|
||||
|
||||
self._log_version("delete", obj, actor)
|
||||
return obj
|
||||
|
||||
|
|
@ -108,7 +132,6 @@ class CRUDService(Generic[T]):
|
|||
data = obj.as_dict()
|
||||
except Exception:
|
||||
data = {"error": "Failed to serialize object."}
|
||||
|
||||
version = Version(
|
||||
model_name=self.model.__name__,
|
||||
object_id=obj.id,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue