diff --git a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala index c02e48c9815fa..f8a6f1d0d8cbb 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala @@ -25,4 +25,3 @@ package org.apache.spark * (may be inexact due to use of compressed map statuses) */ private[spark] class MapOutputStatistics(val shuffleId: Int, val bytesByPartitionId: Array[Long]) - extends Serializable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index f1044e1a32f98..f625581d40533 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -94,11 +94,12 @@ class QueryExecution( * row format conversions as needed. */ protected def prepareForExecution(plan: SparkPlan): SparkPlan = { - if (sparkSession.sessionState.conf.adaptiveExecutionEnabled) { - adaptivePreparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp)} + val rules = if (sparkSession.sessionState.conf.adaptiveExecutionEnabled) { + adaptivePreparations } else { - preparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp)} + preparations } + rules.foldLeft(plan) { case (sp, rule) => rule.apply(sp)} } /** A sequence of rules that will be applied in order to the physical plan before execution. */ @@ -109,14 +110,16 @@ class QueryExecution( ReuseExchange(sparkSession.sessionState.conf), ReuseSubquery(sparkSession.sessionState.conf)) + // With adaptive execution, whole stage codegen will be done inside `QueryStageExecutor`. protected def adaptivePreparations: Seq[Rule[SparkPlan]] = Seq( PlanSubqueries(sparkSession), EnsureRequirements(sparkSession.sessionState.conf), + ReuseExchange(sparkSession.sessionState.conf), ReuseSubquery(sparkSession.sessionState.conf), // PlanQueryStage needs to be the last rule because it divides the plan into multiple sub-trees - // by inserting leaf node QueryStageInput. Transforming the plan after applying this rule will + // by inserting leaf node QueryStage. Transforming the plan after applying this rule will // only transform node in a sub-tree. - PlanQueryStage(sparkSession.sessionState.conf)) + PlanQueryStage(sparkSession)) def simpleString: String = withRedaction { val concat = new StringConcat() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index ca46b1e940e46..53b5fc305a330 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.execution.adaptive.QueryStageInput +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlan, QueryStage} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.SQLMetricInfo import org.apache.spark.sql.internal.SQLConf @@ -53,7 +53,8 @@ private[execution] object SparkPlanInfo { def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = { val children = plan match { case ReusedExchangeExec(_, child) => child :: Nil - case i: QueryStageInput => i.childStage :: Nil + case a: AdaptiveSparkPlan => a.resultStage.plan :: Nil + case stage: QueryStage => stage.plan :: Nil case _ => plan.children ++ plan.subqueries } val metrics = plan.metrics.toSeq.map { case (key, metric) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlan.scala new file mode 100644 index 0000000000000..32afea26f239d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlan.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import java.util.concurrent.CountDownLatch + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, SparkPlanInfo, SQLExecution} +import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate + +/** + * A root node to trigger query stages and execute the query plan adaptively. It incrementally + * updates the query plan when a query stage is materialized and provides accurate runtime + * statistics. + */ +case class AdaptiveSparkPlan(initialPlan: ResultQueryStage, session: SparkSession) + extends LeafExecNode{ + + override def output: Seq[Attribute] = initialPlan.output + + @volatile private var currentQueryStage: QueryStage = initialPlan + @volatile private var error: Throwable = null + private val readyLock = new CountDownLatch(1) + + private def replaceStage(oldStage: QueryStage, newStage: QueryStage): QueryStage = { + if (oldStage.id == newStage.id) { + newStage + } else { + val newPlanForOldStage = oldStage.plan.transform { + case q: QueryStage => replaceStage(q, newStage) + } + oldStage.withNewPlan(newPlanForOldStage) + } + } + + private def createCallback(executionId: Option[Long]): QueryStageTriggerCallback = { + new QueryStageTriggerCallback { + override def onStageUpdated(stage: QueryStage): Unit = { + updateCurrentQueryStage(stage, executionId) + if (stage.isInstanceOf[ResultQueryStage]) readyLock.countDown() + } + + override def onStagePlanningFailed(stage: QueryStage, e: Throwable): Unit = { + error = new RuntimeException( + s""" + |Fail to plan stage ${stage.id}: + |${stage.plan.treeString} + """.stripMargin, e) + readyLock.countDown() + } + + override def onStageMaterializingFailed(stage: QueryStage, e: Throwable): Unit = { + error = new RuntimeException( + s""" + |Fail to materialize stage ${stage.id}: + |${stage.plan.treeString} + """.stripMargin, e) + readyLock.countDown() + } + + override def onError(e: Throwable): Unit = { + error = e + readyLock.countDown() + } + } + } + + private def updateCurrentQueryStage(newStage: QueryStage, executionId: Option[Long]): Unit = { + currentQueryStage = replaceStage(currentQueryStage, newStage) + executionId.foreach { id => + session.sparkContext.listenerBus.post(SparkListenerSQLAdaptiveExecutionUpdate( + id, + SQLExecution.getQueryExecution(id).toString, + SparkPlanInfo.fromSparkPlan(currentQueryStage))) + } + } + + def resultStage: ResultQueryStage = { + if (readyLock.getCount > 0) { + val sc = session.sparkContext + val executionId = Option(sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)).map(_.toLong) + val trigger = new QueryStageTrigger(session, createCallback(executionId)) + trigger.start() + trigger.trigger(initialPlan) + readyLock.await() + trigger.stop() + } + + if (error != null) throw error + currentQueryStage.asInstanceOf[ResultQueryStage] + } + + override def executeCollect(): Array[InternalRow] = resultStage.executeCollect() + override def executeTake(n: Int): Array[InternalRow] = resultStage.executeTake(n) + override def executeToIterator(): Iterator[InternalRow] = resultStage.executeToIterator() + override def doExecute(): RDD[InternalRow] = resultStage.execute() + override def generateTreeString( + depth: Int, + lastChildren: Seq[Boolean], + append: String => Unit, + verbose: Boolean, + prefix: String = "", + addSuffix: Boolean = false, + maxFields: Int): Unit = { + currentQueryStage.generateTreeString( + depth, lastChildren, append, verbose, "", false, maxFields) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanQueryStage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanQueryStage.scala index 0f7ab622b75b9..86ec58c6c77cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanQueryStage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanQueryStage.scala @@ -17,65 +17,41 @@ package org.apache.spark.sql.execution.adaptive -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ShuffleExchangeExec} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec} /** - * Divide the spark plan into multiple QueryStages. For each Exchange in the plan, it adds a - * QueryStage and a QueryStageInput. If reusing Exchange is enabled, it finds duplicated exchanges - * and uses the same QueryStage for all the references. Note this rule must be run after - * EnsureRequirements rule. The rule divides the plan into multiple sub-trees as QueryStageInput - * is a leaf node. Transforming the plan after applying this rule will only transform node in a - * sub-tree. + * Divide the spark plan into multiple QueryStages. For each Exchange in the plan, it wraps it with + * a [[QueryStage]]. At the end it adds an [[AdaptiveSparkPlan]] at the top, which will drive the + * execution of query stages. */ -case class PlanQueryStage(conf: SQLConf) extends Rule[SparkPlan] { +case class PlanQueryStage(session: SparkSession) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { - - val newPlan = if (!conf.exchangeReuseEnabled) { - plan.transformUp { - case e: ShuffleExchangeExec => - ShuffleQueryStageInput(ShuffleQueryStage(e), e.output) - case e: BroadcastExchangeExec => - BroadcastQueryStageInput(BroadcastQueryStage(e), e.output) - } - } else { - // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls. - val stages = mutable.HashMap[StructType, ArrayBuffer[QueryStage]]() - - plan.transformUp { - case exchange: Exchange => - val sameSchema = stages.getOrElseUpdate(exchange.schema, ArrayBuffer[QueryStage]()) - val samePlan = sameSchema.find { s => - exchange.sameResult(s.child) - } - if (samePlan.isDefined) { - // Keep the output of this exchange, the following plans require that to resolve - // attributes. - exchange match { - case e: ShuffleExchangeExec => ShuffleQueryStageInput( - samePlan.get.asInstanceOf[ShuffleQueryStage], exchange.output) - case e: BroadcastExchangeExec => BroadcastQueryStageInput( - samePlan.get.asInstanceOf[BroadcastQueryStage], exchange.output) - } - } else { - val queryStageInput = exchange match { - case e: ShuffleExchangeExec => - ShuffleQueryStageInput(ShuffleQueryStage(e), e.output) - case e: BroadcastExchangeExec => - BroadcastQueryStageInput(BroadcastQueryStage(e), e.output) - } - sameSchema += queryStageInput.childStage - queryStageInput - } - } + var id = 0 + val exchangeToQueryStage = new java.util.IdentityHashMap[Exchange, QueryStage] + val planWithStages = plan.transformUp { + case e: ShuffleExchangeExec => + val queryStage = ShuffleQueryStage(id, e) + id += 1 + exchangeToQueryStage.put(e, queryStage) + queryStage + case e: BroadcastExchangeExec => + val queryStage = BroadcastQueryStage(id, e) + id += 1 + exchangeToQueryStage.put(e, queryStage) + queryStage + // The `ReusedExchangeExec` was added in the rule `ReuseExchange`, via transforming up the + // query plan. This rule also transform up the query plan, so when we hit `ReusedExchangeExec` + // here, the exchange being reused must already be hit before and there should be an entry + // for it in `exchangeToQueryStage`. + case e: ReusedExchangeExec => + val existingQueryStage = exchangeToQueryStage.get(e.child) + assert(existingQueryStage != null, "The exchange being reused should be hit before.") + ReusedQueryStage(existingQueryStage, e.output) } - ResultQueryStage(newPlan) + AdaptiveSparkPlan(ResultQueryStage(id, planWithStages), session) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala index 617f80fde6ff6..0994651c20c19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala @@ -17,159 +17,56 @@ package org.apache.spark.sql.execution.adaptive -import scala.concurrent.{ExecutionContext, Future} -import scala.concurrent.duration.Duration +import scala.concurrent.Future import org.apache.spark.MapOutputStatistics -import org.apache.spark.broadcast +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange._ -import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate -import org.apache.spark.util.ThreadUtils /** - * In adaptive execution mode, an execution plan is divided into multiple QueryStages. Each - * QueryStage is a sub-tree that runs in a single stage. Before executing current stage, we will - * first submit all its child stages, wait for their completions and collect their statistics. - * Based on the collected data, we can potentially optimize the execution plan in current stage, - * change the number of reducer and do other optimizations. + * In adaptive execution mode, an execution plan is divided into multiple QueryStages w.r.t. the + * exchange as boundary. Each QueryStage is a sub-tree that runs in a single Spark stage. */ -abstract class QueryStage extends UnaryExecNode { - - var child: SparkPlan - - // Ignore this wrapper for canonicalizing. - override def doCanonicalize(): SparkPlan = child.canonicalized - - override def output: Seq[Attribute] = child.output - - override def outputPartitioning: Partitioning = child.outputPartitioning - - override def outputOrdering: Seq[SortOrder] = child.outputOrdering +abstract class QueryStage extends LeafExecNode { /** - * Execute childStages and wait until all stages are completed. Use a thread pool to avoid - * blocking on one child stage. + * An id of this query stage which is unique in the entire query plan. */ - def executeChildStages(): Unit = { - val executionId = sqlContext.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - - // Handle broadcast stages - val broadcastQueryStages: Seq[BroadcastQueryStage] = child.collect { - case bqs: BroadcastQueryStageInput => bqs.childStage - } - val broadcastFutures = broadcastQueryStages.map { queryStage => - Future { - SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { - queryStage.prepareBroadcast() - } - }(QueryStage.executionContext) - } - - // Submit shuffle stages - val shuffleQueryStages: Seq[ShuffleQueryStage] = child.collect { - case sqs: ShuffleQueryStageInput => sqs.childStage - } - val shuffleStageFutures = shuffleQueryStages.map { queryStage => - Future { - SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { - queryStage.execute() - } - }(QueryStage.executionContext) - } - - ThreadUtils.awaitResult( - Future.sequence(broadcastFutures)(implicitly, QueryStage.executionContext), Duration.Inf) - ThreadUtils.awaitResult( - Future.sequence(shuffleStageFutures)(implicitly, QueryStage.executionContext), Duration.Inf) - } - - private var prepared = false + def id: Int /** - * Before executing the plan in this query stage, we execute all child stages, optimize the plan - * in this stage and determine the reducer number based on the child stages' statistics. Finally - * we do a codegen for this query stage and update the UI with the new plan. + * The sub-tree of the query plan that belongs to this query stage. */ - def prepareExecuteStage(): Unit = synchronized { - // Ensure the prepareExecuteStage method only be executed once. - if (prepared) { - return - } - // 1. Execute childStages - executeChildStages() - // It is possible to optimize this stage's plan here based on the child stages' statistics. - - // 2. Determine reducer number - val queryStageInputs: Seq[ShuffleQueryStageInput] = child.collect { - case input: ShuffleQueryStageInput => input - } - // mapOutputStatistics can be null if the childStage's RDD has 0 partition. In that case, we - // don't submit that stage and mapOutputStatistics is null. - val childMapOutputStatistics = queryStageInputs.map(_.childStage.mapOutputStatistics) - .filter(_ != null).toArray - if (childMapOutputStatistics.length > 0) { - val exchangeCoordinator = new ExchangeCoordinator( - conf.targetPostShuffleInputSize, - conf.minNumPostShufflePartitions) - - val partitionStartIndices = - exchangeCoordinator.estimatePartitionStartIndices(childMapOutputStatistics) - child = child.transform { - case ShuffleQueryStageInput(childStage, output, _) => - ShuffleQueryStageInput(childStage, output, Some(partitionStartIndices)) - } - } - - // 3. Codegen and update the UI - child = CollapseCodegenStages(sqlContext.conf).apply(child) - val executionId = sqlContext.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - if (executionId != null && executionId.nonEmpty) { - val queryExecution = SQLExecution.getQueryExecution(executionId.toLong) - sparkContext.listenerBus.post(SparkListenerSQLAdaptiveExecutionUpdate( - executionId.toLong, - queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan))) - } - prepared = true - } - - // Caches the created ShuffleRowRDD so we can reuse that. - private var cachedRDD: RDD[InternalRow] = null - - def executeStage(): RDD[InternalRow] = child.execute() + def plan: SparkPlan /** - * A QueryStage can be reused like Exchange. It is possible that multiple threads try to submit - * the same QueryStage. Use synchronized to make sure it is executed only once. + * Returns a new query stage with a new plan, which is optimized based on accurate runtime + * statistics. */ - override def doExecute(): RDD[InternalRow] = synchronized { - if (cachedRDD == null) { - prepareExecuteStage() - cachedRDD = executeStage() - } - cachedRDD - } - - override def executeCollect(): Array[InternalRow] = { - prepareExecuteStage() - child.executeCollect() - } - - override def executeToIterator(): Iterator[InternalRow] = { - prepareExecuteStage() - child.executeToIterator() - } - - override def executeTake(n: Int): Array[InternalRow] = { - prepareExecuteStage() - child.executeTake(n) - } + def withNewPlan(newPlan: SparkPlan): QueryStage + /** + * Materialize this QueryStage, to prepare for the execution, like submitting map stages, + * broadcasting data, etc. The caller side can use the returned [[Future]] to wait until this + * stage is ready. + */ + def materialize(): Future[Any] + + override def output: Seq[Attribute] = plan.output + override def outputPartitioning: Partitioning = plan.outputPartitioning + override def outputOrdering: Seq[SortOrder] = plan.outputOrdering + override def executeCollect(): Array[InternalRow] = plan.executeCollect() + override def executeTake(n: Int): Array[InternalRow] = plan.executeTake(n) + override def executeToIterator(): Iterator[InternalRow] = plan.executeToIterator() + override def doExecute(): RDD[InternalRow] = plan.execute() + override def doExecuteBroadcast[T](): Broadcast[T] = plan.executeBroadcast() + + // TODO: maybe we should not hide QueryStage entirely from explain result. override def generateTreeString( depth: Int, lastChildren: Seq[Boolean], @@ -178,7 +75,7 @@ abstract class QueryStage extends UnaryExecNode { prefix: String = "", addSuffix: Boolean = false, maxFields: Int): Unit = { - child.generateTreeString( + plan.generateTreeString( depth, lastChildren, append, verbose, "", false, maxFields) } } @@ -186,56 +83,86 @@ abstract class QueryStage extends UnaryExecNode { /** * The last QueryStage of an execution plan. */ -case class ResultQueryStage(var child: SparkPlan) extends QueryStage +case class ResultQueryStage(id: Int, plan: SparkPlan) extends QueryStage { + + override def materialize(): Future[Any] = { + Future.unit + } + + override def withNewPlan(newPlan: SparkPlan): QueryStage = { + copy(plan = newPlan) + } +} /** - * A shuffle QueryStage whose child is a ShuffleExchange. + * A shuffle QueryStage whose child is a ShuffleExchangeExec. */ -case class ShuffleQueryStage(var child: SparkPlan) extends QueryStage { +case class ShuffleQueryStage(id: Int, plan: ShuffleExchangeExec) extends QueryStage { - protected var _mapOutputStatistics: MapOutputStatistics = null - - def mapOutputStatistics: MapOutputStatistics = _mapOutputStatistics + override def withNewPlan(newPlan: SparkPlan): QueryStage = { + copy(plan = newPlan.asInstanceOf[ShuffleExchangeExec]) + } - override def executeStage(): RDD[InternalRow] = { - child match { - case e: ShuffleExchangeExec => - val result = e.eagerExecute() - _mapOutputStatistics = e.mapOutputStatistics - result - case _ => throw new IllegalArgumentException( - "The child of ShuffleQueryStage must be a ShuffleExchange.") + @transient lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = { + if (plan.inputRDD.getNumPartitions == 0) { + // `submitMapStage` does not accept RDD with 0 partition. Here we return null and the caller + // side should take care of it. + Future.successful(null) + } else { + sparkContext.submitMapStage(plan.shuffleDependency) } } + + override def materialize(): Future[Any] = { + mapOutputStatisticsFuture + } } /** * A broadcast QueryStage whose child is a BroadcastExchangeExec. */ -case class BroadcastQueryStage(var child: SparkPlan) extends QueryStage { - override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { - child.executeBroadcast() +case class BroadcastQueryStage(id: Int, plan: BroadcastExchangeExec) extends QueryStage { + + override def withNewPlan(newPlan: SparkPlan): QueryStage = { + copy(plan = newPlan.asInstanceOf[BroadcastExchangeExec]) } - private var prepared = false + override def materialize(): Future[Any] = { + plan.relationFuture + } +} - def prepareBroadcast() : Unit = synchronized { - if (!prepared) { - executeChildStages() - child = CollapseCodegenStages(sqlContext.conf).apply(child) - // After child stages are completed, prepare() triggers the broadcast. - prepare() - prepared = true +/** + * A wrapper of QueryStage to indicate that it's reused. Note that this is not a query stage. + */ +case class ReusedQueryStage(child: SparkPlan, output: Seq[Attribute]) extends UnaryExecNode { + + // Ignore this wrapper for canonicalizing. + override def doCanonicalize(): SparkPlan = child.canonicalized + + override def doExecute(): RDD[InternalRow] = { + child.execute() + } + + override def doExecuteBroadcast[T](): Broadcast[T] = { + child.executeBroadcast() + } + + // `ReusedQueryStage` can have distinct set of output attribute ids from its child, we need + // to update the attribute ids in `outputPartitioning` and `outputOrdering`. + private lazy val updateAttr: Expression => Expression = { + val originalAttrToNewAttr = AttributeMap(child.output.zip(output)) + e => e.transform { + case attr: Attribute => originalAttrToNewAttr.getOrElse(attr, attr) } } - override def doExecute(): RDD[InternalRow] = { - throw new UnsupportedOperationException( - "BroadcastExchange does not support the execute() code path.") + override def outputPartitioning: Partitioning = child.outputPartitioning match { + case e: Expression => updateAttr(e).asInstanceOf[Partitioning] + case other => other } -} -object QueryStage { - private[execution] val executionContext = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("adaptive-query-stage")) + override def outputOrdering: Seq[SortOrder] = { + child.outputOrdering.map(updateAttr(_).asInstanceOf[SortOrder]) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageInput.scala deleted file mode 100644 index 8c33e83a91d9b..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageInput.scala +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.adaptive - -import org.apache.spark.broadcast -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder} -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} -import org.apache.spark.sql.execution._ - -/** - * QueryStageInput is the leaf node of a QueryStage and serves as its input. A QueryStage knows - * its child stages by collecting all the QueryStageInputs. For a ShuffleQueryStageInput, it - * controls how to read the ShuffledRowRDD generated by its child stage. It gets the ShuffledRowRDD - * from its child stage and creates a new ShuffledRowRDD with different partitions by specifying - * an array of partition start indices. For example, a ShuffledQueryStage can be reused by two - * different QueryStages. One QueryStageInput can let the first task read partition 0 to 3, while - * in another stage, the QueryStageInput can let the first task read partition 0 to 1. - */ -abstract class QueryStageInput extends LeafExecNode { - - def childStage: QueryStage - - // Ignore this wrapper for canonicalizing. - override def doCanonicalize(): SparkPlan = childStage.canonicalized - - // Similar to ReusedExchangeExec, two QueryStageInputs can reference to the same childStage. - // QueryStageInput can have distinct set of output attribute ids from its childStage, we need - // to update the attribute ids in outputPartitioning and outputOrdering. - private lazy val updateAttr: Expression => Expression = { - val originalAttrToNewAttr = AttributeMap(childStage.output.zip(output)) - e => e.transform { - case attr: Attribute => originalAttrToNewAttr.getOrElse(attr, attr) - } - } - - override def outputPartitioning: Partitioning = childStage.outputPartitioning match { - case h: HashPartitioning => h.copy(expressions = h.expressions.map(updateAttr)) - case other => other - } - - override def outputOrdering: Seq[SortOrder] = { - childStage.outputOrdering.map(updateAttr(_).asInstanceOf[SortOrder]) - } - - override def generateTreeString( - depth: Int, - lastChildren: Seq[Boolean], - append: String => Unit, - verbose: Boolean, - prefix: String = "", - addSuffix: Boolean = false, - maxFields: Int): Unit = { - childStage.generateTreeString( - depth, lastChildren, append, verbose, "", false, maxFields) - } -} - -/** - * A QueryStageInput whose child stage is a ShuffleQueryStage. It returns a new ShuffledRowRDD - * based on the the child stage's result RDD and the specified partitionStartIndices. If the - * child stage is reused by another ShuffleQueryStageInput, they can return RDDs with different - * partitionStartIndices. - */ -case class ShuffleQueryStageInput( - childStage: ShuffleQueryStage, - override val output: Seq[Attribute], - partitionStartIndices: Option[Array[Int]] = None) - extends QueryStageInput { - - override def outputPartitioning: Partitioning = partitionStartIndices.map { - indices => UnknownPartitioning(indices.length) - }.getOrElse(super.outputPartitioning) - - override def doExecute(): RDD[InternalRow] = { - val childRDD = childStage.execute().asInstanceOf[ShuffledRowRDD] - new ShuffledRowRDD(childRDD.dependency, childStage.child.metrics, partitionStartIndices) - } -} - -/** A QueryStageInput whose child stage is a BroadcastQueryStage. */ -case class BroadcastQueryStageInput( - childStage: BroadcastQueryStage, - override val output: Seq[Attribute]) - extends QueryStageInput { - - override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { - childStage.executeBroadcast() - } - - override def doExecute(): RDD[InternalRow] = { - throw new UnsupportedOperationException( - "BroadcastStageInput does not support the execute() code path.") - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStagePlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStagePlanner.scala new file mode 100644 index 0000000000000..0319d146c30dc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStagePlanner.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.execution.{CollapseCodegenStages, SparkPlan} +import org.apache.spark.sql.internal.SQLConf + +class QueryStagePlanner(conf: SQLConf) extends RuleExecutor[SparkPlan] { + + override protected def batches: Seq[Batch] = Seq( + Batch("QueryStage Optimization", Once, + AssertChildStagesMaterialized, + ReduceNumShufflePartitions(conf), + CollapseCodegenStages(conf) + ) + ) +} + +// A sanity check rule to make sure we are running `QueryStagePlanner` on a sub-tree of query plan +// with all input stages materialized. +object AssertChildStagesMaterialized extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = plan.transform { + case q: QueryStage if !q.materialize().isCompleted => + throw new IllegalArgumentException( + s"The input stages should all be materialize, but ${q.id} is not.") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageTrigger.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageTrigger.scala new file mode 100644 index 0000000000000..5dd65f42442b2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageTrigger.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import scala.collection.mutable.{HashMap, HashSet, ListBuffer} +import scala.concurrent.{ExecutionContext, ExecutionContextExecutorService, Future} + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.util.{EventLoop, ThreadUtils} + +/** + * This class triggers [[QueryStage]] bottom-up, apply planner rules for query stages and + * materialize them. It triggers as many query stages as possible at the same time, and triggers + * the parent query stage when all its child stages are materialized. + */ +class QueryStageTrigger(session: SparkSession, callback: QueryStageTriggerCallback) + extends EventLoop[QueryStageTriggerEvent]("QueryStageTrigger") { + + private val stageToParentStages = HashMap.empty[Int, ListBuffer[QueryStage]] + + private val idToUpdatedStage = HashMap.empty[Int, QueryStage] + + private val stageToNumPendingChildStages = HashMap.empty[Int, Int] + + private val submittedStages = HashSet.empty[Int] + + private val readyStages = HashSet.empty[Int] + + private val planner = new QueryStagePlanner(session.sessionState.conf) + + def trigger(stage: QueryStage): Unit = { + post(SubmitStage(stage)) + } + + private implicit def executionContext: ExecutionContextExecutorService = { + QueryStageTrigger.executionContext + } + + override protected def onReceive(event: QueryStageTriggerEvent): Unit = event match { + case SubmitStage(stage) => + // We may submit a query stage multiple times, because of stage reuse. Here we avoid + // re-submitting a query stage. + if (!submittedStages.contains(stage.id)) { + submittedStages += stage.id + val pendingChildStages = stage.plan.collect { + // The stage being submitted may have child stages that are already ready, if the child + // stage is a reused stage. + case stage: QueryStage if !readyStages.contains(stage.id) => stage + } + if (pendingChildStages.isEmpty) { + // This is a leaf stage, or all its child stages are ready, we can plan it now. + post(PlanStage(stage)) + } else { + // This stage has some pending child stages, we store the connection of this stage and + // its child stages, and submit all the child stages, so that we can plan this stage + // later when all its child stages are ready. + stageToNumPendingChildStages(stage.id) = pendingChildStages.length + pendingChildStages.foreach { child => + // a child may have multiple parents, because of query stage reuse. + val parentStages = stageToParentStages.getOrElseUpdate(child.id, new ListBuffer) + parentStages += stage + post(SubmitStage(child)) + } + } + } + + case PlanStage(stage) => + Future { + // planning needs active SparkSession in current thread. + SparkSession.setActiveSession(session) + planner.execute(stage.plan) + }.onComplete { res => + if (res.isSuccess) { + post(StagePlanned(stage, res.get)) + } else { + callback.onStagePlanningFailed(stage, res.failed.get) + stop() + } + } + submittedStages += stage.id + + case StagePlanned(stage, optimizedPlan) => + val newStage = stage.withNewPlan(optimizedPlan) + // We store the new stage with the new query plan after planning, so that later on we can + // update the query plan of its parent stage. + idToUpdatedStage(newStage.id) = newStage + // This stage has optimized its plan, notify the callback about this change. + callback.onStageUpdated(newStage) + + newStage.materialize().onComplete { res => + if (res.isSuccess) { + post(StageReady(stage)) + } else { + callback.onStageMaterializingFailed(newStage, res.failed.get) + stop() + } + } + + case StageReady(stage) => + readyStages += stage.id + stageToParentStages.remove(stage.id).foreach { parentStages => + parentStages.foreach { parent => + val numPendingChildStages = stageToNumPendingChildStages(parent.id) + if (numPendingChildStages == 1) { + stageToNumPendingChildStages.remove(parent.id) + // All its child stages are ready, here we update the query plan via replacing the old + // child stages with new ones that are planned. + val newPlan = parent.plan.transform { + case q: QueryStage => idToUpdatedStage(q.id) + } + // We can plan this stage now. + post(PlanStage(parent.withNewPlan(newPlan))) + } else { + assert(numPendingChildStages > 1) + stageToNumPendingChildStages(parent.id) = numPendingChildStages - 1 + } + } + } + } + + override protected def onError(e: Throwable): Unit = callback.onError(e) +} + +object QueryStageTrigger { + private val executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("QueryStageTrigger", 16)) +} + +trait QueryStageTriggerCallback { + def onStageUpdated(stage: QueryStage): Unit + def onStagePlanningFailed(stage: QueryStage, e: Throwable): Unit + def onStageMaterializingFailed(stage: QueryStage, e: Throwable): Unit + def onError(e: Throwable): Unit +} + +sealed trait QueryStageTriggerEvent + +case class SubmitStage(stage: QueryStage) extends QueryStageTriggerEvent + +case class PlanStage(stage: QueryStage) extends QueryStageTriggerEvent + +case class StagePlanned(stage: QueryStage, optimizedPlan: SparkPlan) extends QueryStageTriggerEvent + +case class StageReady(stage: QueryStage) extends QueryStageTriggerEvent diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala similarity index 61% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala index 99ba2daf36ca1..1f3e85c153dad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala @@ -15,26 +15,23 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.exchange +package org.apache.spark.sql.execution.adaptive import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration.Duration import org.apache.spark.MapOutputStatistics -import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.ThreadUtils /** - * A coordinator used to determines how we shuffle data between stages generated by Spark SQL. - * Right now, the work of this coordinator is to determine the number of post-shuffle partitions - * for a stage that needs to fetch shuffle data from one or multiple stages. - * - * A coordinator is constructed with two parameters, `targetPostShuffleInputSize`, - * and `minNumPostShufflePartitions`. - * - `targetPostShuffleInputSize` is the targeted size of a post-shuffle partition's - * input data size. With this parameter, we can estimate the number of post-shuffle partitions. - * This parameter is configured through - * `spark.sql.adaptive.shuffle.targetPostShuffleInputSize`. - * - `minNumPostShufflePartitions` is used to make sure that there are at least - * `minNumPostShufflePartitions` post-shuffle partitions. + * A rule to adjust the post shuffle partitions based on the map output statistics. * * The strategy used to determine the number of post-shuffle partitions is described as follows. * To determine the number of post-shuffle partitions, we have a target input size for a @@ -53,17 +50,54 @@ import org.apache.spark.internal.Logging * - post-shuffle partition 2: pre-shuffle partition 2 (size 170 MiB) * - post-shuffle partition 3: pre-shuffle partition 3 and 4 (size 50 MiB) */ -class ExchangeCoordinator( - advisoryTargetPostShuffleInputSize: Long, - minNumPostShufflePartitions: Int = 1) - extends Logging { +case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] { + + override def apply(plan: SparkPlan): SparkPlan = { + val shuffleMetrics: Seq[MapOutputStatistics] = plan.collect { + case stage: ShuffleQueryStage => + val metricsFuture = stage.mapOutputStatisticsFuture + assert(metricsFuture.isCompleted, "ShuffleQueryStage should already be ready") + ThreadUtils.awaitResult(metricsFuture, Duration.Zero) + } + + val leafNodes = plan.collect { + case s: SparkPlan if s.children.isEmpty => s + } + + if (shuffleMetrics.length == leafNodes.length) { + // ShuffleQueryStage gives null mapOutputStatistics when the input RDD has 0 partitions, + // we should skip it when calculating the `partitionStartIndices`. + val validMetrics = shuffleMetrics.filter(_ != null) + if (validMetrics.nonEmpty) { + val partitionStartIndices = estimatePartitionStartIndices(validMetrics.toArray) + // This transformation adds new nodes, so we must use `transformUp` here. + plan.transformUp { + // even for shuffle exchange whose input RDD has 0 partition, we should still update its + // `partitionStartIndices`, so that all the leaf shuffles in a stage have the same number + // of output partitions. + case stage: ShuffleQueryStage => + CoalescedShuffleReaderExec(stage, partitionStartIndices) + } + } else { + plan + } + } else { + // If not all leaf nodes are shuffle query stages, it's not safe to reduce the number of + // shuffle partitions, because we may break the assumption that all children of a spark plan + // have same number of output partitions. + plan + } + } /** * Estimates partition start indices for post-shuffle partitions based on * mapOutputStatistics provided by all pre-shuffle stages. */ - def estimatePartitionStartIndices( + // visible for testing. + private[sql] def estimatePartitionStartIndices( mapOutputStatistics: Array[MapOutputStatistics]): Array[Int] = { + val minNumPostShufflePartitions = conf.minNumPostShufflePartitions + val advisoryTargetPostShuffleInputSize = conf.targetPostShuffleInputSize // If minNumPostShufflePartitions is defined, it is possible that we need to use a // value less than advisoryTargetPostShuffleInputSize as the target input size of // a post shuffle task. @@ -79,7 +113,7 @@ class ExchangeCoordinator( logInfo( s"advisoryTargetPostShuffleInputSize: $advisoryTargetPostShuffleInputSize, " + - s"targetPostShuffleInputSize $targetPostShuffleInputSize.") + s"targetPostShuffleInputSize $targetPostShuffleInputSize.") // Make sure we do get the same number of pre-shuffle partitions for those stages. val distinctNumPreShufflePartitions = @@ -126,8 +160,24 @@ class ExchangeCoordinator( partitionStartIndices.toArray } +} + +case class CoalescedShuffleReaderExec( + child: ShuffleQueryStage, + partitionStartIndices: Array[Int]) extends UnaryExecNode { + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = { + UnknownPartitioning(partitionStartIndices.length) + } - override def toString: String = { - s"coordinator[target post-shuffle partition size: $advisoryTargetPostShuffleInputSize]" + private var cachedShuffleRDD: ShuffledRowRDD = null + + override protected def doExecute(): RDD[InternalRow] = { + if (cachedShuffleRDD == null) { + cachedShuffleRDD = child.plan.createShuffledRDD(Some(partitionStartIndices)) + } + cachedShuffleRDD } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 703d351bea7c0..18f13cf2eb5ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -66,7 +66,7 @@ case class BroadcastExchangeExec( } @transient - private lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { + lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) Future { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 8184baf50b042..126e8e6dd1104 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -24,8 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, - SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SQLConf /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index b1adc396e398d..987e73e52950f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -38,10 +38,10 @@ import org.apache.spark.util.MutablePair import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator} /** - * Performs a shuffle that will result in the desired `newPartitioning`. + * Performs a shuffle that will result in the desired partitioning. */ case class ShuffleExchangeExec( - var newPartitioning: Partitioning, + desiredPartitioning: Partitioning, child: SparkPlan) extends Exchange { // NOTE: coordinator can be null after serialization/deserialization, @@ -58,43 +58,31 @@ case class ShuffleExchangeExec( "Exchange" } - override def outputPartitioning: Partitioning = newPartitioning + override def outputPartitioning: Partitioning = { + desiredPartitioning + } private val serializer: Serializer = new UnsafeRowSerializer(child.output.size, longMetric("dataSize")) + @transient lazy val inputRDD: RDD[InternalRow] = child.execute() + /** - * Returns a [[ShuffleDependency]] that will partition rows of its child based on - * the partitioning scheme defined in `newPartitioning`. Those partitions of - * the returned ShuffleDependency will be the input of shuffle. + * A [[ShuffleDependency]] that will partition rows of its child based on the desired + * partitioning/ Those partitions of the returned ShuffleDependency will be the input of shuffle. */ - private[exchange] def prepareShuffleDependency() - : ShuffleDependency[Int, InternalRow, InternalRow] = { + @transient + lazy val shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow] = { ShuffleExchangeExec.prepareShuffleDependency( - child.execute(), + inputRDD, child.output, - newPartitioning, + outputPartitioning, serializer, writeMetrics) } - /** - * Returns a [[ShuffledRowRDD]] that represents the post-shuffle dataset. - * This [[ShuffledRowRDD]] is created based on a given [[ShuffleDependency]] and an optional - * partition start indices array. If this optional array is defined, the returned - * [[ShuffledRowRDD]] will fetch pre-shuffle partitions based on indices of this array. - */ - private[exchange] def preparePostShuffleRDD( - shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow], - specifiedPartitionStartIndices: Option[Array[Int]] = None): ShuffledRowRDD = { - // If an array of partition start indices is provided, we need to use this array - // to create the ShuffledRowRDD. Also, we need to update newPartitioning to - // update the number of post-shuffle partitions. - specifiedPartitionStartIndices.foreach { indices => - assert(newPartitioning.isInstanceOf[HashPartitioning]) - newPartitioning = UnknownPartitioning(indices.length) - } - new ShuffledRowRDD(shuffleDependency, readMetrics, specifiedPartitionStartIndices) + def createShuffledRDD(partitionStartIndices: Option[Array[Int]]): ShuffledRowRDD = { + new ShuffledRowRDD(shuffleDependency, readMetrics, partitionStartIndices) } /** @@ -105,26 +93,7 @@ case class ShuffleExchangeExec( protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { // Returns the same ShuffleRowRDD if this plan is used by multiple plans. if (cachedShuffleRDD == null) { - val shuffleDependency = prepareShuffleDependency() - cachedShuffleRDD = preparePostShuffleRDD(shuffleDependency) - } - cachedShuffleRDD - } - - private var _mapOutputStatistics: MapOutputStatistics = null - - def mapOutputStatistics: MapOutputStatistics = _mapOutputStatistics - - def eagerExecute(): RDD[InternalRow] = { - if (cachedShuffleRDD == null) { - val shuffleDependency = prepareShuffleDependency() - if (shuffleDependency.rdd.partitions.length != 0) { - // submitMapStage does not accept RDD with 0 partition. - // So, we will not submit this dependency. - val submittedStageFuture = sqlContext.sparkContext.submitMapStage(shuffleDependency) - _mapOutputStatistics = submittedStageFuture.get() - } - cachedShuffleRDD = preparePostShuffleRDD(shuffleDependency) + cachedShuffleRDD = createShuffledRDD(None) } cachedShuffleRDD } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 15b4acfb662b9..113b205367a59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -96,7 +96,7 @@ object SparkPlanGraph { case "InputAdapter" => buildSparkPlanGraphNode( planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges) - case "QueryStage" | "BroadcastQueryStage" | "ResultQueryStage" | "ShuffleQueryStage" => + case "BroadcastQueryStage" | "ResultQueryStage" | "ShuffleQueryStage" => if (exchanges.contains(planInfo.children.head)) { // Point to the re-used exchange val node = exchanges(planInfo.children.head) @@ -105,9 +105,6 @@ object SparkPlanGraph { buildSparkPlanGraphNode( planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges) } - case "QueryStageInput" | "ShuffleQueryStageInput" | "BroadcastQueryStageInput" => - buildSparkPlanGraphNode( - planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges) case "Subquery" if subgraph != null => // Subquery should not be included in WholeStageCodegen buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null, exchanges) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index c97041a8f341c..e6aa066d8f9db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -448,7 +448,7 @@ class PlannerSuite extends SharedSQLContext { val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) val shuffle = outputPlan.collect { case e: ShuffleExchangeExec => e } assert(shuffle.size === 1) - assert(shuffle.head.newPartitioning === finalPartitioning) + assert(shuffle.head.outputPartitioning === finalPartitioning) } test("Reuse exchanges") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala similarity index 68% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala index a88360a90e9eb..70a7ed0e8d1d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala @@ -22,12 +22,11 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.{MapOutputStatistics, SparkConf, SparkFunSuite} import org.apache.spark.internal.config.UI.UI_ENABLED import org.apache.spark.sql._ -import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageInput -import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.adaptive._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { +class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAll { private var originalActiveSparkSession: Option[SparkSession] = _ private var originalInstantiatedSparkSession: Option[SparkSession] = _ @@ -52,7 +51,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } private def checkEstimation( - coordinator: ExchangeCoordinator, + rule: ReduceNumShufflePartitions, bytesByPartitionIdArray: Array[Array[Long]], expectedPartitionStartIndices: Array[Int]): Unit = { val mapOutputStatistics = bytesByPartitionIdArray.zipWithIndex.map { @@ -60,18 +59,27 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { new MapOutputStatistics(index, bytesByPartitionId) } val estimatedPartitionStartIndices = - coordinator.estimatePartitionStartIndices(mapOutputStatistics) + rule.estimatePartitionStartIndices(mapOutputStatistics) assert(estimatedPartitionStartIndices === expectedPartitionStartIndices) } + private def createReduceNumShufflePartitionsRule( + advisoryTargetPostShuffleInputSize: Long, + minNumPostShufflePartitions: Int = 1): ReduceNumShufflePartitions = { + val conf = new SQLConf().copy( + SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE -> advisoryTargetPostShuffleInputSize, + SQLConf.SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS -> minNumPostShufflePartitions) + ReduceNumShufflePartitions(conf) + } + test("test estimatePartitionStartIndices - 1 Exchange") { - val coordinator = new ExchangeCoordinator(100L) + val rule = createReduceNumShufflePartitionsRule(100L) { // All bytes per partition are 0. val bytesByPartitionId = Array[Long](0, 0, 0, 0, 0) val expectedPartitionStartIndices = Array[Int](0) - checkEstimation(coordinator, Array(bytesByPartitionId), expectedPartitionStartIndices) + checkEstimation(rule, Array(bytesByPartitionId), expectedPartitionStartIndices) } { @@ -79,40 +87,40 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // 1 post-shuffle partition is needed. val bytesByPartitionId = Array[Long](10, 0, 20, 0, 0) val expectedPartitionStartIndices = Array[Int](0) - checkEstimation(coordinator, Array(bytesByPartitionId), expectedPartitionStartIndices) + checkEstimation(rule, Array(bytesByPartitionId), expectedPartitionStartIndices) } { // 2 post-shuffle partitions are needed. val bytesByPartitionId = Array[Long](10, 0, 90, 20, 0) val expectedPartitionStartIndices = Array[Int](0, 3) - checkEstimation(coordinator, Array(bytesByPartitionId), expectedPartitionStartIndices) + checkEstimation(rule, Array(bytesByPartitionId), expectedPartitionStartIndices) } { // There are a few large pre-shuffle partitions. val bytesByPartitionId = Array[Long](110, 10, 100, 110, 0) val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) - checkEstimation(coordinator, Array(bytesByPartitionId), expectedPartitionStartIndices) + checkEstimation(rule, Array(bytesByPartitionId), expectedPartitionStartIndices) } { // All pre-shuffle partitions are larger than the targeted size. val bytesByPartitionId = Array[Long](100, 110, 100, 110, 110) val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) - checkEstimation(coordinator, Array(bytesByPartitionId), expectedPartitionStartIndices) + checkEstimation(rule, Array(bytesByPartitionId), expectedPartitionStartIndices) } { // The last pre-shuffle partition is in a single post-shuffle partition. val bytesByPartitionId = Array[Long](30, 30, 0, 40, 110) val expectedPartitionStartIndices = Array[Int](0, 4) - checkEstimation(coordinator, Array(bytesByPartitionId), expectedPartitionStartIndices) + checkEstimation(rule, Array(bytesByPartitionId), expectedPartitionStartIndices) } } test("test estimatePartitionStartIndices - 2 Exchanges") { - val coordinator = new ExchangeCoordinator(100L) + val rule = createReduceNumShufflePartitionsRule(100L) { // If there are multiple values of the number of pre-shuffle partitions, @@ -123,7 +131,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { Array( new MapOutputStatistics(0, bytesByPartitionId1), new MapOutputStatistics(1, bytesByPartitionId2)) - intercept[AssertionError](coordinator.estimatePartitionStartIndices(mapOutputStatistics)) + intercept[AssertionError](rule.estimatePartitionStartIndices(mapOutputStatistics)) } { @@ -132,7 +140,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0) val expectedPartitionStartIndices = Array[Int](0) checkEstimation( - coordinator, + rule, Array(bytesByPartitionId1, bytesByPartitionId2), expectedPartitionStartIndices) } @@ -144,7 +152,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { val bytesByPartitionId2 = Array[Long](30, 0, 20, 0, 20) val expectedPartitionStartIndices = Array[Int](0) checkEstimation( - coordinator, + rule, Array(bytesByPartitionId1, bytesByPartitionId2), expectedPartitionStartIndices) } @@ -155,7 +163,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) val expectedPartitionStartIndices = Array[Int](0, 2, 4) checkEstimation( - coordinator, + rule, Array(bytesByPartitionId1, bytesByPartitionId2), expectedPartitionStartIndices) } @@ -166,7 +174,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) val expectedPartitionStartIndices = Array[Int](0, 1, 2, 4) checkEstimation( - coordinator, + rule, Array(bytesByPartitionId1, bytesByPartitionId2), expectedPartitionStartIndices) } @@ -177,7 +185,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) val expectedPartitionStartIndices = Array[Int](0, 1, 2, 4) checkEstimation( - coordinator, + rule, Array(bytesByPartitionId1, bytesByPartitionId2), expectedPartitionStartIndices) } @@ -188,7 +196,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { val bytesByPartitionId2 = Array[Long](30, 0, 60, 0, 110) val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) checkEstimation( - coordinator, + rule, Array(bytesByPartitionId1, bytesByPartitionId2), expectedPartitionStartIndices) } @@ -199,14 +207,14 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { val bytesByPartitionId2 = Array[Long](30, 0, 60, 70, 110) val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) checkEstimation( - coordinator, + rule, Array(bytesByPartitionId1, bytesByPartitionId2), expectedPartitionStartIndices) } } test("test estimatePartitionStartIndices and enforce minimal number of reducers") { - val coordinator = new ExchangeCoordinator(100L, 2) + val rule = createReduceNumShufflePartitionsRule(100L, 2) { // The minimal number of post-shuffle partitions is not enforced because @@ -215,7 +223,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0) val expectedPartitionStartIndices = Array[Int](0) checkEstimation( - coordinator, + rule, Array(bytesByPartitionId1, bytesByPartitionId2), expectedPartitionStartIndices) } @@ -226,7 +234,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { val bytesByPartitionId2 = Array[Long](5, 10, 0, 10, 5) val expectedPartitionStartIndices = Array[Int](0, 3) checkEstimation( - coordinator, + rule, Array(bytesByPartitionId1, bytesByPartitionId2), expectedPartitionStartIndices) } @@ -237,7 +245,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { val bytesByPartitionId2 = Array[Long](40, 10, 0, 10, 30) val expectedPartitionStartIndices = Array[Int](0, 1, 3, 4) checkEstimation( - coordinator, + rule, Array(bytesByPartitionId1, bytesByPartitionId2), expectedPartitionStartIndices) } @@ -258,7 +266,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { def withSparkSession( f: SparkSession => Unit, - targetNumPostShufflePartitions: Int, + targetPostShuffleInputSize: Int, minNumPostShufflePartitions: Option[Int]): Unit = { val sparkConf = new SparkConf(false) @@ -270,7 +278,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { .set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1") .set( SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, - targetNumPostShufflePartitions.toString) + targetPostShuffleInputSize.toString) minNumPostShufflePartitions match { case Some(numPartitions) => sparkConf.set(SQLConf.SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS.key, numPartitions.toString) @@ -305,21 +313,21 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. - val queryStageInputs = agg.queryExecution.executedPlan.collect { - case q: ShuffleQueryStageInput => q + val finalPlan = agg.queryExecution.executedPlan + .asInstanceOf[AdaptiveSparkPlan].resultStage.plan + val shuffleReaders = finalPlan.collect { + case reader: CoalescedShuffleReaderExec => reader } - assert(queryStageInputs.length === 1) + assert(shuffleReaders.length === 1) minNumPostShufflePartitions match { case Some(numPartitions) => - queryStageInputs.foreach { q => - assert(q.partitionStartIndices.isDefined) - assert(q.outputPartitioning.numPartitions === 5) + shuffleReaders.foreach { reader => + assert(reader.outputPartitioning.numPartitions === numPartitions) } case None => - queryStageInputs.foreach { q => - assert(q.partitionStartIndices.isDefined) - assert(q.outputPartitioning.numPartitions === 3) + shuffleReaders.foreach { reader => + assert(reader.outputPartitioning.numPartitions === 3) } } } @@ -352,21 +360,21 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. - val queryStageInputs = join.queryExecution.executedPlan.collect { - case q: ShuffleQueryStageInput => q + val finalPlan = join.queryExecution.executedPlan + .asInstanceOf[AdaptiveSparkPlan].resultStage.plan + val shuffleReaders = finalPlan.collect { + case reader: CoalescedShuffleReaderExec => reader } - assert(queryStageInputs.length === 2) + assert(shuffleReaders.length === 2) minNumPostShufflePartitions match { case Some(numPartitions) => - queryStageInputs.foreach { q => - assert(q.partitionStartIndices.isDefined) - assert(q.outputPartitioning.numPartitions === 5) + shuffleReaders.foreach { reader => + assert(reader.outputPartitioning.numPartitions === numPartitions) } case None => - queryStageInputs.foreach { q => - assert(q.partitionStartIndices.isDefined) - assert(q.outputPartitioning.numPartitions === 2) + shuffleReaders.foreach { reader => + assert(reader.outputPartitioning.numPartitions === 2) } } } @@ -404,21 +412,21 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. - val queryStageInputs = join.queryExecution.executedPlan.collect { - case q: ShuffleQueryStageInput => q + val finalPlan = join.queryExecution.executedPlan + .asInstanceOf[AdaptiveSparkPlan].resultStage.plan + val shuffleReaders = finalPlan.collect { + case reader: CoalescedShuffleReaderExec => reader } - assert(queryStageInputs.length === 2) + assert(shuffleReaders.length === 2) minNumPostShufflePartitions match { case Some(numPartitions) => - queryStageInputs.foreach { q => - assert(q.partitionStartIndices.isDefined) - assert(q.outputPartitioning.numPartitions === 5) + shuffleReaders.foreach { reader => + assert(reader.outputPartitioning.numPartitions === numPartitions) } case None => - queryStageInputs.foreach { q => - assert(q.partitionStartIndices.isDefined) - assert(q.outputPartitioning.numPartitions === 2) + shuffleReaders.foreach { reader => + assert(reader.outputPartitioning.numPartitions === 2) } } } @@ -456,40 +464,111 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. - val queryStageInputs = join.queryExecution.executedPlan.collect { - case q: ShuffleQueryStageInput => q + val finalPlan = join.queryExecution.executedPlan + .asInstanceOf[AdaptiveSparkPlan].resultStage.plan + val shuffleReaders = finalPlan.collect { + case reader: CoalescedShuffleReaderExec => reader } - assert(queryStageInputs.length === 2) + assert(shuffleReaders.length === 2) minNumPostShufflePartitions match { case Some(numPartitions) => - queryStageInputs.foreach { q => - assert(q.partitionStartIndices.isDefined) - assert(q.outputPartitioning.numPartitions === 5) + shuffleReaders.foreach { reader => + assert(reader.outputPartitioning.numPartitions === numPartitions) } case None => - queryStageInputs.foreach { q => - assert(q.partitionStartIndices.isDefined) - assert(q.outputPartitioning.numPartitions === 3) + shuffleReaders.foreach { reader => + assert(reader.outputPartitioning.numPartitions === 3) } } } withSparkSession(test, 12000, minNumPostShufflePartitions) } + + test(s"determining the number of reducers: plan already partitioned$testNameNote") { + val test: SparkSession => Unit = { spark: SparkSession => + try { + spark.range(1000).write.bucketBy(30, "id").saveAsTable("t") + // `df1` is hash partitioned by `id`. + val df1 = spark.read.table("t") + val df2 = + spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key2", "id as value2") + + val join = df1.join(df2, col("id") === col("key2")).select(col("id"), col("value2")) + + // Check the answer first. + val expectedAnswer = spark.range(0, 500).selectExpr("id % 500", "id as value") + .union(spark.range(500, 1000).selectExpr("id % 500", "id as value")) + checkAnswer( + join, + expectedAnswer.collect()) + + // Then, let's make sure we do not reduce number of ppst shuffle partitions. + val finalPlan = join.queryExecution.executedPlan + .asInstanceOf[AdaptiveSparkPlan].resultStage.plan + val shuffleReaders = finalPlan.collect { + case reader: CoalescedShuffleReaderExec => reader + } + assert(shuffleReaders.length === 0) + } finally { + spark.sql("drop table t") + } + } + withSparkSession(test, 12000, minNumPostShufflePartitions) + } } test("SPARK-24705 adaptive query execution works correctly when exchange reuse enabled") { val test = { spark: SparkSession => spark.sql("SET spark.sql.exchange.reuse=true") val df = spark.range(1).selectExpr("id AS key", "id AS value") + + // test case 1: a stage has 3 child stages but they are the same stage. + // ResultQueryStage 1 + // ShuffleQueryStage 0 + // ReusedQueryStage 0 + // ReusedQueryStage 0 val resultDf = df.join(df, "key").join(df, "key") - val sparkPlan = resultDf.queryExecution.executedPlan - val queryStageInputs = sparkPlan.collect { case p: ShuffleQueryStageInput => p } - assert(queryStageInputs.length === 3) - assert(queryStageInputs(0).childStage === queryStageInputs(1).childStage) - assert(queryStageInputs(1).childStage === queryStageInputs(2).childStage) + val finalPlan = resultDf.queryExecution.executedPlan + .asInstanceOf[AdaptiveSparkPlan].resultStage.plan + assert(finalPlan.collect { case p: ReusedQueryStage => p }.length == 2) + assert(finalPlan.collect { case p: CoalescedShuffleReaderExec => p }.length == 3) checkAnswer(resultDf, Row(0, 0, 0, 0) :: Nil) + + // test case 2: a stage has 2 parent stages. + // ResultQueryStage 3 + // ShuffleQueryStage 1 + // ShuffleQueryStage 0 + // ShuffleQueryStage 2 + // ReusedQueryStage 0 + val grouped = df.groupBy("key").agg(max("value").as("value")) + val resultDf2 = grouped.groupBy(col("key") + 1).max("value") + .union(grouped.groupBy(col("key") + 2).max("value")) + + val resultStage = resultDf2.queryExecution.executedPlan + .asInstanceOf[AdaptiveSparkPlan].resultStage + + // The result stage has 2 children + val level1Stages = resultStage.plan.collect { case q: QueryStage => q } + assert(level1Stages.length == 2) + + val leafStages = level1Stages.flatMap { stage => + // All of the child stages of result stage have only one child stage. + val children = stage.plan.collect { case q: QueryStage => q } + assert(children.length == 1) + children + } + assert(leafStages.length == 2) + + val reusedStages = level1Stages.flatMap { stage => + stage.plan.collect { case r: ReusedQueryStage => r } + } + assert(reusedStages.length == 1) + + checkAnswer(resultDf2, Row(1, 0) :: Row(2, 0) :: Nil) } withSparkSession(test, 4, None) }