diff --git a/crudkit/config.py b/crudkit/config.py index 0439a3e..fb87b51 100644 --- a/crudkit/config.py +++ b/crudkit/config.py @@ -187,6 +187,8 @@ class Config: "synchronous": os.getenv("SQLITE_SYNCHRONOUS", "NORMAL"), } + STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "1"))) + @classmethod def engine_kwargs(cls) -> Dict[str, Any]: url = cls.DATABASE_URL @@ -221,15 +223,18 @@ class Config: class DevConfig(Config): DEBUG = True SQLALCHEMY_ECHO = bool(int(os.getenv("DB_ECHO", "1"))) + STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "1"))) class TestConfig(Config): TESTING = True DATABASE_URL = build_database_url(backend="sqlite", database=":memory:") SQLALCHEMY_ECHO = False + STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "1"))) class ProdConfig(Config): DEBUG = False SQLALCHEMY_ECHO = bool(int(os.getenv("DB_ECHO", "0"))) + STRICT_NPLUS1 = bool(int(os.getenv("CRUDKIT_STRICT_NPLUS1", "0"))) def get_config(name: str | None) -> Type[Config]: """ diff --git a/crudkit/core/service.py b/crudkit/core/service.py index db510cd..fbad3a1 100644 --- a/crudkit/core/service.py +++ b/crudkit/core/service.py @@ -4,22 +4,34 @@ from collections.abc import Iterable from dataclasses import dataclass from flask import current_app from typing import Any, Callable, Type, TypeVar, Generic, Optional, Protocol, runtime_checkable, cast -from sqlalchemy import and_, func, inspect, or_, text +from sqlalchemy import and_, func, inspect, or_, text, select, literal from sqlalchemy.engine import Engine, Connection -from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, selectinload, with_loader_criteria +from sqlalchemy.orm import Load, Session, with_polymorphic, Mapper, selectinload, with_loader_criteria, aliased, with_parent from sqlalchemy.orm.attributes import InstrumentedAttribute -from sqlalchemy.sql import operators +from sqlalchemy.sql import operators, visitors from sqlalchemy.sql.elements import UnaryExpression, ColumnElement from crudkit.core import to_jsonable, deep_diff, diff_to_patch, filter_to_columns, normalize_payload from crudkit.core.base import Version -from crudkit.core.spec import CRUDSpec +from crudkit.core.spec import CRUDSpec, CollPred from crudkit.core.types import OrderSpec, SeekWindow from crudkit.backend import BackendInfo, make_backend_info from crudkit.projection import compile_projection import logging log = logging.getLogger("crudkit.service") +# logging.getLogger("crudkit.service").setLevel(logging.DEBUG) +# Ensure our debug actually prints even if the app/root logger is WARNING+ +# if not log.handlers: +# _h = logging.StreamHandler() +# _h.setLevel(logging.DEBUG) +# _h.setFormatter(logging.Formatter( +# "%(asctime)s %(levelname)s %(name)s: %(message)s" +# )) +# log.addHandler(_h) +# +# log.setLevel(logging.DEBUG) +# log.propagate = False @runtime_checkable class _HasID(Protocol): @@ -230,7 +242,9 @@ class CRUDService(Generic[T]): # Make sure joins/filters match the real query query = self._apply_firsthop_strategies(query, root_alias, plan) if plan.filters: - query = query.filter(*plan.filters) + filters = self._final_filters(root_alias, plan) + if filters: + query = query.filter(*filters) order_spec = self._extract_order_spec(root_alias, plan.order_by) @@ -358,10 +372,11 @@ class CRUDService(Generic[T]): spec.parse_includes() join_paths = tuple(spec.get_join_paths()) filter_tables = _collect_tables_from_filters(filters) + fkeys = set() _, proj_opts = compile_projection(self.model, req_fields) if req_fields else ([], []) - filter_tables = () - fkeys = set() + # filter_tables = () + # fkeys = set() return self._Plan( spec=spec, filters=filters, order_by=order_by, limit=limit, offset=offset, @@ -377,6 +392,9 @@ class CRUDService(Generic[T]): def _apply_firsthop_strategies(self, query, root_alias, plan: _Plan): nested_first_hops = { p[0] for p in (plan.rel_field_names or {}).keys() if len(p) > 1 } + joined_rel_keys = set() + + # Existing behavior: join everything in join_paths (to-one), selectinload collections for base_alias, rel_attr, target_alias in plan.join_paths: if base_alias is not root_alias: continue @@ -385,17 +403,50 @@ class CRUDService(Generic[T]): if not is_collection: query = query.join(target_alias, rel_attr.of_type(target_alias), isouter=True) + joined_rel_keys.add(prop.key if prop is not None else rel_attr.key) else: opt = selectinload(rel_attr) - if is_collection: - child_names = (plan.collection_field_names or {}).get(rel_attr.key, []) - if child_names: - target_cls = prop.mapper.class_ - cols = [getattr(target_cls, n, None) for n in child_names] - cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] - if cols: - opt = opt.load_only(*cols) + child_names = (plan.collection_field_names or {}).get(rel_attr.key, []) + if child_names: + target_cls = prop.mapper.class_ + cols = [getattr(target_cls, n, None) for n in child_names] + cols = [c for c in cols if isinstance(c, InstrumentedAttribute)] + if cols: + opt = opt.load_only(*cols) query = query.options(opt) + + # NEW: if a first-hop to-one relationship’s target table is present in filter expressions, + # make sure we actually JOIN it (outer) so filters don’t create a cartesian product. + if plan.filter_tables: + mapper: Mapper[Any] = cast(Mapper[Any], inspect(self.model)) + for rel in mapper.relationships: + if rel.uselist: + continue # only first-hop to-one here + target_tbl = getattr(rel.mapper.class_, "__table__", None) + if target_tbl is None: + continue + if target_tbl in plan.filter_tables: + if rel.key in joined_rel_keys: + continue # already joined via join_paths + query = query.join(getattr(root_alias, rel.key), isouter=True) + joined_rel_keys.add(rel.key) + if log.isEnabledFor(logging.DEBUG): + info = [] + for base_alias, rel_attr, target_alias in plan.join_paths: + if base_alias is not root_alias: + continue + prop = getattr(rel_attr, "property", None) + sel = getattr(target_alias, "selectable", None) + info.append({ + "rel": (getattr(prop, "key", getattr(rel_attr, "key", "?"))), + "collection": bool(getattr(prop, "uselist", False)), + "target_keys": list(_selectable_keys(sel)) if sel is not None else [], + "joined": (getattr(prop, "key", None) in joined_rel_keys), + }) + log.debug("FIRSTHOP: %s.%s first-hop paths: %s", + self.model.__name__, getattr(root_alias, "__table__", type(root_alias)).key, + info) + return query def _apply_proj_opts(self, query, plan: _Plan): @@ -428,6 +479,127 @@ class CRUDService(Generic[T]): except Exception: pass + def _rebind_filters_to_firsthop_aliases(self, filters, root_alias, plan): + """Make filter expressions use the exact same alias objects as our JOINs.""" + if not filters: + return filters + + # Map first-hop target selectable keysets -> the exact selectable object we JOINed with + alias_map = {} + for base_alias, _rel_attr, target_alias in plan.join_paths: + if base_alias is not root_alias: + continue + sel = getattr(target_alias, "selectable", None) + if sel is not None: + alias_map[frozenset(_selectable_keys(sel))] = sel + + if not alias_map: + return filters + + def replace(elem): + tbl = getattr(elem, "table", None) + if tbl is None: + return elem + keyset = frozenset(_selectable_keys(tbl)) + new_sel = alias_map.get(keyset) + if new_sel is None or new_sel is tbl: + return elem + + colkey = getattr(elem, "key", None) or getattr(elem, "name", None) + if not colkey: + return elem + try: + return getattr(new_sel.c, colkey) + except Exception: + return elem + + return [visitors.replacement_traverse(f, {}, replace) for f in filters] + + def _final_filters(self, root_alias, plan): + """ + Return filters where: + - root/to-one predicates are kept as SQLAlchemy expressions. + - first-hop collection predicates (CollPred) are rebuilt into a single + EXISTS via rel.any(...) with one alias per collection table. + """ + filters = list(plan.filters or []) + if not filters: + return [] + + # 1) Build a map of first-hop relationships: TABLE -> (rel_attr, target_cls) + coll_map = {} + for base_alias, rel_attr, target_alias in plan.join_paths: + if base_alias is not root_alias: + continue + prop = getattr(rel_attr, "property", None) + if not prop or not getattr(prop, "uselist", False): + continue + target_cls = prop.mapper.class_ + tbl = getattr(target_cls, "__table__", None) + if tbl is not None: + coll_map[tbl] = (rel_attr, target_cls) + + # 2) Split raw filters into normal SQLA and CollPreds (by target table) + normal_filters = [] + by_table: dict[Any, list[CollPred]] = {} + for f in filters: + if isinstance(f, CollPred): + by_table.setdefault(f.table, []).append(f) + else: + normal_filters.append(f) + + # 3) Rebuild each table group into ONE .any(...) using one alias + from sqlalchemy.orm import aliased + from sqlalchemy import and_ + + exists_filters = [] + for tbl, preds in by_table.items(): + if tbl not in coll_map: + # Safety: if it's not a first-hop collection, ignore or raise + continue + rel_attr, target_cls = coll_map[tbl] + ta = aliased(target_cls) + + built = [] + for p in preds: + col = getattr(ta, p.col_key) + op = p.op + val = p.value + if op == 'icontains': + built.append(col.ilike(f"%{val}%")) + elif op == 'eq': + built.append(col == val) + elif op == 'ne': + built.append(col != val) + elif op == 'in': + vs = val if isinstance(val, (list, tuple, set)) else [val] + built.append(col.in_(vs)) + elif op == 'nin': + vs = val if isinstance(val, (list, tuple, set)) else [val] + built.append(~col.in_(vs)) + elif op == 'lt': + built.append(col < val) + elif op == 'lte': + built.append(col <= val) + elif op == 'gt': + built.append(col > val) + elif op == 'gte': + built.append(col >= val) + else: + # unknown op — skip or raise + continue + + # enforce child soft delete inside the EXISTS + if hasattr(target_cls, "is_deleted"): + built.append(ta.is_deleted == False) + + crit = and_(*built) if built else None + exists_filters.append(rel_attr.of_type(ta).any(crit) if crit is not None + else rel_attr.of_type(ta).any()) + + # 4) Final filter list = normal SQLA filters + all EXISTS filters + return normal_filters + exists_filters + # ---- public read ops def page(self, params=None, *, page: int = 1, per_page: int = 50, include_total: bool = True): @@ -469,7 +641,9 @@ class CRUDService(Generic[T]): query = self._apply_firsthop_strategies(query, root_alias, plan) query = self._apply_soft_delete_criteria_for_children(query, plan, params) if plan.filters: - query = query.filter(*plan.filters) + filters = self._final_filters(root_alias, plan) + if filters: + query = query.filter(*filters) order_spec = self._extract_order_spec(root_alias, plan.order_by) limit = 50 if plan.limit is None else (None if plan.limit == 0 else plan.limit) @@ -529,7 +703,9 @@ class CRUDService(Generic[T]): if not bool(getattr(getattr(rel_attr, "property", None), "uselist", False)): base = base.join(target_alias, rel_attr.of_type(target_alias), isouter=True) if plan.filters: - base = base.filter(*plan.filters) + filters = self._final_filters(root_alias, plan) + if filters: + base = base.filter(*filters) # <-- use base, not query total = session.query(func.count()).select_from( base.order_by(None).distinct().subquery() ).scalar() or 0 @@ -556,7 +732,9 @@ class CRUDService(Generic[T]): query = self._apply_firsthop_strategies(query, root_alias, plan) query = self._apply_soft_delete_criteria_for_children(query, plan, params) if plan.filters: - query = query.filter(*plan.filters) + filters = self._final_filters(root_alias, plan) + if filters: + query = query.filter(*filters) query = query.filter(getattr(root_alias, "id") == id) query = self._apply_proj_opts(query, plan) @@ -577,7 +755,9 @@ class CRUDService(Generic[T]): query = self._apply_firsthop_strategies(query, root_alias, plan) query = self._apply_soft_delete_criteria_for_children(query, plan, params) if plan.filters: - query = query.filter(*plan.filters) + filters = self._final_filters(root_alias, plan) + if filters: + query = query.filter(*filters) order_by = plan.order_by paginating = (plan.limit is not None) or (plan.offset not in (None, 0)) diff --git a/crudkit/core/spec.py b/crudkit/core/spec.py index 4ec972f..bfd5f11 100644 --- a/crudkit/core/spec.py +++ b/crudkit/core/spec.py @@ -1,9 +1,17 @@ +from dataclasses import dataclass from typing import Any, List, Tuple, Set, Dict, Optional, Iterable from sqlalchemy import and_, asc, desc, or_ from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import aliased, selectinload from sqlalchemy.orm.attributes import InstrumentedAttribute +@dataclass(frozen=True) +class CollPred: + table: Any + col_key: str + op: str + value: Any + OPERATORS = { 'eq': lambda col, val: col == val, 'lt': lambda col, val: col < val, @@ -68,16 +76,48 @@ class CRUDSpec: exprs = [] for col, join_path in pairs: - # Track eager path for each involved relationship chain if join_path: self.eager_paths.add(join_path) + + try: + cur_cls = self.model + names = list(join_path) + last_name = names[-1] + is_collection = False + for nm in names: + rel_attr = getattr(cur_cls, nm) + prop = rel_attr.property + cur_cls = prop.mapper.class_ + is_collection = bool(getattr(getattr(self.model, last_name), "property", None) + and getattr(getattr(self.model, last_name).property, "uselist", False)) + except Exception: + is_collection = False + + if is_collection: + target_cls = cur_cls + key = getattr(col, "key", None) or getattr(col, "name", None) + if key and hasattr(target_cls, key): + target_tbl = getattr(target_cls, "__table__", None) + if target_tbl is not None: + exprs.append(CollPred(table=target_tbl, col_key=key, op=op, value=value)) + continue + exprs.append(OPERATORS[op](col, value)) if not exprs: return None - if len(exprs) == 1: + + # If any CollPred is in exprs, do NOT or_ them. Keep it single for now. + if any(isinstance(x, CollPred) for x in exprs): + # If someone used a pipe 'relA.col|relB.col' that produced multiple CollPreds, + # keep the first or raise for now (your choice). + if len(exprs) > 1: + # raise NotImplementedError("OR across collection paths not supported yet") + exprs = [next(x for x in exprs if isinstance(x, CollPred))] return exprs[0] - return or_(*exprs) + + # Otherwise, standard SQLA clause(s) + return exprs[0] if len(exprs) == 1 else or_(*exprs) def _collect_filters(self, params: dict) -> list: """ diff --git a/crudkit/engines.py b/crudkit/engines.py index b420a8d..4e18fd5 100644 --- a/crudkit/engines.py +++ b/crudkit/engines.py @@ -1,7 +1,8 @@ +# engines.py from __future__ import annotations from typing import Type, Optional -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy import create_engine, event +from sqlalchemy.orm import sessionmaker, raiseload, Mapper, RelationshipProperty from .backend import make_backend_info, BackendInfo from .config import Config, get_config from ._sqlite import apply_sqlite_pragmas @@ -12,15 +13,31 @@ def build_engine(config_cls: Type[Config] | None = None): apply_sqlite_pragmas(engine, config_cls.SQLITE_PRAGMAS) return engine +def _install_nplus1_guards(SessionMaker, *, strict: bool): + if not strict: + return + + @event.listens_for(SessionMaker, "do_orm_execute") + def _add_global_raiseload(execute_state): + stmt = execute_state.statement + # Only touch ORM statements (have column_descriptions) + if getattr(stmt, "column_descriptions", None): + execute_state.statement = stmt.options(raiseload("*")) + def build_sessionmaker(config_cls: Type[Config] | None = None, engine=None): config_cls = config_cls or get_config(None) engine = engine or build_engine(config_cls) - return sessionmaker(bind=engine, **config_cls.session_kwargs()) + SessionMaker = sessionmaker(bind=engine, **config_cls.session_kwargs()) + + # Toggle with a config flag; default off so you can turn it on when ready + strict = bool(getattr(config_cls, "STRICT_NPLUS1", False)) + _install_nplus1_guards(SessionMaker, strict=strict) + return SessionMaker class CRUDKitRuntime: """ Lightweight container so CRUDKit can be given either: - - prebuild engine/sessionmaker, or + - prebuilt engine/sessionmaker, or - a Config to build them lazily """ def __init__(self, *, engine=None, session_factory=None, config: Optional[Type[Config]] = None): diff --git a/crudkit/ui/fragments.py b/crudkit/ui/fragments.py index b4db83f..08903de 100644 --- a/crudkit/ui/fragments.py +++ b/crudkit/ui/fragments.py @@ -413,7 +413,7 @@ def _value_label_for_field(field: dict, mapper, values_map: dict, instance, sess if not rel_prop: return None - rid = _coerce_fk_value(values_map, instance, base) + rid = _coerce_fk_value(values_map, instance, base, rel_prop) rel_obj = _resolve_rel_obj(values_map, instance, base) label_spec = ( @@ -493,7 +493,7 @@ class _SafeObj: val = _get_loaded_attr(self._obj, name) return "" if val is None else _SafeObj(val) -def _coerce_fk_value(values: dict | None, instance: Any, base: str): +def _coerce_fk_value(values: dict | None, instance: Any, base: str, rel_prop: Optional[RelationshipProperty] = None): """ Resolve current selection for relationship 'base': 1) values['_id'] @@ -540,6 +540,25 @@ def _coerce_fk_value(values: dict | None, instance: Any, base: str): except Exception: pass + # Fallback: if we know the relationship, try its local FK column names + if rel_prop is not None: + try: + st = inspect(instance) if instance is not None else None + except Exception: + st = None + + # Try values[...] first + for col in getattr(rel_prop, "local_columns", []) or []: + key = getattr(col, "key", None) or getattr(col, "name", None) + if not key: + continue + if isinstance(values, dict) and key in values and values[key] not in (None, ""): + return values[key] + if set is not None: + attr = st.attrs.get(key) if hasattr(st, "attrs") else None + if attr is not None and attr.loaded_value is not NO_VALUE: + return attr.loaded_value + return None def _is_many_to_one(mapper, name: str) -> Optional[RelationshipProperty]: @@ -1136,7 +1155,7 @@ def render_form( base = name[:-3] rel_prop = mapper.relationships.get(base) if isinstance(rel_prop, RelationshipProperty) and rel_prop.direction.name == "MANYTOONE": - values_map[name] = _coerce_fk_value(values, instance, base) + values_map[name] = _coerce_fk_value(values, instance, base, rel_prop) # add rel_prop else: # Auto-generate path (your original behavior) @@ -1169,7 +1188,7 @@ def render_form( fk_fields.add(f"{base}_id") # NEW: set the current selection for this dropdown - values_map[f"{base}_id"] = _coerce_fk_value(values, instance, base) + values_map[f"{base}_id"] = _coerce_fk_value(values, instance, base, prop) # Then plain columns for col in model_cls.__table__.columns: diff --git a/crudkit/ui/templates/form.html b/crudkit/ui/templates/form.html index 9046b65..f57074a 100644 --- a/crudkit/ui/templates/form.html +++ b/crudkit/ui/templates/form.html @@ -35,5 +35,5 @@ {% if submit_attrs %}{% for k,v in submit_attrs.items() %} {{k}}{% if v is not sameas true %}="{{ v }}"{% endif %} {% endfor %}{% endif %} - >{{ submit_label if label else 'Save' }} + >{{ submit_label if submit_label else 'Save' }}