Downstream fixes from inventory app.

This commit is contained in:
Yaro Kasear 2025-09-10 08:14:50 -05:00
parent 4cdbc44a13
commit e09cee0c79
3 changed files with 120 additions and 20 deletions

View file

@ -9,13 +9,26 @@ class CRUDMixin:
created_at = Column(DateTime, default=func.now(), nullable=False) created_at = Column(DateTime, default=func.now(), nullable=False)
updated_at = Column(DateTime, default=func.now(), nullable=False, onupdate=func.now()) updated_at = Column(DateTime, default=func.now(), nullable=False, onupdate=func.now())
def as_dict(self): def as_dict(self, fields: list[str] | None = None):
# Combine all columns from all inherited tables """
Serialize mapped columns. Honors projection if either:
- 'fields' is passed explicitly, or
-
"""
allowed = None
if fields:
allowed = set(fields)
else:
allowed = getattr(self, "__crudkit_root_fields__", None)
result = {} result = {}
for cls in self.__class__.__mro__: for cls in self.__class__.__mro__:
if hasattr(cls, "__table__"): if not hasattr(cls, "__table__"):
continue
for column in cls.__table__.columns: for column in cls.__table__.columns:
result[column.name] = getattr(self, column.name) name = column.name
if allowed is not None and name not in allowed and name != "id":
continue
result[name] = getattr(self, name)
return result return result
class Version(Base): class Version(Base):

View file

@ -1,5 +1,5 @@
from typing import Type, TypeVar, Generic, Optional from typing import Type, TypeVar, Generic, Optional
from sqlalchemy.orm import Session, with_polymorphic from sqlalchemy.orm import Load, Session, raiseload, with_polymorphic
from sqlalchemy import inspect, text from sqlalchemy import inspect, text
from crudkit.core.base import Version from crudkit.core.base import Version
from crudkit.core.spec import CRUDSpec from crudkit.core.spec import CRUDSpec
@ -28,11 +28,9 @@ class CRUDService(Generic[T]):
def get_query(self): def get_query(self):
if self.polymorphic: if self.polymorphic:
poly_model = with_polymorphic(self.model, '*') poly = with_polymorphic(self.model, "*")
return self.session.query(poly_model), poly_model return self.session.query(poly), poly
else: return self.session.query(self.model), self.model
base_only = with_polymorphic(self.model, [], flat=True)
return self.session.query(base_only), base_only
# Helper: default ORDER BY for MSSQL when paginating without explicit order # Helper: default ORDER BY for MSSQL when paginating without explicit order
def _default_order_by(self, root_alias): def _default_order_by(self, root_alias):
@ -62,10 +60,11 @@ class CRUDService(Generic[T]):
if not include_deleted: if not include_deleted:
query = query.filter(getattr(root_alias, "is_deleted") == False) query = query.filter(getattr(root_alias, "is_deleted") == False)
spec = CRUDSpec(self.model, params, root_alias) spec = CRUDSpec(self.model, params or {}, root_alias)
filters = spec.parse_filters() filters = spec.parse_filters()
order_by = spec.parse_sort() order_by = spec.parse_sort()
limit, offset = spec.parse_pagination() limit, offset = spec.parse_pagination()
spec.parse_includes()
for parent_alias, relationship_attr, target_alias in spec.get_join_paths(): for parent_alias, relationship_attr, target_alias in spec.get_join_paths():
query = query.join( query = query.join(
@ -74,9 +73,17 @@ class CRUDService(Generic[T]):
isouter=True isouter=True
) )
for eager in spec.get_eager_loads(root_alias): root_fields, rel_field_names = spec.parse_fields()
if root_fields:
query = query.options(Load(root_alias).load_only(*root_fields))
for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
query = query.options(eager) query = query.options(eager)
# if root_fields or rel_field_names:
# query = query.options(Load(root_alias).raiseload("*"))
if filters: if filters:
query = query.filter(*filters) query = query.filter(*filters)
@ -91,10 +98,27 @@ class CRUDService(Generic[T]):
# Only apply offset/limit when not None. # Only apply offset/limit when not None.
if offset is not None and offset != 0: if offset is not None and offset != 0:
query = query.offset(offset) query = query.offset(offset)
if limit is not None: if limit is not None and limit > 0:
query = query.limit(limit) query = query.limit(limit)
return query.all() # return query.all()
rows = query.all()
try:
rf_names = [c.key for c in (root_fields or [])]
except NameError:
rf_names = []
if rf_names:
allow = set(rf_names)
if "id" not in allow and hasattr(self.model, "id"):
allow.add("id")
for obj in rows:
try:
setattr(obj, "__crudkit_root_fields__", allow)
except Exception:
pass
return rows
def create(self, data: dict, actor=None) -> T: def create(self, data: dict, actor=None) -> T:
obj = self.model(**data) obj = self.model(**data)

View file

@ -1,6 +1,6 @@
from typing import List, Tuple, Set, Dict from typing import List, Tuple, Set, Dict, Optional
from sqlalchemy import asc, desc from sqlalchemy import asc, desc
from sqlalchemy.orm import joinedload, aliased from sqlalchemy.orm import aliased, selectinload
from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.attributes import InstrumentedAttribute
OPERATORS = { OPERATORS = {
@ -21,6 +21,9 @@ class CRUDSpec:
self.eager_paths: Set[Tuple[str, ...]] = set() self.eager_paths: Set[Tuple[str, ...]] = set()
self.join_paths: List[Tuple[object, InstrumentedAttribute, object]] = [] self.join_paths: List[Tuple[object, InstrumentedAttribute, object]] = []
self.alias_map: Dict[Tuple[str, ...], object] = {} self.alias_map: Dict[Tuple[str, ...], object] = {}
self._root_fields: List[InstrumentedAttribute] = []
self._rel_field_names: Dict[Tuple[str, ...], object] = {}
self.include_paths: Set[Tuple[str, ...]] = set()
def _resolve_column(self, path: str): def _resolve_column(self, path: str):
current_alias = self.root_alias current_alias = self.root_alias
@ -50,6 +53,20 @@ class CRUDSpec:
return None, None return None, None
def parse_includes(self):
raw = self.params.get('include')
if not raw:
return
tokens = [p.strip() for p in str(raw).split(',') if p.strip()]
for token in tokens:
_, join_path = self._resolve_column(token)
if join_path:
self.eager_paths.add(join_path)
else:
col, maybe = self._resolve_column(token + '.id')
if maybe:
self.eager_paths.add(maybe)
def parse_filters(self): def parse_filters(self):
filters = [] filters = []
for key, value in self.params.items(): for key, value in self.params.items():
@ -94,14 +111,60 @@ class CRUDSpec:
offset = int(self.params.get('offset', 0)) offset = int(self.params.get('offset', 0))
return limit, offset return limit, offset
def get_eager_loads(self, root_alias): def parse_fields(self):
"""
Parse ?fields=colA,colB,rel1.colC,rel1.rel2.colD
- Root fields become InstrumentedAttributes bound to root_alias.
- Related fields store attribute NAMES; we'll resolve them on the target class when building loader options.
Returns (root_fields, rel_field_names).
"""
raw = self.params.get('fields')
if not raw:
return [], {}
if isinstance(raw, list):
tokens = []
for chunk in raw:
tokens.extend(p.strip() for p in str(chunk).split(',') if p.strip())
else:
tokens = [p.strip() for p in str(raw).split(',') if p.strip()]
root_fields: List[InstrumentedAttribute] = []
rel_field_names: Dict[Tuple[str, ...], List[str]] = {}
for token in tokens:
col, join_path = self._resolve_column(token)
if not col:
continue
if join_path:
rel_field_names.setdefault(join_path, []).append(col.key)
self.eager_paths.add(join_path)
else:
root_fields.append(col)
seen = set()
root_fields = [c for c in root_fields if not (c.key in seen or seen.add(c.key))]
for k, names in rel_field_names.items():
seen2 = set()
rel_field_names[k] = [n for n in names if not (n in seen2 or seen2.add(n))]
self._root_fields = root_fields
self._rel_field_names = rel_field_names
return root_fields, rel_field_names
def get_eager_loads(self, root_alias, *, fields_map: Optional[Dict[Tuple[str, ...], List[str]]] = None):
loads = [] loads = []
for path in self.eager_paths: for path in self.eager_paths:
current = root_alias current = root_alias
loader = None loader = None
for name in path: for idx, name in enumerate(path):
rel_attr = getattr(current, name) rel_attr = getattr(current, name)
loader = (joinedload(rel_attr) if loader is None else loader.joinedload(name)) loader = (selectinload(rel_attr) if loader is None else loader.selectinload(name))
if fields_map and idx == len(path) - 1 and path in fields_map:
target_cls = rel_attr.property.mapper.class_
cols = [getattr(target_cls, n) for n in fields_map[path] if hasattr(target_cls, n)]
if cols:
loader = loader.load_only(*cols)
current = rel_attr.property.mapper.class_ current = rel_attr.property.mapper.class_
if loader is not None: if loader is not None:
loads.append(loader) loads.append(loader)