Optimizations and refactoring.

This commit is contained in:
Yaro Kasear 2025-09-24 09:53:25 -05:00
parent 94837e1b6f
commit a0ee1caeb7
4 changed files with 273 additions and 85 deletions

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
from sqlalchemy import and_, func, inspect, or_, text
from sqlalchemy.engine import Engine, Connection
@ -10,6 +12,9 @@ from crudkit.core.types import OrderSpec, SeekWindow
from crudkit.backend import BackendInfo, make_backend_info
from crudkit.projection import compile_projection
import logging
log = logging.getLogger("crudkit.service")
def _is_rel(model_cls, name: str) -> bool:
try:
prop = model_cls.__mapper__.relationships.get(name)
@ -56,7 +61,7 @@ class CRUDService(Generic[T]):
self.polymorphic = polymorphic
self.supports_soft_delete = hasattr(model, 'is_deleted')
# Cache backend info once. If not provided, derive from session bind.
bind = self.session.get_bind()
bind = session_factory().get_bind()
eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind)
self.backend = backend or make_backend_info(eng)
@ -70,6 +75,11 @@ class CRUDService(Generic[T]):
return self.session.query(poly), poly
return self.session.query(self.model), self.model
def _apply_not_deleted(self, query, root_alias, params) -> Any:
if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")):
return query.filter(getattr(root_alias, "is_deleted") == False)
return query
def _extract_order_spec(self, root_alias, given_order_by):
"""
SQLAlchemy 2.x only:
@ -85,7 +95,7 @@ class CRUDService(Generic[T]):
for ob in given:
# Unwrap column if this is a UnaryExpression produced by .asc()/.desc()
elem = getattr(ob, "element", None)
col = elem if elem is not None else ob # don't use "or" with SA expressions
col = elem if elem is not None else ob
# Detect direction in SA 2.x
is_desc = False
@ -103,27 +113,30 @@ class CRUDService(Generic[T]):
return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags))
def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool):
"""
Build lexicographic predicate for keyset seek.
For backward traversal, import comparisons.
"""
if not key_vals:
return None
conds = []
for i, col in enumerate(spec.cols):
# If NULLs possible, normalize for comparison. Example using coalesce to a sentinel:
# sent_col = func.coalesce(col, literal("-∞"))
sent_col = col
ties = [spec.cols[j] == key_vals[j] for j in range(i)]
is_desc = spec.desc[i]
if not backward:
op = col < key_vals[i] if is_desc else col > key_vals[i]
op = (sent_col < key_vals[i]) if is_desc else (sent_col > key_vals[i])
else:
op = col > key_vals[i] if is_desc else col < key_vals[i]
op = (sent_col > key_vals[i]) if is_desc else (sent_col < key_vals[i])
conds.append(and_(*ties, op))
return or_(*conds)
def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]:
out = []
for c in spec.cols:
# Only simple mapped columns supported for key pluck
key = getattr(c, "key", None) or getattr(c, "name", None)
if key is None or not hasattr(obj, key):
raise ValueError("Order includes non-mapped or related column; cannot pluck cursor key from row object.")
out.append(getattr(obj, key))
return out
@ -142,6 +155,7 @@ class CRUDService(Generic[T]):
- forward/backward seek via `key` and `backward`
Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total.
"""
session = self.session
fields = list((params or {}).get("fields", []))
expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], [])
query, root_alias = self.get_query()
@ -156,8 +170,9 @@ class CRUDService(Generic[T]):
root_fields, rel_field_names, root_field_names = spec.parse_fields()
# Soft delete filter
if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")):
query = query.filter(getattr(root_alias, "is_deleted") == False)
# if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")):
# query = query.filter(getattr(root_alias, "is_deleted") == False)
query = self._apply_not_deleted(query, root_alias, params)
# Parse filters first
if filters:
@ -165,6 +180,8 @@ class CRUDService(Generic[T]):
# Includes + joins (so relationship fields like brand.name, location.label work)
spec.parse_includes()
join_paths = tuple(spec.get_join_paths())
for _, relationship_attr, target_alias in spec.get_join_paths():
rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias)
@ -178,8 +195,12 @@ class CRUDService(Generic[T]):
# Order + limit
order_spec = self._extract_order_spec(root_alias, order_by) # SA2-only helper
limit, _ = spec.parse_pagination()
if not limit or limit <= 0:
limit = 50 # sensible default
if limit is None:
effective_limit = 50
elif limit == 0:
effective_limit = None
else:
effective_limit = limit
# Keyset predicate
if key:
@ -189,18 +210,19 @@ class CRUDService(Generic[T]):
# Apply ordering. For backward, invert SQL order then reverse in-memory for display.
if not backward:
clauses = []
for col, is_desc in zip(order_spec.cols, order_spec.desc):
clauses.append(col.desc() if is_desc else col.asc())
query = query.order_by(*clauses).limit(limit)
clauses = [(c.desc() if is_desc else c.asc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)]
query = query.order_by(*clauses)
if effective_limit is not None:
query = query.limit(effective_limit)
items = query.all()
else:
inv_clauses = []
for col, is_desc in zip(order_spec.cols, order_spec.desc):
inv_clauses.append(col.asc() if is_desc else col.desc())
query = query.order_by(*inv_clauses).limit(limit)
inv_clauses = [(c.asc() if is_desc else c.desc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)]
query = query.order_by(*inv_clauses)
if effective_limit is not None:
query = query.limit(effective_limit)
items = list(reversed(query.all()))
# Tag projection so your renderer knows what fields were requested
if expanded_fields:
proj = list(expanded_fields)
@ -231,23 +253,27 @@ class CRUDService(Generic[T]):
# Optional total thats safe under JOINs (COUNT DISTINCT ids)
total = None
if include_total:
base = self.session.query(getattr(root_alias, "id"))
if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")):
base = base.filter(getattr(root_alias, "is_deleted") == False)
base = session.query(getattr(root_alias, "id"))
base = self._apply_not_deleted(base, root_alias, params)
if filters:
base = base.filter(*filters)
# replicate the same joins used above
for _, relationship_attr, target_alias in spec.get_join_paths():
for _, relationship_attr, target_alias in join_paths: # reuse
rel_attr = cast(InstrumentedAttribute, relationship_attr)
target = cast(Any, target_alias)
base = base.join(target, rel_attr.of_type(target), isouter=True)
total = self.session.query(func.count()).select_from(base.order_by(None).distinct().subquery()).scalar() or 0
print(f"!!! QUERY !!! -> {str(query)}")
total = session.query(func.count()).select_from(
base.order_by(None).distinct().subquery()
).scalar() or 0
window_limit_for_body = 0 if effective_limit is None and limit == 0 else (effective_limit or 50)
if log.isEnabledFor(logging.DEBUG):
log.debug("QUERY: %s", str(query))
from crudkit.core.types import SeekWindow # avoid circulars at module top
return SeekWindow(
items=items,
limit=limit,
limit=window_limit_for_body,
first_key=first_key,
last_key=last_key,
order=order_spec,
@ -342,7 +368,9 @@ class CRUDService(Generic[T]):
except Exception:
pass
print(f"!!! QUERY !!! -> {str(query)}")
if log.isEnabledFor(logging.DEBUG):
log.debug("QUERY: %s", str(query))
return obj or None
def list(self, params=None) -> list[T]:
@ -422,42 +450,51 @@ class CRUDService(Generic[T]):
except Exception:
pass
print(f"!!! QUERY !!! -> {str(query)}")
if log.isEnabledFor(logging.DEBUG):
log.debug("QUERY: %s", str(query))
return rows
def create(self, data: dict, actor=None) -> T:
session = self.session
obj = self.model(**data)
self.session.add(obj)
self.session.commit()
session.add(obj)
session.commit()
self._log_version("create", obj, actor)
return obj
def update(self, id: int, data: dict, actor=None) -> T:
session = self.session
obj = self.get(id)
if not obj:
raise ValueError(f"{self.model.__name__} with ID {id} not found.")
valid_fields = {c.name for c in self.model.__table__.columns}
unknown = set(data) - valid_fields
if unknown:
raise ValueError(f"Unknown fields: {', '.join(sorted(unknown))}")
for k, v in data.items():
if k in valid_fields:
setattr(obj, k, v)
self.session.commit()
session.commit()
self._log_version("update", obj, actor)
return obj
def delete(self, id: int, hard: bool = False, actor = False):
obj = self.session.get(self.model, id)
def delete(self, id: int, hard: bool = False, actor = None):
session = self.session
obj = session.get(self.model, id)
if not obj:
return None
if hard or not self.supports_soft_delete:
self.session.delete(obj)
session.delete(obj)
else:
soft = cast(_SoftDeletable, obj)
soft.is_deleted = True
self.session.commit()
session.commit()
self._log_version("delete", obj, actor)
return obj
def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict = {}):
def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict | None = None):
session = self.session
try:
data = obj.as_dict()
except Exception:
@ -470,5 +507,5 @@ class CRUDService(Generic[T]):
actor=str(actor) if actor else None,
meta=metadata
)
self.session.add(version)
self.session.commit()
session.add(version)
session.commit()