Lots of downstream fixes.

This commit is contained in:
Yaro Kasear 2025-10-10 09:23:45 -05:00
parent 90dd16baf4
commit f956e09e2b
6 changed files with 292 additions and 31 deletions

View file

@ -187,6 +187,8 @@ class Config:
"synchronous": os.getenv("SQLITE_SYNCHRONOUS", "NORMAL"), "synchronous": os.getenv("SQLITE_SYNCHRONOUS", "NORMAL"),
} }
STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "1")))
@classmethod @classmethod
def engine_kwargs(cls) -> Dict[str, Any]: def engine_kwargs(cls) -> Dict[str, Any]:
url = cls.DATABASE_URL url = cls.DATABASE_URL
@ -221,15 +223,18 @@ class Config:
class DevConfig(Config): class DevConfig(Config):
DEBUG = True DEBUG = True
SQLALCHEMY_ECHO = bool(int(os.getenv("DB_ECHO", "1"))) SQLALCHEMY_ECHO = bool(int(os.getenv("DB_ECHO", "1")))
STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "1")))
class TestConfig(Config): class TestConfig(Config):
TESTING = True TESTING = True
DATABASE_URL = build_database_url(backend="sqlite", database=":memory:") DATABASE_URL = build_database_url(backend="sqlite", database=":memory:")
SQLALCHEMY_ECHO = False SQLALCHEMY_ECHO = False
STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "1")))
class ProdConfig(Config): class ProdConfig(Config):
DEBUG = False DEBUG = False
SQLALCHEMY_ECHO = bool(int(os.getenv("DB_ECHO", "0"))) SQLALCHEMY_ECHO = bool(int(os.getenv("DB_ECHO", "0")))
STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "0")))
def get_config(name: str | None) -> Type[Config]: def get_config(name: str | None) -> Type[Config]:
""" """

View file

@ -4,22 +4,34 @@ from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from flask import current_app from flask import current_app
from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast
from sqlalchemy import and_, func, inspect, or_, text from sqlalchemy import and_, func, inspect, or_, text, select, literal
from sqlalchemy.engine import Engine, Connection from sqlalchemy.engine import Engine, Connection
from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, selectinload, with_loader_criteria from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, selectinload, with_loader_criteria, aliased, with_parent
from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql import operators from sqlalchemy.sql import operators, visitors
from sqlalchemy.sql.elements import UnaryExpression, ColumnElement from sqlalchemy.sql.elements import UnaryExpression, ColumnElement
from crudkit.core import to_jsonable, deep_diff, diff_to_patch, filter_to_columns, normalize_payload from crudkit.core import to_jsonable, deep_diff, diff_to_patch, filter_to_columns, normalize_payload
from crudkit.core.base import Version from crudkit.core.base import Version
from crudkit.core.spec import CRUDSpec from crudkit.core.spec import CRUDSpec, CollPred
from crudkit.core.types import OrderSpec, SeekWindow from crudkit.core.types import OrderSpec, SeekWindow
from crudkit.backend import BackendInfo, make_backend_info from crudkit.backend import BackendInfo, make_backend_info
from crudkit.projection import compile_projection from crudkit.projection import compile_projection
import logging import logging
log = logging.getLogger("crudkit.service") log = logging.getLogger("crudkit.service")
# logging.getLogger("crudkit.service").setLevel(logging.DEBUG)
# Ensure our debug actually prints even if the app/root logger is WARNING+
# if not log.handlers:
# _h = logging.StreamHandler()
# _h.setLevel(logging.DEBUG)
# _h.setFormatter(logging.Formatter(
# "%(asctime)s %(levelname)s %(name)s: %(message)s"
# ))
# log.addHandler(_h)
#
# log.setLevel(logging.DEBUG)
# log.propagate = False
@runtime_checkable @runtime_checkable
class _HasID(Protocol): class _HasID(Protocol):
@ -230,7 +242,9 @@ class CRUDService(Generic[T]):
# Make sure joins/filters match the real query # Make sure joins/filters match the real query
query = self._apply_firsthop_strategies(query, root_alias, plan) query = self._apply_firsthop_strategies(query, root_alias, plan)
if plan.filters: if plan.filters:
query = query.filter(*plan.filters) filters = self._final_filters(root_alias, plan)
if filters:
query = query.filter(*filters)
order_spec = self._extract_order_spec(root_alias, plan.order_by) order_spec = self._extract_order_spec(root_alias, plan.order_by)
@ -358,10 +372,11 @@ class CRUDService(Generic[T]):
spec.parse_includes() spec.parse_includes()
join_paths = tuple(spec.get_join_paths()) join_paths = tuple(spec.get_join_paths())
filter_tables = _collect_tables_from_filters(filters) filter_tables = _collect_tables_from_filters(filters)
fkeys = set()
_, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) _, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], [])
filter_tables = () # filter_tables = ()
fkeys = set() # fkeys = set()
return self._Plan( return self._Plan(
spec=spec, filters=filters, order_by=order_by, limit=limit, offset=offset, spec=spec, filters=filters, order_by=order_by, limit=limit, offset=offset,
@ -377,6 +392,9 @@ class CRUDService(Generic[T]):
def _apply_firsthop_strategies(self, query, root_alias, plan: _Plan): def _apply_firsthop_strategies(self, query, root_alias, plan: _Plan):
nested_first_hops = { p[0] for p in (plan.rel_field_names or {}).keys() if len(p) > 1 } nested_first_hops = { p[0] for p in (plan.rel_field_names or {}).keys() if len(p) > 1 }
joined_rel_keys = set()
# Existing behavior: join everything in join_paths (to-one), selectinload collections
for base_alias, rel_attr, target_alias in plan.join_paths: for base_alias, rel_attr, target_alias in plan.join_paths:
if base_alias is not root_alias: if base_alias is not root_alias:
continue continue
@ -385,9 +403,9 @@ class CRUDService(Generic[T]):
if not is_collection: if not is_collection:
query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
joined_rel_keys.add(prop.key if prop is not None else rel_attr.key)
else: else:
opt = selectinload(rel_attr) opt = selectinload(rel_attr)
if is_collection:
child_names = (plan.collection_field_names or {}).get(rel_attr.key, []) child_names = (plan.collection_field_names or {}).get(rel_attr.key, [])
if child_names: if child_names:
target_cls = prop.mapper.class_ target_cls = prop.mapper.class_
@ -396,6 +414,39 @@ class CRUDService(Generic[T]):
if cols: if cols:
opt = opt.load_only(*cols) opt = opt.load_only(*cols)
query = query.options(opt) query = query.options(opt)
# NEW: if a first-hop to-one relationships target table is present in filter expressions,
# make sure we actually JOIN it (outer) so filters dont create a cartesian product.
if plan.filter_tables:
mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model))
for rel in mapper.relationships:
if rel.uselist:
continue # only first-hop to-one here
target_tbl = getattr(rel.mapper.class_, "__table__", None)
if target_tbl is None:
continue
if target_tbl in plan.filter_tables:
if rel.key in joined_rel_keys:
continue # already joined via join_paths
query = query.join(getattr(root_alias, rel.key), isouter=True)
joined_rel_keys.add(rel.key)
if log.isEnabledFor(logging.DEBUG):
info = []
for base_alias, rel_attr, target_alias in plan.join_paths:
if base_alias is not root_alias:
continue
prop = getattr(rel_attr, "property", None)
sel = getattr(target_alias, "selectable", None)
info.append({
"rel": (getattr(prop, "key", getattr(rel_attr, "key", "?"))),
"collection": bool(getattr(prop, "uselist", False)),
"target_keys": list(_selectable_keys(sel)) if sel is not None else [],
"joined": (getattr(prop, "key", None) in joined_rel_keys),
})
log.debug("FIRSTHOP: %s.%s first-hop paths: %s",
self.model.__name__, getattr(root_alias, "__table__", type(root_alias)).key,
info)
return query return query
def _apply_proj_opts(self, query, plan: _Plan): def _apply_proj_opts(self, query, plan: _Plan):
@ -428,6 +479,127 @@ class CRUDService(Generic[T]):
except Exception: except Exception:
pass pass
def _rebind_filters_to_firsthop_aliases(self, filters, root_alias, plan):
"""Make filter expressions use the exact same alias objects as our JOINs."""
if not filters:
return filters
# Map first-hop target selectable keysets -> the exact selectable object we JOINed with
alias_map = {}
for base_alias, _rel_attr, target_alias in plan.join_paths:
if base_alias is not root_alias:
continue
sel = getattr(target_alias, "selectable", None)
if sel is not None:
alias_map[frozenset(_selectable_keys(sel))] = sel
if not alias_map:
return filters
def replace(elem):
tbl = getattr(elem, "table", None)
if tbl is None:
return elem
keyset = frozenset(_selectable_keys(tbl))
new_sel = alias_map.get(keyset)
if new_sel is None or new_sel is tbl:
return elem
colkey = getattr(elem, "key", None) or getattr(elem, "name", None)
if not colkey:
return elem
try:
return getattr(new_sel.c, colkey)
except Exception:
return elem
return [visitors.replacement_traverse(f, {}, replace) for f in filters]
def _final_filters(self, root_alias, plan):
"""
Return filters where:
- root/to-one predicates are kept as SQLAlchemy expressions.
- first-hop collection predicates (CollPred) are rebuilt into a single
EXISTS via rel.any(...) with one alias per collection table.
"""
filters = list(plan.filters or [])
if not filters:
return []
# 1) Build a map of first-hop relationships: TABLE -> (rel_attr, target_cls)
coll_map = {}
for base_alias, rel_attr, target_alias in plan.join_paths:
if base_alias is not root_alias:
continue
prop = getattr(rel_attr, "property", None)
if not prop or not getattr(prop, "uselist", False):
continue
target_cls = prop.mapper.class_
tbl = getattr(target_cls, "__table__", None)
if tbl is not None:
coll_map[tbl] = (rel_attr, target_cls)
# 2) Split raw filters into normal SQLA and CollPreds (by target table)
normal_filters = []
by_table: dict[Any, list[CollPred]] = {}
for f in filters:
if isinstance(f, CollPred):
by_table.setdefault(f.table, []).append(f)
else:
normal_filters.append(f)
# 3) Rebuild each table group into ONE .any(...) using one alias
from sqlalchemy.orm import aliased
from sqlalchemy import and_
exists_filters = []
for tbl, preds in by_table.items():
if tbl not in coll_map:
# Safety: if it's not a first-hop collection, ignore or raise
continue
rel_attr, target_cls = coll_map[tbl]
ta = aliased(target_cls)
built = []
for p in preds:
col = getattr(ta, p.col_key)
op = p.op
val = p.value
if op == 'icontains':
built.append(col.ilike(f"%{val}%"))
elif op == 'eq':
built.append(col == val)
elif op == 'ne':
built.append(col != val)
elif op == 'in':
vs = val if isinstance(val, (list, tuple, set)) else [val]
built.append(col.in_(vs))
elif op == 'nin':
vs = val if isinstance(val, (list, tuple, set)) else [val]
built.append(~col.in_(vs))
elif op == 'lt':
built.append(col < val)
elif op == 'lte':
built.append(col <= val)
elif op == 'gt':
built.append(col > val)
elif op == 'gte':
built.append(col >= val)
else:
# unknown op — skip or raise
continue
# enforce child soft delete inside the EXISTS
if hasattr(target_cls, "is_deleted"):
built.append(ta.is_deleted == False)
crit = and_(*built) if built else None
exists_filters.append(rel_attr.of_type(ta).any(crit) if crit is not None
else rel_attr.of_type(ta).any())
# 4) Final filter list = normal SQLA filters + all EXISTS filters
return normal_filters + exists_filters
# ---- public read ops # ---- public read ops
def page(self, params=None, *, page: int = 1, per_page: int = 50, include_total: bool = True): def page(self, params=None, *, page: int = 1, per_page: int = 50, include_total: bool = True):
@ -469,7 +641,9 @@ class CRUDService(Generic[T]):
query = self._apply_firsthop_strategies(query, root_alias, plan) query = self._apply_firsthop_strategies(query, root_alias, plan)
query = self._apply_soft_delete_criteria_for_children(query, plan, params) query = self._apply_soft_delete_criteria_for_children(query, plan, params)
if plan.filters: if plan.filters:
query = query.filter(*plan.filters) filters = self._final_filters(root_alias, plan)
if filters:
query = query.filter(*filters)
order_spec = self._extract_order_spec(root_alias, plan.order_by) order_spec = self._extract_order_spec(root_alias, plan.order_by)
limit = 50 if plan.limit is None else (None if plan.limit == 0 else plan.limit) limit = 50 if plan.limit is None else (None if plan.limit == 0 else plan.limit)
@ -529,7 +703,9 @@ class CRUDService(Generic[T]):
if not bool(getattr(getattr(rel_attr, "property", None), "uselist", False)): if not bool(getattr(getattr(rel_attr, "property", None), "uselist", False)):
base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True) base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True)
if plan.filters: if plan.filters:
base = base.filter(*plan.filters) filters = self._final_filters(root_alias, plan)
if filters:
base = base.filter(*filters) # <-- use base, not query
total = session.query(func.count()).select_from( total = session.query(func.count()).select_from(
base.order_by(None).distinct().subquery() base.order_by(None).distinct().subquery()
).scalar() or 0 ).scalar() or 0
@ -556,7 +732,9 @@ class CRUDService(Generic[T]):
query = self._apply_firsthop_strategies(query, root_alias, plan) query = self._apply_firsthop_strategies(query, root_alias, plan)
query = self._apply_soft_delete_criteria_for_children(query, plan, params) query = self._apply_soft_delete_criteria_for_children(query, plan, params)
if plan.filters: if plan.filters:
query = query.filter(*plan.filters) filters = self._final_filters(root_alias, plan)
if filters:
query = query.filter(*filters)
query = query.filter(getattr(root_alias, "id") == id) query = query.filter(getattr(root_alias, "id") == id)
query = self._apply_proj_opts(query, plan) query = self._apply_proj_opts(query, plan)
@ -577,7 +755,9 @@ class CRUDService(Generic[T]):
query = self._apply_firsthop_strategies(query, root_alias, plan) query = self._apply_firsthop_strategies(query, root_alias, plan)
query = self._apply_soft_delete_criteria_for_children(query, plan, params) query = self._apply_soft_delete_criteria_for_children(query, plan, params)
if plan.filters: if plan.filters:
query = query.filter(*plan.filters) filters = self._final_filters(root_alias, plan)
if filters:
query = query.filter(*filters)
order_by = plan.order_by order_by = plan.order_by
paginating = (plan.limit is not None) or (plan.offset not in (None, 0)) paginating = (plan.limit is not None) or (plan.offset not in (None, 0))

View file

@ -1,9 +1,17 @@
from dataclasses import dataclass
from typing import Any, List, Tuple, Set, Dict, Optional, Iterable from typing import Any, List, Tuple, Set, Dict, Optional, Iterable
from sqlalchemy import and_, asc, desc, or_ from sqlalchemy import and_, asc, desc, or_
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import aliased, selectinload from sqlalchemy.orm import aliased, selectinload
from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.attributes import InstrumentedAttribute
@dataclass(frozen=True)
class CollPred:
table: Any
col_key: str
op: str
value: Any
OPERATORS = { OPERATORS = {
'eq': lambda col, val: col == val, 'eq': lambda col, val: col == val,
'lt': lambda col, val: col < val, 'lt': lambda col, val: col < val,
@ -68,16 +76,48 @@ class CRUDSpec:
exprs = [] exprs = []
for col, join_path in pairs: for col, join_path in pairs:
# Track eager path for each involved relationship chain
if join_path: if join_path:
self.eager_paths.add(join_path) self.eager_paths.add(join_path)
try:
cur_cls = self.model
names = list(join_path)
last_name = names[-1]
is_collection = False
for nm in names:
rel_attr = getattr(cur_cls, nm)
prop = rel_attr.property
cur_cls = prop.mapper.class_
is_collection = bool(getattr(getattr(self.model, last_name), "property", None)
and getattr(getattr(self.model, last_name).property, "uselist", False))
except Exception:
is_collection = False
if is_collection:
target_cls = cur_cls
key = getattr(col, "key", None) or getattr(col, "name", None)
if key and hasattr(target_cls, key):
target_tbl = getattr(target_cls, "__table__", None)
if target_tbl is not None:
exprs.append(CollPred(table=target_tbl, col_key=key, op=op, value=value))
continue
exprs.append(OPERATORS[op](col, value)) exprs.append(OPERATORS[op](col, value))
if not exprs: if not exprs:
return None return None
if len(exprs) == 1:
# If any CollPred is in exprs, do NOT or_ them. Keep it single for now.
if any(isinstance(x, CollPred) for x in exprs):
# If someone used a pipe 'relA.col|relB.col' that produced multiple CollPreds,
# keep the first or raise for now (your choice).
if len(exprs) > 1:
# raise NotImplementedError("OR across collection paths not supported yet")
exprs = [next(x for x in exprs if isinstance(x, CollPred))]
return exprs[0] return exprs[0]
return or_(*exprs)
# Otherwise, standard SQLA clause(s)
return exprs[0] if len(exprs) == 1 else or_(*exprs)
def _collect_filters(self, params: dict) -> list: def _collect_filters(self, params: dict) -> list:
""" """

View file

@ -1,7 +1,8 @@
# engines.py
from __future__ import annotations from __future__ import annotations
from typing import Type, Optional from typing import Type, Optional
from sqlalchemy import create_engine from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker, raiseload, Mapper, RelationshipProperty
from .backend import make_backend_info, BackendInfo from .backend import make_backend_info, BackendInfo
from .config import Config, get_config from .config import Config, get_config
from ._sqlite import apply_sqlite_pragmas from ._sqlite import apply_sqlite_pragmas
@ -12,15 +13,31 @@ def build_engine(config_cls: Type[Config] | None = None):
apply_sqlite_pragmas(engine, config_cls.SQLITE_PRAGMAS) apply_sqlite_pragmas(engine, config_cls.SQLITE_PRAGMAS)
return engine return engine
def _install_nplus1_guards(SessionMaker, *, strict: bool):
if not strict:
return
@event.listens_for(SessionMaker, "do_orm_execute")
def _add_global_raiseload(execute_state):
stmt = execute_state.statement
# Only touch ORM statements (have column_descriptions)
if getattr(stmt, "column_descriptions", None):
execute_state.statement = stmt.options(raiseload("*"))
def build_sessionmaker(config_cls: Type[Config] | None = None, engine=None): def build_sessionmaker(config_cls: Type[Config] | None = None, engine=None):
config_cls = config_cls or get_config(None) config_cls = config_cls or get_config(None)
engine = engine or build_engine(config_cls) engine = engine or build_engine(config_cls)
return sessionmaker(bind=engine, **config_cls.session_kwargs()) SessionMaker = sessionmaker(bind=engine, **config_cls.session_kwargs())
# Toggle with a config flag; default off so you can turn it on when ready
strict = bool(getattr(config_cls, "STRICT_NPLUS1", False))
_install_nplus1_guards(SessionMaker, strict=strict)
return SessionMaker
class CRUDKitRuntime: class CRUDKitRuntime:
""" """
Lightweight container so CRUDKit can be given either: Lightweight container so CRUDKit can be given either:
- prebuild engine/sessionmaker, or - prebuilt engine/sessionmaker, or
- a Config to build them lazily - a Config to build them lazily
""" """
def __init__(self, *, engine=None, session_factory=None, config: Optional[Type[Config]] = None): def __init__(self, *, engine=None, session_factory=None, config: Optional[Type[Config]] = None):

View file

@ -413,7 +413,7 @@ def _value_label_for_field(field: dict, mapper, values_map: dict, instance, sess
if not rel_prop: if not rel_prop:
return None return None
rid = _coerce_fk_value(values_map, instance, base) rid = _coerce_fk_value(values_map, instance, base, rel_prop)
rel_obj = _resolve_rel_obj(values_map, instance, base) rel_obj = _resolve_rel_obj(values_map, instance, base)
label_spec = ( label_spec = (
@ -493,7 +493,7 @@ class _SafeObj:
val = _get_loaded_attr(self._obj, name) val = _get_loaded_attr(self._obj, name)
return "" if val is None else _SafeObj(val) return "" if val is None else _SafeObj(val)
def _coerce_fk_value(values: dict | None, instance: Any, base: str): def _coerce_fk_value(values: dict | None, instance: Any, base: str, rel_prop: Optional[RelationshipProperty] = None):
""" """
Resolve current selection for relationship 'base': Resolve current selection for relationship 'base':
1) values['<base>_id'] 1) values['<base>_id']
@ -540,6 +540,25 @@ def _coerce_fk_value(values: dict | None, instance: Any, base: str):
except Exception: except Exception:
pass pass
# Fallback: if we know the relationship, try its local FK column names
if rel_prop is not None:
try:
st = inspect(instance) if instance is not None else None
except Exception:
st = None
# Try values[...] first
for col in getattr(rel_prop, "local_columns", []) or []:
key = getattr(col, "key", None) or getattr(col, "name", None)
if not key:
continue
if isinstance(values, dict) and key in values and values[key] not in (None, ""):
return values[key]
if set is not None:
attr = st.attrs.get(key) if hasattr(st, "attrs") else None
if attr is not None and attr.loaded_value is not NO_VALUE:
return attr.loaded_value
return None return None
def _is_many_to_one(mapper, name: str) -> Optional[RelationshipProperty]: def _is_many_to_one(mapper, name: str) -> Optional[RelationshipProperty]:
@ -1136,7 +1155,7 @@ def render_form(
base = name[:-3] base = name[:-3]
rel_prop = mapper.relationships.get(base) rel_prop = mapper.relationships.get(base)
if isinstance(rel_prop, RelationshipProperty) and rel_prop.direction.name == "MANYTOONE": if isinstance(rel_prop, RelationshipProperty) and rel_prop.direction.name == "MANYTOONE":
values_map[name] = _coerce_fk_value(values, instance, base) values_map[name] = _coerce_fk_value(values, instance, base, rel_prop) # add rel_prop
else: else:
# Auto-generate path (your original behavior) # Auto-generate path (your original behavior)
@ -1169,7 +1188,7 @@ def render_form(
fk_fields.add(f"{base}_id") fk_fields.add(f"{base}_id")
# NEW: set the current selection for this dropdown # NEW: set the current selection for this dropdown
values_map[f"{base}_id"] = _coerce_fk_value(values, instance, base) values_map[f"{base}_id"] = _coerce_fk_value(values, instance, base, prop)
# Then plain columns # Then plain columns
for col in model_cls.__table__.columns: for col in model_cls.__table__.columns:

View file

@ -35,5 +35,5 @@
{% if submit_attrs %}{% for k,v in submit_attrs.items() %} {% if submit_attrs %}{% for k,v in submit_attrs.items() %}
{{k}}{% if v is not sameas true %}="{{ v }}"{% endif %} {{k}}{% if v is not sameas true %}="{{ v }}"{% endif %}
{% endfor %}{% endif %} {% endfor %}{% endif %}
>{{ submit_label if label else 'Save' }}</button> >{{ submit_label if submit_label else 'Save' }}</button>
</form> </form>