340 lines
14 KiB
Python
340 lines
14 KiB
Python
from dataclasses import dataclass
|
|
from typing import Any, List, Tuple, Set, Dict, Optional, Iterable
|
|
from sqlalchemy import and_, asc, desc, or_
|
|
from sqlalchemy.ext.hybrid import hybrid_property
|
|
from sqlalchemy.orm import aliased, selectinload
|
|
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
|
|
|
@dataclass(frozen=True)
|
|
class CollPred:
|
|
table: Any
|
|
col_key: str
|
|
op: str
|
|
value: Any
|
|
|
|
OPERATORS = {
|
|
'eq': lambda col, val: col == val,
|
|
'lt': lambda col, val: col < val,
|
|
'lte': lambda col, val: col <= val,
|
|
'gt': lambda col, val: col > val,
|
|
'gte': lambda col, val: col >= val,
|
|
'ne': lambda col, val: col != val,
|
|
'icontains': lambda col, val: col.ilike(f"%{val}%"),
|
|
'in': lambda col, val: col.in_(val if isinstance(val, (list, tuple, set)) else [val]),
|
|
'nin': lambda col, val: ~col.in_(val if isinstance(val, (list, tuple, set)) else [val]),
|
|
}
|
|
|
|
class CRUDSpec:
|
|
def __init__(self, model, params, root_alias):
|
|
self.model = model
|
|
self.params = params
|
|
self.root_alias = root_alias
|
|
self.eager_paths: Set[Tuple[str, ...]] = set()
|
|
# (parent_alias. relationship_attr, alias_for_target)
|
|
self.join_paths: List[Tuple[object, InstrumentedAttribute, object]] = []
|
|
self.alias_map: Dict[Tuple[str, ...], object] = {}
|
|
self._root_fields: List[InstrumentedAttribute] = []
|
|
# dotted non-collection fields (MANYTOONE etc)
|
|
self._rel_field_names: Dict[Tuple[str, ...], List[str]] = {}
|
|
# dotted collection fields (ONETOMANY)
|
|
self._collection_field_names: Dict[str, List[str]] = {}
|
|
self.include_paths: Set[Tuple[str, ...]] = set()
|
|
|
|
def _split_path_and_op(self, key: str) -> tuple[str, str]:
|
|
if '__' in key:
|
|
path, op = key.rsplit('__', 1)
|
|
else:
|
|
path, op = key, 'eq'
|
|
return path, op
|
|
|
|
def _resolve_many_columns(self, path: str) -> list[tuple[InstrumentedAttribute, Optional[tuple[str, ...]]]]:
|
|
"""
|
|
Accepts pipe-delimited paths like 'label|owner.label'
|
|
Returns a list of (column, join_path) pairs for every resolvable subpath.
|
|
"""
|
|
cols: list[tuple[InstrumentedAttribute, Optional[tuple[str, ...]]]] = []
|
|
for sub in path.split('|'):
|
|
sub = sub.strip()
|
|
if not sub:
|
|
continue
|
|
col, join_path = self._resolve_column(sub)
|
|
if col is not None:
|
|
cols.append((col, join_path))
|
|
return cols
|
|
|
|
def _build_predicate_for(self, path: str, op: str, value: Any):
|
|
"""
|
|
Builds a SQLA BooleanClauseList or BinaryExpression for a single key.
|
|
If multiple subpaths are provided via pipe, returns an OR of them.
|
|
"""
|
|
if op not in OPERATORS:
|
|
return None
|
|
|
|
pairs = self._resolve_many_columns(path)
|
|
if not pairs:
|
|
return None
|
|
|
|
exprs = []
|
|
for col, join_path in pairs:
|
|
if join_path:
|
|
self.eager_paths.add(join_path)
|
|
|
|
try:
|
|
cur_cls = self.model
|
|
names = list(join_path)
|
|
last_name = names[-1]
|
|
is_collection = False
|
|
for nm in names:
|
|
rel_attr = getattr(cur_cls, nm)
|
|
prop = rel_attr.property
|
|
cur_cls = prop.mapper.class_
|
|
is_collection = bool(getattr(getattr(self.model, last_name), "property", None)
|
|
and getattr(getattr(self.model, last_name).property, "uselist", False))
|
|
except Exception:
|
|
is_collection = False
|
|
|
|
if is_collection:
|
|
target_cls = cur_cls
|
|
key = getattr(col, "key", None) or getattr(col, "name", None)
|
|
if key and hasattr(target_cls, key):
|
|
target_tbl = getattr(target_cls, "__table__", None)
|
|
if target_tbl is not None:
|
|
exprs.append(CollPred(table=target_tbl, col_key=key, op=op, value=value))
|
|
continue
|
|
|
|
exprs.append(OPERATORS[op](col, value))
|
|
|
|
if not exprs:
|
|
return None
|
|
|
|
# If any CollPred is in exprs, do NOT or_ them. Keep it single for now.
|
|
if any(isinstance(x, CollPred) for x in exprs):
|
|
# If someone used a pipe 'relA.col|relB.col' that produced multiple CollPreds,
|
|
# keep the first or raise for now (your choice).
|
|
if len(exprs) > 1:
|
|
# raise NotImplementedError("OR across collection paths not supported yet")
|
|
exprs = [next(x for x in exprs if isinstance(x, CollPred))]
|
|
return exprs[0]
|
|
|
|
# Otherwise, standard SQLA clause(s)
|
|
return exprs[0] if len(exprs) == 1 else or_(*exprs)
|
|
|
|
def _collect_filters(self, params: dict) -> list:
|
|
"""
|
|
Recursively parse filters from 'param' into a flat list of SQLA expressions.
|
|
Supports $or / $and groups. Any other keys are parsed as normal filters.
|
|
"""
|
|
filters: list = []
|
|
|
|
for key, value in (params or {}).items():
|
|
if key in ('sort', 'limit', 'offset', 'fields', 'include'):
|
|
continue
|
|
|
|
if key == '$or':
|
|
# value should be a list of dicts
|
|
groups = []
|
|
for group in value if isinstance(value, (list, tuple)) else []:
|
|
sub = self._collect_filters(group)
|
|
if not sub:
|
|
continue
|
|
groups.append(and_(*sub) if len(sub) > 1 else sub[0])
|
|
if groups:
|
|
filters.append(or_(*groups))
|
|
continue
|
|
|
|
if key == '$and':
|
|
# value should be a list of dicts
|
|
parts = []
|
|
for group in value if isinstance(value, (list, tuple)) else []:
|
|
sub = self._collect_filters(group)
|
|
if not sub:
|
|
continue
|
|
parts.append(and_(*sub) if len(sub) > 1 else sub[0])
|
|
if parts:
|
|
filters.append(and_(*parts))
|
|
continue
|
|
|
|
# Normal key
|
|
path, op = self._split_path_and_op(key)
|
|
pred = self._build_predicate_for(path, op, value)
|
|
if pred is not None:
|
|
filters.append(pred)
|
|
|
|
return filters
|
|
|
|
def _resolve_column(self, path: str):
|
|
current_alias = self.root_alias
|
|
parts = path.split('.')
|
|
join_path: list[str] = []
|
|
|
|
for i, attr in enumerate(parts):
|
|
try:
|
|
attr_obj = getattr(current_alias, attr)
|
|
except AttributeError:
|
|
return None, None
|
|
|
|
prop = getattr(attr_obj, "property", None)
|
|
if prop is not None and hasattr(prop, "direction"):
|
|
join_path.append(attr)
|
|
path_key = tuple(join_path)
|
|
alias = self.alias_map.get(path_key)
|
|
if not alias:
|
|
alias = aliased(prop.mapper.class_)
|
|
self.alias_map[path_key] = alias
|
|
self.join_paths.append((current_alias, attr_obj, alias))
|
|
current_alias = alias
|
|
continue
|
|
|
|
if isinstance(attr_obj, InstrumentedAttribute) or getattr(attr_obj, "comparator", None) is not None or hasattr(attr_obj, "clauses"):
|
|
return attr_obj, tuple(join_path) if join_path else None
|
|
|
|
return None, None
|
|
|
|
def parse_includes(self):
|
|
raw = self.params.get('include')
|
|
if not raw:
|
|
return
|
|
tokens = [p.strip() for p in str(raw).split(',') if p.strip()]
|
|
for token in tokens:
|
|
_, join_path = self._resolve_column(token)
|
|
if join_path:
|
|
self.eager_paths.add(join_path)
|
|
else:
|
|
col, maybe = self._resolve_column(token + '.id')
|
|
if maybe:
|
|
self.eager_paths.add(maybe)
|
|
|
|
def parse_filters(self, params: dict | None = None):
|
|
"""
|
|
Public entry: parse filters from given params or self.params.
|
|
Returns a list of SQLAlchemy filter expressions
|
|
"""
|
|
return self._collect_filters(params if params is not None else self.params)
|
|
|
|
def parse_sort(self):
|
|
sort_args = self.params.get('sort', '')
|
|
result = []
|
|
for part in sort_args.split(','):
|
|
part = part.strip()
|
|
if not part:
|
|
continue
|
|
if part.startswith('-'):
|
|
field = part[1:]
|
|
order = desc
|
|
else:
|
|
field = part
|
|
order = asc
|
|
col, join_path = self._resolve_column(field)
|
|
if col:
|
|
result.append(order(col))
|
|
if join_path:
|
|
self.eager_paths.add(join_path)
|
|
return result
|
|
|
|
def parse_pagination(self):
|
|
limit = int(self.params.get('limit', 100))
|
|
offset = int(self.params.get('offset', 0))
|
|
return limit, offset
|
|
|
|
def parse_fields(self):
|
|
"""
|
|
Parse ?fields=colA,colB,rel1.colC,rel1.rel2.colD
|
|
- Root fields become InstrumentedAttributes bound to root_alias.
|
|
- Related fields store attribute NAMES; we'll resolve them on the target class when building loader options.
|
|
- Collection (uselist=True) relationships record child names by relationship key.
|
|
Returns (root_fields, rel_field_names, root_field_names, collection_field_names_by_rel).
|
|
"""
|
|
raw = self.params.get('fields')
|
|
if not raw:
|
|
return [], {}, {}, {}
|
|
|
|
if isinstance(raw, list):
|
|
tokens = []
|
|
for chunk in raw:
|
|
tokens.extend(p.strip() for p in str(chunk).split(',') if p.strip())
|
|
else:
|
|
tokens = [p.strip() for p in str(raw).split(',') if p.strip()]
|
|
|
|
root_fields: List[InstrumentedAttribute] = []
|
|
root_field_names: list[str] = []
|
|
rel_field_names: Dict[Tuple[str, ...], List[str]] = {}
|
|
collection_field_names: Dict[str, List[str]] = {}
|
|
|
|
for token in tokens:
|
|
col, join_path = self._resolve_column(token)
|
|
if not col:
|
|
continue
|
|
if join_path:
|
|
# rel_field_names.setdefault(join_path, []).append(col.key)
|
|
# self.eager_paths.add(join_path)
|
|
try:
|
|
cur_cls = self.model
|
|
names = list(join_path)
|
|
last_name = names[-1]
|
|
for nm in names:
|
|
rel_attr = getattr(cur_cls, nm)
|
|
cur_cls = rel_attr.property.mapper.class_
|
|
is_collection = bool(getattr(getattr(self.model, last_name), "property", None) and getattr(getattr(self.model, last_name).property, "uselist", False))
|
|
except Exception:
|
|
# Fallback: inspect the InstrumentedAttribute we recorded on join_paths
|
|
is_collection = False
|
|
for _pa, rel_attr, _al in self.join_paths:
|
|
if rel_attr.key == (join_path[-1] if join_path else ""):
|
|
is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False))
|
|
break
|
|
|
|
if is_collection:
|
|
collection_field_names.setdefault(join_path[-1], []).append(col.key)
|
|
else:
|
|
rel_field_names.setdefault(join_path, []).append(col.key)
|
|
self.eager_paths.add(join_path)
|
|
else:
|
|
root_fields.append(col)
|
|
root_field_names.append(getattr(col, "key", token))
|
|
|
|
seen = set()
|
|
root_fields = [c for c in root_fields if not (c.key in seen or seen.add(c.key))]
|
|
for k, names in rel_field_names.items():
|
|
seen2 = set()
|
|
rel_field_names[k] = [n for n in names if not (n in seen2 or seen2.add(n))]
|
|
|
|
self._root_fields = root_fields
|
|
self._rel_field_names = rel_field_names
|
|
# return root_fields, rel_field_names, root_field_names
|
|
for r, names in collection_field_names.items():
|
|
seen3 = set()
|
|
collection_field_names[r] = [n for n in names if not (n in seen3 or seen3.add(n))]
|
|
return root_field_names, rel_field_names, root_field_names, collection_field_names
|
|
|
|
def get_eager_loads(self, root_alias, *, fields_map=None):
|
|
loads = []
|
|
for path in self.eager_paths:
|
|
current = root_alias
|
|
loader = None
|
|
for idx, name in enumerate(path):
|
|
rel_attr = getattr(current, name)
|
|
loader = selectinload(rel_attr) if loader is None else loader.selectinload(rel_attr)
|
|
|
|
# step to target class for the next hop
|
|
target_cls = rel_attr.property.mapper.class_
|
|
current = target_cls
|
|
|
|
# if final hop and we have a fields map, narrow columns
|
|
if fields_map and idx == len(path) - 1 and path in fields_map:
|
|
cols = []
|
|
for n in fields_map[path]:
|
|
attr = getattr(target_cls, n, None)
|
|
# Only include real column attributes; skip hybrids/expressions
|
|
if isinstance(attr, InstrumentedAttribute):
|
|
cols.append(attr)
|
|
|
|
# Only apply load_only if we have at least one real column
|
|
if cols:
|
|
loader = loader.load_only(*cols)
|
|
|
|
if loader is not None:
|
|
loads.append(loader)
|
|
return loads
|
|
|
|
def get_join_paths(self):
|
|
return self.join_paths
|