Finally got order_by fixed so sorting actually works properly with relationships and "identifier."
This commit is contained in:
parent
8100d221a1
commit
f47fb6b505
6 changed files with 179 additions and 18 deletions
|
|
@ -1,10 +1,93 @@
|
|||
from sqlalchemy import func
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import func, asc
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, aliased
|
||||
from sqlalchemy.inspection import inspect
|
||||
from sqlalchemy.sql.elements import UnaryExpression
|
||||
|
||||
from .dsl import QuerySpec, build_query
|
||||
from .dsl import QuerySpec, build_query, split_sort_tokens
|
||||
from .eager import default_eager_policy
|
||||
|
||||
def _dedup_order_by(ordering):
|
||||
seen = set()
|
||||
result = []
|
||||
for ob in ordering:
|
||||
col = ob.element if isinstance(ob, UnaryExpression) else ob
|
||||
key = f"{col}-{getattr(ob, 'modifier', '')}-{getattr(ob, 'operator', '')}"
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
result.append(ob)
|
||||
return result
|
||||
|
||||
def _parse_sort_token(token: str):
|
||||
token = token.strip()
|
||||
direction = "asc"
|
||||
if token.startswith('-'):
|
||||
direction = "desc"
|
||||
token = token[1:]
|
||||
if ":" in token:
|
||||
key, dirpart = token.rsplit(":", 1)
|
||||
direction = "desc" if dirpart.lower().startswith("d") else "asc"
|
||||
return key, direction
|
||||
return token, direction
|
||||
|
||||
def _apply_dotted_ordering(stmt, Model, sort_tokens):
|
||||
"""
|
||||
stmt: a select(Model) statement
|
||||
sort_tokens: list[str] like ["owner.identifier", "-brand.name"]
|
||||
Returns: (stmt, alias_cache)
|
||||
"""
|
||||
mapper = inspect(Model)
|
||||
alias_cache = {} # maps a path like "owner" or "brand" to its alias
|
||||
|
||||
for tok in sort_tokens:
|
||||
path, direction = _parse_sort_token(tok)
|
||||
parts = [p for p in path.split(".") if p]
|
||||
if not parts:
|
||||
continue
|
||||
|
||||
entity = Model
|
||||
current_mapper = mapper
|
||||
alias_path = []
|
||||
|
||||
# Walk relationships for all but the last part
|
||||
for rel_name in parts[:-1]:
|
||||
rel = current_mapper.relationships.get(rel_name)
|
||||
if rel is None:
|
||||
# invalid sort key; skip quietly or raise
|
||||
# raise ValueError(f"Unknown relationship {current_mapper.class_.__name__}.{rel_name}")
|
||||
entity = None
|
||||
break
|
||||
|
||||
alias_path.append(rel_name)
|
||||
key = ".".join(alias_path)
|
||||
|
||||
if key in alias_cache:
|
||||
entity_alias = alias_cache[key]
|
||||
else:
|
||||
# build an alias and join
|
||||
entity_alias = aliased(rel.mapper.class_)
|
||||
stmt = stmt.outerjoin(entity_alias, getattr(entity, rel.key))
|
||||
alias_cache[key] = entity_alias
|
||||
|
||||
entity = entity_alias
|
||||
current_mapper = inspect(rel.mapper.class_)
|
||||
|
||||
if entity is None:
|
||||
continue
|
||||
|
||||
col_name = parts[-1]
|
||||
# Validate final column
|
||||
if col_name not in current_mapper.columns:
|
||||
# raise ValueError(f"Unknown column {current_mapper.class_.__name__}.{col_name}")
|
||||
continue
|
||||
|
||||
col = getattr(entity, col_name) if entity is not Model else getattr(Model, col_name)
|
||||
stmt = stmt.order_by(col.desc() if direction == "desc" else col.asc())
|
||||
|
||||
return stmt
|
||||
|
||||
class CrudService:
|
||||
def __init__(self, session: Session, eager_policy=default_eager_policy):
|
||||
self.s = session
|
||||
|
|
@ -25,10 +108,44 @@ class CrudService:
|
|||
|
||||
def list(self, Model, spec: QuerySpec):
|
||||
stmt = build_query(Model, spec, self.eager_policy)
|
||||
count_stmt = stmt.with_only_columns(func.count()).order_by(None)
|
||||
total = self.s.execute(count_stmt).scalar_one()
|
||||
|
||||
simple_sorts, dotted_sorts = split_sort_tokens(spec.order_by)
|
||||
if dotted_sorts:
|
||||
stmt = _apply_dotted_ordering(stmt, Model, dotted_sorts)
|
||||
|
||||
# count query
|
||||
pk = getattr(Model, "id") # adjust if not 'id'
|
||||
count_base = stmt.with_only_columns(sa.distinct(pk)).order_by(None)
|
||||
total = self.s.execute(
|
||||
sa.select(sa.func.count()).select_from(count_base.subquery())
|
||||
).scalar_one()
|
||||
|
||||
if spec.page and spec.per_page:
|
||||
stmt = stmt.limit(spec.per_page).offset((spec.page - 1) * spec.per_page)
|
||||
offset = (spec.page - 1) * spec.per_page
|
||||
stmt = stmt.limit(spec.per_page).offset(offset)
|
||||
|
||||
# ---- ORDER BY handling ----
|
||||
mapper = inspect(Model)
|
||||
pk_cols = mapper.primary_key
|
||||
|
||||
# Gather all clauses added so far
|
||||
ordering = list(stmt._order_by_clauses)
|
||||
|
||||
# Append pk tie-breakers if not already present
|
||||
existing_cols = {
|
||||
str(ob.element if isinstance(ob, UnaryExpression) else ob)
|
||||
for ob in ordering
|
||||
}
|
||||
for c in pk_cols:
|
||||
if str(c) not in existing_cols:
|
||||
ordering.append(asc(c))
|
||||
|
||||
# Dedup *before* applying
|
||||
ordering = _dedup_order_by(ordering)
|
||||
|
||||
# Now wipe old order_bys and set once
|
||||
stmt = stmt.order_by(None).order_by(*ordering)
|
||||
|
||||
rows = self.s.execute(stmt).scalars().all()
|
||||
return rows, total
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue