From 637e873ccfad7ecb7b0711b7c2ef7d36b57816e0 Mon Sep 17 00:00:00 2001 From: Yaro Kasear Date: Thu, 11 Sep 2025 11:30:07 -0500 Subject: [PATCH] Add params support to get() and improve hybrid support. I also want a taco and will not get one. --- crudkit/api/flask_api.py | 2 +- crudkit/core/service.py | 81 +++++++++++++++++++++++++++++----------- crudkit/core/spec.py | 34 ++++++++++------- 3 files changed, 81 insertions(+), 36 deletions(-) diff --git a/crudkit/api/flask_api.py b/crudkit/api/flask_api.py index ddb77a9..46832c2 100644 --- a/crudkit/api/flask_api.py +++ b/crudkit/api/flask_api.py @@ -10,7 +10,7 @@ def generate_crud_blueprint(model, service): @bp.get('/') def get_item(id): - item = service.get(id) + item = service.get(id, request.args) return jsonify(item.as_dict()) @bp.post('/') diff --git a/crudkit/core/service.py b/crudkit/core/service.py index 07f51ca..2e1c400 100644 --- a/crudkit/core/service.py +++ b/crudkit/core/service.py @@ -1,5 +1,6 @@ from typing import Type, TypeVar, Generic, Optional from sqlalchemy.orm import Load, Session, raiseload, with_polymorphic +from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy import inspect, text from crudkit.core.base import Version from crudkit.core.spec import CRUDSpec @@ -43,18 +44,70 @@ class CRUDService(Generic[T]): cols.append(col) return cols or [text("1")] - def get(self, id: int, include_deleted: bool = False) -> T | None: + def get(self, id: int, params=None) -> T | None: + print(f"I AM GETTING A THING! A THINGS! {params}") query, root_alias = self.get_query() + + include_deleted = False + root_fields = [] + root_field_names = {} + rel_field_names = {} + + spec = CRUDSpec(self.model, params or {}, root_alias) + if params: + if self.supports_soft_delete: + include_deleted = _is_truthy(params.get('include_deleted')) if self.supports_soft_delete and not include_deleted: query = query.filter(getattr(root_alias, "is_deleted") == False) query = query.filter(getattr(root_alias, "id") == id) + + spec.parse_includes() + + for parent_alias, relationship_attr, target_alias in spec.get_join_paths(): + query = query.join( + target_alias, + relationship_attr.of_type(target_alias), + isouter=True + ) + + if params: + root_fields, rel_field_names, root_field_names = spec.parse_fields() + + only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] + if only_cols: + query = query.options(Load(root_alias).load_only(*only_cols)) + + for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names): + query = query.options(eager) + obj = query.first() + + proj = [] + if root_field_names: + proj.extend(root_field_names) + if root_fields: + proj.extend(c.key for c in root_fields) + for path, names in (rel_field_names or {}).items(): + prefix = ".".join(path) + for n in names: + proj.append(f"{prefix}.{n}") + + if proj and "id" not in proj and hasattr(self.model, "id"): + proj.insert(0, "id") + + if proj: + try: + setattr(obj, "__crudkit_projection__", tuple(proj)) + except Exception: + pass + return obj or None def list(self, params=None) -> list[T]: query, root_alias = self.get_query() root_fields = [] + root_field_names = {} rel_field_names = {} if params: @@ -77,17 +130,15 @@ class CRUDService(Generic[T]): ) if params: - root_fields, rel_field_names = spec.parse_fields() + root_fields, rel_field_names, root_field_names = spec.parse_fields() - if root_fields: - query = query.options(Load(root_alias).load_only(*root_fields)) + only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)] + if only_cols: + query = query.options(Load(root_alias).load_only(*only_cols)) for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names): query = query.options(eager) - # if root_fields or rel_field_names: - # query = query.options(Load(root_alias).raiseload("*")) - if filters: query = query.filter(*filters) @@ -108,6 +159,8 @@ class CRUDService(Generic[T]): rows = query.all() proj = [] + if root_field_names: + proj.extend(root_field_names) if root_fields: proj.extend(c.key for c in root_fields) for path, names in (rel_field_names or {}).items(): @@ -125,20 +178,6 @@ class CRUDService(Generic[T]): except Exception: pass - # try: - # rf_names = [c.key for c in (root_fields or [])] - # except NameError: - # rf_names = [] - # if rf_names: - # allow = set(rf_names) - # if "id" not in allow and hasattr(self.model, "id"): - # allow.add("id") - # for obj in rows: - # try: - # setattr(obj, "__crudkit_root_fields__", allow) - # except Exception: - # pass - return rows def create(self, data: dict, actor=None) -> T: diff --git a/crudkit/core/spec.py b/crudkit/core/spec.py index 1b6ea31..9c0e53b 100644 --- a/crudkit/core/spec.py +++ b/crudkit/core/spec.py @@ -1,5 +1,6 @@ from typing import List, Tuple, Set, Dict, Optional from sqlalchemy import asc, desc +from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import aliased, selectinload from sqlalchemy.orm.attributes import InstrumentedAttribute @@ -48,7 +49,7 @@ class CRUDSpec: current_alias = alias continue - if isinstance(attr_obj, InstrumentedAttribute) or hasattr(attr_obj, "clauses"): + if isinstance(attr_obj, InstrumentedAttribute) or getattr(attr_obj, "comparator", None) is not None or hasattr(attr_obj, "clauses"): return attr_obj, tuple(join_path) if join_path else None return None, None @@ -120,7 +121,7 @@ class CRUDSpec: """ raw = self.params.get('fields') if not raw: - return [], {} + return [], {}, {} if isinstance(raw, list): tokens = [] @@ -130,6 +131,7 @@ class CRUDSpec: tokens = [p.strip() for p in str(raw).split(',') if p.strip()] root_fields: List[InstrumentedAttribute] = [] + root_field_names: list[str] = [] rel_field_names: Dict[Tuple[str, ...], List[str]] = {} for token in tokens: @@ -141,6 +143,7 @@ class CRUDSpec: self.eager_paths.add(join_path) else: root_fields.append(col) + root_field_names.append(getattr(col, "key", token)) seen = set() root_fields = [c for c in root_fields if not (c.key in seen or seen.add(c.key))] @@ -150,33 +153,36 @@ class CRUDSpec: self._root_fields = root_fields self._rel_field_names = rel_field_names - return root_fields, rel_field_names + return root_fields, rel_field_names, root_field_names - def get_eager_loads(self, root_alias, *, fields_map: Optional[Dict[Tuple[str, ...], List[str]]] = None): + def get_eager_loads(self, root_alias, *, fields_map=None): loads = [] for path in self.eager_paths: current = root_alias loader = None - for idx, name in enumerate(path): rel_attr = getattr(current, name) + loader = selectinload(rel_attr) if loader is None else loader.selectinload(rel_attr) - if loader is None: - loader = selectinload(rel_attr) - else: - loader = loader.selectinload(rel_attr) - - current = rel_attr.property.mapper.class_ + # step to target class for the next hop + target_cls = rel_attr.property.mapper.class_ + current = target_cls + # if final hop and we have a fields map, narrow columns if fields_map and idx == len(path) - 1 and path in fields_map: - target_cls = current - cols = [getattr(target_cls, n) for n in fields_map[path] if hasattr(target_cls, n)] + cols = [] + for n in fields_map[path]: + attr = getattr(target_cls, n, None) + # Only include real column attributes; skip hybrids/expressions + if isinstance(attr, InstrumentedAttribute): + cols.append(attr) + + # Only apply load_only if we have at least one real column if cols: loader = loader.load_only(*cols) if loader is not None: loads.append(loader) - return loads def get_join_paths(self):