from typing import List, Tuple, Set, Dict from sqlalchemy import asc, desc from sqlalchemy.orm import joinedload, aliased from sqlalchemy.orm.attributes import InstrumentedAttribute OPERATORS = { 'eq': lambda col, val: col == val, 'lt': lambda col, val: col < val, 'lte': lambda col, val: col <= val, 'gt': lambda col, val: col > val, 'gte': lambda col, val: col >= val, 'ne': lambda col, val: col != val, 'icontains': lambda col, val: col.ilike(f"%{val}%"), } class CRUDSpec: def __init__(self, model, params): self.model = model self.params = params self.eager_paths: Set[Tuple[str, ...]] = set() self.join_paths: List[Tuple[object, InstrumentedAttribute, object]] = [] self.alias_map: Dict[Tuple[str, ...], object] = {} def _resolve_column(self, path: str): current_model = self.model current_alias = self.model parts = path.split('.') join_path = [] for i, attr in enumerate(parts): if not hasattr(current_model, attr): return None, None attr_obj = getattr(current_model, attr) if isinstance(attr_obj, InstrumentedAttribute): if hasattr(attr_obj.property, 'direction'): join_path.append(attr) path_key = tuple(join_path) alias = self.alias_map.get(path_key) if not alias: alias = aliased(attr_obj.property.mapper.class_) self.alias_map[path_key] = alias self.join_paths.append((current_alias, attr_obj, alias)) current_model = attr_obj.property.mapper.class_ current_alias = alias else: return getattr(current_alias, attr), tuple(join_path) if join_path else None return None, None def parse_filters(self): filters = [] for key, value in self.params.items(): if key in ('sort', 'limit', 'offset'): continue if '__' in key: path_op = key.rsplit('__', 1) if len(path_op) != 2: continue path, op = path_op else: path, op = key, 'eq' col, join_path = self._resolve_column(path) if col and op in OPERATORS: filters.append(OPERATORS[op](col, value)) if join_path: self.eager_paths.add(join_path) return filters def parse_sort(self): sort_args = self.params.get('sort', '') result = [] for part in sort_args.split(','): part = part.strip() if not part: continue if part.startswith('-'): field = part[1:] order = desc else: field = part order = asc col, join_path = self._resolve_column(field) if col: result.append(order(col)) if join_path: self.eager_paths.add(join_path) return result def parse_pagination(self): limit = int(self.params.get('limit', 100)) offset = int(self.params.get('offset', 0)) return limit, offset def get_eager_loads(self): loads = [] for path in self.eager_paths: current = self.model loader = None for attr in path: attr_obj = getattr(current, attr) if loader is None: loader = joinedload(attr_obj) else: loader = loader.joinedload(attr_obj) current = attr_obj.property.mapper.class_ if loader: loads.append(loader) return loads def get_join_paths(self): return self.join_paths