Makign corrections on field selection.

This commit is contained in:
Yaro Kasear 2025-09-09 16:27:41 -05:00
parent 7e915423a3
commit 9d36d600bb
9 changed files with 1504 additions and 39 deletions

View file

@ -1,6 +1,6 @@
from typing import List, Tuple, Set, Dict
from typing import List, Tuple, Set, Dict, Optional
from sqlalchemy import asc, desc
from sqlalchemy.orm import joinedload, aliased
from sqlalchemy.orm import aliased, selectinload
from sqlalchemy.orm.attributes import InstrumentedAttribute
OPERATORS = {
@ -21,6 +21,9 @@ class CRUDSpec:
self.eager_paths: Set[Tuple[str, ...]] = set()
self.join_paths: List[Tuple[object, InstrumentedAttribute, object]] = []
self.alias_map: Dict[Tuple[str, ...], object] = {}
self._root_fields: List[InstrumentedAttribute] = []
self._rel_field_names: Dict[Tuple[str, ...], object] = {}
self.include_paths: Set[Tuple[str, ...]] = set()
def _resolve_column(self, path: str):
current_alias = self.root_alias
@ -50,6 +53,20 @@ class CRUDSpec:
return None, None
def parse_includes(self):
raw = self.params.get('include')
if not raw:
return
tokens = [p.strip() for p in str(raw).split(',') if p.strip()]
for token in tokens:
_, join_path = self._resolve_column(token)
if join_path:
self.eager_paths.add(join_path)
else:
col, maybe = self._resolve_column(token + '.id')
if maybe:
self.eager_paths.add(maybe)
def parse_filters(self):
filters = []
for key, value in self.params.items():
@ -94,14 +111,60 @@ class CRUDSpec:
offset = int(self.params.get('offset', 0))
return limit, offset
def get_eager_loads(self, root_alias):
def parse_fields(self):
"""
Parse ?fields=colA,colB,rel1.colC,rel1.rel2.colD
- Root fields become InstrumentedAttributes bound to root_alias.
- Related fields store attribute NAMES; we'll resolve them on the target class when building loader options.
Returns (root_fields, rel_field_names).
"""
raw = self.params.get('fields')
if not raw:
return [], {}
if isinstance(raw, list):
tokens = []
for chunk in raw:
tokens.extend(p.strip() for p in str(chunk).split(',') if p.strip())
else:
tokens = [p.strip() for p in str(raw).split(',') if p.strip()]
root_fields: List[InstrumentedAttribute] = []
rel_field_names: Dict[Tuple[str, ...], List[str]] = {}
for token in tokens:
col, join_path = self._resolve_column(token)
if not col:
continue
if join_path:
rel_field_names.setdefault(join_path, []).append(col.key)
self.eager_paths.add(join_path)
else:
root_fields.append(col)
seen = set()
root_fields = [c for c in root_fields if not (c.key in seen or seen.add(c.key))]
for k, names in rel_field_names.items():
seen2 = set()
rel_field_names[k] = [n for n in names if not (n in seen2 or seen2.add(n))]
self._root_fields = root_fields
self._rel_field_names = rel_field_names
return root_fields, rel_field_names
def get_eager_loads(self, root_alias, *, fields_map: Optional[Dict[Tuple[str, ...], List[str]]] = None):
loads = []
for path in self.eager_paths:
current = root_alias
loader = None
for name in path:
for idx, name in enumerate(path):
rel_attr = getattr(current, name)
loader = (joinedload(rel_attr) if loader is None else loader.joinedload(name))
loader = (selectinload(rel_attr) if loader is None else loader.selectinload(name))
if fields_map and idx == len(path) - 1 and path in fields_map:
target_cls = rel_attr.property.mapper.class_
cols = [getattr(target_cls, n) for n in fields_map[path] if hasattr(target_cls, n)]
if cols:
loader = loader.load_only(*cols)
current = rel_attr.property.mapper.class_
if loader is not None:
loads.append(loader)