748 lines
30 KiB
Python
748 lines
30 KiB
Python
from __future__ import annotations
|
||
|
||
from collections.abc import Iterable
|
||
from flask import current_app
|
||
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
|
||
from sqlalchemy import and_, func, inspect, or_, text
|
||
from sqlalchemy.engine import Engine, Connection
|
||
from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, contains_eager, selectinload
|
||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||
from sqlalchemy.sql import operators
|
||
from sqlalchemy.sql.elements import UnaryExpression, ColumnElement
|
||
|
||
from crudkit.core.base import Version
|
||
from crudkit.core.spec import CRUDSpec
|
||
from crudkit.core.types import OrderSpec, SeekWindow
|
||
from crudkit.backend import BackendInfo, make_backend_info
|
||
from crudkit.projection import compile_projection
|
||
|
||
import logging
|
||
log = logging.getLogger("crudkit.service")
|
||
|
||
@runtime_checkable
|
||
class _HasID(Protocol):
|
||
id: int
|
||
|
||
@runtime_checkable
|
||
class _HasTable(Protocol):
|
||
__table__: Any
|
||
|
||
@runtime_checkable
|
||
class _HasADict(Protocol):
|
||
def as_dict(self) -> dict: ...
|
||
|
||
@runtime_checkable
|
||
class _SoftDeletable(Protocol):
|
||
is_deleted: bool
|
||
|
||
class _CRUDModelProto(_HasID, _HasTable, _HasADict, Protocol):
|
||
"""Minimal surface that our CRUD service relies on. Soft-delete is optional."""
|
||
pass
|
||
|
||
T = TypeVar("T", bound=_CRUDModelProto)
|
||
|
||
def _hops_from_sort(params: dict | None) -> set[str]:
|
||
"""Extract first-hop relationship names from a sort spec like 'owner.first_name,-brand.name'."""
|
||
if not params:
|
||
return set()
|
||
raw = params.get("sort")
|
||
tokens: list[str] = []
|
||
if isinstance(raw, str):
|
||
tokens = [t.strip() for t in raw.split(",") if t.strip()]
|
||
elif isinstance(raw, (list, tuple)):
|
||
for item in raw:
|
||
if isinstance(item, str):
|
||
tokens.extend([t.strip() for t in item.split(",") if t.strip()])
|
||
hops: set[str] = set()
|
||
for tok in tokens:
|
||
tok = tok.lstrip("+-")
|
||
if "." in tok:
|
||
hops.add(tok.split(".", 1)[0])
|
||
return hops
|
||
|
||
def _belongs_to_alias(col: Any, alias: Any) -> bool:
|
||
# Try to detect if a column/expression ultimately comes from this alias.
|
||
# Works for most ORM columns; complex expressions may need more.
|
||
t = getattr(col, "table", None)
|
||
selectable = getattr(alias, "selectable", None)
|
||
return t is not None and selectable is not None and t is selectable
|
||
|
||
def _paths_needed_for_sql(order_by: Iterable[Any], filters: Iterable[Any], join_paths: tuple) -> set[str]:
|
||
hops: set[str] = set()
|
||
paths: set[tuple[str, ...]] = set()
|
||
# Sort columns
|
||
for ob in order_by or []:
|
||
col = getattr(ob, "element", ob) # unwrap UnaryExpression
|
||
for _path, rel_attr, target_alias in join_paths:
|
||
if _belongs_to_alias(col, target_alias):
|
||
hops.add(rel_attr.key)
|
||
# Filter columns (best-effort)
|
||
# Walk simple binary expressions
|
||
def _extract_cols(expr: Any) -> Iterable[Any]:
|
||
if isinstance(expr, ColumnElement):
|
||
yield expr
|
||
for ch in getattr(expr, "get_children", lambda: [])():
|
||
yield from _extract_cols(ch)
|
||
elif hasattr(expr, "clauses"):
|
||
for ch in expr.clauses:
|
||
yield from _extract_cols(ch)
|
||
|
||
for flt in filters or []:
|
||
for col in _extract_cols(flt):
|
||
for _path, rel_attr, target_alias in join_paths:
|
||
if _belongs_to_alias(col, target_alias):
|
||
hops.add(rel_attr.key)
|
||
return hops
|
||
|
||
def _paths_from_fields(req_fields: list[str]) -> set[str]:
|
||
out: set[str] = set()
|
||
for f in req_fields:
|
||
if "." in f:
|
||
parent = f.split(".", 1)[0]
|
||
if parent:
|
||
out.add(parent)
|
||
return out
|
||
|
||
def _is_truthy(val):
|
||
return str(val).lower() in ('1', 'true', 'yes', 'on')
|
||
|
||
def _normalize_fields_param(params: dict | None) -> list[str]:
|
||
if not params:
|
||
return []
|
||
raw = params.get("fields")
|
||
if isinstance(raw, (list, tuple)):
|
||
out: list[str] = []
|
||
for item in raw:
|
||
if isinstance(item, str):
|
||
out.extend([p for p in (s.strip() for s in item.split(",")) if p])
|
||
return out
|
||
if isinstance(raw, str):
|
||
return [p for p in (s.strip() for s in raw.split(",")) if p]
|
||
return []
|
||
|
||
class CRUDService(Generic[T]):
|
||
def __init__(
|
||
self,
|
||
model: Type[T],
|
||
session_factory: Callable[[], Session],
|
||
polymorphic: bool = False,
|
||
*,
|
||
backend: Optional[BackendInfo] = None
|
||
):
|
||
self.model = model
|
||
self._session_factory = session_factory
|
||
self.polymorphic = polymorphic
|
||
self.supports_soft_delete = hasattr(model, 'is_deleted')
|
||
|
||
self._backend: Optional[BackendInfo] = backend
|
||
|
||
@property
|
||
def session(self) -> Session:
|
||
"""Always return the Flask-scoped Session if available; otherwise the provided factory."""
|
||
try:
|
||
sess = current_app.extensions["crudkit"]["Session"]
|
||
return sess
|
||
except Exception:
|
||
return self._session_factory()
|
||
|
||
@property
|
||
def backend(self) -> BackendInfo:
|
||
"""Resolve backend info lazily against the active session's engine."""
|
||
if self._backend is None:
|
||
bind = self.session.get_bind()
|
||
eng: Engine = bind.engine if isinstance(bind, Connection) else cast(Engine, bind)
|
||
self._backend = make_backend_info(eng)
|
||
return self._backend
|
||
|
||
def get_query(self):
|
||
if self.polymorphic:
|
||
poly = with_polymorphic(self.model, "*")
|
||
return self.session.query(poly), poly
|
||
return self.session.query(self.model), self.model
|
||
|
||
def _debug_bind(self, where: str):
|
||
try:
|
||
bind = self.session.get_bind()
|
||
eng = getattr(bind, "engine", bind)
|
||
print(f"SERVICE BIND [{where}]: engine_id={id(eng)} url={getattr(eng, 'url', '?')} session={type(self.session).__name__}")
|
||
except Exception as e:
|
||
print(f"SERVICE BIND [{where}]: failed to introspect bind: {e}")
|
||
|
||
def _apply_not_deleted(self, query, root_alias, params) -> Any:
|
||
if self.supports_soft_delete and not _is_truthy((params or {}).get("include_deleted")):
|
||
return query.filter(getattr(root_alias, "is_deleted") == False)
|
||
return query
|
||
|
||
def _extract_order_spec(self, root_alias, given_order_by):
|
||
"""
|
||
SQLAlchemy 2.x only:
|
||
Normalize order_by into (cols, desc_flags). Supports plain columns and
|
||
col.asc()/col.desc() (UnaryExpression). Avoids boolean evaluation of clauses.
|
||
"""
|
||
|
||
given = self._stable_order_by(root_alias, given_order_by)
|
||
|
||
cols, desc_flags = [], []
|
||
for ob in given:
|
||
# Unwrap column if this is a UnaryExpression produced by .asc()/.desc()
|
||
elem = getattr(ob, "element", None)
|
||
col = elem if elem is not None else ob
|
||
|
||
# Detect direction in SA 2.x
|
||
is_desc = False
|
||
dir_attr = getattr(ob, "_direction", None)
|
||
if dir_attr is not None:
|
||
is_desc = (dir_attr is operators.desc_op) or (getattr(dir_attr, "name", "").upper() == "DESC")
|
||
elif isinstance(ob, UnaryExpression):
|
||
op = getattr(ob, "operator", None)
|
||
is_desc = (op is operators.desc_op) or (getattr(op, "name", "").upper() == "DESC")
|
||
|
||
cols.append(col)
|
||
desc_flags.append(bool(is_desc))
|
||
|
||
return OrderSpec(cols=tuple(cols), desc=tuple(desc_flags))
|
||
|
||
def _key_predicate(self, spec: OrderSpec, key_vals: list[Any], backward: bool):
|
||
if not key_vals:
|
||
return None
|
||
|
||
conds = []
|
||
for i, col in enumerate(spec.cols):
|
||
# If NULLs possible, normalize for comparison. Example using coalesce to a sentinel:
|
||
# sent_col = func.coalesce(col, literal("-∞"))
|
||
sent_col = col
|
||
ties = [spec.cols[j] == key_vals[j] for j in range(i)]
|
||
is_desc = spec.desc[i]
|
||
if not backward:
|
||
op = (sent_col < key_vals[i]) if is_desc else (sent_col > key_vals[i])
|
||
else:
|
||
op = (sent_col > key_vals[i]) if is_desc else (sent_col < key_vals[i])
|
||
conds.append(and_(*ties, op))
|
||
return or_(*conds)
|
||
|
||
def _pluck_key(self, obj: Any, spec: OrderSpec) -> list[Any]:
|
||
out = []
|
||
for c in spec.cols:
|
||
# Only simple mapped columns supported for key pluck
|
||
key = getattr(c, "key", None) or getattr(c, "name", None)
|
||
if key is None or not hasattr(obj, key):
|
||
raise ValueError("Order includes non-mapped or related column; cannot pluck cursor key from row object.")
|
||
out.append(getattr(obj, key))
|
||
return out
|
||
|
||
def seek_window(
|
||
self,
|
||
params: dict | None = None,
|
||
*,
|
||
key: list[Any] | None = None,
|
||
backward: bool = False,
|
||
include_total: bool = True,
|
||
) -> "SeekWindow[T]":
|
||
"""
|
||
Transport-agnostic keyset pagination that preserves all the goodies from `list()`:
|
||
- filters, includes, joins, field projection, eager loading, soft-delete
|
||
- deterministic ordering (user sort + PK tiebreakers)
|
||
- forward/backward seek via `key` and `backward`
|
||
Returns a SeekWindow with items, first/last keys, order spec, limit, and optional total.
|
||
"""
|
||
session = self.session
|
||
query, root_alias = self.get_query()
|
||
|
||
# Normalize requested fields and compile projection (may skip later to avoid conflicts)
|
||
fields = _normalize_fields_param(params)
|
||
expanded_fields, proj_opts = compile_projection(self.model, fields) if fields else ([], [])
|
||
|
||
spec = CRUDSpec(self.model, params or {}, root_alias)
|
||
|
||
filters = spec.parse_filters()
|
||
order_by = spec.parse_sort()
|
||
|
||
# Field parsing for root load_only fallback
|
||
root_fields, rel_field_names, root_field_names = spec.parse_fields()
|
||
|
||
# Soft delete filter
|
||
query = self._apply_not_deleted(query, root_alias, params)
|
||
|
||
# Apply filters first
|
||
if filters:
|
||
query = query.filter(*filters)
|
||
|
||
# Includes + join paths (dotted fields etc.)
|
||
spec.parse_includes()
|
||
join_paths = tuple(spec.get_join_paths()) # iterable of (path, relationship_attr, target_alias)
|
||
|
||
# Relationship names required by ORDER BY / WHERE
|
||
sql_hops: set[str] = _paths_needed_for_sql(order_by, filters, join_paths)
|
||
# Also include relationships mentioned directly in the sort spec
|
||
sql_hops |= _hops_from_sort(params)
|
||
|
||
# First-hop relationship names implied by dotted projection fields
|
||
proj_hops: set[str] = _paths_from_fields(fields)
|
||
|
||
# Root column projection
|
||
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
|
||
if only_cols:
|
||
query = query.options(Load(root_alias).load_only(*only_cols))
|
||
|
||
# Relationship handling per path (avoid loader strategy conflicts)
|
||
used_contains_eager = False
|
||
joined_names: set[str] = set()
|
||
|
||
for _path, relationship_attr, target_alias in join_paths:
|
||
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
||
name = relationship_attr.key
|
||
if name in sql_hops:
|
||
# Needed for WHERE/ORDER BY: join + hydrate from that join
|
||
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
|
||
query = query.options(contains_eager(rel_attr, alias=target_alias))
|
||
used_contains_eager = True
|
||
joined_names.add(name)
|
||
elif name in proj_hops:
|
||
# Display-only: bulk-load efficiently, no join
|
||
query = query.options(selectinload(rel_attr))
|
||
joined_names.add(name)
|
||
|
||
# Force-join any SQL-needed relationships that weren't in join_paths
|
||
missing_sql = sql_hops - joined_names
|
||
for name in missing_sql:
|
||
rel_attr = cast(InstrumentedAttribute, getattr(root_alias, name))
|
||
query = query.join(rel_attr, isouter=True)
|
||
query = query.options(contains_eager(rel_attr))
|
||
used_contains_eager = True
|
||
joined_names.add(name)
|
||
|
||
# Apply projection loader options only if they won't conflict with contains_eager
|
||
if proj_opts and not used_contains_eager:
|
||
query = query.options(*proj_opts)
|
||
|
||
# Order + limit
|
||
order_spec = self._extract_order_spec(root_alias, order_by) # SA 2.x helper
|
||
limit, _ = spec.parse_pagination()
|
||
if limit is None:
|
||
effective_limit = 50
|
||
elif limit == 0:
|
||
effective_limit = None # unlimited
|
||
else:
|
||
effective_limit = limit
|
||
|
||
# Keyset predicate
|
||
if key:
|
||
pred = self._key_predicate(order_spec, key, backward)
|
||
if pred is not None:
|
||
query = query.filter(pred)
|
||
|
||
# Apply ordering. For backward, invert SQL order then reverse in-memory for display.
|
||
if not backward:
|
||
clauses = [(c.desc() if is_desc else c.asc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)]
|
||
query = query.order_by(*clauses)
|
||
if effective_limit is not None:
|
||
query = query.limit(effective_limit)
|
||
items = query.all()
|
||
else:
|
||
inv_clauses = [(c.asc() if is_desc else c.desc()) for c, is_desc in zip(order_spec.cols, order_spec.desc)]
|
||
query = query.order_by(*inv_clauses)
|
||
if effective_limit is not None:
|
||
query = query.limit(effective_limit)
|
||
items = list(reversed(query.all()))
|
||
|
||
# Tag projection so your renderer knows what fields were requested
|
||
if fields:
|
||
proj = list(dict.fromkeys(fields)) # dedupe, preserve order
|
||
if "id" not in proj and hasattr(self.model, "id"):
|
||
proj.insert(0, "id")
|
||
else:
|
||
proj = []
|
||
if root_field_names:
|
||
proj.extend(root_field_names)
|
||
if root_fields:
|
||
proj.extend(c.key for c in root_fields if hasattr(c, "key"))
|
||
for path, names in (rel_field_names or {}).items():
|
||
prefix = ".".join(path)
|
||
for n in names:
|
||
proj.append(f"{prefix}.{n}")
|
||
if proj and "id" not in proj and hasattr(self.model, "id"):
|
||
proj.insert(0, "id")
|
||
|
||
if proj:
|
||
for obj in items:
|
||
try:
|
||
setattr(obj, "__crudkit_projection__", tuple(proj))
|
||
except Exception:
|
||
pass
|
||
|
||
# Boundary keys for cursor encoding in the API layer
|
||
# When ORDER BY includes related columns (e.g., owner.first_name),
|
||
# pluck values from the related object we hydrated with contains_eager/selectinload.
|
||
def _pluck_key_from_obj(obj: Any) -> list[Any]:
|
||
vals: list[Any] = []
|
||
# Build a quick map: selectable -> relationship name
|
||
alias_to_rel: dict[Any, str] = {}
|
||
for _p, relationship_attr, target_alias in join_paths:
|
||
sel = getattr(target_alias, "selectable", None)
|
||
if sel is not None:
|
||
alias_to_rel[sel] = relationship_attr.key
|
||
|
||
for col in order_spec.cols:
|
||
key = getattr(col, "key", None) or getattr(col, "name", None)
|
||
# Try root attribute first
|
||
if key and hasattr(obj, key):
|
||
vals.append(getattr(obj, key))
|
||
continue
|
||
# Try relationship hop by matching the column's table/selectable
|
||
table = getattr(col, "table", None)
|
||
relname = alias_to_rel.get(table)
|
||
if relname and key:
|
||
relobj = getattr(obj, relname, None)
|
||
if relobj is not None and hasattr(relobj, key):
|
||
vals.append(getattr(relobj, key))
|
||
continue
|
||
# Give up: unsupported expression for cursor purposes
|
||
raise ValueError("unpluckable")
|
||
return vals
|
||
|
||
try:
|
||
first_key = _pluck_key_from_obj(items[0]) if items else None
|
||
last_key = _pluck_key_from_obj(items[-1]) if items else None
|
||
except Exception:
|
||
# If we can't derive cursor keys (e.g., ORDER BY expression/aggregate),
|
||
# disable cursors for this response rather than exploding.
|
||
first_key = None
|
||
last_key = None
|
||
|
||
# Optional total that’s safe under JOINs (COUNT DISTINCT ids)
|
||
total = None
|
||
if include_total:
|
||
base = session.query(getattr(root_alias, "id"))
|
||
base = self._apply_not_deleted(base, root_alias, params)
|
||
if filters:
|
||
base = base.filter(*filters)
|
||
# Mirror join structure for any SQL-needed relationships
|
||
for _path, relationship_attr, target_alias in join_paths:
|
||
if relationship_attr.key in sql_hops:
|
||
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
||
base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
|
||
# Also mirror any forced joins
|
||
for name in (sql_hops - {ra.key for _p, ra, _a in join_paths}):
|
||
rel_attr = cast(InstrumentedAttribute, getattr(root_alias, name))
|
||
base = base.join(rel_attr, isouter=True)
|
||
|
||
total = session.query(func.count()).select_from(
|
||
base.order_by(None).distinct().subquery()
|
||
).scalar() or 0
|
||
|
||
window_limit_for_body = 0 if effective_limit is None and limit == 0 else (effective_limit or 50)
|
||
|
||
if log.isEnabledFor(logging.DEBUG):
|
||
log.debug("QUERY: %s", str(query))
|
||
|
||
return SeekWindow(
|
||
items=items,
|
||
limit=window_limit_for_body,
|
||
first_key=first_key,
|
||
last_key=last_key,
|
||
order=order_spec,
|
||
total=total,
|
||
)
|
||
|
||
# Helper: default ORDER BY for MSSQL when paginating without explicit order
|
||
def _default_order_by(self, root_alias):
|
||
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
|
||
cols = []
|
||
for col in mapper.primary_key:
|
||
try:
|
||
cols.append(getattr(root_alias, col.key))
|
||
except AttributeError:
|
||
cols.append(col)
|
||
return cols or [text("1")]
|
||
|
||
def _stable_order_by(self, root_alias, given_order_by):
|
||
"""
|
||
Ensure deterministic ordering by appending PK columns as tiebreakers.
|
||
If no order is provided, fall back to default primary-key order.
|
||
"""
|
||
order_by = list(given_order_by or [])
|
||
if not order_by:
|
||
return self._default_order_by(root_alias)
|
||
|
||
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
|
||
pk_cols = []
|
||
for col in mapper.primary_key:
|
||
try:
|
||
pk_cols.append(getattr(root_alias, col.key))
|
||
except AttributeError:
|
||
pk_cols.append(col)
|
||
|
||
return [*order_by, *pk_cols]
|
||
|
||
def get(self, id: int, params=None) -> T | None:
|
||
"""Fetch a single row by id with conflict-free eager loading and clean projection."""
|
||
query, root_alias = self.get_query()
|
||
|
||
# Defaults so we can build a projection even if params is None
|
||
root_fields: list[Any] = []
|
||
root_field_names: dict[str, str] = {}
|
||
rel_field_names: dict[tuple[str, ...], list[str]] = {}
|
||
req_fields: list[str] = _normalize_fields_param(params)
|
||
|
||
# Soft-delete guard
|
||
query = self._apply_not_deleted(query, root_alias, params)
|
||
|
||
spec = CRUDSpec(self.model, params or {}, root_alias)
|
||
|
||
# Optional extra filters (in addition to id); keep parity with list()
|
||
filters = spec.parse_filters()
|
||
if filters:
|
||
query = query.filter(*filters)
|
||
|
||
# Always filter by id
|
||
query = query.filter(getattr(root_alias, "id") == id)
|
||
|
||
# Includes + join paths we may need
|
||
spec.parse_includes()
|
||
join_paths = tuple(spec.get_join_paths())
|
||
|
||
# Field parsing to enable root load_only
|
||
if params:
|
||
root_fields, rel_field_names, root_field_names = spec.parse_fields()
|
||
|
||
# Decide which relationship paths are needed for SQL vs display-only
|
||
# For get(), there is no ORDER BY; only filters might force SQL use.
|
||
sql_hops = _paths_needed_for_sql(order_by=None, filters=filters, join_paths=join_paths)
|
||
proj_hops = _paths_from_fields(req_fields)
|
||
|
||
# Root column projection
|
||
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
|
||
if only_cols:
|
||
query = query.options(Load(root_alias).load_only(*only_cols))
|
||
|
||
# Relationship handling per path: avoid loader strategy conflicts
|
||
used_contains_eager = False
|
||
for _path, relationship_attr, target_alias in join_paths:
|
||
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
||
name = relationship_attr.key
|
||
if name in sql_hops:
|
||
# Needed in WHERE: join + hydrate from the join
|
||
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
|
||
query = query.options(contains_eager(rel_attr, alias=target_alias))
|
||
used_contains_eager = True
|
||
elif name in proj_hops:
|
||
# Display-only: bulk-load efficiently
|
||
query = query.options(selectinload(rel_attr))
|
||
else:
|
||
pass
|
||
|
||
# Projection loader options compiled from requested fields.
|
||
# Skip if we used contains_eager to avoid strategy conflicts.
|
||
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
|
||
if proj_opts and not used_contains_eager:
|
||
query = query.options(*proj_opts)
|
||
|
||
obj = query.first()
|
||
|
||
# Emit exactly what the client requested (plus id), or a reasonable fallback
|
||
if req_fields:
|
||
proj = list(dict.fromkeys(req_fields)) # dedupe, preserve order
|
||
if "id" not in proj and hasattr(self.model, "id"):
|
||
proj.insert(0, "id")
|
||
else:
|
||
proj = []
|
||
if root_field_names:
|
||
proj.extend(root_field_names)
|
||
if root_fields:
|
||
proj.extend(c.key for c in root_fields if hasattr(c, "key"))
|
||
for path, names in (rel_field_names or {}).items():
|
||
prefix = ".".join(path)
|
||
for n in names:
|
||
proj.append(f"{prefix}.{n}")
|
||
if proj and "id" not in proj and hasattr(self.model, "id"):
|
||
proj.insert(0, "id")
|
||
|
||
if proj and obj is not None:
|
||
try:
|
||
setattr(obj, "__crudkit_projection__", tuple(proj))
|
||
except Exception:
|
||
pass
|
||
|
||
if log.isEnabledFor(logging.DEBUG):
|
||
log.debug("QUERY: %s", str(query))
|
||
|
||
return obj or None
|
||
|
||
def list(self, params=None) -> list[T]:
|
||
"""Offset/limit listing with smart relationship loading and clean projection."""
|
||
query, root_alias = self.get_query()
|
||
|
||
# Defaults so we can reference them later even if params is None
|
||
root_fields: list[Any] = []
|
||
root_field_names: dict[str, str] = {}
|
||
rel_field_names: dict[tuple[str, ...], list[str]] = {}
|
||
req_fields: list[str] = _normalize_fields_param(params)
|
||
|
||
if params:
|
||
query = self._apply_not_deleted(query, root_alias, params)
|
||
|
||
spec = CRUDSpec(self.model, params or {}, root_alias)
|
||
filters = spec.parse_filters()
|
||
order_by = spec.parse_sort()
|
||
limit, offset = spec.parse_pagination()
|
||
|
||
# Includes + join paths we might need
|
||
spec.parse_includes()
|
||
join_paths = tuple(spec.get_join_paths())
|
||
|
||
# Field parsing for load_only on root columns
|
||
root_fields, rel_field_names, root_field_names = spec.parse_fields()
|
||
|
||
if filters:
|
||
query = query.filter(*filters)
|
||
|
||
# Determine which relationship paths are needed for SQL vs display-only
|
||
sql_hops = _paths_needed_for_sql(order_by, filters, join_paths)
|
||
sql_hops |= _hops_from_sort(params) # ensure sort-driven joins exist
|
||
proj_hops = _paths_from_fields(req_fields)
|
||
|
||
# Root column projection
|
||
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
|
||
if only_cols:
|
||
query = query.options(Load(root_alias).load_only(*only_cols))
|
||
|
||
# Relationship handling per path
|
||
used_contains_eager = False
|
||
joined_names: set[str] = set()
|
||
|
||
for _path, relationship_attr, target_alias in join_paths:
|
||
rel_attr = cast(InstrumentedAttribute, relationship_attr)
|
||
name = relationship_attr.key
|
||
if name in sql_hops:
|
||
# Needed for WHERE/ORDER BY: join + hydrate from the join
|
||
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
|
||
query = query.options(contains_eager(rel_attr, alias=target_alias))
|
||
used_contains_eager = True
|
||
joined_names.add(name)
|
||
elif name in proj_hops:
|
||
# Display-only: no join, bulk-load efficiently
|
||
query = query.options(selectinload(rel_attr))
|
||
joined_names.add(name)
|
||
|
||
# Force-join any SQL-needed relationships that weren't in join_paths
|
||
missing_sql = sql_hops - joined_names
|
||
for name in missing_sql:
|
||
rel_attr = cast(InstrumentedAttribute, getattr(root_alias, name))
|
||
query = query.join(rel_attr, isouter=True)
|
||
query = query.options(contains_eager(rel_attr))
|
||
used_contains_eager = True
|
||
joined_names.add(name)
|
||
|
||
# MSSQL requires ORDER BY when OFFSET is used (SQLA uses OFFSET for limit/offset)
|
||
paginating = (limit is not None) or (offset is not None and offset != 0)
|
||
if paginating and not order_by and self.backend.requires_order_by_for_offset:
|
||
order_by = self._default_order_by(root_alias)
|
||
|
||
if order_by:
|
||
query = query.order_by(*order_by)
|
||
|
||
# Only apply offset/limit when not None and not zero
|
||
if offset is not None and offset != 0:
|
||
query = query.offset(offset)
|
||
if limit is not None and limit > 0:
|
||
query = query.limit(limit)
|
||
|
||
# Projection loader options compiled from requested fields.
|
||
# Skip if we used contains_eager to avoid loader-strategy conflicts.
|
||
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
|
||
if proj_opts and not used_contains_eager:
|
||
query = query.options(*proj_opts)
|
||
|
||
else:
|
||
# No params means no filters/sorts/limits; still honor projection loaders if any
|
||
expanded_fields, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
|
||
if proj_opts:
|
||
query = query.options(*proj_opts)
|
||
|
||
rows = query.all()
|
||
|
||
# Emit exactly what the client requested (plus id), or a reasonable fallback
|
||
if req_fields:
|
||
proj = list(dict.fromkeys(req_fields)) # dedupe while preserving order
|
||
if "id" not in proj and hasattr(self.model, "id"):
|
||
proj.insert(0, "id")
|
||
else:
|
||
proj = []
|
||
if root_field_names:
|
||
proj.extend(root_field_names)
|
||
if root_fields:
|
||
proj.extend(c.key for c in root_fields if hasattr(c, "key"))
|
||
for path, names in (rel_field_names or {}).items():
|
||
prefix = ".".join(path)
|
||
for n in names:
|
||
proj.append(f"{prefix}.{n}")
|
||
if proj and "id" not in proj and hasattr(self.model, "id"):
|
||
proj.insert(0, "id")
|
||
|
||
if proj:
|
||
for obj in rows:
|
||
try:
|
||
setattr(obj, "__crudkit_projection__", tuple(proj))
|
||
except Exception:
|
||
pass
|
||
|
||
if log.isEnabledFor(logging.DEBUG):
|
||
log.debug("QUERY: %s", str(query))
|
||
|
||
return rows
|
||
|
||
|
||
def create(self, data: dict, actor=None) -> T:
|
||
session = self.session
|
||
obj = self.model(**data)
|
||
session.add(obj)
|
||
session.commit()
|
||
self._log_version("create", obj, actor)
|
||
return obj
|
||
|
||
def update(self, id: int, data: dict, actor=None) -> T:
|
||
session = self.session
|
||
obj = session.get(self.model, id)
|
||
if not obj:
|
||
raise ValueError(f"{self.model.__name__} with ID {id} not found.")
|
||
valid_fields = {c.name for c in self.model.__table__.columns}
|
||
unknown = set(data) - valid_fields
|
||
if unknown:
|
||
raise ValueError(f"Unknown fields: {', '.join(sorted(unknown))}")
|
||
for k, v in data.items():
|
||
if k in valid_fields:
|
||
setattr(obj, k, v)
|
||
session.commit()
|
||
self._log_version("update", obj, actor)
|
||
return obj
|
||
|
||
def delete(self, id: int, hard: bool = False, actor = None):
|
||
session = self.session
|
||
obj = session.get(self.model, id)
|
||
if not obj:
|
||
return None
|
||
if hard or not self.supports_soft_delete:
|
||
session.delete(obj)
|
||
else:
|
||
soft = cast(_SoftDeletable, obj)
|
||
soft.is_deleted = True
|
||
session.commit()
|
||
self._log_version("delete", obj, actor)
|
||
return obj
|
||
|
||
def _log_version(self, change_type: str, obj: T, actor=None, metadata: dict | None = None):
|
||
session = self.session
|
||
try:
|
||
data = obj.as_dict()
|
||
except Exception:
|
||
data = {"error": "Failed to serialize object."}
|
||
version = Version(
|
||
model_name=self.model.__name__,
|
||
object_id=obj.id,
|
||
change_type=change_type,
|
||
data=data,
|
||
actor=str(actor) if actor else None,
|
||
meta=metadata
|
||
)
|
||
session.add(version)
|
||
session.commit()
|