Lots of form code done!

This commit is contained in:
Conrad Nelson 2025-09-18 16:03:12 -05:00
parent 25589a79d3
commit 2ae96e5c80
6 changed files with 400 additions and 60 deletions

View file

@ -1,7 +1,7 @@
from typing import Any, Callable, Dict, Iterable, List, Tuple, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
from sqlalchemy import and_, func, inspect, or_, text
from sqlalchemy.engine import Engine, Connection
from sqlalchemy.orm import Load, Session, raiseload, selectinload, with_polymorphic, Mapper, RelationshipProperty
from sqlalchemy.orm import Load, Session, raiseload, selectinload, with_polymorphic, Mapper, RelationshipProperty, class_mapper
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.util import AliasedClass
from sqlalchemy.sql import operators
@ -12,6 +12,35 @@ from crudkit.core.spec import CRUDSpec
from crudkit.core.types import OrderSpec, SeekWindow
from crudkit.backend import BackendInfo, make_backend_info
def _loader_options_for_fields(root_alias, model_cls, fields: list[str]) -> list[Load]:
"""
For bare MANYTOONE names in fields (e.g. "location"), selectinload the relationship
and only fetch the related PK. This is enough for preselecting <select> inputs
without projecting the FK column on the root model.
"""
opts: list[Load] = []
if not fields:
return opts
mapper = class_mapper(model_cls)
for name in fields:
prop = mapper.relationships.get(name)
if not isinstance(prop, RelationshipProperty):
continue
if prop.direction.name != "MANYTOONE":
continue
rel_attr = getattr(root_alias, name)
target_cls = prop.mapper.class_
# load_only PK if present; else just selectinload
id_attr = getattr(target_cls, "id", None)
if id_attr is not None:
opts.append(selectinload(rel_attr).load_only(id_attr))
else:
opts.append(selectinload(rel_attr))
return opts
@runtime_checkable
class _HasID(Protocol):
id: int
@ -358,6 +387,11 @@ class CRUDService(Generic[T]):
if params:
root_fields, rel_field_names, root_field_names = spec.parse_fields()
fields = (params or {}).get("fields") if isinstance(params, dict) else None
if fields:
for opt in _loader_options_for_fields(root_alias, self.model, fields):
query = query.options(opt)
only_cols = [c for c in root_fields if isinstance(c, InstrumentedAttribute)]
if only_cols:
query = query.options(Load(root_alias).load_only(*only_cols))
@ -365,6 +399,11 @@ class CRUDService(Generic[T]):
for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
query = query.options(eager)
if params:
fields = params.get("fields") or []
for opt in _loader_options_for_fields(root_alias, self.model, fields):
query = query.options(opt)
obj = query.first()
proj = []
@ -422,6 +461,11 @@ class CRUDService(Generic[T]):
for eager in spec.get_eager_loads(root_alias, fields_map=rel_field_names):
query = query.options(eager)
if params:
fields = params.get("fields") or []
for opt in _loader_options_for_fields(root_alias, self.model, fields):
query = query.options(opt)
if filters:
query = query.filter(*filters)