Makign corrections on field selection.
This commit is contained in:
parent
7e915423a3
commit
9d36d600bb
9 changed files with 1504 additions and 39 deletions
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Type, TypeVar, Generic, Optional
|
||||
from sqlalchemy.orm import Session, with_polymorphic
|
||||
from sqlalchemy.orm import Load, Session, raiseload, with_polymorphic
|
||||
from sqlalchemy import inspect, text
|
||||
from crudkit.core.base import Version
|
||||
from crudkit.core.spec import CRUDSpec
|
||||
|
|
@ -28,11 +28,9 @@ class CRUDService(Generic[T]):
|
|||
|
||||
def get_query(self):
|
||||
if self.polymorphic:
|
||||
poly_model = with_polymorphic(self.model, '*')
|
||||
return self.session.query(poly_model), poly_model
|
||||
else:
|
||||
base_only = with_polymorphic(self.model, [], flat=True)
|
||||
return self.session.query(base_only), base_only
|
||||
poly = with_polymorphic(self.model, "*")
|
||||
return self.session.query(poly), poly
|
||||
return self.session.query(self.model), self.model
|
||||
|
||||
# Helper: default ORDER BY for MSSQL when paginating without explicit order
|
||||
def _default_order_by(self, root_alias):
|
||||
|
|
@ -62,10 +60,11 @@ class CRUDService(Generic[T]):
|
|||
if not include_deleted:
|
||||
query = query.filter(getattr(root_alias, "is_deleted") == False)
|
||||
|
||||
spec = CRUDSpec(self.model, params, root_alias)
|
||||
spec = CRUDSpec(self.model, params or {}, root_alias)
|
||||
filters = spec.parse_filters()
|
||||
order_by = spec.parse_sort()
|
||||
limit, offset = spec.parse_pagination()
|
||||
spec.parse_includes()
|
||||
|
||||
for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
|
||||
query = query.join(
|
||||
|
|
@ -74,9 +73,17 @@ class CRUDService(Generic[T]):
|
|||
isouter=True
|
||||
)
|
||||
|
||||
for eager in spec.get_eager_loads(root_alias):
|
||||
root_fields, rel_field_names = spec.parse_fields()
|
||||
|
||||
if root_fields:
|
||||
query = query.options(Load(root_alias).load_only(*root_fields))
|
||||
|
||||
for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
|
||||
query = query.options(eager)
|
||||
|
||||
if root_fields or rel_field_names:
|
||||
query = query.options(Load(root_alias).raiseload("*"))
|
||||
|
||||
if filters:
|
||||
query = query.filter(*filters)
|
||||
|
||||
|
|
@ -94,7 +101,24 @@ class CRUDService(Generic[T]):
|
|||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
return query.all()
|
||||
# return query.all()
|
||||
rows = query.all()
|
||||
|
||||
try:
|
||||
rf_names = [c.key for c in (root_fields or [])]
|
||||
except NameError:
|
||||
rf_names = []
|
||||
if rf_names:
|
||||
allow = set(rf_names)
|
||||
if "id" not in allow and hasattr(self.model, "id"):
|
||||
allow.add("id")
|
||||
for obj in rows:
|
||||
try:
|
||||
setattr(obj, "__crudkit_root_fields__", allow)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return rows
|
||||
|
||||
def create(self, data: dict, actor=None) -> T:
|
||||
obj = self.model(**data)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue