crudkit/crudkit/core/spec.py
2025-10-10 09:23:45 -05:00

340 lines
14 KiB
Python

from dataclasses import dataclass
from typing import Any, List, Tuple, Set, Dict, Optional, Iterable
from sqlalchemy import and_, asc, desc, or_
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import aliased, selectinload
from sqlalchemy.orm.attributes import InstrumentedAttribute
@dataclass(frozen=True)
class CollPred:
table: Any
col_key: str
op: str
value: Any
OPERATORS = {
'eq': lambda col, val: col == val,
'lt': lambda col, val: col < val,
'lte': lambda col, val: col <= val,
'gt': lambda col, val: col > val,
'gte': lambda col, val: col >= val,
'ne': lambda col, val: col != val,
'icontains': lambda col, val: col.ilike(f"%{val}%"),
'in': lambda col, val: col.in_(val if isinstance(val, (list, tuple, set)) else [val]),
'nin': lambda col, val: ~col.in_(val if isinstance(val, (list, tuple, set)) else [val]),
}
class CRUDSpec:
def __init__(self, model, params, root_alias):
self.model = model
self.params = params
self.root_alias = root_alias
self.eager_paths: Set[Tuple[str, ...]] = set()
# (parent_alias. relationship_attr, alias_for_target)
self.join_paths: List[Tuple[object, InstrumentedAttribute, object]] = []
self.alias_map: Dict[Tuple[str, ...], object] = {}
self._root_fields: List[InstrumentedAttribute] = []
# dotted non-collection fields (MANYTOONE etc)
self._rel_field_names: Dict[Tuple[str, ...], List[str]] = {}
# dotted collection fields (ONETOMANY)
self._collection_field_names: Dict[str, List[str]] = {}
self.include_paths: Set[Tuple[str, ...]] = set()
def _split_path_and_op(self, key: str) -> tuple[str, str]:
if '__' in key:
path, op = key.rsplit('__', 1)
else:
path, op = key, 'eq'
return path, op
def _resolve_many_columns(self, path: str) -> list[tuple[InstrumentedAttribute, Optional[tuple[str, ...]]]]:
"""
Accepts pipe-delimited paths like 'label|owner.label'
Returns a list of (column, join_path) pairs for every resolvable subpath.
"""
cols: list[tuple[InstrumentedAttribute, Optional[tuple[str, ...]]]] = []
for sub in path.split('|'):
sub = sub.strip()
if not sub:
continue
col, join_path = self._resolve_column(sub)
if col is not None:
cols.append((col, join_path))
return cols
def _build_predicate_for(self, path: str, op: str, value: Any):
"""
Builds a SQLA BooleanClauseList or BinaryExpression for a single key.
If multiple subpaths are provided via pipe, returns an OR of them.
"""
if op not in OPERATORS:
return None
pairs = self._resolve_many_columns(path)
if not pairs:
return None
exprs = []
for col, join_path in pairs:
if join_path:
self.eager_paths.add(join_path)
try:
cur_cls = self.model
names = list(join_path)
last_name = names[-1]
is_collection = False
for nm in names:
rel_attr = getattr(cur_cls, nm)
prop = rel_attr.property
cur_cls = prop.mapper.class_
is_collection = bool(getattr(getattr(self.model, last_name), "property", None)
and getattr(getattr(self.model, last_name).property, "uselist", False))
except Exception:
is_collection = False
if is_collection:
target_cls = cur_cls
key = getattr(col, "key", None) or getattr(col, "name", None)
if key and hasattr(target_cls, key):
target_tbl = getattr(target_cls, "__table__", None)
if target_tbl is not None:
exprs.append(CollPred(table=target_tbl, col_key=key, op=op, value=value))
continue
exprs.append(OPERATORS[op](col, value))
if not exprs:
return None
# If any CollPred is in exprs, do NOT or_ them. Keep it single for now.
if any(isinstance(x, CollPred) for x in exprs):
# If someone used a pipe 'relA.col|relB.col' that produced multiple CollPreds,
# keep the first or raise for now (your choice).
if len(exprs) > 1:
# raise NotImplementedError("OR across collection paths not supported yet")
exprs = [next(x for x in exprs if isinstance(x, CollPred))]
return exprs[0]
# Otherwise, standard SQLA clause(s)
return exprs[0] if len(exprs) == 1 else or_(*exprs)
def _collect_filters(self, params: dict) -> list:
"""
Recursively parse filters from 'param' into a flat list of SQLA expressions.
Supports $or / $and groups. Any other keys are parsed as normal filters.
"""
filters: list = []
for key, value in (params or {}).items():
if key in ('sort', 'limit', 'offset', 'fields', 'include'):
continue
if key == '$or':
# value should be a list of dicts
groups = []
for group in value if isinstance(value, (list, tuple)) else []:
sub = self._collect_filters(group)
if not sub:
continue
groups.append(and_(*sub) if len(sub) > 1 else sub[0])
if groups:
filters.append(or_(*groups))
continue
if key == '$and':
# value should be a list of dicts
parts = []
for group in value if isinstance(value, (list, tuple)) else []:
sub = self._collect_filters(group)
if not sub:
continue
parts.append(and_(*sub) if len(sub) > 1 else sub[0])
if parts:
filters.append(and_(*parts))
continue
# Normal key
path, op = self._split_path_and_op(key)
pred = self._build_predicate_for(path, op, value)
if pred is not None:
filters.append(pred)
return filters
def _resolve_column(self, path: str):
current_alias = self.root_alias
parts = path.split('.')
join_path: list[str] = []
for i, attr in enumerate(parts):
try:
attr_obj = getattr(current_alias, attr)
except AttributeError:
return None, None
prop = getattr(attr_obj, "property", None)
if prop is not None and hasattr(prop, "direction"):
join_path.append(attr)
path_key = tuple(join_path)
alias = self.alias_map.get(path_key)
if not alias:
alias = aliased(prop.mapper.class_)
self.alias_map[path_key] = alias
self.join_paths.append((current_alias, attr_obj, alias))
current_alias = alias
continue
if isinstance(attr_obj, InstrumentedAttribute) or getattr(attr_obj, "comparator", None) is not None or hasattr(attr_obj, "clauses"):
return attr_obj, tuple(join_path) if join_path else None
return None, None
def parse_includes(self):
raw = self.params.get('include')
if not raw:
return
tokens = [p.strip() for p in str(raw).split(',') if p.strip()]
for token in tokens:
_, join_path = self._resolve_column(token)
if join_path:
self.eager_paths.add(join_path)
else:
col, maybe = self._resolve_column(token + '.id')
if maybe:
self.eager_paths.add(maybe)
def parse_filters(self, params: dict | None = None):
"""
Public entry: parse filters from given params or self.params.
Returns a list of SQLAlchemy filter expressions
"""
return self._collect_filters(params if params is not None else self.params)
def parse_sort(self):
sort_args = self.params.get('sort', '')
result = []
for part in sort_args.split(','):
part = part.strip()
if not part:
continue
if part.startswith('-'):
field = part[1:]
order = desc
else:
field = part
order = asc
col, join_path = self._resolve_column(field)
if col:
result.append(order(col))
if join_path:
self.eager_paths.add(join_path)
return result
def parse_pagination(self):
limit = int(self.params.get('limit', 100))
offset = int(self.params.get('offset', 0))
return limit, offset
def parse_fields(self):
"""
Parse ?fields=colA,colB,rel1.colC,rel1.rel2.colD
- Root fields become InstrumentedAttributes bound to root_alias.
- Related fields store attribute NAMES; we'll resolve them on the target class when building loader options.
- Collection (uselist=True) relationships record child names by relationship key.
Returns (root_fields, rel_field_names, root_field_names, collection_field_names_by_rel).
"""
raw = self.params.get('fields')
if not raw:
return [], {}, {}, {}
if isinstance(raw, list):
tokens = []
for chunk in raw:
tokens.extend(p.strip() for p in str(chunk).split(',') if p.strip())
else:
tokens = [p.strip() for p in str(raw).split(',') if p.strip()]
root_fields: List[InstrumentedAttribute] = []
root_field_names: list[str] = []
rel_field_names: Dict[Tuple[str, ...], List[str]] = {}
collection_field_names: Dict[str, List[str]] = {}
for token in tokens:
col, join_path = self._resolve_column(token)
if not col:
continue
if join_path:
# rel_field_names.setdefault(join_path, []).append(col.key)
# self.eager_paths.add(join_path)
try:
cur_cls = self.model
names = list(join_path)
last_name = names[-1]
for nm in names:
rel_attr = getattr(cur_cls, nm)
cur_cls = rel_attr.property.mapper.class_
is_collection = bool(getattr(getattr(self.model, last_name), "property", None) and getattr(getattr(self.model, last_name).property, "uselist", False))
except Exception:
# Fallback: inspect the InstrumentedAttribute we recorded on join_paths
is_collection = False
for _pa, rel_attr, _al in self.join_paths:
if rel_attr.key == (join_path[-1] if join_path else ""):
is_collection = bool(getattr(getattr(rel_attr, "property", None), "uselist", False))
break
if is_collection:
collection_field_names.setdefault(join_path[-1], []).append(col.key)
else:
rel_field_names.setdefault(join_path, []).append(col.key)
self.eager_paths.add(join_path)
else:
root_fields.append(col)
root_field_names.append(getattr(col, "key", token))
seen = set()
root_fields = [c for c in root_fields if not (c.key in seen or seen.add(c.key))]
for k, names in rel_field_names.items():
seen2 = set()
rel_field_names[k] = [n for n in names if not (n in seen2 or seen2.add(n))]
self._root_fields = root_fields
self._rel_field_names = rel_field_names
# return root_fields, rel_field_names, root_field_names
for r, names in collection_field_names.items():
seen3 = set()
collection_field_names[r] = [n for n in names if not (n in seen3 or seen3.add(n))]
return root_field_names, rel_field_names, root_field_names, collection_field_names
def get_eager_loads(self, root_alias, *, fields_map=None):
loads = []
for path in self.eager_paths:
current = root_alias
loader = None
for idx, name in enumerate(path):
rel_attr = getattr(current, name)
loader = selectinload(rel_attr) if loader is None else loader.selectinload(rel_attr)
# step to target class for the next hop
target_cls = rel_attr.property.mapper.class_
current = target_cls
# if final hop and we have a fields map, narrow columns
if fields_map and idx == len(path) - 1 and path in fields_map:
cols = []
for n in fields_map[path]:
attr = getattr(target_cls, n, None)
# Only include real column attributes; skip hybrids/expressions
if isinstance(attr, InstrumentedAttribute):
cols.append(attr)
# Only apply load_only if we have at least one real column
if cols:
loader = loader.load_only(*cols)
if loader is not None:
loads.append(loader)
return loads
def get_join_paths(self):
return self.join_paths