176 lines
5.9 KiB
Python
176 lines
5.9 KiB
Python
import sqlalchemy as sa
|
|
from sqlalchemy import func, asc
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.orm import Session, aliased
|
|
from sqlalchemy.inspection import inspect
|
|
from sqlalchemy.sql.elements import UnaryExpression
|
|
|
|
from .dsl import QuerySpec, build_query, split_sort_tokens
|
|
from .eager import default_eager_policy
|
|
|
|
def _dedup_order_by(ordering):
|
|
seen = set()
|
|
result = []
|
|
for ob in ordering:
|
|
col = ob.element if isinstance(ob, UnaryExpression) else ob
|
|
key = f"{col}-{getattr(ob, 'modifier', '')}-{getattr(ob, 'operator', '')}"
|
|
if key in seen:
|
|
continue
|
|
seen.add(key)
|
|
result.append(ob)
|
|
return result
|
|
|
|
def _parse_sort_token(token: str):
|
|
token = token.strip()
|
|
direction = "asc"
|
|
if token.startswith('-'):
|
|
direction = "desc"
|
|
token = token[1:]
|
|
if ":" in token:
|
|
key, dirpart = token.rsplit(":", 1)
|
|
direction = "desc" if dirpart.lower().startswith("d") else "asc"
|
|
return key, direction
|
|
return token, direction
|
|
|
|
def _apply_dotted_ordering(stmt, Model, sort_tokens):
|
|
"""
|
|
stmt: a select(Model) statement
|
|
sort_tokens: list[str] like ["owner.identifier", "-brand.name"]
|
|
Returns: (stmt, alias_cache)
|
|
"""
|
|
mapper = inspect(Model)
|
|
alias_cache = {} # maps a path like "owner" or "brand" to its alias
|
|
|
|
for tok in sort_tokens:
|
|
path, direction = _parse_sort_token(tok)
|
|
parts = [p for p in path.split(".") if p]
|
|
if not parts:
|
|
continue
|
|
|
|
entity = Model
|
|
current_mapper = mapper
|
|
alias_path = []
|
|
|
|
# Walk relationships for all but the last part
|
|
for rel_name in parts[:-1]:
|
|
rel = current_mapper.relationships.get(rel_name)
|
|
if rel is None:
|
|
# invalid sort key; skip quietly or raise
|
|
print(f"Unknown relationship {current_mapper.class_.__name__}.{rel_name}")
|
|
entity = None
|
|
break
|
|
|
|
alias_path.append(rel_name)
|
|
key = ".".join(alias_path)
|
|
|
|
if key in alias_cache:
|
|
entity_alias = alias_cache[key]
|
|
else:
|
|
# build an alias and join
|
|
entity_alias = aliased(rel.mapper.class_)
|
|
stmt = stmt.outerjoin(entity_alias, getattr(entity, rel.key))
|
|
alias_cache[key] = entity_alias
|
|
|
|
entity = entity_alias
|
|
current_mapper = inspect(rel.mapper.class_)
|
|
|
|
if entity is None:
|
|
continue
|
|
|
|
col_name = parts[-1]
|
|
# # Validate final column
|
|
# if col_name not in current_mapper.columns:
|
|
# print(f"Unknown column {current_mapper.class_.__name__}.{col_name}")
|
|
# continue
|
|
|
|
# col = getattr(entity, col_name) if entity is not Model else getattr(Model, col_name)
|
|
|
|
attr = getattr(entity, col_name, None)
|
|
if attr is None:
|
|
attr = getattr(current_mapper.class_, col_name, None)
|
|
if attr is None:
|
|
print(f"Unknown column {current_mapper.class_.__name__}.{col_name}")
|
|
continue
|
|
stmt = stmt.order_by(attr.desc() if direction == "desc" else attr.asc())
|
|
|
|
return stmt
|
|
|
|
class CrudService:
|
|
def __init__(self, session: Session, eager_policy=default_eager_policy):
|
|
self.s = session
|
|
self.eager_policy = eager_policy
|
|
|
|
def create(self, Model, data, *, before=None, after=None):
|
|
if before: data = before(data) or data
|
|
obj = Model(**data)
|
|
self.s.add(obj)
|
|
self.s.flush()
|
|
if after: after(obj)
|
|
return obj
|
|
|
|
def get(self, Model, id, spec: QuerySpec | None = None):
|
|
spec = spec or QuerySpec()
|
|
stmt = build_query(Model, spec, self.eager_policy).where(Model.id == id)
|
|
return self.s.execute(stmt).scalars().first()
|
|
|
|
def list(self, Model, spec: QuerySpec):
|
|
stmt = build_query(Model, spec, self.eager_policy)
|
|
|
|
simple_sorts, dotted_sorts = split_sort_tokens(spec.order_by)
|
|
if dotted_sorts:
|
|
stmt = _apply_dotted_ordering(stmt, Model, dotted_sorts)
|
|
|
|
# count query
|
|
pk = getattr(Model, "id") # adjust if not 'id'
|
|
count_base = stmt.with_only_columns(sa.distinct(pk)).order_by(None)
|
|
total = self.s.execute(
|
|
sa.select(sa.func.count()).select_from(count_base.subquery())
|
|
).scalar_one()
|
|
|
|
if spec.page and spec.per_page:
|
|
offset = (spec.page - 1) * spec.per_page
|
|
stmt = stmt.limit(spec.per_page).offset(offset)
|
|
|
|
# ---- ORDER BY handling ----
|
|
mapper = inspect(Model)
|
|
pk_cols = mapper.primary_key
|
|
|
|
# Gather all clauses added so far
|
|
ordering = list(stmt._order_by_clauses)
|
|
|
|
# Append pk tie-breakers if not already present
|
|
existing_cols = {
|
|
str(ob.element if isinstance(ob, UnaryExpression) else ob)
|
|
for ob in ordering
|
|
}
|
|
for c in pk_cols:
|
|
if str(c) not in existing_cols:
|
|
ordering.append(asc(c))
|
|
|
|
# Dedup *before* applying
|
|
ordering = _dedup_order_by(ordering)
|
|
|
|
# Now wipe old order_bys and set once
|
|
stmt = stmt.order_by(None).order_by(*ordering)
|
|
|
|
rows = self.s.execute(stmt).scalars().all()
|
|
return rows, total
|
|
|
|
def update(self, obj, data, *, before=None, after=None):
|
|
if obj.is_deleted: raise ValueError("Cannot update a deleted record")
|
|
if before: data = before(obj, data) or data
|
|
for k, v in data.items(): setattr(obj, k, v)
|
|
obj.version += 1
|
|
if after: after(obj)
|
|
return obj
|
|
|
|
def soft_delete(self, obj, *, cascade=False, guard=None):
|
|
if guard and not guard(obj): raise ValueError("Delete blocked by guard")
|
|
# optionsl FK hygiene checks go here
|
|
obj.mark_deleted()
|
|
return obj
|
|
|
|
def undelete(self, obj):
|
|
obj.deleted = False
|
|
obj.version += 1
|
|
return obj
|