Skip to content

Commit 1283c3d

Browse files
cloud-fanhvanhovell
authored andcommitted
[SPARK-20725][SQL] partial aggregate should behave correctly for sameResult
## What changes were proposed in this pull request? For aggregate function with `PartialMerge` or `Final` mode, the input is aggregate buffers instead of the actual children expressions. So the actual children expressions won't affect the result, we should normalize the expr id for them. ## How was this patch tested? a new regression test Author: Wenchen Fan <[email protected]> Closes #17964 from cloud-fan/tmp.
1 parent 3f98375 commit 1283c3d

File tree

3 files changed

+26
-4
lines changed

3 files changed

+26
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,22 @@ case class AggregateExpression(
105105
}
106106

107107
// We compute the same thing regardless of our final result.
108-
override lazy val canonicalized: Expression =
108+
override lazy val canonicalized: Expression = {
109+
val normalizedAggFunc = mode match {
110+
// For PartialMerge or Final mode, the input to the `aggregateFunction` is aggregate buffers,
111+
// and the actual children of `aggregateFunction` is not used, here we normalize the expr id.
112+
case PartialMerge | Final => aggregateFunction.transform {
113+
case a: AttributeReference => a.withExprId(ExprId(0))
114+
}
115+
case Partial | Complete => aggregateFunction
116+
}
117+
109118
AggregateExpression(
110-
aggregateFunction.canonicalized.asInstanceOf[AggregateFunction],
119+
normalizedAggFunc.canonicalized.asInstanceOf[AggregateFunction],
111120
mode,
112121
isDistinct,
113122
ExprId(0))
123+
}
114124

115125
override def children: Seq[Expression] = aggregateFunction :: Nil
116126
override def dataType: DataType = aggregateFunction.dataType

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
286286

287287
def recursiveTransform(arg: Any): AnyRef = arg match {
288288
case e: Expression => transformExpression(e)
289-
case Some(e: Expression) => Some(transformExpression(e))
289+
case Some(value) => Some(recursiveTransform(value))
290290
case m: Map[_, _] => m
291291
case d: DataType => d // Avoid unpacking Structs
292292
case seq: Traversable[_] => seq.map(recursiveTransform)
@@ -320,7 +320,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
320320

321321
productIterator.flatMap {
322322
case e: Expression => e :: Nil
323-
case Some(e: Expression) => e :: Nil
323+
case s: Some[_] => seqToExpressions(s.toSeq)
324324
case seq: Traversable[_] => seqToExpressions(seq)
325325
case other => Nil
326326
}.toSeq

sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818
package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.sql.{DataFrame, QueryTest}
21+
import org.apache.spark.sql.functions._
2122
import org.apache.spark.sql.test.SharedSQLContext
2223

2324
/**
2425
* Tests for the sameResult function for [[SparkPlan]]s.
2526
*/
2627
class SameResultSuite extends QueryTest with SharedSQLContext {
28+
import testImplicits._
2729

2830
test("FileSourceScanExec: different orders of data filters and partition filters") {
2931
withTempPath { path =>
@@ -46,4 +48,14 @@ class SameResultSuite extends QueryTest with SharedSQLContext {
4648
df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get
4749
.asInstanceOf[FileSourceScanExec]
4850
}
51+
52+
test("SPARK-20725: partial aggregate should behave correctly for sameResult") {
53+
val df1 = spark.range(10).agg(sum($"id"))
54+
val df2 = spark.range(10).agg(sum($"id"))
55+
assert(df1.queryExecution.executedPlan.sameResult(df2.queryExecution.executedPlan))
56+
57+
val df3 = spark.range(10).agg(sumDistinct($"id"))
58+
val df4 = spark.range(10).agg(sumDistinct($"id"))
59+
assert(df3.queryExecution.executedPlan.sameResult(df4.queryExecution.executedPlan))
60+
}
4961
}

0 commit comments

Comments
 (0)