diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index eacd35b0771f..731e7daf2acd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -674,6 +674,8 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override def doCanonicalize(): SparkPlan = child.canonicalized + @transient private lazy val relationFuture: Future[Array[InternalRow]] = { // relationFuture is used in "doExecute". Therefore we can get the execution id correctly here. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index e8d1eccd329d..5916cbb7e681 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -25,7 +25,7 @@ import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.{AccumulatorSuite, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils -import org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.execution.{aggregate, ScalarSubquery, SubqueryExec} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.datasources.FilePartition import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} @@ -113,6 +113,33 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("Reuse Subquery") { + Seq(true, false).foreach { reuse => + withSQLConf(SQLConf.SUBQUERY_REUSE_ENABLED.key -> reuse.toString) { + val df = sql( + """ + |SELECT (SELECT avg(key) FROM testData) + (SELECT avg(key) FROM testData) + |FROM testData + |LIMIT 1 + """.stripMargin) + + import scala.collection.mutable.ArrayBuffer + val subqueries = ArrayBuffer[SubqueryExec]() + df.queryExecution.executedPlan.transformAllExpressions { + case s @ ScalarSubquery(plan: SubqueryExec, _) => + subqueries += plan + s + } + + if (reuse) { + assert(subqueries.distinct.size == 1, "Subquery reusing not working correctly") + } else { + assert(subqueries.distinct.size == 2, "There should be 2 subqueries when not reusing") + } + } + } + } + test("SPARK-6743: no columns from cache") { Seq( (83, 0, 38),