diff --git a/crudkit/core/base.py b/crudkit/core/base.py index dae5615..c66aaf3 100644 --- a/crudkit/core/base.py +++ b/crudkit/core/base.py @@ -9,13 +9,26 @@ class CRUDMixin: created_at = Column(DateTime, default=func.now(), nullable=False) updated_at = Column(DateTime, default=func.now(), nullable=False, onupdate=func.now()) - def as_dict(self): - # Combine all columns from all inherited tables + def as_dict(self, fields: list[str] | None = None): + """ + Serialize mapped columns. Honors projection if either: + - 'fields' is passed explicitly, or + - + """ + allowed = None + if fields: + allowed = set(fields) + else: + allowed = getattr(self, "__crudkit_root_fields__", None) result = {} for cls in self.__class__.__mro__: - if hasattr(cls, "__table__"): - for column in cls.__table__.columns: - result[column.name] = getattr(self, column.name) + if not hasattr(cls, "__table__"): + continue + for column in cls.__table__.columns: + name = column.name + if allowed is not None and name not in allowed and name != "id": + continue + result[name] = getattr(self, name) return result class Version(Base): diff --git a/crudkit/core/service.py b/crudkit/core/service.py index 12e5f4e..ab1be78 100644 --- a/crudkit/core/service.py +++ b/crudkit/core/service.py @@ -1,5 +1,5 @@ from typing import Type, TypeVar, Generic, Optional -from sqlalchemy.orm import Session, with_polymorphic +from sqlalchemy.orm import Load, Session, raiseload, with_polymorphic from sqlalchemy import inspect, text from crudkit.core.base import Version from crudkit.core.spec import CRUDSpec @@ -28,11 +28,9 @@ class CRUDService(Generic[T]): def get_query(self): if self.polymorphic: - poly_model = with_polymorphic(self.model, '*') - return self.session.query(poly_model), poly_model - else: - base_only = with_polymorphic(self.model, [], flat=True) - return self.session.query(base_only), base_only + poly = with_polymorphic(self.model, "*") + return self.session.query(poly), poly + return self.session.query(self.model), self.model # Helper: default ORDER BY for MSSQL when paginating without explicit order def _default_order_by(self, root_alias): @@ -62,10 +60,11 @@ class CRUDService(Generic[T]): if not include_deleted: query = query.filter(getattr(root_alias, "is_deleted") == False) - spec = CRUDSpec(self.model, params, root_alias) + spec = CRUDSpec(self.model, params or {}, root_alias) filters = spec.parse_filters() order_by = spec.parse_sort() limit, offset = spec.parse_pagination() + spec.parse_includes() for parent_alias, relationship_attr, target_alias in spec.get_join_paths(): query = query.join( @@ -74,9 +73,17 @@ class CRUDService(Generic[T]): isouter=True ) - for eager in spec.get_eager_loads(root_alias): + root_fields, rel_field_names = spec.parse_fields() + + if root_fields: + query = query.options(Load(root_alias).load_only(*root_fields)) + + for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names): query = query.options(eager) + # if root_fields or rel_field_names: + # query = query.options(Load(root_alias).raiseload("*")) + if filters: query = query.filter(*filters) @@ -91,10 +98,27 @@ class CRUDService(Generic[T]): # Only apply offset/limit when not None. if offset is not None and offset != 0: query = query.offset(offset) - if limit is not None: + if limit is not None and limit > 0: query = query.limit(limit) - return query.all() + # return query.all() + rows = query.all() + + try: + rf_names = [c.key for c in (root_fields or [])] + except NameError: + rf_names = [] + if rf_names: + allow = set(rf_names) + if "id" not in allow and hasattr(self.model, "id"): + allow.add("id") + for obj in rows: + try: + setattr(obj, "__crudkit_root_fields__", allow) + except Exception: + pass + + return rows def create(self, data: dict, actor=None) -> T: obj = self.model(**data) diff --git a/crudkit/core/spec.py b/crudkit/core/spec.py index 26348ea..f895e6e 100644 --- a/crudkit/core/spec.py +++ b/crudkit/core/spec.py @@ -1,6 +1,6 @@ -from typing import List, Tuple, Set, Dict +from typing import List, Tuple, Set, Dict, Optional from sqlalchemy import asc, desc -from sqlalchemy.orm import joinedload, aliased +from sqlalchemy.orm import aliased, selectinload from sqlalchemy.orm.attributes import InstrumentedAttribute OPERATORS = { @@ -21,6 +21,9 @@ class CRUDSpec: self.eager_paths: Set[Tuple[str, ...]] = set() self.join_paths: List[Tuple[object, InstrumentedAttribute, object]] = [] self.alias_map: Dict[Tuple[str, ...], object] = {} + self._root_fields: List[InstrumentedAttribute] = [] + self._rel_field_names: Dict[Tuple[str, ...], object] = {} + self.include_paths: Set[Tuple[str, ...]] = set() def _resolve_column(self, path: str): current_alias = self.root_alias @@ -50,6 +53,20 @@ class CRUDSpec: return None, None + def parse_includes(self): + raw = self.params.get('include') + if not raw: + return + tokens = [p.strip() for p in str(raw).split(',') if p.strip()] + for token in tokens: + _, join_path = self._resolve_column(token) + if join_path: + self.eager_paths.add(join_path) + else: + col, maybe = self._resolve_column(token + '.id') + if maybe: + self.eager_paths.add(maybe) + def parse_filters(self): filters = [] for key, value in self.params.items(): @@ -94,14 +111,60 @@ class CRUDSpec: offset = int(self.params.get('offset', 0)) return limit, offset - def get_eager_loads(self, root_alias): + def parse_fields(self): + """ + Parse ?fields=colA,colB,rel1.colC,rel1.rel2.colD + - Root fields become InstrumentedAttributes bound to root_alias. + - Related fields store attribute NAMES; we'll resolve them on the target class when building loader options. + Returns (root_fields, rel_field_names). + """ + raw = self.params.get('fields') + if not raw: + return [], {} + + if isinstance(raw, list): + tokens = [] + for chunk in raw: + tokens.extend(p.strip() for p in str(chunk).split(',') if p.strip()) + else: + tokens = [p.strip() for p in str(raw).split(',') if p.strip()] + + root_fields: List[InstrumentedAttribute] = [] + rel_field_names: Dict[Tuple[str, ...], List[str]] = {} + + for token in tokens: + col, join_path = self._resolve_column(token) + if not col: + continue + if join_path: + rel_field_names.setdefault(join_path, []).append(col.key) + self.eager_paths.add(join_path) + else: + root_fields.append(col) + + seen = set() + root_fields = [c for c in root_fields if not (c.key in seen or seen.add(c.key))] + for k, names in rel_field_names.items(): + seen2 = set() + rel_field_names[k] = [n for n in names if not (n in seen2 or seen2.add(n))] + + self._root_fields = root_fields + self._rel_field_names = rel_field_names + return root_fields, rel_field_names + + def get_eager_loads(self, root_alias, *, fields_map: Optional[Dict[Tuple[str, ...], List[str]]] = None): loads = [] for path in self.eager_paths: current = root_alias loader = None - for name in path: + for idx, name in enumerate(path): rel_attr = getattr(current, name) - loader = (joinedload(rel_attr) if loader is None else loader.joinedload(name)) + loader = (selectinload(rel_attr) if loader is None else loader.selectinload(name)) + if fields_map and idx == len(path) - 1 and path in fields_map: + target_cls = rel_attr.property.mapper.class_ + cols = [getattr(target_cls, n) for n in fields_map[path] if hasattr(target_cls, n)] + if cols: + loader = loader.load_only(*cols) current = rel_attr.property.mapper.class_ if loader is not None: loads.append(loader)