diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala index 621c063e5a7d8..52e24cff825d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala @@ -51,7 +51,7 @@ case class InsertAdaptiveSparkPlan( // Plan sub-queries recursively and pass in the shared stage cache for exchange reuse. // Fall back to non-AQE mode if AQE is not supported in any of the sub-queries. val subqueryMap = buildSubqueryMap(plan) - val planSubqueriesRule = PlanAdaptiveSubqueries(subqueryMap) + val planSubqueriesRule = PlanAdaptiveSubqueries(subqueryMap, conf.subqueryReuseEnabled) val preprocessingRules = Seq( planSubqueriesRule) // Run pre-processing rules. @@ -112,21 +112,24 @@ case class InsertAdaptiveSparkPlan( * For each sub-query, generate the adaptive execution plan for each sub-query by applying this * rule, or reuse the execution plan from another sub-query of the same semantics if possible. */ - private def buildSubqueryMap(plan: SparkPlan): Map[Long, SubqueryExec] = { - val subqueryMap = mutable.HashMap.empty[Long, SubqueryExec] + private def buildSubqueryMap(plan: SparkPlan): Map[Long, mutable.Queue[SubqueryExec]] = { + val reuseSubquery = conf.subqueryReuseEnabled + val subqueryMap = mutable.HashMap.empty[Long, mutable.Queue[SubqueryExec]] plan.foreach(_.expressions.foreach(_.foreach { case expressions.ScalarSubquery(p, _, exprId) - if !subqueryMap.contains(exprId.id) => + if !(reuseSubquery && subqueryMap.contains(exprId.id)) => val executedPlan = compileSubquery(p) verifyAdaptivePlan(executedPlan, p) val subquery = SubqueryExec(s"subquery${exprId.id}", executedPlan) - subqueryMap.put(exprId.id, subquery) + val subqueries = subqueryMap.getOrElseUpdate(exprId.id, mutable.Queue()) + subqueries.enqueue(subquery) case expressions.InSubquery(_, ListQuery(query, _, exprId, _)) - if !subqueryMap.contains(exprId.id) => + if !(reuseSubquery && subqueryMap.contains(exprId.id)) => val executedPlan = compileSubquery(query) verifyAdaptivePlan(executedPlan, query) val subquery = SubqueryExec(s"subquery#${exprId.id}", executedPlan) - subqueryMap.put(exprId.id, subquery) + val subqueries = subqueryMap.getOrElseUpdate(exprId.id, mutable.Queue()) + subqueries.enqueue(subquery) case _ => })) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala index f845b6b16ee3a..9389183d6c287 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala @@ -17,18 +17,27 @@ package org.apache.spark.sql.execution.adaptive +import scala.collection.mutable + import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, ListQuery, Literal} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution import org.apache.spark.sql.execution.{InSubqueryExec, SparkPlan, SubqueryExec} -case class PlanAdaptiveSubqueries(subqueryMap: Map[Long, SubqueryExec]) extends Rule[SparkPlan] { +case class PlanAdaptiveSubqueries( + subqueryMap: Map[Long, mutable.Queue[SubqueryExec]], + reuseSubquery: Boolean) extends Rule[SparkPlan] { + + private def subqueryExec(id: Long): SubqueryExec = { + val subqueries = subqueryMap(id) + if (reuseSubquery) subqueries.head else subqueries.dequeue() + } def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressions { case expressions.ScalarSubquery(_, _, exprId) => - execution.ScalarSubquery(subqueryMap(exprId.id), exprId) + execution.ScalarSubquery(subqueryExec(exprId.id), exprId) case expressions.InSubquery(values, ListQuery(_, _, exprId, _)) => val expr = if (values.length == 1) { values.head @@ -39,7 +48,7 @@ case class PlanAdaptiveSubqueries(subqueryMap: Map[Long, SubqueryExec]) extends } ) } - InSubqueryExec(expr, subqueryMap(exprId.id), exprId) + InSubqueryExec(expr, subqueryExec(exprId.id), exprId) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index ff8f94c68c5ee..9d2ce26246503 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort} -import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, FileSourceScanExec, InputAdapter, ReusedSubqueryExec, ScalarSubquery, SubqueryExec, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{BaseSubqueryExec, ColumnarToRowExec, ExecSubqueryExpression, FileSourceScanExec, InputAdapter, ReusedSubqueryExec, ScalarSubquery, SubqueryExec, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.FileScanRDD import org.apache.spark.sql.internal.SQLConf @@ -1646,4 +1646,24 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer(df, df2) checkAnswer(df, Nil) } + + test("SPARK-31206: AQE should not use same SubqueryExec when reuse is off") { + Seq(true, false).foreach { reuse => + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.SUBQUERY_REUSE_ENABLED.key -> reuse.toString) { + withTempView("t1", "t2") { + spark.range(10).selectExpr("id as a").createOrReplaceTempView("t1") + spark.range(10).selectExpr("id as b").createOrReplaceTempView("t2") + val plan = stripAQEPlan(spark.sql( + "select b as b1, b as b2 from (select a, (select b from t2 where b = 51) b from t1)") + .queryExecution.executedPlan) + val subqueries = collectInPlanAndSubqueries(plan) { + case subqury: BaseSubqueryExec => subqury + } + assert(subqueries.size == 2) + assert(subqueries.map(_.id).toSet.size === (if (reuse) 1 else 2)) + } + } + } + } }