from typing import List, Tuple, Set, Dict, Optional from sqlalchemy import asc, desc from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import aliased, selectinload from sqlalchemy.orm.attributes import InstrumentedAttribute OPERATORS = { 'eq': lambda col, val: col == val, 'lt': lambda col, val: col < val, 'lte': lambda col, val: col <= val, 'gt': lambda col, val: col > val, 'gte': lambda col, val: col >= val, 'ne': lambda col, val: col != val, 'icontains': lambda col, val: col.ilike(f"%{val}%"), } class CRUDSpec: def __init__(self, model, params, root_alias): self.model = model self.params = params self.root_alias = root_alias self.eager_paths: Set[Tuple[str, ...]] = set() self.join_paths: List[Tuple[object, InstrumentedAttribute, object]] = [] self.alias_map: Dict[Tuple[str, ...], object] = {} self._root_fields: List[InstrumentedAttribute] = [] self._rel_field_names: Dict[Tuple[str, ...], object] = {} self.include_paths: Set[Tuple[str, ...]] = set() def _resolve_column(self, path: str): current_alias = self.root_alias parts = path.split('.') join_path: list[str] = [] for i, attr in enumerate(parts): try: attr_obj = getattr(current_alias, attr) except AttributeError: return None, None prop = getattr(attr_obj, "property", None) if prop is not None and hasattr(prop, "direction"): join_path.append(attr) path_key = tuple(join_path) alias = self.alias_map.get(path_key) if not alias: alias = aliased(prop.mapper.class_) self.alias_map[path_key] = alias self.join_paths.append((current_alias, attr_obj, alias)) current_alias = alias continue if isinstance(attr_obj, InstrumentedAttribute) or getattr(attr_obj, "comparator", None) is not None or hasattr(attr_obj, "clauses"): return attr_obj, tuple(join_path) if join_path else None return None, None def parse_includes(self): raw = self.params.get('include') if not raw: return tokens = [p.strip() for p in str(raw).split(',') if p.strip()] for token in tokens: _, join_path = self._resolve_column(token) if join_path: self.eager_paths.add(join_path) else: col, maybe = self._resolve_column(token + '.id') if maybe: self.eager_paths.add(maybe) def parse_filters(self): filters = [] for key, value in self.params.items(): if key in ('sort', 'limit', 'offset'): continue if '__' in key: path_op = key.rsplit('__', 1) if len(path_op) != 2: continue path, op = path_op else: path, op = key, 'eq' col, join_path = self._resolve_column(path) if col and op in OPERATORS: filters.append(OPERATORS[op](col, value)) if join_path: self.eager_paths.add(join_path) return filters def parse_sort(self): sort_args = self.params.get('sort', '') result = [] for part in sort_args.split(','): part = part.strip() if not part: continue if part.startswith('-'): field = part[1:] order = desc else: field = part order = asc col, join_path = self._resolve_column(field) if col: result.append(order(col)) if join_path: self.eager_paths.add(join_path) return result def parse_pagination(self): limit = int(self.params.get('limit', 100)) offset = int(self.params.get('offset', 0)) return limit, offset def parse_fields(self): """ Parse ?fields=colA,colB,rel1.colC,rel1.rel2.colD - Root fields become InstrumentedAttributes bound to root_alias. - Related fields store attribute NAMES; we'll resolve them on the target class when building loader options. Returns (root_fields, rel_field_names). """ raw = self.params.get('fields') if not raw: return [], {} if isinstance(raw, list): tokens = [] for chunk in raw: tokens.extend(p.strip() for p in str(chunk).split(',') if p.strip()) else: tokens = [p.strip() for p in str(raw).split(',') if p.strip()] root_fields: List[InstrumentedAttribute] = [] rel_field_names: Dict[Tuple[str, ...], List[str]] = {} for token in tokens: col, join_path = self._resolve_column(token) if not col: continue if join_path: rel_field_names.setdefault(join_path, []).append(col.key) self.eager_paths.add(join_path) else: root_fields.append(col) seen = set() root_fields = [c for c in root_fields if not (c.key in seen or seen.add(c.key))] for k, names in rel_field_names.items(): seen2 = set() rel_field_names[k] = [n for n in names if not (n in seen2 or seen2.add(n))] self._root_fields = root_fields self._rel_field_names = rel_field_names return root_fields, rel_field_names def get_eager_loads(self, root_alias, *, fields_map=None): loads = [] for path in self.eager_paths: current = root_alias loader = None for idx, name in enumerate(path): rel_attr = getattr(current, name) loader = selectinload(rel_attr) if loader is None else loader.selectinload(rel_attr) # step to target class for the next hop target_cls = rel_attr.property.mapper.class_ current = target_cls # if final hop and we have a fields map, narrow columns if fields_map and idx == len(path) - 1 and path in fields_map: cols = [] for n in fields_map[path]: attr = getattr(target_cls, n, None) # Only include real column attributes; skip hybrids/expressions if isinstance(attr, InstrumentedAttribute): cols.append(attr) # Only apply load_only if we have at least one real column if cols: loader = loader.load_only(*cols) if loader is not None: loads.append(loader) return loads def get_join_paths(self): return self.join_paths