Makign corrections on field selection.

This commit is contained in:
Yaro Kasear 2025-09-09 16:27:41 -05:00
parent 7e915423a3
commit 9d36d600bb
9 changed files with 1504 additions and 39 deletions

View file

@ -1,5 +1,5 @@
from typing import Type, TypeVar, Generic, Optional
from sqlalchemy.orm import Session, with_polymorphic
from sqlalchemy.orm import Load, Session, raiseload, with_polymorphic
from sqlalchemy import inspect, text
from crudkit.core.base import Version
from crudkit.core.spec import CRUDSpec
@ -28,11 +28,9 @@ class CRUDService(Generic[T]):
def get_query(self):
if self.polymorphic:
poly_model = with_polymorphic(self.model, '*')
return self.session.query(poly_model), poly_model
else:
base_only = with_polymorphic(self.model, [], flat=True)
return self.session.query(base_only), base_only
poly = with_polymorphic(self.model, "*")
return self.session.query(poly), poly
return self.session.query(self.model), self.model
# Helper: default ORDER BY for MSSQL when paginating without explicit order
def _default_order_by(self, root_alias):
@ -62,10 +60,11 @@ class CRUDService(Generic[T]):
if not include_deleted:
query = query.filter(getattr(root_alias, "is_deleted") == False)
spec = CRUDSpec(self.model, params, root_alias)
spec = CRUDSpec(self.model, params or {}, root_alias)
filters = spec.parse_filters()
order_by = spec.parse_sort()
limit, offset = spec.parse_pagination()
spec.parse_includes()
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
query = query.join(
@ -74,9 +73,17 @@ class CRUDService(Generic[T]):
isouter=True
)
for eager in spec.get_eager_loads(root_alias):
root_fields, rel_field_names = spec.parse_fields()
if root_fields:
query = query.options(Load(root_alias).load_only(*root_fields))
for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
query = query.options(eager)
if root_fields or rel_field_names:
query = query.options(Load(root_alias).raiseload("*"))
if filters:
query = query.filter(*filters)
@ -94,7 +101,24 @@ class CRUDService(Generic[T]):
if limit is not None:
query = query.limit(limit)
return query.all()
# return query.all()
rows = query.all()
try:
rf_names = [c.key for c in (root_fields or [])]
except NameError:
rf_names = []
if rf_names:
allow = set(rf_names)
if "id" not in allow and hasattr(self.model, "id"):
allow.add("id")
for obj in rows:
try:
setattr(obj, "__crudkit_root_fields__", allow)
except Exception:
pass
return rows
def create(self, data: dict, actor=None) -> T:
obj = self.model(**data)