inventory/crudkit/service.py

169 lines
5.6 KiB
Python

import sqlalchemy as sa
from sqlalchemy import func, asc
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, aliased
from sqlalchemy.inspection import inspect
from sqlalchemy.sql.elements import UnaryExpression
from .dsl import QuerySpec, build_query, split_sort_tokens
from .eager import default_eager_policy
def _dedup_order_by(ordering):
seen = set()
result = []
for ob in ordering:
col = ob.element if isinstance(ob, UnaryExpression) else ob
key = f"{col}-{getattr(ob, 'modifier', '')}-{getattr(ob, 'operator', '')}"
if key in seen:
continue
seen.add(key)
result.append(ob)
return result
def _parse_sort_token(token: str):
token = token.strip()
direction = "asc"
if token.startswith('-'):
direction = "desc"
token = token[1:]
if ":" in token:
key, dirpart = token.rsplit(":", 1)
direction = "desc" if dirpart.lower().startswith("d") else "asc"
return key, direction
return token, direction
def _apply_dotted_ordering(stmt, Model, sort_tokens):
"""
stmt: a select(Model) statement
sort_tokens: list[str] like ["owner.identifier", "-brand.name"]
Returns: (stmt, alias_cache)
"""
mapper = inspect(Model)
alias_cache = {} # maps a path like "owner" or "brand" to its alias
for tok in sort_tokens:
path, direction = _parse_sort_token(tok)
parts = [p for p in path.split(".") if p]
if not parts:
continue
entity = Model
current_mapper = mapper
alias_path = []
# Walk relationships for all but the last part
for rel_name in parts[:-1]:
rel = current_mapper.relationships.get(rel_name)
if rel is None:
# invalid sort key; skip quietly or raise
# raise ValueError(f"Unknown relationship {current_mapper.class_.__name__}.{rel_name}")
entity = None
break
alias_path.append(rel_name)
key = ".".join(alias_path)
if key in alias_cache:
entity_alias = alias_cache[key]
else:
# build an alias and join
entity_alias = aliased(rel.mapper.class_)
stmt = stmt.outerjoin(entity_alias, getattr(entity, rel.key))
alias_cache[key] = entity_alias
entity = entity_alias
current_mapper = inspect(rel.mapper.class_)
if entity is None:
continue
col_name = parts[-1]
# Validate final column
if col_name not in current_mapper.columns:
# raise ValueError(f"Unknown column {current_mapper.class_.__name__}.{col_name}")
continue
col = getattr(entity, col_name) if entity is not Model else getattr(Model, col_name)
stmt = stmt.order_by(col.desc() if direction == "desc" else col.asc())
return stmt
class CrudService:
def __init__(self, session: Session, eager_policy=default_eager_policy):
self.s = session
self.eager_policy = eager_policy
def create(self, Model, data, *, before=None, after=None):
if before: data = before(data) or data
obj = Model(**data)
self.s.add(obj)
self.s.flush()
if after: after(obj)
return obj
def get(self, Model, id, spec: QuerySpec | None = None):
spec = spec or QuerySpec()
stmt = build_query(Model, spec, self.eager_policy).where(Model.id == id)
return self.s.execute(stmt).scalars().first()
def list(self, Model, spec: QuerySpec):
stmt = build_query(Model, spec, self.eager_policy)
simple_sorts, dotted_sorts = split_sort_tokens(spec.order_by)
if dotted_sorts:
stmt = _apply_dotted_ordering(stmt, Model, dotted_sorts)
# count query
pk = getattr(Model, "id") # adjust if not 'id'
count_base = stmt.with_only_columns(sa.distinct(pk)).order_by(None)
total = self.s.execute(
sa.select(sa.func.count()).select_from(count_base.subquery())
).scalar_one()
if spec.page and spec.per_page:
offset = (spec.page - 1) * spec.per_page
stmt = stmt.limit(spec.per_page).offset(offset)
# ---- ORDER BY handling ----
mapper = inspect(Model)
pk_cols = mapper.primary_key
# Gather all clauses added so far
ordering = list(stmt._order_by_clauses)
# Append pk tie-breakers if not already present
existing_cols = {
str(ob.element if isinstance(ob, UnaryExpression) else ob)
for ob in ordering
}
for c in pk_cols:
if str(c) not in existing_cols:
ordering.append(asc(c))
# Dedup *before* applying
ordering = _dedup_order_by(ordering)
# Now wipe old order_bys and set once
stmt = stmt.order_by(None).order_by(*ordering)
rows = self.s.execute(stmt).scalars().all()
return rows, total
def update(self, obj, data, *, before=None, after=None):
if obj.is_deleted: raise ValueError("Cannot update a deleted record")
if before: data = before(obj, data) or data
for k, v in data.items(): setattr(obj, k, v)
obj.version += 1
if after: after(obj)
return obj
def soft_delete(self, obj, *, cascade=False, guard=None):
if guard and not guard(obj): raise ValueError("Delete blocked by guard")
# optionsl FK hygiene checks go here
obj.mark_deleted()
return obj
def undelete(self, obj):
obj.deleted = False
obj.version += 1
return obj