Lots of downstream fixes.

This commit is contained in:
Yaro Kasear 2025-10-10 09:23:45 -05:00
parent 90dd16baf4
commit f956e09e2b
6 changed files with 292 additions and 31 deletions

View file

@ -187,6 +187,8 @@ class Config:
"synchronous": os.getenv("SQLITE_SYNCHRONOUS", "NORMAL"),
}
STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "1")))
@classmethod
def engine_kwargs(cls) -> Dict[str, Any]:
url = cls.DATABASE_URL
@ -221,15 +223,18 @@ class Config:
class DevConfig(Config):
DEBUG = True
SQLALCHEMY_ECHO = bool(int(os.getenv("DB_ECHO", "1")))
STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "1")))
class TestConfig(Config):
TESTING = True
DATABASE_URL = build_database_url(backend="sqlite", database=":memory:")
SQLALCHEMY_ECHO = False
STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "1")))
class ProdConfig(Config):
DEBUG = False
SQLALCHEMY_ECHO = bool(int(os.getenv("DB_ECHO", "0")))
STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "0")))
def get_config(name: str | None) -> Type[Config]:
"""

View file

@ -4,22 +4,34 @@ from collections.abc import Iterable
from dataclasses import dataclass
from flask import current_app
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
from sqlalchemy import and_, func, inspect, or_, text
from sqlalchemy import and_, func, inspect, or_, text, select, literal
from sqlalchemy.engine import Engine, Connection
from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, selectinload, with_loader_criteria
from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, selectinload, with_loader_criteria, aliased, with_parent
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql import operators
from sqlalchemy.sql import operators, visitors
from sqlalchemy.sql.elements import UnaryExpression, ColumnElement
from crudkit.core import to_jsonable, deep_diff, diff_to_patch, filter_to_columns, normalize_payload
from crudkit.core.base import Version
from crudkit.core.spec import CRUDSpec
from crudkit.core.spec import CRUDSpec, CollPred
from crudkit.core.types import OrderSpec, SeekWindow
from crudkit.backend import BackendInfo, make_backend_info
from crudkit.projection import compile_projection
import logging
log = logging.getLogger("crudkit.service")
# logging.getLogger("crudkit.service").setLevel(logging.DEBUG)
# Ensure our debug actually prints even if the app/root logger is WARNING+
# if not log.handlers:
# _h = logging.StreamHandler()
# _h.setLevel(logging.DEBUG)
# _h.setFormatter(logging.Formatter(
# "%(asctime)s %(levelname)s %(name)s: %(message)s"
# ))
# log.addHandler(_h)
#
# log.setLevel(logging.DEBUG)
# log.propagate = False
@runtime_checkable
class _HasID(Protocol):
@ -230,7 +242,9 @@ class CRUDService(Generic[T]):
# Make sure joins/filters match the real query
query = self._apply_firsthop_strategies(query, root_alias, plan)
if plan.filters:
query = query.filter(*plan.filters)
filters = self._final_filters(root_alias, plan)
if filters:
query = query.filter(*filters)
order_spec = self._extract_order_spec(root_alias, plan.order_by)
@ -358,10 +372,11 @@ class CRUDService(Generic[T]):
spec.parse_includes()
join_paths = tuple(spec.get_join_paths())
filter_tables = _collect_tables_from_filters(filters)
fkeys = set()
_, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
filter_tables = ()
fkeys = set()
# filter_tables = ()
# fkeys = set()
return self._Plan(
spec=spec, filters=filters, order_by=order_by, limit=limit, offset=offset,
@ -377,6 +392,9 @@ class CRUDService(Generic[T]):
def _apply_firsthop_strategies(self, query, root_alias, plan: _Plan):
nested_first_hops = { p[0] for p in (plan.rel_field_names or {}).keys() if len(p) > 1 }
joined_rel_keys = set()
# Existing behavior: join everything in join_paths (to-one), selectinload collections
for base_alias, rel_attr, target_alias in plan.join_paths:
if base_alias is not root_alias:
continue
@ -385,17 +403,50 @@ class CRUDService(Generic[T]):
if not is_collection:
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
joined_rel_keys.add(prop.key if prop is not None else rel_attr.key)
else:
opt = selectinload(rel_attr)
if is_collection:
child_names = (plan.collection_field_names or {}).get(rel_attr.key, [])
if child_names:
target_cls = prop.mapper.class_
cols = [getattr(target_cls, n, None) for n in child_names]
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
if cols:
opt = opt.load_only(*cols)
child_names = (plan.collection_field_names or {}).get(rel_attr.key, [])
if child_names:
target_cls = prop.mapper.class_
cols = [getattr(target_cls, n, None) for n in child_names]
cols = [c for c in cols if isinstance(c, InstrumentedAttribute)]
if cols:
opt = opt.load_only(*cols)
query = query.options(opt)
# NEW: if a first-hop to-one relationships target table is present in filter expressions,
# make sure we actually JOIN it (outer) so filters dont create a cartesian product.
if plan.filter_tables:
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
for rel in mapper.relationships:
if rel.uselist:
continue # only first-hop to-one here
target_tbl = getattr(rel.mapper.class_, "__table__", None)
if target_tbl is None:
continue
if target_tbl in plan.filter_tables:
if rel.key in joined_rel_keys:
continue # already joined via join_paths
query = query.join(getattr(root_alias, rel.key), isouter=True)
joined_rel_keys.add(rel.key)
if log.isEnabledFor(logging.DEBUG):
info = []
for base_alias, rel_attr, target_alias in plan.join_paths:
if base_alias is not root_alias:
continue
prop = getattr(rel_attr, "property", None)
sel = getattr(target_alias, "selectable", None)
info.append({
"rel": (getattr(prop, "key", getattr(rel_attr, "key", "?"))),
"collection": bool(getattr(prop, "uselist", False)),
"target_keys": list(_selectable_keys(sel)) if sel is not None else [],
"joined": (getattr(prop, "key", None) in joined_rel_keys),
})
log.debug("FIRSTHOP: %s.%s first-hop paths: %s",
self.model.__name__, getattr(root_alias, "__table__", type(root_alias)).key,
info)
return query
def _apply_proj_opts(self, query, plan: _Plan):
@ -428,6 +479,127 @@ class CRUDService(Generic[T]):
except Exception:
pass
def _rebind_filters_to_firsthop_aliases(self, filters, root_alias, plan):
"""Make filter expressions use the exact same alias objects as our JOINs."""
if not filters:
return filters
# Map first-hop target selectable keysets -> the exact selectable object we JOINed with
alias_map = {}
for base_alias, _rel_attr, target_alias in plan.join_paths:
if base_alias is not root_alias:
continue
sel = getattr(target_alias, "selectable", None)
if sel is not None:
alias_map[frozenset(_selectable_keys(sel))] = sel
if not alias_map:
return filters
def replace(elem):
tbl = getattr(elem, "table", None)
if tbl is None:
return elem
keyset = frozenset(_selectable_keys(tbl))
new_sel = alias_map.get(keyset)
if new_sel is None or new_sel is tbl:
return elem
colkey = getattr(elem, "key", None) or getattr(elem, "name", None)
if not colkey:
return elem
try:
return getattr(new_sel.c, colkey)
except Exception:
return elem
return [visitors.replacement_traverse(f, {}, replace) for f in filters]
def _final_filters(self, root_alias, plan):
"""
Return filters where:
- root/to-one predicates are kept as SQLAlchemy expressions.
- first-hop collection predicates (CollPred) are rebuilt into a single
EXISTS via rel.any(...) with one alias per collection table.
"""
filters = list(plan.filters or [])
if not filters:
return []
# 1) Build a map of first-hop relationships: TABLE -> (rel_attr, target_cls)
coll_map = {}
for base_alias, rel_attr, target_alias in plan.join_paths:
if base_alias is not root_alias:
continue
prop = getattr(rel_attr, "property", None)
if not prop or not getattr(prop, "uselist", False):
continue
target_cls = prop.mapper.class_
tbl = getattr(target_cls, "__table__", None)
if tbl is not None:
coll_map[tbl] = (rel_attr, target_cls)
# 2) Split raw filters into normal SQLA and CollPreds (by target table)
normal_filters = []
by_table: dict[Any, list[CollPred]] = {}
for f in filters:
if isinstance(f, CollPred):
by_table.setdefault(f.table, []).append(f)
else:
normal_filters.append(f)
# 3) Rebuild each table group into ONE .any(...) using one alias
from sqlalchemy.orm import aliased
from sqlalchemy import and_
exists_filters = []
for tbl, preds in by_table.items():
if tbl not in coll_map:
# Safety: if it's not a first-hop collection, ignore or raise
continue
rel_attr, target_cls = coll_map[tbl]
ta = aliased(target_cls)
built = []
for p in preds:
col = getattr(ta, p.col_key)
op = p.op
val = p.value
if op == 'icontains':
built.append(col.ilike(f"%{val}%"))
elif op == 'eq':
built.append(col == val)
elif op == 'ne':
built.append(col != val)
elif op == 'in':
vs = val if isinstance(val, (list, tuple, set)) else [val]
built.append(col.in_(vs))
elif op == 'nin':
vs = val if isinstance(val, (list, tuple, set)) else [val]
built.append(~col.in_(vs))
elif op == 'lt':
built.append(col < val)
elif op == 'lte':
built.append(col <= val)
elif op == 'gt':
built.append(col > val)
elif op == 'gte':
built.append(col >= val)
else:
# unknown op — skip or raise
continue
# enforce child soft delete inside the EXISTS
if hasattr(target_cls, "is_deleted"):
built.append(ta.is_deleted == False)
crit = and_(*built) if built else None
exists_filters.append(rel_attr.of_type(ta).any(crit) if crit is not None
else rel_attr.of_type(ta).any())
# 4) Final filter list = normal SQLA filters + all EXISTS filters
return normal_filters + exists_filters
# ---- public read ops
def page(self, params=None, *, page: int = 1, per_page: int = 50, include_total: bool = True):
@ -469,7 +641,9 @@ class CRUDService(Generic[T]):
query = self._apply_firsthop_strategies(query, root_alias, plan)
query = self._apply_soft_delete_criteria_for_children(query, plan, params)
if plan.filters:
query = query.filter(*plan.filters)
filters = self._final_filters(root_alias, plan)
if filters:
query = query.filter(*filters)
order_spec = self._extract_order_spec(root_alias, plan.order_by)
limit = 50 if plan.limit is None else (None if plan.limit == 0 else plan.limit)
@ -529,7 +703,9 @@ class CRUDService(Generic[T]):
if not bool(getattr(getattr(rel_attr, "property", None), "uselist", False)):
base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
if plan.filters:
base = base.filter(*plan.filters)
filters = self._final_filters(root_alias, plan)
if filters:
base = base.filter(*filters) # <-- use base, not query
total = session.query(func.count()).select_from(
base.order_by(None).distinct().subquery()
).scalar() or 0
@ -556,7 +732,9 @@ class CRUDService(Generic[T]):
query = self._apply_firsthop_strategies(query, root_alias, plan)
query = self._apply_soft_delete_criteria_for_children(query, plan, params)
if plan.filters:
query = query.filter(*plan.filters)
filters = self._final_filters(root_alias, plan)
if filters:
query = query.filter(*filters)
query = query.filter(getattr(root_alias, "id") == id)
query = self._apply_proj_opts(query, plan)
@ -577,7 +755,9 @@ class CRUDService(Generic[T]):
query = self._apply_firsthop_strategies(query, root_alias, plan)
query = self._apply_soft_delete_criteria_for_children(query, plan, params)
if plan.filters:
query = query.filter(*plan.filters)
filters = self._final_filters(root_alias, plan)
if filters:
query = query.filter(*filters)
order_by = plan.order_by
paginating = (plan.limit is not None) or (plan.offset not in (None, 0))

View file

@ -1,9 +1,17 @@
from dataclasses import dataclass
from typing import Any, List, Tuple, Set, Dict, Optional, Iterable
from sqlalchemy import and_, asc, desc, or_
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import aliased, selectinload
from sqlalchemy.orm.attributes import InstrumentedAttribute
@dataclass(frozen=True)
class CollPred:
table: Any
col_key: str
op: str
value: Any
OPERATORS = {
'eq': lambda col, val: col == val,
'lt': lambda col, val: col < val,
@ -68,16 +76,48 @@ class CRUDSpec:
exprs = []
for col, join_path in pairs:
# Track eager path for each involved relationship chain
if join_path:
self.eager_paths.add(join_path)
try:
cur_cls = self.model
names = list(join_path)
last_name = names[-1]
is_collection = False
for nm in names:
rel_attr = getattr(cur_cls, nm)
prop = rel_attr.property
cur_cls = prop.mapper.class_
is_collection = bool(getattr(getattr(self.model, last_name), "property", None)
and getattr(getattr(self.model, last_name).property, "uselist", False))
except Exception:
is_collection = False
if is_collection:
target_cls = cur_cls
key = getattr(col, "key", None) or getattr(col, "name", None)
if key and hasattr(target_cls, key):
target_tbl = getattr(target_cls, "__table__", None)
if target_tbl is not None:
exprs.append(CollPred(table=target_tbl, col_key=key, op=op, value=value))
continue
exprs.append(OPERATORS[op](col, value))
if not exprs:
return None
if len(exprs) == 1:
# If any CollPred is in exprs, do NOT or_ them. Keep it single for now.
if any(isinstance(x, CollPred) for x in exprs):
# If someone used a pipe 'relA.col|relB.col' that produced multiple CollPreds,
# keep the first or raise for now (your choice).
if len(exprs) > 1:
# raise NotImplementedError("OR across collection paths not supported yet")
exprs = [next(x for x in exprs if isinstance(x, CollPred))]
return exprs[0]
return or_(*exprs)
# Otherwise, standard SQLA clause(s)
return exprs[0] if len(exprs) == 1 else or_(*exprs)
def _collect_filters(self, params: dict) -> list:
"""

View file

@ -1,7 +1,8 @@
# engines.py
from __future__ import annotations
from typing import Type, Optional
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker, raiseload, Mapper, RelationshipProperty
from .backend import make_backend_info, BackendInfo
from .config import Config, get_config
from ._sqlite import apply_sqlite_pragmas
@ -12,15 +13,31 @@ def build_engine(config_cls: Type[Config] | None = None):
apply_sqlite_pragmas(engine, config_cls.SQLITE_PRAGMAS)
return engine
def _install_nplus1_guards(SessionMaker, *, strict: bool):
if not strict:
return
@event.listens_for(SessionMaker, "do_orm_execute")
def _add_global_raiseload(execute_state):
stmt = execute_state.statement
# Only touch ORM statements (have column_descriptions)
if getattr(stmt, "column_descriptions", None):
execute_state.statement = stmt.options(raiseload("*"))
def build_sessionmaker(config_cls: Type[Config] | None = None, engine=None):
config_cls = config_cls or get_config(None)
engine = engine or build_engine(config_cls)
return sessionmaker(bind=engine, **config_cls.session_kwargs())
SessionMaker = sessionmaker(bind=engine, **config_cls.session_kwargs())
# Toggle with a config flag; default off so you can turn it on when ready
strict = bool(getattr(config_cls, "STRICT_NPLUS1", False))
_install_nplus1_guards(SessionMaker, strict=strict)
return SessionMaker
class CRUDKitRuntime:
"""
Lightweight container so CRUDKit can be given either:
- prebuild engine/sessionmaker, or
- prebuilt engine/sessionmaker, or
- a Config to build them lazily
"""
def __init__(self, *, engine=None, session_factory=None, config: Optional[Type[Config]] = None):

View file

@ -413,7 +413,7 @@ def _value_label_for_field(field: dict, mapper, values_map: dict, instance, sess
if not rel_prop:
return None
rid = _coerce_fk_value(values_map, instance, base)
rid = _coerce_fk_value(values_map, instance, base, rel_prop)
rel_obj = _resolve_rel_obj(values_map, instance, base)
label_spec = (
@ -493,7 +493,7 @@ class _SafeObj:
val = _get_loaded_attr(self._obj, name)
return "" if val is None else _SafeObj(val)
def _coerce_fk_value(values: dict | None, instance: Any, base: str):
def _coerce_fk_value(values: dict | None, instance: Any, base: str, rel_prop: Optional[RelationshipProperty] = None):
"""
Resolve current selection for relationship 'base':
1) values['<base>_id']
@ -540,6 +540,25 @@ def _coerce_fk_value(values: dict | None, instance: Any, base: str):
except Exception:
pass
# Fallback: if we know the relationship, try its local FK column names
if rel_prop is not None:
try:
st = inspect(instance) if instance is not None else None
except Exception:
st = None
# Try values[...] first
for col in getattr(rel_prop, "local_columns", []) or []:
key = getattr(col, "key", None) or getattr(col, "name", None)
if not key:
continue
if isinstance(values, dict) and key in values and values[key] not in (None, ""):
return values[key]
if set is not None:
attr = st.attrs.get(key) if hasattr(st, "attrs") else None
if attr is not None and attr.loaded_value is not NO_VALUE:
return attr.loaded_value
return None
def _is_many_to_one(mapper, name: str) -> Optional[RelationshipProperty]:
@ -1136,7 +1155,7 @@ def render_form(
base = name[:-3]
rel_prop = mapper.relationships.get(base)
if isinstance(rel_prop, RelationshipProperty) and rel_prop.direction.name == "MANYTOONE":
values_map[name] = _coerce_fk_value(values, instance, base)
values_map[name] = _coerce_fk_value(values, instance, base, rel_prop) # add rel_prop
else:
# Auto-generate path (your original behavior)
@ -1169,7 +1188,7 @@ def render_form(
fk_fields.add(f"{base}_id")
# NEW: set the current selection for this dropdown
values_map[f"{base}_id"] = _coerce_fk_value(values, instance, base)
values_map[f"{base}_id"] = _coerce_fk_value(values, instance, base, prop)
# Then plain columns
for col in model_cls.__table__.columns:

View file

@ -35,5 +35,5 @@
{% if submit_attrs %}{% for k,v in submit_attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %}
>{{ submit_label if label else 'Save' }}</button>
>{{ submit_label if submit_label else 'Save' }}</button>
</form>