From 091db0b443cca4127e7bf1dc0b858dce3a3f457f Mon Sep 17 00:00:00 2001 From: Yaro Kasear Date: Fri, 29 Aug 2025 09:30:26 -0500 Subject: [PATCH] Lots of fixes coming from downstream. --- crudkit/dsl.py | 35 ++++- crudkit/eager.py | 73 ++++++--- crudkit/html/templates/crudkit/_macros.html | 159 ++++++++++++-------- crudkit/html/ui_fragments.py | 7 +- crudkit/service.py | 129 +++++++++++++++- 5 files changed, 305 insertions(+), 98 deletions(-) diff --git a/crudkit/dsl.py b/crudkit/dsl.py index 5a4e5b8..3953a83 100644 --- a/crudkit/dsl.py +++ b/crudkit/dsl.py @@ -74,6 +74,17 @@ def _related_predicate(Model, path_parts, op_key, value): # wrap at this hop using the *attribute*, not the RelationshipProperty return attr.any(pred) if rel.uselist else attr.has(pred) +def split_sort_tokens(tokens): + simple, dotted = [], [] + for tok in (tokens or []): + if not tok: + continue + key = tok.lstrip("-") + if ":" in key: + key = key.split(":", 1)[0] + (dotted if "." in key else simple).append(tok) + return simple, dotted + def build_query(Model, spec: QuerySpec, eager_policy=None): stmt = select(Model) @@ -102,11 +113,25 @@ def build_query(Model, spec: QuerySpec, eager_policy=None): continue stmt = stmt.where(FILTER_OPS[op_key](col, val) if op_key else (col == val)) - # order_by - for key in spec.order_by: - desc_ = key.startswith("-") - col = getattr(Model, key[1:] if desc_ else key) - stmt = stmt.order_by(desc(col) if desc_ else asc(col)) + simple_sorts, _ = split_sort_tokens(spec.order_by) + + for token in simple_sorts: + direction = "asc" + key = token + if token.startswith("-"): + direction = "desc" + key = token[1:] + if ":" in key: + key, d = key.rsplit(":", 1) + direction = "desc" if d.lower().startswith("d") else "asc" + + if "." in key: + continue + + col = getattr(Model, key, None) + if col is None: + continue + stmt = stmt.order_by(desc(col) if direction == "desc" else asc(col)) if not spec.order_by and spec.page and spec.per_page: pk_cols = inspect(Model).primary_key diff --git a/crudkit/eager.py b/crudkit/eager.py index f32efc2..34e7884 100644 --- a/crudkit/eager.py +++ b/crudkit/eager.py @@ -1,8 +1,21 @@ -from typing import List +from __future__ import annotations +from typing import Iterable, List, Sequence, Set from sqlalchemy.inspection import inspect -from sqlalchemy.orm import Load, joinedload, selectinload +from sqlalchemy.orm import Load, joinedload, selectinload, RelationshipProperty -def default_eager_policy(Model, expand: List[str]) -> List[Load]: +class EagerConfig: + def __init__(self, strict: bool = False, max_depth: int = 4): + self.strict = strict + self.max_depth = max_depth + +def _rel(cls, name: str) -> RelationshipProperty | None: + return inspect(cls).relationships.get(name) + +def _is_expandable(rel: RelationshipProperty) -> bool: + # Skip dynamic or viewonly collections; they don’t support eagerload + return rel.lazy != "dynamic" + +def default_eager_policy(Model, expand: Sequence[str], cfg: EagerConfig | None = None) -> List[Load]: """ Heuristic: - many-to-one / one-to-one: joinedload @@ -12,31 +25,51 @@ def default_eager_policy(Model, expand: List[str]) -> List[Load]: if not expand: return [] + cfg = cfg or EagerConfig() + # normalize, dedupe, and prefer longer paths over their prefixes + raw: Set[str] = {p.strip() for p in expand if p and p.strip()} + # drop prefixes if a longer path exists (author, author.publisher -> keep only author.publisher) + pruned: Set[str] = set(raw) + for p in raw: + parts = p.split(".") + for i in range(1, len(parts)): + pruned.discard(".".join(parts[:i])) + opts: List[Load] = [] + seen: Set[tuple] = set() - for path in expand: + for path in sorted(pruned): parts = path.split(".") + if len(parts) > cfg.max_depth: + if cfg.strict: + raise ValueError(f"expand path too deep: {path} (max {cfg.max_depth})") + continue + current_model = Model - current_inspect = inspect(current_model) + # build the chain incrementally + loader: Load | None = None + ok = True - # first hop - rel = current_inspect.relationships.get(parts[0]) - if not rel: - continue # silently skip bad names - attr = getattr(current_model, parts[0]) - loader: Load = selectinload(attr) if rel.uselist else joinedload(attr) - current_model = rel.mapper.class_ - - # nested hops, if any - for name in parts[1:]: - current_inspect = inspect(current_model) - rel = current_inspect.relationships.get(name) - if not rel: + for i, name in enumerate(parts): + rel = _rel(current_model, name) + if not rel or not _is_expandable(rel): + ok = False break attr = getattr(current_model, name) - loader = loader.selectinload(attr) if rel.uselist else loader.joinedload(attr) + if loader is None: + loader = selectinload(attr) if rel.uselist else joinedload(attr) + else: + loader = loader.selectinload(attr) if rel.uselist else loader.joinedload(attr) current_model = rel.mapper.class_ - opts.append(loader) + if not ok: + if cfg.strict: + raise ValueError(f"unknown or non-expandable relationship in expand path: {path}") + continue + + key = (tuple(parts),) + if loader is not None and key not in seen: + opts.append(loader) + seen.add(key) return opts diff --git a/crudkit/html/templates/crudkit/_macros.html b/crudkit/html/templates/crudkit/_macros.html index f8c504e..342c731 100644 --- a/crudkit/html/templates/crudkit/_macros.html +++ b/crudkit/html/templates/crudkit/_macros.html @@ -32,78 +32,109 @@ {%- endfor -%} {%- endmacro %} +{# helper: centralize the query string once #} +{% macro _q(model, page, per_page, sort, filters, fields_csv) -%} +/ui/{{ model }}/frag/rows +?page={{ page }}&per_page={{ per_page }} +{%- if sort %}&sort={{ sort }}{% endif -%} +{%- if fields_csv %}&fields_csv={{ fields_csv|urlencode }}{% endif -%} +{%- for k, v in (filters or {}).items() %}&{{ k }}={{ v|urlencode }}{% endfor -%} +{%- endmacro %} + {% macro pager(model, page, pages, per_page, sort, filters, fields_csv) -%} {% set p = page|int %} {% set pg = pages|int %} -{% set prev = [1, p-1]|max %} -{% set nxt = [pg, p+1]|min %} - -
- -
- -{%- endmacro %} \ No newline at end of file + {# one tiny listener to keep #pager-state in sync for every button #} + + {%- endmacro %} + + {% macro form(schema, action, method="POST", obj_id=None, hx=False, csrf_token=None) -%} +
+ {%- if csrf_token %}{% endif -%} + {%- if obj_id %}{% endif -%} + + + {%- for f in schema -%} +
+ {% set fid = 'f-' ~ f.name ~ '-' ~ (obj_id or 'new') %} + + {%- if f.type == "textarea" -%} + + {%- elif f.type == "select" -%} + + {%- elif f.type == "checkbox" -%} + + + {%- else -%} + + {%- endif -%} + {%- if f.help %}
{{ f.help }}
{% endif -%} +
+ {%- endfor -%} + +
+ +
+
+ {%- endmacro %} \ No newline at end of file diff --git a/crudkit/html/ui_fragments.py b/crudkit/html/ui_fragments.py index a75cd59..c2da1da 100644 --- a/crudkit/html/ui_fragments.py +++ b/crudkit/html/ui_fragments.py @@ -5,6 +5,7 @@ from flask import Blueprint, request, render_template, abort, make_response from sqlalchemy import select from sqlalchemy.orm import scoped_session from sqlalchemy.inspection import inspect +from sqlalchemy.sql.elements import UnaryExpression from sqlalchemy.sql.sqltypes import Integer, Boolean, Date, DateTime, Float, Numeric from ..dsl import QuerySpec @@ -115,12 +116,12 @@ def make_fragments_blueprint(db_session_factory, registry: Dict[str, Any], *, na page = request.args.get("page", type=int) or 1 per_page = request.args.get("per_page", type=int) or 20 - expand = _collect_expand_from_paths(fields) + expand = _collect_expand_from_paths(fields + ([sort.split(":")[0]] if sort else [])) spec = QuerySpec(filters=filters, order_by=[sort] if sort else [], page=page, per_page=per_page, expand=expand) s = session(); svc = CrudService(s, default_eager_policy) rows, _ = svc.list(Model, spec) - html = render_template("crudkit/rows.html", items=rows, fields=fields, getp=_getp) + html = render_template("crudkit/rows.html", items=rows, fields=fields, getp=_getp, model=model) return html @@ -134,7 +135,7 @@ def make_fragments_blueprint(db_session_factory, registry: Dict[str, Any], *, na sort = request.args.get("sort") fields_csv = request.args.get("fields_csv") or "id,name" fields = _paths_from_csv(fields_csv) - expand = _collect_expand_from_paths(fields) + expand = _collect_expand_from_paths(fields + ([sort.split(":")[0]] if sort else [])) spec = QuerySpec(filters=filters, order_by=[sort] if sort else [], page=page, per_page=per_page, expand=expand) s = session(); svc = CrudService(s, default_eager_policy) diff --git a/crudkit/service.py b/crudkit/service.py index 1950d06..a62ae18 100644 --- a/crudkit/service.py +++ b/crudkit/service.py @@ -1,10 +1,93 @@ -from sqlalchemy import func +import sqlalchemy as sa +from sqlalchemy import func, asc from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, aliased +from sqlalchemy.inspection import inspect +from sqlalchemy.sql.elements import UnaryExpression -from .dsl import QuerySpec, build_query +from .dsl import QuerySpec, build_query, split_sort_tokens from .eager import default_eager_policy +def _dedup_order_by(ordering): + seen = set() + result = [] + for ob in ordering: + col = ob.element if isinstance(ob, UnaryExpression) else ob + key = f"{col}-{getattr(ob, 'modifier', '')}-{getattr(ob, 'operator', '')}" + if key in seen: + continue + seen.add(key) + result.append(ob) + return result + +def _parse_sort_token(token: str): + token = token.strip() + direction = "asc" + if token.startswith('-'): + direction = "desc" + token = token[1:] + if ":" in token: + key, dirpart = token.rsplit(":", 1) + direction = "desc" if dirpart.lower().startswith("d") else "asc" + return key, direction + return token, direction + +def _apply_dotted_ordering(stmt, Model, sort_tokens): + """ + stmt: a select(Model) statement + sort_tokens: list[str] like ["owner.identifier", "-brand.name"] + Returns: (stmt, alias_cache) + """ + mapper = inspect(Model) + alias_cache = {} # maps a path like "owner" or "brand" to its alias + + for tok in sort_tokens: + path, direction = _parse_sort_token(tok) + parts = [p for p in path.split(".") if p] + if not parts: + continue + + entity = Model + current_mapper = mapper + alias_path = [] + + # Walk relationships for all but the last part + for rel_name in parts[:-1]: + rel = current_mapper.relationships.get(rel_name) + if rel is None: + # invalid sort key; skip quietly or raise + # raise ValueError(f"Unknown relationship {current_mapper.class_.__name__}.{rel_name}") + entity = None + break + + alias_path.append(rel_name) + key = ".".join(alias_path) + + if key in alias_cache: + entity_alias = alias_cache[key] + else: + # build an alias and join + entity_alias = aliased(rel.mapper.class_) + stmt = stmt.outerjoin(entity_alias, getattr(entity, rel.key)) + alias_cache[key] = entity_alias + + entity = entity_alias + current_mapper = inspect(rel.mapper.class_) + + if entity is None: + continue + + col_name = parts[-1] + # Validate final column + if col_name not in current_mapper.columns: + # raise ValueError(f"Unknown column {current_mapper.class_.__name__}.{col_name}") + continue + + col = getattr(entity, col_name) if entity is not Model else getattr(Model, col_name) + stmt = stmt.order_by(col.desc() if direction == "desc" else col.asc()) + + return stmt + class CrudService: def __init__(self, session: Session, eager_policy=default_eager_policy): self.s = session @@ -25,10 +108,44 @@ class CrudService: def list(self, Model, spec: QuerySpec): stmt = build_query(Model, spec, self.eager_policy) - count_stmt = stmt.with_only_columns(func.count()).order_by(None) - total = self.s.execute(count_stmt).scalar_one() + + simple_sorts, dotted_sorts = split_sort_tokens(spec.order_by) + if dotted_sorts: + stmt = _apply_dotted_ordering(stmt, Model, dotted_sorts) + + # count query + pk = getattr(Model, "id") # adjust if not 'id' + count_base = stmt.with_only_columns(sa.distinct(pk)).order_by(None) + total = self.s.execute( + sa.select(sa.func.count()).select_from(count_base.subquery()) + ).scalar_one() + if spec.page and spec.per_page: - stmt = stmt.limit(spec.per_page).offset((spec.page - 1) * spec.per_page) + offset = (spec.page - 1) * spec.per_page + stmt = stmt.limit(spec.per_page).offset(offset) + + # ---- ORDER BY handling ---- + mapper = inspect(Model) + pk_cols = mapper.primary_key + + # Gather all clauses added so far + ordering = list(stmt._order_by_clauses) + + # Append pk tie-breakers if not already present + existing_cols = { + str(ob.element if isinstance(ob, UnaryExpression) else ob) + for ob in ordering + } + for c in pk_cols: + if str(c) not in existing_cols: + ordering.append(asc(c)) + + # Dedup *before* applying + ordering = _dedup_order_by(ordering) + + # Now wipe old order_bys and set once + stmt = stmt.order_by(None).order_by(*ordering) + rows = self.s.execute(stmt).scalars().all() return rows, total