Redesign1 #1

Merged
yaro merged 36 commits from Redesign1 into main 2025-09-22 14:12:39 -05:00
3 changed files with 81 additions and 36 deletions
Showing only changes of commit 637e873ccf - Show all commits

View file

@ -10,7 +10,7 @@ def generate_crud_blueprint(model, service):
@bp.get('/<int:id>') @bp.get('/<int:id>')
def get_item(id): def get_item(id):
item = service.get(id) item = service.get(id, request.args)
return jsonify(item.as_dict()) return jsonify(item.as_dict())
@bp.post('/') @bp.post('/')

View file

@ -1,5 +1,6 @@
from typing import Type, TypeVar, Generic, Optional from typing import Type, TypeVar, Generic, Optional
from sqlalchemy.orm import Load, Session, raiseload, with_polymorphic from sqlalchemy.orm import Load, Session, raiseload, with_polymorphic
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy import inspect, text from sqlalchemy import inspect, text
from crudkit.core.base import Version from crudkit.core.base import Version
from crudkit.core.spec import CRUDSpec from crudkit.core.spec import CRUDSpec
@ -43,18 +44,70 @@ class CRUDService(Generic[T]):
cols.append(col) cols.append(col)
return cols or [text("1")] return cols or [text("1")]
def get(self, id: int, include_deleted: bool = False) -> T | None: def get(self, id: int, params=None) -> T | None:
print(f"I AM GETTING A THING! A THINGS! {params}")
query, root_alias = self.get_query() query, root_alias = self.get_query()
include_deleted = False
root_fields = []
root_field_names = {}
rel_field_names = {}
spec = CRUDSpec(self.model, params or {}, root_alias)
if params:
if self.supports_soft_delete:
include_deleted = _is_truthy(params.get('include_deleted'))
if self.supports_soft_delete and not include_deleted: if self.supports_soft_delete and not include_deleted:
query = query.filter(getattr(root_alias, "is_deleted") == False) query = query.filter(getattr(root_alias, "is_deleted") == False)
query = query.filter(getattr(root_alias, "id") == id) query = query.filter(getattr(root_alias, "id") == id)
spec.parse_includes()
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
query = query.join(
target_alias,
relationship_attr.of_type(target_alias),
isouter=True
)
if params:
root_fields, rel_field_names, root_field_names = spec.parse_fields()
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols))
for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
query = query.options(eager)
obj = query.first() obj = query.first()
proj = []
if root_field_names:
proj.extend(root_field_names)
if root_fields:
proj.extend(c.key for c in root_fields)
for path, names in (rel_field_names or {}).items():
prefix = ".".join(path)
for n in names:
proj.append(f"{prefix}.{n}")
if proj and "id" not in proj and hasattr(self.model, "id"):
proj.insert(0, "id")
if proj:
try:
setattr(obj, "__crudkit_projection__", tuple(proj))
except Exception:
pass
return obj or None return obj or None
def list(self, params=None) -> list[T]: def list(self, params=None) -> list[T]:
query, root_alias = self.get_query() query, root_alias = self.get_query()
root_fields = [] root_fields = []
root_field_names = {}
rel_field_names = {} rel_field_names = {}
if params: if params:
@ -77,17 +130,15 @@ class CRUDService(Generic[T]):
) )
if params: if params:
root_fields, rel_field_names = spec.parse_fields() root_fields, rel_field_names, root_field_names = spec.parse_fields()
if root_fields: only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
query = query.options(Load(root_alias).load_only(*root_fields)) if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols))
for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names): for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
query = query.options(eager) query = query.options(eager)
# if root_fields or rel_field_names:
# query = query.options(Load(root_alias).raiseload("*"))
if filters: if filters:
query = query.filter(*filters) query = query.filter(*filters)
@ -108,6 +159,8 @@ class CRUDService(Generic[T]):
rows = query.all() rows = query.all()
proj = [] proj = []
if root_field_names:
proj.extend(root_field_names)
if root_fields: if root_fields:
proj.extend(c.key for c in root_fields) proj.extend(c.key for c in root_fields)
for path, names in (rel_field_names or {}).items(): for path, names in (rel_field_names or {}).items():
@ -125,20 +178,6 @@ class CRUDService(Generic[T]):
except Exception: except Exception:
pass pass
# try:
# rf_names = [c.key for c in (root_fields or [])]
# except NameError:
# rf_names = []
# if rf_names:
# allow = set(rf_names)
# if "id" not in allow and hasattr(self.model, "id"):
# allow.add("id")
# for obj in rows:
# try:
# setattr(obj, "__crudkit_root_fields__", allow)
# except Exception:
# pass
return rows return rows
def create(self, data: dict, actor=None) -> T: def create(self, data: dict, actor=None) -> T:

View file

@ -1,5 +1,6 @@
from typing import List, Tuple, Set, Dict, Optional from typing import List, Tuple, Set, Dict, Optional
from sqlalchemy import asc, desc from sqlalchemy import asc, desc
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import aliased, selectinload from sqlalchemy.orm import aliased, selectinload
from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.attributes import InstrumentedAttribute
@ -48,7 +49,7 @@ class CRUDSpec:
current_alias = alias current_alias = alias
continue continue
if isinstance(attr_obj, InstrumentedAttribute) or hasattr(attr_obj, "clauses"): if isinstance(attr_obj, InstrumentedAttribute) or getattr(attr_obj, "comparator", None) is not None or hasattr(attr_obj, "clauses"):
return attr_obj, tuple(join_path) if join_path else None return attr_obj, tuple(join_path) if join_path else None
return None, None return None, None
@ -120,7 +121,7 @@ class CRUDSpec:
""" """
raw = self.params.get('fields') raw = self.params.get('fields')
if not raw: if not raw:
return [], {} return [], {}, {}
if isinstance(raw, list): if isinstance(raw, list):
tokens = [] tokens = []
@ -130,6 +131,7 @@ class CRUDSpec:
tokens = [p.strip() for p in str(raw).split(',') if p.strip()] tokens = [p.strip() for p in str(raw).split(',') if p.strip()]
root_fields: List[InstrumentedAttribute] = [] root_fields: List[InstrumentedAttribute] = []
root_field_names: list[str] = []
rel_field_names: Dict[Tuple[str, ...], List[str]] = {} rel_field_names: Dict[Tuple[str, ...], List[str]] = {}
for token in tokens: for token in tokens:
@ -141,6 +143,7 @@ class CRUDSpec:
self.eager_paths.add(join_path) self.eager_paths.add(join_path)
else: else:
root_fields.append(col) root_fields.append(col)
root_field_names.append(getattr(col, "key", token))
seen = set() seen = set()
root_fields = [c for c in root_fields if not (c.key in seen or seen.add(c.key))] root_fields = [c for c in root_fields if not (c.key in seen or seen.add(c.key))]
@ -150,33 +153,36 @@ class CRUDSpec:
self._root_fields = root_fields self._root_fields = root_fields
self._rel_field_names = rel_field_names self._rel_field_names = rel_field_names
return root_fields, rel_field_names return root_fields, rel_field_names, root_field_names
def get_eager_loads(self, root_alias, *, fields_map: Optional[Dict[Tuple[str, ...], List[str]]] = None): def get_eager_loads(self, root_alias, *, fields_map=None):
loads = [] loads = []
for path in self.eager_paths: for path in self.eager_paths:
current = root_alias current = root_alias
loader = None loader = None
for idx, name in enumerate(path): for idx, name in enumerate(path):
rel_attr = getattr(current, name) rel_attr = getattr(current, name)
loader = selectinload(rel_attr) if loader is None else loader.selectinload(rel_attr)
if loader is None: # step to target class for the next hop
loader = selectinload(rel_attr) target_cls = rel_attr.property.mapper.class_
else: current = target_cls
loader = loader.selectinload(rel_attr)
current = rel_attr.property.mapper.class_
# if final hop and we have a fields map, narrow columns
if fields_map and idx == len(path) - 1 and path in fields_map: if fields_map and idx == len(path) - 1 and path in fields_map:
target_cls = current cols = []
cols = [getattr(target_cls, n) for n in fields_map[path] if hasattr(target_cls, n)] for n in fields_map[path]:
attr = getattr(target_cls, n, None)
# Only include real column attributes; skip hybrids/expressions
if isinstance(attr, InstrumentedAttribute):
cols.append(attr)
# Only apply load_only if we have at least one real column
if cols: if cols:
loader = loader.load_only(*cols) loader = loader.load_only(*cols)
if loader is not None: if loader is not None:
loads.append(loader) loads.append(loader)
return loads return loads
def get_join_paths(self): def get_join_paths(self):