import sqlalchemy as sa from sqlalchemy import func, asc from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session, aliased from sqlalchemy.inspection import inspect from sqlalchemy.sql.elements import UnaryExpression from .dsl import QuerySpec, build_query, split_sort_tokens from .eager import default_eager_policy def _dedup_order_by(ordering): seen = set() result = [] for ob in ordering: col = ob.element if isinstance(ob, UnaryExpression) else ob key = f"{col}-{getattr(ob, 'modifier', '')}-{getattr(ob, 'operator', '')}" if key in seen: continue seen.add(key) result.append(ob) return result def _parse_sort_token(token: str): token = token.strip() direction = "asc" if token.startswith('-'): direction = "desc" token = token[1:] if ":" in token: key, dirpart = token.rsplit(":", 1) direction = "desc" if dirpart.lower().startswith("d") else "asc" return key, direction return token, direction def _apply_dotted_ordering(stmt, Model, sort_tokens): """ stmt: a select(Model) statement sort_tokens: list[str] like ["owner.identifier", "-brand.name"] Returns: (stmt, alias_cache) """ mapper = inspect(Model) alias_cache = {} # maps a path like "owner" or "brand" to its alias for tok in sort_tokens: path, direction = _parse_sort_token(tok) parts = [p for p in path.split(".") if p] if not parts: continue entity = Model current_mapper = mapper alias_path = [] # Walk relationships for all but the last part for rel_name in parts[:-1]: rel = current_mapper.relationships.get(rel_name) if rel is None: # invalid sort key; skip quietly or raise # raise ValueError(f"Unknown relationship {current_mapper.class_.__name__}.{rel_name}") entity = None break alias_path.append(rel_name) key = ".".join(alias_path) if key in alias_cache: entity_alias = alias_cache[key] else: # build an alias and join entity_alias = aliased(rel.mapper.class_) stmt = stmt.outerjoin(entity_alias, getattr(entity, rel.key)) alias_cache[key] = entity_alias entity = entity_alias current_mapper = inspect(rel.mapper.class_) if entity is None: continue col_name = parts[-1] # Validate final column if col_name not in current_mapper.columns: # raise ValueError(f"Unknown column {current_mapper.class_.__name__}.{col_name}") continue col = getattr(entity, col_name) if entity is not Model else getattr(Model, col_name) stmt = stmt.order_by(col.desc() if direction == "desc" else col.asc()) return stmt class CrudService: def __init__(self, session: Session, eager_policy=default_eager_policy): self.s = session self.eager_policy = eager_policy def create(self, Model, data, *, before=None, after=None): if before: data = before(data) or data obj = Model(**data) self.s.add(obj) self.s.flush() if after: after(obj) return obj def get(self, Model, id, spec: QuerySpec | None = None): spec = spec or QuerySpec() stmt = build_query(Model, spec, self.eager_policy).where(Model.id == id) return self.s.execute(stmt).scalars().first() def list(self, Model, spec: QuerySpec): stmt = build_query(Model, spec, self.eager_policy) simple_sorts, dotted_sorts = split_sort_tokens(spec.order_by) if dotted_sorts: stmt = _apply_dotted_ordering(stmt, Model, dotted_sorts) # count query pk = getattr(Model, "id") # adjust if not 'id' count_base = stmt.with_only_columns(sa.distinct(pk)).order_by(None) total = self.s.execute( sa.select(sa.func.count()).select_from(count_base.subquery()) ).scalar_one() if spec.page and spec.per_page: offset = (spec.page - 1) * spec.per_page stmt = stmt.limit(spec.per_page).offset(offset) # ---- ORDER BY handling ---- mapper = inspect(Model) pk_cols = mapper.primary_key # Gather all clauses added so far ordering = list(stmt._order_by_clauses) # Append pk tie-breakers if not already present existing_cols = { str(ob.element if isinstance(ob, UnaryExpression) else ob) for ob in ordering } for c in pk_cols: if str(c) not in existing_cols: ordering.append(asc(c)) # Dedup *before* applying ordering = _dedup_order_by(ordering) # Now wipe old order_bys and set once stmt = stmt.order_by(None).order_by(*ordering) rows = self.s.execute(stmt).scalars().all() return rows, total def update(self, obj, data, *, before=None, after=None): if obj.is_deleted: raise ValueError("Cannot update a deleted record") if before: data = before(obj, data) or data for k, v in data.items(): setattr(obj, k, v) obj.version += 1 if after: after(obj) return obj def soft_delete(self, obj, *, cascade=False, guard=None): if guard and not guard(obj): raise ValueError("Delete blocked by guard") # optionsl FK hygiene checks go here obj.mark_deleted() return obj def undelete(self, obj): obj.deleted = False obj.version += 1 return obj