diff --git a/crudkit/core/service.py b/crudkit/core/service.py index b9d10ac..e82bec1 100644 --- a/crudkit/core/service.py +++ b/crudkit/core/service.py @@ -20,6 +20,12 @@ class CRUDService(Generic[T]): order_by = spec.parse_sort() limit, offset = spec.parse_pagination() + for parent, relationship_attr, alias in spec.get_join_paths(): + query = query.join(alias, relationship_attr.of_type(alias), isouter=True) + + for eager in spec.get_eager_loads(): + query = query.options(eager) + if filters: query = query.filter(*filters) if order_by: diff --git a/crudkit/core/spec.py b/crudkit/core/spec.py index a1f8292..5840071 100644 --- a/crudkit/core/spec.py +++ b/crudkit/core/spec.py @@ -1,5 +1,7 @@ -from typing import List, Tuple -from sqlalchemy import asc, desc, or_, and_ +from typing import List, Tuple, Set, Dict +from sqlalchemy import asc, desc +from sqlalchemy.orm import joinedload, aliased +from sqlalchemy.orm.attributes import InstrumentedAttribute OPERATORS = { 'eq': lambda col, val: col == val, @@ -15,6 +17,34 @@ class CRUDSpec: def __init__(self, model, params): self.model = model self.params = params + self.eager_paths: Set[Tuple[str, ...]] = set() + self.join_paths: List[Tuple[object, InstrumentedAttribute, object]] = [] + self.alias_map: Dict[Tuple[str, ...], object] = {} + + def _resolve_column(self, path: str): + current_model = self.model + current_alias = self.model + parts = path.split('.') + join_path = [] + + for i, attr in enumerate(parts): + if not hasattr(current_model, attr): + return None, None + attr_obj = getattr(current_model, attr) + if isinstance(attr_obj, InstrumentedAttribute): + if hasattr(attr_obj.property, 'direction'): + join_path.append(attr) + path_key = tuple(join_path) + alias = self.alias_map.get(path_key) + if not alias: + alias = aliased(attr_obj.property.mapper.class_) + self.alias_map[path_key] = alias + self.join_paths.append((current_alias, attr_obj, alias)) + current_model = attr_obj.property.mapper.class_ + current_alias = alias + else: + return getattr(current_alias, attr), tuple(join_path) if join_path else None + return None, None def parse_filters(self): filters = [] @@ -22,12 +52,17 @@ class CRUDSpec: if key in ('sort', 'limit', 'offset'): continue if '__' in key: - field, op = key.split('__', 1) + path_op = key.rsplit('__', 1) + if len(path_op) != 2: + continue + path, op = path_op else: - field, op = key, 'eq' - if hasattr(self.model, field): - col = getattr(self.model, field) + path, op = key, 'eq' + col, join_path = self._resolve_column(path) + if col and op in OPERATORS: filters.append(OPERATORS[op](col, value)) + if join_path: + self.eager_paths.add(join_path) return filters def parse_sort(self): @@ -43,11 +78,33 @@ class CRUDSpec: else: field = part order = asc - if hasattr(self.model, field): - result.append(order(getattr(self.model, field))) + col, join_path = self._resolve_column(field) + if col: + result.append(order(col)) + if join_path: + self.eager_paths.add(join_path) return result def parse_pagination(self): limit = int(self.params.get('limit', 100)) offset = int(self.params.get('offset', 0)) return limit, offset + + def get_eager_loads(self): + loads = [] + for path in self.eager_paths: + current = self.model + loader = None + for attr in path: + attr_obj = getattr(current, attr) + if loader is None: + loader = joinedload(attr_obj) + else: + loader = loader.joinedload(attr_obj) + current = attr_obj.property.mapper.class_ + if loader: + loads.append(loader) + return loads + + def get_join_paths(self): + return self.join_paths