File tree Expand file tree Collapse file tree 3 files changed +26
-4
lines changed
catalyst/src/main/scala/org/apache/spark/sql/catalyst
core/src/test/scala/org/apache/spark/sql/execution Expand file tree Collapse file tree 3 files changed +26
-4
lines changed Original file line number Diff line number Diff line change @@ -105,12 +105,22 @@ case class AggregateExpression(
105105 }
106106
107107 // We compute the same thing regardless of our final result.
108- override lazy val canonicalized : Expression =
108+ override lazy val canonicalized : Expression = {
109+ val normalizedAggFunc = mode match {
110+ // For PartialMerge or Final mode, the input to the `aggregateFunction` is aggregate buffers,
111+ // and the actual children of `aggregateFunction` is not used, here we normalize the expr id.
112+ case PartialMerge | Final => aggregateFunction.transform {
113+ case a : AttributeReference => a.withExprId(ExprId (0 ))
114+ }
115+ case Partial | Complete => aggregateFunction
116+ }
117+
109118 AggregateExpression (
110- aggregateFunction .canonicalized.asInstanceOf [AggregateFunction ],
119+ normalizedAggFunc .canonicalized.asInstanceOf [AggregateFunction ],
111120 mode,
112121 isDistinct,
113122 ExprId (0 ))
123+ }
114124
115125 override def children : Seq [Expression ] = aggregateFunction :: Nil
116126 override def dataType : DataType = aggregateFunction.dataType
Original file line number Diff line number Diff line change @@ -286,7 +286,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
286286
287287 def recursiveTransform (arg : Any ): AnyRef = arg match {
288288 case e : Expression => transformExpression(e)
289- case Some (e : Expression ) => Some (transformExpression(e ))
289+ case Some (value ) => Some (recursiveTransform(value ))
290290 case m : Map [_, _] => m
291291 case d : DataType => d // Avoid unpacking Structs
292292 case seq : Traversable [_] => seq.map(recursiveTransform)
@@ -320,7 +320,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
320320
321321 productIterator.flatMap {
322322 case e : Expression => e :: Nil
323- case Some ( e : Expression ) => e :: Nil
323+ case s : Some [_] => seqToExpressions(s.toSeq)
324324 case seq : Traversable [_] => seqToExpressions(seq)
325325 case other => Nil
326326 }.toSeq
Original file line number Diff line number Diff line change 1818package org .apache .spark .sql .execution
1919
2020import org .apache .spark .sql .{DataFrame , QueryTest }
21+ import org .apache .spark .sql .functions ._
2122import org .apache .spark .sql .test .SharedSQLContext
2223
2324/**
2425 * Tests for the sameResult function for [[SparkPlan ]]s.
2526 */
2627class SameResultSuite extends QueryTest with SharedSQLContext {
28+ import testImplicits ._
2729
2830 test(" FileSourceScanExec: different orders of data filters and partition filters" ) {
2931 withTempPath { path =>
@@ -46,4 +48,14 @@ class SameResultSuite extends QueryTest with SharedSQLContext {
4648 df.queryExecution.sparkPlan.find(_.isInstanceOf [FileSourceScanExec ]).get
4749 .asInstanceOf [FileSourceScanExec ]
4850 }
51+
52+ test(" SPARK-20725: partial aggregate should behave correctly for sameResult" ) {
53+ val df1 = spark.range(10 ).agg(sum($" id" ))
54+ val df2 = spark.range(10 ).agg(sum($" id" ))
55+ assert(df1.queryExecution.executedPlan.sameResult(df2.queryExecution.executedPlan))
56+
57+ val df3 = spark.range(10 ).agg(sumDistinct($" id" ))
58+ val df4 = spark.range(10 ).agg(sumDistinct($" id" ))
59+ assert(df3.queryExecution.executedPlan.sameResult(df4.queryExecution.executedPlan))
60+ }
4961}
You can’t perform that action at this time.
0 commit comments