Finally got order_by fixed so sorting actually works properly with relationships and "identifier."
This commit is contained in:
parent
8100d221a1
commit
f47fb6b505
6 changed files with 179 additions and 18 deletions
|
|
@ -74,6 +74,17 @@ def _related_predicate(Model, path_parts, op_key, value):
|
|||
# wrap at this hop using the *attribute*, not the RelationshipProperty
|
||||
return attr.any(pred) if rel.uselist else attr.has(pred)
|
||||
|
||||
def split_sort_tokens(tokens):
|
||||
simple, dotted = [], []
|
||||
for tok in (tokens or []):
|
||||
if not tok:
|
||||
continue
|
||||
key = tok.lstrip("-")
|
||||
if ":" in key:
|
||||
key = key.split(":", 1)[0]
|
||||
(dotted if "." in key else simple).append(tok)
|
||||
return simple, dotted
|
||||
|
||||
def build_query(Model, spec: QuerySpec, eager_policy=None):
|
||||
stmt = select(Model)
|
||||
|
||||
|
|
@ -102,11 +113,25 @@ def build_query(Model, spec: QuerySpec, eager_policy=None):
|
|||
continue
|
||||
stmt = stmt.where(FILTER_OPS[op_key](col, val) if op_key else (col == val))
|
||||
|
||||
# order_by
|
||||
for key in spec.order_by:
|
||||
desc_ = key.startswith("-")
|
||||
col = getattr(Model, key[1:] if desc_ else key)
|
||||
stmt = stmt.order_by(desc(col) if desc_ else asc(col))
|
||||
simple_sorts, _ = split_sort_tokens(spec.order_by)
|
||||
|
||||
for token in simple_sorts:
|
||||
direction = "asc"
|
||||
key = token
|
||||
if token.startswith("-"):
|
||||
direction = "desc"
|
||||
key = token[1:]
|
||||
if ":" in key:
|
||||
key, d = key.rsplit(":", 1)
|
||||
direction = "desc" if d.lower().startswith("d") else "asc"
|
||||
|
||||
if "." in key:
|
||||
continue
|
||||
|
||||
col = getattr(Model, key, None)
|
||||
if col is None:
|
||||
continue
|
||||
stmt = stmt.order_by(desc(col) if direction == "desc" else asc(col))
|
||||
|
||||
if not spec.order_by and spec.page and spec.per_page:
|
||||
pk_cols = inspect(Model).primary_key
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from flask import Blueprint, request, render_template, abort, make_response
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import scoped_session
|
||||
from sqlalchemy.inspection import inspect
|
||||
from sqlalchemy.sql.elements import UnaryExpression
|
||||
from sqlalchemy.sql.sqltypes import Integer, Boolean, Date, DateTime, Float, Numeric
|
||||
|
||||
from ..dsl import QuerySpec
|
||||
|
|
@ -115,7 +116,7 @@ def make_fragments_blueprint(db_session_factory, registry: Dict[str, Any], *, na
|
|||
page = request.args.get("page", type=int) or 1
|
||||
per_page = request.args.get("per_page", type=int) or 20
|
||||
|
||||
expand = _collect_expand_from_paths(fields)
|
||||
expand = _collect_expand_from_paths(fields + ([sort.split(":")[0]] if sort else []))
|
||||
spec = QuerySpec(filters=filters, order_by=[sort] if sort else [], page=page, per_page=per_page, expand=expand)
|
||||
s = session(); svc = CrudService(s, default_eager_policy)
|
||||
rows, _ = svc.list(Model, spec)
|
||||
|
|
@ -134,7 +135,7 @@ def make_fragments_blueprint(db_session_factory, registry: Dict[str, Any], *, na
|
|||
sort = request.args.get("sort")
|
||||
fields_csv = request.args.get("fields_csv") or "id,name"
|
||||
fields = _paths_from_csv(fields_csv)
|
||||
expand = _collect_expand_from_paths(fields)
|
||||
expand = _collect_expand_from_paths(fields + ([sort.split(":")[0]] if sort else []))
|
||||
|
||||
spec = QuerySpec(filters=filters, order_by=[sort] if sort else [], page=page, per_page=per_page, expand=expand)
|
||||
s = session(); svc = CrudService(s, default_eager_policy)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,93 @@
|
|||
from sqlalchemy import func
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import func, asc
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, aliased
|
||||
from sqlalchemy.inspection import inspect
|
||||
from sqlalchemy.sql.elements import UnaryExpression
|
||||
|
||||
from .dsl import QuerySpec, build_query
|
||||
from .dsl import QuerySpec, build_query, split_sort_tokens
|
||||
from .eager import default_eager_policy
|
||||
|
||||
def _dedup_order_by(ordering):
|
||||
seen = set()
|
||||
result = []
|
||||
for ob in ordering:
|
||||
col = ob.element if isinstance(ob, UnaryExpression) else ob
|
||||
key = f"{col}-{getattr(ob, 'modifier', '')}-{getattr(ob, 'operator', '')}"
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
result.append(ob)
|
||||
return result
|
||||
|
||||
def _parse_sort_token(token: str):
|
||||
token = token.strip()
|
||||
direction = "asc"
|
||||
if token.startswith('-'):
|
||||
direction = "desc"
|
||||
token = token[1:]
|
||||
if ":" in token:
|
||||
key, dirpart = token.rsplit(":", 1)
|
||||
direction = "desc" if dirpart.lower().startswith("d") else "asc"
|
||||
return key, direction
|
||||
return token, direction
|
||||
|
||||
def _apply_dotted_ordering(stmt, Model, sort_tokens):
|
||||
"""
|
||||
stmt: a select(Model) statement
|
||||
sort_tokens: list[str] like ["owner.identifier", "-brand.name"]
|
||||
Returns: (stmt, alias_cache)
|
||||
"""
|
||||
mapper = inspect(Model)
|
||||
alias_cache = {} # maps a path like "owner" or "brand" to its alias
|
||||
|
||||
for tok in sort_tokens:
|
||||
path, direction = _parse_sort_token(tok)
|
||||
parts = [p for p in path.split(".") if p]
|
||||
if not parts:
|
||||
continue
|
||||
|
||||
entity = Model
|
||||
current_mapper = mapper
|
||||
alias_path = []
|
||||
|
||||
# Walk relationships for all but the last part
|
||||
for rel_name in parts[:-1]:
|
||||
rel = current_mapper.relationships.get(rel_name)
|
||||
if rel is None:
|
||||
# invalid sort key; skip quietly or raise
|
||||
# raise ValueError(f"Unknown relationship {current_mapper.class_.__name__}.{rel_name}")
|
||||
entity = None
|
||||
break
|
||||
|
||||
alias_path.append(rel_name)
|
||||
key = ".".join(alias_path)
|
||||
|
||||
if key in alias_cache:
|
||||
entity_alias = alias_cache[key]
|
||||
else:
|
||||
# build an alias and join
|
||||
entity_alias = aliased(rel.mapper.class_)
|
||||
stmt = stmt.outerjoin(entity_alias, getattr(entity, rel.key))
|
||||
alias_cache[key] = entity_alias
|
||||
|
||||
entity = entity_alias
|
||||
current_mapper = inspect(rel.mapper.class_)
|
||||
|
||||
if entity is None:
|
||||
continue
|
||||
|
||||
col_name = parts[-1]
|
||||
# Validate final column
|
||||
if col_name not in current_mapper.columns:
|
||||
# raise ValueError(f"Unknown column {current_mapper.class_.__name__}.{col_name}")
|
||||
continue
|
||||
|
||||
col = getattr(entity, col_name) if entity is not Model else getattr(Model, col_name)
|
||||
stmt = stmt.order_by(col.desc() if direction == "desc" else col.asc())
|
||||
|
||||
return stmt
|
||||
|
||||
class CrudService:
|
||||
def __init__(self, session: Session, eager_policy=default_eager_policy):
|
||||
self.s = session
|
||||
|
|
@ -25,10 +108,44 @@ class CrudService:
|
|||
|
||||
def list(self, Model, spec: QuerySpec):
|
||||
stmt = build_query(Model, spec, self.eager_policy)
|
||||
count_stmt = stmt.with_only_columns(func.count()).order_by(None)
|
||||
total = self.s.execute(count_stmt).scalar_one()
|
||||
|
||||
simple_sorts, dotted_sorts = split_sort_tokens(spec.order_by)
|
||||
if dotted_sorts:
|
||||
stmt = _apply_dotted_ordering(stmt, Model, dotted_sorts)
|
||||
|
||||
# count query
|
||||
pk = getattr(Model, "id") # adjust if not 'id'
|
||||
count_base = stmt.with_only_columns(sa.distinct(pk)).order_by(None)
|
||||
total = self.s.execute(
|
||||
sa.select(sa.func.count()).select_from(count_base.subquery())
|
||||
).scalar_one()
|
||||
|
||||
if spec.page and spec.per_page:
|
||||
stmt = stmt.limit(spec.per_page).offset((spec.page - 1) * spec.per_page)
|
||||
offset = (spec.page - 1) * spec.per_page
|
||||
stmt = stmt.limit(spec.per_page).offset(offset)
|
||||
|
||||
# ---- ORDER BY handling ----
|
||||
mapper = inspect(Model)
|
||||
pk_cols = mapper.primary_key
|
||||
|
||||
# Gather all clauses added so far
|
||||
ordering = list(stmt._order_by_clauses)
|
||||
|
||||
# Append pk tie-breakers if not already present
|
||||
existing_cols = {
|
||||
str(ob.element if isinstance(ob, UnaryExpression) else ob)
|
||||
for ob in ordering
|
||||
}
|
||||
for c in pk_cols:
|
||||
if str(c) not in existing_cols:
|
||||
ordering.append(asc(c))
|
||||
|
||||
# Dedup *before* applying
|
||||
ordering = _dedup_order_by(ordering)
|
||||
|
||||
# Now wipe old order_bys and set once
|
||||
stmt = stmt.order_by(None).order_by(*ordering)
|
||||
|
||||
rows = self.s.execute(stmt).scalars().all()
|
||||
return rows, total
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,8 @@ if TYPE_CHECKING:
|
|||
from .users import User
|
||||
|
||||
from crudkit import CrudMixin
|
||||
from sqlalchemy import Boolean, ForeignKey, Identity, Index, Integer, Unicode, DateTime, text
|
||||
from sqlalchemy import Boolean, ForeignKey, Identity, Index, Integer, Unicode, DateTime, text, cast, func
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
import datetime
|
||||
|
||||
|
|
@ -86,7 +87,7 @@ class Inventory(db.Model, ImageAttachable, CrudMixin):
|
|||
|
||||
return f"<Inventory({', '.join(parts)})>"
|
||||
|
||||
@property
|
||||
@hybrid_property
|
||||
def identifier(self) -> str:
|
||||
if self.name:
|
||||
return f"Name: {self.name}"
|
||||
|
|
@ -97,6 +98,15 @@ class Inventory(db.Model, ImageAttachable, CrudMixin):
|
|||
else:
|
||||
return f"ID: {self.id}"
|
||||
|
||||
@identifier.expression
|
||||
def identifier(cls):
|
||||
return func.coalesce(
|
||||
cls.name,
|
||||
cls.barcode,
|
||||
cls.serial,
|
||||
cast(cls.id, Unicode)
|
||||
)
|
||||
|
||||
def serialize(self) -> dict[str, Any]:
|
||||
return {
|
||||
'id': self.id,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ if TYPE_CHECKING:
|
|||
from .image import Image
|
||||
|
||||
from crudkit import CrudMixin
|
||||
from sqlalchemy import Boolean, ForeignKey, Identity, Integer, Unicode, text
|
||||
from sqlalchemy import Boolean, ForeignKey, Identity, Integer, Unicode, text, func
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from . import db
|
||||
|
|
@ -35,10 +36,18 @@ class User(db.Model, ImageAttachable, CrudMixin):
|
|||
ui_eagerload = tuple()
|
||||
ui_order_cols = ('first_name', 'last_name',)
|
||||
|
||||
@property
|
||||
@hybrid_property
|
||||
def identifier(self) -> str:
|
||||
return f"{self.first_name or ''} {self.last_name or ''}{', ' + (''.join(word[0].upper() for word in self.title.split())) if self.title else ''}".strip()
|
||||
|
||||
@identifier.expression
|
||||
def identifier(cls):
|
||||
return func.concat(
|
||||
func.coalesce(cls.first_name, ''),
|
||||
' ',
|
||||
func.coalesce(cls.last_name, '')
|
||||
)
|
||||
|
||||
def __init__(self, first_name: Optional[str] = None, last_name: Optional[str] = None,
|
||||
title: Optional[str] = None,location_id: Optional[int] = None,
|
||||
supervisor_id: Optional[int] = None, staff: Optional[bool] = False,
|
||||
|
|
|
|||
|
|
@ -54,7 +54,6 @@
|
|||
{% macro dynamic_table(id, headers=none, fields=none, entry_route=None, title=None, page=1, per_page=15, offset=0,
|
||||
refresh_url=none, model=none, sort=none) %}
|
||||
<!-- Table Fragment -->
|
||||
|
||||
{% if title %}
|
||||
<label for="datatable-{{ id|default('table')|replace(' ', '-')|lower }}" class="form-label">{{ title }}</label>
|
||||
{% endif %}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue