Skip to content

Commit ab2341a

Browse files
committed
support order by non-attribute grouping expression on Aggregate
1 parent edc5095 commit ab2341a

File tree

2 files changed

+47
-35
lines changed

2 files changed

+47
-35
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -576,43 +576,47 @@ class Analyzer(
576576
filter
577577
}
578578

579-
case sort @ Sort(sortOrder, global,
580-
aggregate @ Aggregate(grouping, originalAggExprs, child))
579+
case sort @ Sort(sortOrder, global, aggregate: Aggregate)
581580
if aggregate.resolved && !sort.resolved =>
582581

583582
// Try resolving the ordering as though it is in the aggregate clause.
584583
try {
585-
val aliasedOrder = sortOrder.map(o => Alias(o.child, "aggOrder")())
586-
val aggregatedOrdering = Aggregate(grouping, aliasedOrder, child)
587-
val resolvedOperator: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate]
588-
def resolvedAggregateOrdering = resolvedOperator.aggregateExpressions
589-
590-
// Expressions that have an aggregate can be pushed down.
591-
val needsAggregate = resolvedAggregateOrdering.exists(containsAggregate)
592-
593-
// Attribute references, that are missing from the order but are present in the grouping
594-
// expressions can also be pushed down.
595-
val requiredAttributes = resolvedAggregateOrdering.map(_.references).reduce(_ ++ _)
596-
val missingAttributes = requiredAttributes -- aggregate.outputSet
597-
val validPushdownAttributes =
598-
missingAttributes.filter(a => grouping.exists(a.semanticEquals))
599-
600-
// If resolution was successful and we see the ordering either has an aggregate in it or
601-
// it is missing something that is projected away by the aggregate, add the ordering
602-
// the original aggregate operator.
603-
if (resolvedOperator.resolved && (needsAggregate || validPushdownAttributes.nonEmpty)) {
604-
val evaluatedOrderings: Seq[SortOrder] = sortOrder.zip(resolvedAggregateOrdering).map {
605-
case (order, evaluated) => order.copy(child = evaluated.toAttribute)
606-
}
607-
val aggExprsWithOrdering: Seq[NamedExpression] =
608-
resolvedAggregateOrdering ++ originalAggExprs
609-
610-
Project(aggregate.output,
611-
Sort(evaluatedOrderings, global,
612-
aggregate.copy(aggregateExpressions = aggExprsWithOrdering)))
613-
} else {
614-
sort
584+
val aliasedOrdering = sortOrder.map(o => Alias(o.child, "aggOrder")())
585+
val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering)
586+
val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate]
587+
val resolvedAliasedOrdering: Seq[Alias] =
588+
resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]]
589+
590+
// If we pass the analysis check, then the ordering expressions should only reference to
591+
// aggregate expressions or grouping expressions, and it's safe to push them down to
592+
// Aggregate.
593+
checkAnalysis(resolvedAggregate)
594+
595+
val originalAggExprs = aggregate.aggregateExpressions.map(
596+
CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
597+
598+
// If the ordering expression is same with original aggregate expression, we don't need
599+
// to push down this ordering expression and can reference the original aggregate
600+
// expression instead.
601+
val needsPushDown = ArrayBuffer.empty[NamedExpression]
602+
val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map {
603+
case (evaluated, order) =>
604+
val index = originalAggExprs.indexWhere {
605+
case Alias(child, _) => child semanticEquals evaluated.child
606+
case other => other semanticEquals evaluated.child
607+
}
608+
609+
if (index == -1) {
610+
needsPushDown += evaluated
611+
order.copy(child = evaluated.toAttribute)
612+
} else {
613+
order.copy(child = originalAggExprs(index).toAttribute)
614+
}
615615
}
616+
617+
Project(aggregate.output,
618+
Sort(evaluatedOrderings, global,
619+
aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown)))
616620
} catch {
617621
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
618622
// just return the original plan.
@@ -621,9 +625,7 @@ class Analyzer(
621625
}
622626

623627
protected def containsAggregate(condition: Expression): Boolean = {
624-
condition
625-
.collect { case ae: AggregateExpression => ae }
626-
.nonEmpty
628+
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
627629
}
628630
}
629631

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1745,4 +1745,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
17451745
df1.withColumn("diff", lit(0)))
17461746
}
17471747
}
1748+
1749+
test("SPARK-10389: order by non-attribute grouping expression on Aggregate") {
1750+
withTempTable("src") {
1751+
Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src")
1752+
checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY key + 1"),
1753+
Seq(Row(1), Row(1)))
1754+
checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY (key + 1) * 2"),
1755+
Seq(Row(1), Row(1)))
1756+
}
1757+
}
17481758
}

0 commit comments

Comments
 (0)