from dataclasses import dataclass from typing import Any, List, Tuple, Set, Dict, Optional, Iterable from sqlalchemy import and_, asc, desc, or_ from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import aliased, selectinload from sqlalchemy.orm.attributes import InstrumentedAttribute @dataclass(frozen=True) class CollPred: table: Any col_key: str op: str value: Any OPERATORS = { 'eq': lambda col, val: col == val, 'lt': lambda col, val: col < val, 'lte': lambda col, val: col <= val, 'gt': lambda col, val: col > val, 'gte': lambda col, val: col >= val, 'ne': lambda col, val: col != val, 'icontains': lambda col, val: col.ilike(f"%{val}%"), 'in': lambda col, val: col.in_(val if isinstance(val, (list, tuple, set)) else [val]), 'nin': lambda col, val: ~col.in_(val if isinstance(val, (list, tuple, set)) else [val]), } class CRUDSpec: def __init__(self, model, params, root_alias): self.model = model self.params = params self.root_alias = root_alias self.eager_paths: Set[Tuple[str, ...]] = set() # (parent_alias. relationship_attr, alias_for_target) self.join_paths: List[Tuple[object, InstrumentedAttribute, object]] = [] self.alias_map: Dict[Tuple[str, ...], object] = {} self._root_fields: List[InstrumentedAttribute] = [] # dotted non-collection fields (MANYTOONE etc) self._rel_field_names: Dict[Tuple[str, ...], List[str]] = {} # dotted collection fields (ONETOMANY) self._collection_field_names: Dict[str, List[str]] = {} self.include_paths: Set[Tuple[str, ...]] = set() def _split_path_and_op(self, key: str) -> tuple[str, str]: if '__' in key: path, op = key.rsplit('__', 1) else: path, op = key, 'eq' return path, op def _resolve_many_columns(self, path: str) -> list[tuple[InstrumentedAttribute, Optional[tuple[str, ...]]]]: """ Accepts pipe-delimited paths like 'label|owner.label' Returns a list of (column, join_path) pairs for every resolvable subpath. """ cols: list[tuple[InstrumentedAttribute, Optional[tuple[str, ...]]]] = [] for sub in path.split('|'): sub = sub.strip() if not sub: continue col, join_path = self._resolve_column(sub) if col is not None: cols.append((col, join_path)) return cols def _build_predicate_for(self, path: str, op: str, value: Any): """ Builds a SQLA BooleanClauseList or BinaryExpression for a single key. If multiple subpaths are provided via pipe, returns an OR of them. """ if op not in OPERATORS: return None pairs = self._resolve_many_columns(path) if not pairs: return None exprs = [] for col, join_path in pairs: if join_path: self.eager_paths.add(join_path) try: cur_cls = self.model names = list(join_path) last_name = names[-1] is_collection = False for nm in names: rel_attr = getattr(cur_cls, nm) prop = rel_attr.property cur_cls = prop.mapper.class_ is_collection = bool(getattr(getattr(self.model, last_name), "property", None) and getattr(getattr(self.model, last_name).property, "uselist", False)) except Exception: is_collection = False if is_collection: target_cls = cur_cls key = getattr(col, "key", None) or getattr(col, "name", None) if key and hasattr(target_cls, key): target_tbl = getattr(target_cls, "__table__", None) if target_tbl is not None: exprs.append(CollPred(table=target_tbl, col_key=key, op=op, value=value)) continue exprs.append(OPERATORS[op](col, value)) if not exprs: return None # If any CollPred is in exprs, do NOT or_ them. Keep it single for now. if any(isinstance(x, CollPred) for x in exprs): # If someone used a pipe 'relA.col|relB.col' that produced multiple CollPreds, # keep the first or raise for now (your choice). if len(exprs) > 1: # raise NotImplementedError("OR across collection paths not supported yet") exprs = [next(x for x in exprs if isinstance(x, CollPred))] return exprs[0] # Otherwise, standard SQLA clause(s) return exprs[0] if len(exprs) == 1 else or_(*exprs) def _collect_filters(self, params: dict) -> list: """ Recursively parse filters from 'param' into a flat list of SQLA expressions. Supports $or / $and groups. Any other keys are parsed as normal filters. """ filters: list = [] for key, value in (params or {}).items(): if key in ('sort', 'limit', 'offset', 'fields', 'include'): continue if key == '$or': # value should be a list of dicts groups = [] for group in value if isinstance(value, (list, tuple)) else []: sub = self._collect_filters(group) if not sub: continue groups.append(and_(*sub) if len(sub) > 1 else sub[0]) if groups: filters.append(or_(*groups)) continue if key == '$and': # value should be a list of dicts parts = [] for group in value if isinstance(value, (list, tuple)) else []: sub = self._collect_filters(group) if not sub: continue parts.append(and_(*sub) if len(sub) > 1 else sub[0]) if parts: filters.append(and_(*parts)) continue # Normal key path, op = self._split_path_and_op(key) pred = self._build_predicate_for(path, op, value) if pred is not None: filters.append(pred) return filters def _resolve_column(self, path: str): current_alias = self.root_alias parts = path.split('.') join_path: list[str] = [] for i, attr in enumerate(parts): try: attr_obj = getattr(current_alias, attr) except AttributeError: return None, None prop = getattr(attr_obj, "property", None) if prop is not None and hasattr(prop, "direction"): join_path.append(attr) path_key = tuple(join_path) alias = self.alias_map.get(path_key) if not alias: alias = aliased(prop.mapper.class_) self.alias_map[path_key] = alias self.join_paths.append((current_alias, attr_obj, alias)) current_alias = alias continue if isinstance(attr_obj, InstrumentedAttribute) or getattr(attr_obj, "comparator", None) is not None or hasattr(attr_obj, "clauses"): return attr_obj, tuple(join_path) if join_path else None return None, None def parse_includes(self): raw = self.params.get('include') if not raw: return tokens = [p.strip() for p in str(raw).split(',') if p.strip()] for token in tokens: _, join_path = self._resolve_column(token) if join_path: self.eager_paths.add(join_path) else: col, maybe = self._resolve_column(token + '.id') if maybe: self.eager_paths.add(maybe) def parse_filters(self, params: dict | None = None): """ Public entry: parse filters from given params or self.params. Returns a list of SQLAlchemy filter expressions """ return self._collect_filters(params if params is not None else self.params) def parse_sort(self): sort_args = self.params.get('sort', '') result = [] for part in sort_args.split(','): part = part.strip() if not part: continue if part.startswith('-'): field = part[1:] order = desc else: field = part order = asc col, join_path = self._resolve_column(field) if col: result.append(order(col)) if join_path: self.eager_paths.add(join_path) return result def parse_pagination(self): limit = int(self.params.get('limit', 100)) offset = int(self.params.get('offset', 0)) return limit, offset def parse_fields(self): """ Parse ?fields=colA,colB,rel1.colC,rel1.rel2.colD - Root fields become InstrumentedAttributes bound to root_alias. - Related fields store attribute NAMES; we'll resolve them on the target class when building loader options. - Collection (uselist=True) relationships record child names by relationship key. Returns (root_fields, rel_field_names, root_field_names, collection_field_names_by_rel). """ raw = self.params.get('fields') if not raw: return [], {}, {}, {} if isinstance(raw, list): tokens = [] for chunk in raw: tokens.extend(p.strip() for p in str(chunk).split(',') if p.strip()) else: tokens = [p.strip() for p in str(raw).split(',') if p.strip()] root_fields: List[InstrumentedAttribute] = [] root_field_names: list[str] = [] rel_field_names: Dict[Tuple[str, ...], List[str]] = {} collection_field_names: Dict[str, List[str]] = {} for token in tokens: col, join_path = self._resolve_column(token) if not col: continue if join_path: # rel_field_names.setdefault(join_path, []).append(col.key) # self.eager_paths.add(join_path) try: cur_cls = self.model names = list(join_path) last_name = names[-1] for nm in names: rel_attr = getattr(cur_cls, nm) cur_cls = rel_attr.property.mapper.class_ is_collection = bool(getattr(getattr(self.model, last_name), "property", None) and getattr(getattr(self.model, last_name).property, "uselist", False)) except Exception: # Fallback: inspect the InstrumentedAttribute we recorded on join_paths is_collection = False for _pa, rel_attr, _al in self.join_paths: if rel_attr.key == (join_path[-1] if join_path else ""): is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False)) break if is_collection: collection_field_names.setdefault(join_path[-1], []).append(col.key) else: rel_field_names.setdefault(join_path, []).append(col.key) self.eager_paths.add(join_path) else: root_fields.append(col) root_field_names.append(getattr(col, "key", token)) seen = set() root_fields = [c for c in root_fields if not (c.key in seen or seen.add(c.key))] for k, names in rel_field_names.items(): seen2 = set() rel_field_names[k] = [n for n in names if not (n in seen2 or seen2.add(n))] self._root_fields = root_fields self._rel_field_names = rel_field_names # return root_fields, rel_field_names, root_field_names for r, names in collection_field_names.items(): seen3 = set() collection_field_names[r] = [n for n in names if not (n in seen3 or seen3.add(n))] return root_field_names, rel_field_names, root_field_names, collection_field_names def get_eager_loads(self, root_alias, *, fields_map=None): loads = [] for path in self.eager_paths: current = root_alias loader = None for idx, name in enumerate(path): rel_attr = getattr(current, name) loader = selectinload(rel_attr) if loader is None else loader.selectinload(rel_attr) # step to target class for the next hop target_cls = rel_attr.property.mapper.class_ current = target_cls # if final hop and we have a fields map, narrow columns if fields_map and idx == len(path) - 1 and path in fields_map: cols = [] for n in fields_map[path]: attr = getattr(target_cls, n, None) # Only include real column attributes; skip hybrids/expressions if isinstance(attr, InstrumentedAttribute): cols.append(attr) # Only apply load_only if we have at least one real column if cols: loader = loader.load_only(*cols) if loader is not None: loads.append(loader) return loads def get_join_paths(self): return self.join_paths