Finally got order_by fixed so sorting actually works properly with relationships and "identifier."
This commit is contained in:
parent
8100d221a1
commit
f47fb6b505
6 changed files with 179 additions and 18 deletions
|
|
@ -74,6 +74,17 @@ def _related_predicate(Model, path_parts, op_key, value):
|
||||||
# wrap at this hop using the *attribute*, not the RelationshipProperty
|
# wrap at this hop using the *attribute*, not the RelationshipProperty
|
||||||
return attr.any(pred) if rel.uselist else attr.has(pred)
|
return attr.any(pred) if rel.uselist else attr.has(pred)
|
||||||
|
|
||||||
|
def split_sort_tokens(tokens):
|
||||||
|
simple, dotted = [], []
|
||||||
|
for tok in (tokens or []):
|
||||||
|
if not tok:
|
||||||
|
continue
|
||||||
|
key = tok.lstrip("-")
|
||||||
|
if ":" in key:
|
||||||
|
key = key.split(":", 1)[0]
|
||||||
|
(dotted if "." in key else simple).append(tok)
|
||||||
|
return simple, dotted
|
||||||
|
|
||||||
def build_query(Model, spec: QuerySpec, eager_policy=None):
|
def build_query(Model, spec: QuerySpec, eager_policy=None):
|
||||||
stmt = select(Model)
|
stmt = select(Model)
|
||||||
|
|
||||||
|
|
@ -102,11 +113,25 @@ def build_query(Model, spec: QuerySpec, eager_policy=None):
|
||||||
continue
|
continue
|
||||||
stmt = stmt.where(FILTER_OPS[op_key](col, val) if op_key else (col == val))
|
stmt = stmt.where(FILTER_OPS[op_key](col, val) if op_key else (col == val))
|
||||||
|
|
||||||
# order_by
|
simple_sorts, _ = split_sort_tokens(spec.order_by)
|
||||||
for key in spec.order_by:
|
|
||||||
desc_ = key.startswith("-")
|
for token in simple_sorts:
|
||||||
col = getattr(Model, key[1:] if desc_ else key)
|
direction = "asc"
|
||||||
stmt = stmt.order_by(desc(col) if desc_ else asc(col))
|
key = token
|
||||||
|
if token.startswith("-"):
|
||||||
|
direction = "desc"
|
||||||
|
key = token[1:]
|
||||||
|
if ":" in key:
|
||||||
|
key, d = key.rsplit(":", 1)
|
||||||
|
direction = "desc" if d.lower().startswith("d") else "asc"
|
||||||
|
|
||||||
|
if "." in key:
|
||||||
|
continue
|
||||||
|
|
||||||
|
col = getattr(Model, key, None)
|
||||||
|
if col is None:
|
||||||
|
continue
|
||||||
|
stmt = stmt.order_by(desc(col) if direction == "desc" else asc(col))
|
||||||
|
|
||||||
if not spec.order_by and spec.page and spec.per_page:
|
if not spec.order_by and spec.page and spec.per_page:
|
||||||
pk_cols = inspect(Model).primary_key
|
pk_cols = inspect(Model).primary_key
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from flask import Blueprint, request, render_template, abort, make_response
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import scoped_session
|
from sqlalchemy.orm import scoped_session
|
||||||
from sqlalchemy.inspection import inspect
|
from sqlalchemy.inspection import inspect
|
||||||
|
from sqlalchemy.sql.elements import UnaryExpression
|
||||||
from sqlalchemy.sql.sqltypes import Integer, Boolean, Date, DateTime, Float, Numeric
|
from sqlalchemy.sql.sqltypes import Integer, Boolean, Date, DateTime, Float, Numeric
|
||||||
|
|
||||||
from ..dsl import QuerySpec
|
from ..dsl import QuerySpec
|
||||||
|
|
@ -115,7 +116,7 @@ def make_fragments_blueprint(db_session_factory, registry: Dict[str, Any], *, na
|
||||||
page = request.args.get("page", type=int) or 1
|
page = request.args.get("page", type=int) or 1
|
||||||
per_page = request.args.get("per_page", type=int) or 20
|
per_page = request.args.get("per_page", type=int) or 20
|
||||||
|
|
||||||
expand = _collect_expand_from_paths(fields)
|
expand = _collect_expand_from_paths(fields + ([sort.split(":")[0]] if sort else []))
|
||||||
spec = QuerySpec(filters=filters, order_by=[sort] if sort else [], page=page, per_page=per_page, expand=expand)
|
spec = QuerySpec(filters=filters, order_by=[sort] if sort else [], page=page, per_page=per_page, expand=expand)
|
||||||
s = session(); svc = CrudService(s, default_eager_policy)
|
s = session(); svc = CrudService(s, default_eager_policy)
|
||||||
rows, _ = svc.list(Model, spec)
|
rows, _ = svc.list(Model, spec)
|
||||||
|
|
@ -134,7 +135,7 @@ def make_fragments_blueprint(db_session_factory, registry: Dict[str, Any], *, na
|
||||||
sort = request.args.get("sort")
|
sort = request.args.get("sort")
|
||||||
fields_csv = request.args.get("fields_csv") or "id,name"
|
fields_csv = request.args.get("fields_csv") or "id,name"
|
||||||
fields = _paths_from_csv(fields_csv)
|
fields = _paths_from_csv(fields_csv)
|
||||||
expand = _collect_expand_from_paths(fields)
|
expand = _collect_expand_from_paths(fields + ([sort.split(":")[0]] if sort else []))
|
||||||
|
|
||||||
spec = QuerySpec(filters=filters, order_by=[sort] if sort else [], page=page, per_page=per_page, expand=expand)
|
spec = QuerySpec(filters=filters, order_by=[sort] if sort else [], page=page, per_page=per_page, expand=expand)
|
||||||
s = session(); svc = CrudService(s, default_eager_policy)
|
s = session(); svc = CrudService(s, default_eager_policy)
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,93 @@
|
||||||
from sqlalchemy import func
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy import func, asc
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session, aliased
|
||||||
|
from sqlalchemy.inspection import inspect
|
||||||
|
from sqlalchemy.sql.elements import UnaryExpression
|
||||||
|
|
||||||
from .dsl import QuerySpec, build_query
|
from .dsl import QuerySpec, build_query, split_sort_tokens
|
||||||
from .eager import default_eager_policy
|
from .eager import default_eager_policy
|
||||||
|
|
||||||
|
def _dedup_order_by(ordering):
|
||||||
|
seen = set()
|
||||||
|
result = []
|
||||||
|
for ob in ordering:
|
||||||
|
col = ob.element if isinstance(ob, UnaryExpression) else ob
|
||||||
|
key = f"{col}-{getattr(ob, 'modifier', '')}-{getattr(ob, 'operator', '')}"
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
result.append(ob)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _parse_sort_token(token: str):
|
||||||
|
token = token.strip()
|
||||||
|
direction = "asc"
|
||||||
|
if token.startswith('-'):
|
||||||
|
direction = "desc"
|
||||||
|
token = token[1:]
|
||||||
|
if ":" in token:
|
||||||
|
key, dirpart = token.rsplit(":", 1)
|
||||||
|
direction = "desc" if dirpart.lower().startswith("d") else "asc"
|
||||||
|
return key, direction
|
||||||
|
return token, direction
|
||||||
|
|
||||||
|
def _apply_dotted_ordering(stmt, Model, sort_tokens):
|
||||||
|
"""
|
||||||
|
stmt: a select(Model) statement
|
||||||
|
sort_tokens: list[str] like ["owner.identifier", "-brand.name"]
|
||||||
|
Returns: (stmt, alias_cache)
|
||||||
|
"""
|
||||||
|
mapper = inspect(Model)
|
||||||
|
alias_cache = {} # maps a path like "owner" or "brand" to its alias
|
||||||
|
|
||||||
|
for tok in sort_tokens:
|
||||||
|
path, direction = _parse_sort_token(tok)
|
||||||
|
parts = [p for p in path.split(".") if p]
|
||||||
|
if not parts:
|
||||||
|
continue
|
||||||
|
|
||||||
|
entity = Model
|
||||||
|
current_mapper = mapper
|
||||||
|
alias_path = []
|
||||||
|
|
||||||
|
# Walk relationships for all but the last part
|
||||||
|
for rel_name in parts[:-1]:
|
||||||
|
rel = current_mapper.relationships.get(rel_name)
|
||||||
|
if rel is None:
|
||||||
|
# invalid sort key; skip quietly or raise
|
||||||
|
# raise ValueError(f"Unknown relationship {current_mapper.class_.__name__}.{rel_name}")
|
||||||
|
entity = None
|
||||||
|
break
|
||||||
|
|
||||||
|
alias_path.append(rel_name)
|
||||||
|
key = ".".join(alias_path)
|
||||||
|
|
||||||
|
if key in alias_cache:
|
||||||
|
entity_alias = alias_cache[key]
|
||||||
|
else:
|
||||||
|
# build an alias and join
|
||||||
|
entity_alias = aliased(rel.mapper.class_)
|
||||||
|
stmt = stmt.outerjoin(entity_alias, getattr(entity, rel.key))
|
||||||
|
alias_cache[key] = entity_alias
|
||||||
|
|
||||||
|
entity = entity_alias
|
||||||
|
current_mapper = inspect(rel.mapper.class_)
|
||||||
|
|
||||||
|
if entity is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
col_name = parts[-1]
|
||||||
|
# Validate final column
|
||||||
|
if col_name not in current_mapper.columns:
|
||||||
|
# raise ValueError(f"Unknown column {current_mapper.class_.__name__}.{col_name}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
col = getattr(entity, col_name) if entity is not Model else getattr(Model, col_name)
|
||||||
|
stmt = stmt.order_by(col.desc() if direction == "desc" else col.asc())
|
||||||
|
|
||||||
|
return stmt
|
||||||
|
|
||||||
class CrudService:
|
class CrudService:
|
||||||
def __init__(self, session: Session, eager_policy=default_eager_policy):
|
def __init__(self, session: Session, eager_policy=default_eager_policy):
|
||||||
self.s = session
|
self.s = session
|
||||||
|
|
@ -25,10 +108,44 @@ class CrudService:
|
||||||
|
|
||||||
def list(self, Model, spec: QuerySpec):
|
def list(self, Model, spec: QuerySpec):
|
||||||
stmt = build_query(Model, spec, self.eager_policy)
|
stmt = build_query(Model, spec, self.eager_policy)
|
||||||
count_stmt = stmt.with_only_columns(func.count()).order_by(None)
|
|
||||||
total = self.s.execute(count_stmt).scalar_one()
|
simple_sorts, dotted_sorts = split_sort_tokens(spec.order_by)
|
||||||
|
if dotted_sorts:
|
||||||
|
stmt = _apply_dotted_ordering(stmt, Model, dotted_sorts)
|
||||||
|
|
||||||
|
# count query
|
||||||
|
pk = getattr(Model, "id") # adjust if not 'id'
|
||||||
|
count_base = stmt.with_only_columns(sa.distinct(pk)).order_by(None)
|
||||||
|
total = self.s.execute(
|
||||||
|
sa.select(sa.func.count()).select_from(count_base.subquery())
|
||||||
|
).scalar_one()
|
||||||
|
|
||||||
if spec.page and spec.per_page:
|
if spec.page and spec.per_page:
|
||||||
stmt = stmt.limit(spec.per_page).offset((spec.page - 1) * spec.per_page)
|
offset = (spec.page - 1) * spec.per_page
|
||||||
|
stmt = stmt.limit(spec.per_page).offset(offset)
|
||||||
|
|
||||||
|
# ---- ORDER BY handling ----
|
||||||
|
mapper = inspect(Model)
|
||||||
|
pk_cols = mapper.primary_key
|
||||||
|
|
||||||
|
# Gather all clauses added so far
|
||||||
|
ordering = list(stmt._order_by_clauses)
|
||||||
|
|
||||||
|
# Append pk tie-breakers if not already present
|
||||||
|
existing_cols = {
|
||||||
|
str(ob.element if isinstance(ob, UnaryExpression) else ob)
|
||||||
|
for ob in ordering
|
||||||
|
}
|
||||||
|
for c in pk_cols:
|
||||||
|
if str(c) not in existing_cols:
|
||||||
|
ordering.append(asc(c))
|
||||||
|
|
||||||
|
# Dedup *before* applying
|
||||||
|
ordering = _dedup_order_by(ordering)
|
||||||
|
|
||||||
|
# Now wipe old order_bys and set once
|
||||||
|
stmt = stmt.order_by(None).order_by(*ordering)
|
||||||
|
|
||||||
rows = self.s.execute(stmt).scalars().all()
|
rows = self.s.execute(stmt).scalars().all()
|
||||||
return rows, total
|
return rows, total
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,8 @@ if TYPE_CHECKING:
|
||||||
from .users import User
|
from .users import User
|
||||||
|
|
||||||
from crudkit import CrudMixin
|
from crudkit import CrudMixin
|
||||||
from sqlalchemy import Boolean, ForeignKey, Identity, Index, Integer, Unicode, DateTime, text
|
from sqlalchemy import Boolean, ForeignKey, Identity, Index, Integer, Unicode, DateTime, text, cast, func
|
||||||
|
from sqlalchemy.ext.hybrid import hybrid_property
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
|
@ -86,7 +87,7 @@ class Inventory(db.Model, ImageAttachable, CrudMixin):
|
||||||
|
|
||||||
return f"<Inventory({', '.join(parts)})>"
|
return f"<Inventory({', '.join(parts)})>"
|
||||||
|
|
||||||
@property
|
@hybrid_property
|
||||||
def identifier(self) -> str:
|
def identifier(self) -> str:
|
||||||
if self.name:
|
if self.name:
|
||||||
return f"Name: {self.name}"
|
return f"Name: {self.name}"
|
||||||
|
|
@ -97,6 +98,15 @@ class Inventory(db.Model, ImageAttachable, CrudMixin):
|
||||||
else:
|
else:
|
||||||
return f"ID: {self.id}"
|
return f"ID: {self.id}"
|
||||||
|
|
||||||
|
@identifier.expression
|
||||||
|
def identifier(cls):
|
||||||
|
return func.coalesce(
|
||||||
|
cls.name,
|
||||||
|
cls.barcode,
|
||||||
|
cls.serial,
|
||||||
|
cast(cls.id, Unicode)
|
||||||
|
)
|
||||||
|
|
||||||
def serialize(self) -> dict[str, Any]:
|
def serialize(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
'id': self.id,
|
'id': self.id,
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,8 @@ if TYPE_CHECKING:
|
||||||
from .image import Image
|
from .image import Image
|
||||||
|
|
||||||
from crudkit import CrudMixin
|
from crudkit import CrudMixin
|
||||||
from sqlalchemy import Boolean, ForeignKey, Identity, Integer, Unicode, text
|
from sqlalchemy import Boolean, ForeignKey, Identity, Integer, Unicode, text, func
|
||||||
|
from sqlalchemy.ext.hybrid import hybrid_property
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from . import db
|
from . import db
|
||||||
|
|
@ -35,10 +36,18 @@ class User(db.Model, ImageAttachable, CrudMixin):
|
||||||
ui_eagerload = tuple()
|
ui_eagerload = tuple()
|
||||||
ui_order_cols = ('first_name', 'last_name',)
|
ui_order_cols = ('first_name', 'last_name',)
|
||||||
|
|
||||||
@property
|
@hybrid_property
|
||||||
def identifier(self) -> str:
|
def identifier(self) -> str:
|
||||||
return f"{self.first_name or ''} {self.last_name or ''}{', ' + (''.join(word[0].upper() for word in self.title.split())) if self.title else ''}".strip()
|
return f"{self.first_name or ''} {self.last_name or ''}{', ' + (''.join(word[0].upper() for word in self.title.split())) if self.title else ''}".strip()
|
||||||
|
|
||||||
|
@identifier.expression
|
||||||
|
def identifier(cls):
|
||||||
|
return func.concat(
|
||||||
|
func.coalesce(cls.first_name, ''),
|
||||||
|
' ',
|
||||||
|
func.coalesce(cls.last_name, '')
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, first_name: Optional[str] = None, last_name: Optional[str] = None,
|
def __init__(self, first_name: Optional[str] = None, last_name: Optional[str] = None,
|
||||||
title: Optional[str] = None,location_id: Optional[int] = None,
|
title: Optional[str] = None,location_id: Optional[int] = None,
|
||||||
supervisor_id: Optional[int] = None, staff: Optional[bool] = False,
|
supervisor_id: Optional[int] = None, staff: Optional[bool] = False,
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,6 @@
|
||||||
{% macro dynamic_table(id, headers=none, fields=none, entry_route=None, title=None, page=1, per_page=15, offset=0,
|
{% macro dynamic_table(id, headers=none, fields=none, entry_route=None, title=None, page=1, per_page=15, offset=0,
|
||||||
refresh_url=none, model=none, sort=none) %}
|
refresh_url=none, model=none, sort=none) %}
|
||||||
<!-- Table Fragment -->
|
<!-- Table Fragment -->
|
||||||
|
|
||||||
{% if title %}
|
{% if title %}
|
||||||
<label for="datatable-{{ id|default('table')|replace(' ', '-')|lower }}" class="form-label">{{ title }}</label>
|
<label for="datatable-{{ id|default('table')|replace(' ', '-')|lower }}" class="form-label">{{ title }}</label>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue