diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index f625581d40533..37e1e54d8766b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.execution.adaptive.PlanQueryStage +import org.apache.spark.sql.execution.adaptive.InsertAdaptiveSparkPlan import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils @@ -94,32 +94,19 @@ class QueryExecution( * row format conversions as needed. */ protected def prepareForExecution(plan: SparkPlan): SparkPlan = { - val rules = if (sparkSession.sessionState.conf.adaptiveExecutionEnabled) { - adaptivePreparations - } else { - preparations - } - rules.foldLeft(plan) { case (sp, rule) => rule.apply(sp)} + preparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp) } } /** A sequence of rules that will be applied in order to the physical plan before execution. */ protected def preparations: Seq[Rule[SparkPlan]] = Seq( PlanSubqueries(sparkSession), + ReuseSubquery(sparkSession.sessionState.conf), EnsureRequirements(sparkSession.sessionState.conf), + // `AdaptiveSparkPlan` is a leaf node. If inserted, all the following rules will be no-op as + // the original plan is hidden behind `AdaptiveSparkPlan`. + InsertAdaptiveSparkPlan(sparkSession), CollapseCodegenStages(sparkSession.sessionState.conf), - ReuseExchange(sparkSession.sessionState.conf), - ReuseSubquery(sparkSession.sessionState.conf)) - - // With adaptive execution, whole stage codegen will be done inside `QueryStageExecutor`. - protected def adaptivePreparations: Seq[Rule[SparkPlan]] = Seq( - PlanSubqueries(sparkSession), - EnsureRequirements(sparkSession.sessionState.conf), - ReuseExchange(sparkSession.sessionState.conf), - ReuseSubquery(sparkSession.sessionState.conf), - // PlanQueryStage needs to be the last rule because it divides the plan into multiple sub-trees - // by inserting leaf node QueryStage. Transforming the plan after applying this rule will - // only transform node in a sub-tree. - PlanQueryStage(sparkSession)) + ReuseExchange(sparkSession.sessionState.conf)) def simpleString: String = withRedaction { val concat = new StringConcat() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 53b5fc305a330..9c6d5928259d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -53,7 +53,7 @@ private[execution] object SparkPlanInfo { def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = { val children = plan match { case ReusedExchangeExec(_, child) => child :: Nil - case a: AdaptiveSparkPlan => a.resultStage.plan :: Nil + case a: AdaptiveSparkPlan => a.finalPlan.plan :: Nil case stage: QueryStage => stage.plan :: Nil case _ => plan.children ++ plan.subqueries } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlan.scala index 32afea26f239d..44f6c6f497918 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlan.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.adaptive import java.util.concurrent.CountDownLatch +import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow @@ -27,48 +28,33 @@ import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, SparkPlanInfo, S import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate /** - * A root node to trigger query stages and execute the query plan adaptively. It incrementally + * A root node to execute the query plan adaptively. It creates query stages, and incrementally * updates the query plan when a query stage is materialized and provides accurate runtime - * statistics. + * data statistics. */ -case class AdaptiveSparkPlan(initialPlan: ResultQueryStage, session: SparkSession) +case class AdaptiveSparkPlan(initialPlan: SparkPlan, session: SparkSession) extends LeafExecNode{ override def output: Seq[Attribute] = initialPlan.output - @volatile private var currentQueryStage: QueryStage = initialPlan + @volatile private var currentPlan: SparkPlan = initialPlan @volatile private var error: Throwable = null - private val readyLock = new CountDownLatch(1) - private def replaceStage(oldStage: QueryStage, newStage: QueryStage): QueryStage = { - if (oldStage.id == newStage.id) { - newStage - } else { - val newPlanForOldStage = oldStage.plan.transform { - case q: QueryStage => replaceStage(q, newStage) - } - oldStage.withNewPlan(newPlanForOldStage) - } - } + // We will release the lock when we finish planning query stages, or we fail to do the planning. + // Getting `resultStage` will be blocked until the lock is release. + // This is better than wait()/notify(), as we can easily check if the computation has completed, + // by calling `readyLock.getCount()`. + private val readyLock = new CountDownLatch(1) private def createCallback(executionId: Option[Long]): QueryStageTriggerCallback = { new QueryStageTriggerCallback { - override def onStageUpdated(stage: QueryStage): Unit = { - updateCurrentQueryStage(stage, executionId) - if (stage.isInstanceOf[ResultQueryStage]) readyLock.countDown() - } - - override def onStagePlanningFailed(stage: QueryStage, e: Throwable): Unit = { - error = new RuntimeException( - s""" - |Fail to plan stage ${stage.id}: - |${stage.plan.treeString} - """.stripMargin, e) - readyLock.countDown() + override def onPlanUpdate(updatedPlan: SparkPlan): Unit = { + updateCurrentPlan(updatedPlan, executionId) + if (updatedPlan.isInstanceOf[ResultQueryStage]) readyLock.countDown() } override def onStageMaterializingFailed(stage: QueryStage, e: Throwable): Unit = { - error = new RuntimeException( + error = new SparkException( s""" |Fail to materialize stage ${stage.id}: |${stage.plan.treeString} @@ -83,35 +69,34 @@ case class AdaptiveSparkPlan(initialPlan: ResultQueryStage, session: SparkSessio } } - private def updateCurrentQueryStage(newStage: QueryStage, executionId: Option[Long]): Unit = { - currentQueryStage = replaceStage(currentQueryStage, newStage) + private def updateCurrentPlan(newPlan: SparkPlan, executionId: Option[Long]): Unit = { + currentPlan = newPlan executionId.foreach { id => session.sparkContext.listenerBus.post(SparkListenerSQLAdaptiveExecutionUpdate( id, SQLExecution.getQueryExecution(id).toString, - SparkPlanInfo.fromSparkPlan(currentQueryStage))) + SparkPlanInfo.fromSparkPlan(currentPlan))) } } - def resultStage: ResultQueryStage = { + def finalPlan: ResultQueryStage = { if (readyLock.getCount > 0) { val sc = session.sparkContext val executionId = Option(sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)).map(_.toLong) - val trigger = new QueryStageTrigger(session, createCallback(executionId)) - trigger.start() - trigger.trigger(initialPlan) + val creator = new QueryStageCreator(initialPlan, session, createCallback(executionId)) + creator.start() readyLock.await() - trigger.stop() + creator.stop() } if (error != null) throw error - currentQueryStage.asInstanceOf[ResultQueryStage] + currentPlan.asInstanceOf[ResultQueryStage] } - override def executeCollect(): Array[InternalRow] = resultStage.executeCollect() - override def executeTake(n: Int): Array[InternalRow] = resultStage.executeTake(n) - override def executeToIterator(): Iterator[InternalRow] = resultStage.executeToIterator() - override def doExecute(): RDD[InternalRow] = resultStage.execute() + override def executeCollect(): Array[InternalRow] = finalPlan.executeCollect() + override def executeTake(n: Int): Array[InternalRow] = finalPlan.executeTake(n) + override def executeToIterator(): Iterator[InternalRow] = finalPlan.executeToIterator() + override def doExecute(): RDD[InternalRow] = finalPlan.execute() override def generateTreeString( depth: Int, lastChildren: Seq[Boolean], @@ -120,7 +105,7 @@ case class AdaptiveSparkPlan(initialPlan: ResultQueryStage, session: SparkSessio prefix: String = "", addSuffix: Boolean = false, maxFields: Int): Unit = { - currentQueryStage.generateTreeString( + currentPlan.generateTreeString( depth, lastChildren, append, verbose, "", false, maxFields) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala new file mode 100644 index 0000000000000..4a1297b71feb8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.command.ExecutedCommandExec + +/** + * This rule wraps the query plan with an [[AdaptiveSparkPlan]], which executes the query plan + * adaptively with runtime data statistics. Note that this rule must be run after + * [[org.apache.spark.sql.execution.exchange.EnsureRequirements]], so that the exchange nodes are + * already inserted. + */ +case class InsertAdaptiveSparkPlan(session: SparkSession) extends Rule[SparkPlan] { + + override def apply(plan: SparkPlan): SparkPlan = plan match { + case _: ExecutedCommandExec => plan + case _ if session.sessionState.conf.adaptiveExecutionEnabled => + AdaptiveSparkPlan(plan, session.cloneSession()) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanQueryStage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanQueryStage.scala deleted file mode 100644 index 86ec58c6c77cd..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanQueryStage.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.adaptive - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec} - -/** - * Divide the spark plan into multiple QueryStages. For each Exchange in the plan, it wraps it with - * a [[QueryStage]]. At the end it adds an [[AdaptiveSparkPlan]] at the top, which will drive the - * execution of query stages. - */ -case class PlanQueryStage(session: SparkSession) extends Rule[SparkPlan] { - - def apply(plan: SparkPlan): SparkPlan = { - var id = 0 - val exchangeToQueryStage = new java.util.IdentityHashMap[Exchange, QueryStage] - val planWithStages = plan.transformUp { - case e: ShuffleExchangeExec => - val queryStage = ShuffleQueryStage(id, e) - id += 1 - exchangeToQueryStage.put(e, queryStage) - queryStage - case e: BroadcastExchangeExec => - val queryStage = BroadcastQueryStage(id, e) - id += 1 - exchangeToQueryStage.put(e, queryStage) - queryStage - // The `ReusedExchangeExec` was added in the rule `ReuseExchange`, via transforming up the - // query plan. This rule also transform up the query plan, so when we hit `ReusedExchangeExec` - // here, the exchange being reused must already be hit before and there should be an entry - // for it in `exchangeToQueryStage`. - case e: ReusedExchangeExec => - val existingQueryStage = exchangeToQueryStage.get(e.child) - assert(existingQueryStage != null, "The exchange being reused should be hit before.") - ReusedQueryStage(existingQueryStage, e.output) - } - AdaptiveSparkPlan(ResultQueryStage(id, planWithStages), session) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala index 0994651c20c19..6edd1e4eafb6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStage.scala @@ -29,8 +29,9 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange._ /** - * In adaptive execution mode, an execution plan is divided into multiple QueryStages w.r.t. the - * exchange as boundary. Each QueryStage is a sub-tree that runs in a single Spark stage. + * A query stage is an individual sub-tree of a query plan, which can be executed ahead and provide + * accurate data statistics. For example, a sub-tree under shuffle/broadcast node is a query stage. + * Each query stage runs in a single Spark job/stage. */ abstract class QueryStage extends LeafExecNode { @@ -65,6 +66,7 @@ abstract class QueryStage extends LeafExecNode { override def executeToIterator(): Iterator[InternalRow] = plan.executeToIterator() override def doExecute(): RDD[InternalRow] = plan.execute() override def doExecuteBroadcast[T](): Broadcast[T] = plan.executeBroadcast() + override def doCanonicalize(): SparkPlan = plan.canonicalized // TODO: maybe we should not hide QueryStage entirely from explain result. override def generateTreeString( @@ -86,7 +88,7 @@ abstract class QueryStage extends LeafExecNode { case class ResultQueryStage(id: Int, plan: SparkPlan) extends QueryStage { override def materialize(): Future[Any] = { - Future.unit + throw new IllegalStateException("Cannot materialize ResultQueryStage.") } override def withNewPlan(newPlan: SparkPlan): QueryStage = { @@ -95,7 +97,7 @@ case class ResultQueryStage(id: Int, plan: SparkPlan) extends QueryStage { } /** - * A shuffle QueryStage whose child is a ShuffleExchangeExec. + * A shuffle QueryStage whose child is a [[ShuffleExchangeExec]]. */ case class ShuffleQueryStage(id: Int, plan: ShuffleExchangeExec) extends QueryStage { @@ -119,7 +121,7 @@ case class ShuffleQueryStage(id: Int, plan: ShuffleExchangeExec) extends QuerySt } /** - * A broadcast QueryStage whose child is a BroadcastExchangeExec. + * A broadcast QueryStage whose child is a [[BroadcastExchangeExec]]. */ case class BroadcastQueryStage(id: Int, plan: BroadcastExchangeExec) extends QueryStage { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageCreator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageCreator.scala new file mode 100644 index 0000000000000..56f883f5123f8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageCreator.scala @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import scala.collection.mutable +import scala.concurrent.{ExecutionContext, ExecutionContextExecutorService} + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{CollapseCodegenStages, SparkPlan} +import org.apache.spark.sql.execution.adaptive.rule.{AssertChildStagesMaterialized, ReduceNumShufflePartitions} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ShuffleExchangeExec} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{EventLoop, ThreadUtils} + +/** + * This class dynamically creates [[QueryStage]] bottom-up, optimize the query plan of query stages + * and materialize them. It creates as many query stages as possible at the same time, and + * materialize a query stage when all its child stages are materialized. + * + * To create query stages, we traverse the query tree bottom up. When we hit an exchange node, and + * all the child query stages of this exchange node are materialized, we try to create a new query + * stage for this exchange node. + * + * To create a new query stage, we first optimize the sub-tree of the exchange. After optimization, + * we check the output partitioning of the optimized sub-tree, and see if the exchange node is still + * necessary. + * + * If the exchange node becomes unnecessary, remove it and give up this query stage creation, and + * continue to traverse the query plan tree until we hit the next exchange node. + * + * If the exchange node is still needed, create the query stage and optimize its sub-tree again. + * It's necessary to have both the pre-creation optimization and post-creation optimization, because + * these 2 optimization have different assumptions. For pre-creation optimization, the shuffle node + * may be removed later on and the current sub-tree may be only a part of a query stage, so we don't + * have the big picture of the query stage yet. For post-creation optimization, the query stage is + * created and we have the big picture of the query stage. + * + * After the query stage is optimized, we materialize it asynchronously, and continue to traverse + * the query plan tree to create more query stages. + * + * When a query stage completes materialization, we trigger the process of query stages creation and + * traverse the query plan tree again. + */ +class QueryStageCreator( + initialPlan: SparkPlan, + session: SparkSession, + callback: QueryStageTriggerCallback) + extends EventLoop[QueryStageCreatorEvent]("QueryStageCreator") { + + private def conf = session.sessionState.conf + + private val readyStages = mutable.HashSet.empty[Int] + + private var currentStageId = 0 + + private val stageCache = mutable.HashMap.empty[StructType, mutable.Buffer[(Exchange, QueryStage)]] + + // The optimizer rules that will be applied to a sub-tree of the query plan before the stage is + // created. Note that we may end up not creating the query stage, so the rules here should not + // assume the given sub-plan-tree is the entire query plan of the query stage. For example, if a + // rule want to collect all the child query stages, it should not be put here. + private val preStageCreationOptimizerRules: Seq[Rule[SparkPlan]] = Seq( + AssertChildStagesMaterialized + ) + + // The optimizer rules that will be applied to a sub-tree of the query plan after the stage is + // created. Note that once the stage is created, we will not remove it anymore. If a rule changes + // the output partitioning of the sub-plan-tree, which may help to remove the exchange node, it's + // better to put it in `preStageCreationOptimizerRules`, so that we may create less query stages. + private val postStageCreationOptimizerRules: Seq[Rule[SparkPlan]] = Seq( + ReduceNumShufflePartitions(conf), + CollapseCodegenStages(conf)) + + private var currentPlan = initialPlan + + private implicit def executionContext: ExecutionContextExecutorService = { + QueryStageCreator.executionContext + } + + override protected def onReceive(event: QueryStageCreatorEvent): Unit = event match { + case StartCreation => + // set active session for the event loop thread. + SparkSession.setActiveSession(session) + currentPlan = createQueryStages(initialPlan) + + case MaterializeStage(stage) => + stage.materialize().onComplete { res => + if (res.isSuccess) { + post(StageReady(stage)) + } else { + callback.onStageMaterializingFailed(stage, res.failed.get) + stop() + } + } + + case StageReady(stage) => + if (stage.isInstanceOf[ResultQueryStage]) { + callback.onPlanUpdate(stage) + stop() + } else { + readyStages += stage.id + currentPlan = createQueryStages(currentPlan) + } + } + + override protected def onStart(): Unit = { + post(StartCreation) + } + + private def preStageCreationOptimize(plan: SparkPlan): SparkPlan = { + preStageCreationOptimizerRules.foldLeft(plan) { + case (current, rule) => rule(current) + } + } + + private def postStageCreationOptimize(plan: SparkPlan): SparkPlan = { + postStageCreationOptimizerRules.foldLeft(plan) { + case (current, rule) => rule(current) + } + } + + /** + * Traverse the query plan bottom-up, and creates query stages as many as possible. + */ + private def createQueryStages(plan: SparkPlan): SparkPlan = { + val result = createQueryStages0(plan) + if (result.allChildStagesReady) { + val finalPlan = postStageCreationOptimize(preStageCreationOptimize(result.newPlan)) + post(StageReady(ResultQueryStage(currentStageId, finalPlan))) + finalPlan + } else { + callback.onPlanUpdate(result.newPlan) + result.newPlan + } + } + + /** + * This method is called recursively to traverse the plan tree bottom-up. This method returns two + * information: 1) the new plan after we insert query stages. 2) whether or not the child query + * stages of the new plan are all ready. + * + * if the current plan is an exchange node, and all its child query stages are ready, we try to + * create a new query stage. + */ + private def createQueryStages0(plan: SparkPlan): CreateStageResult = plan match { + case e: Exchange => + val similarStages = stageCache.getOrElseUpdate(e.schema, mutable.Buffer.empty) + similarStages.find(_._1.sameResult(e)) match { + case Some((_, existingStage)) if conf.exchangeReuseEnabled => + CreateStageResult( + newPlan = ReusedQueryStage(existingStage, e.output), + allChildStagesReady = readyStages.contains(existingStage.id)) + + case _ => + val result = createQueryStages0(e.child) + // Try to create a query stage only when all the child query stages are ready. + if (result.allChildStagesReady) { + val optimizedPlan = preStageCreationOptimize(result.newPlan) + e match { + case s: ShuffleExchangeExec => + (s.desiredPartitioning, optimizedPlan.outputPartitioning) match { + case (desired: HashPartitioning, actual: HashPartitioning) + if desired.semanticEquals(actual) => + // This shuffle exchange is unnecessary now, remove it. The reason maybe: + // 1. the child plan has changed its output partitioning after optimization, + // and makes this exchange node unnecessary. + // 2. this exchange node is user specified, which turns out to be unnecessary. + CreateStageResult(newPlan = optimizedPlan, allChildStagesReady = true) + case _ => + val queryStage = createQueryStage(s.copy(child = optimizedPlan)) + similarStages.append(e -> queryStage) + // We've created a new stage, which is obviously not ready yet. + CreateStageResult(newPlan = queryStage, allChildStagesReady = false) + } + + case b: BroadcastExchangeExec => + val queryStage = createQueryStage(b.copy(child = optimizedPlan)) + similarStages.append(e -> queryStage) + // We've created a new stage, which is obviously not ready yet. + CreateStageResult(newPlan = queryStage, allChildStagesReady = false) + } + } else { + CreateStageResult( + newPlan = e.withNewChildren(Seq(result.newPlan)), + allChildStagesReady = false) + } + } + + case q: QueryStage => + CreateStageResult(newPlan = q, allChildStagesReady = readyStages.contains(q.id)) + + case _ => + if (plan.children.isEmpty) { + CreateStageResult(newPlan = plan, allChildStagesReady = true) + } else { + val results = plan.children.map(createQueryStages0) + CreateStageResult( + newPlan = plan.withNewChildren(results.map(_.newPlan)), + allChildStagesReady = results.forall(_.allChildStagesReady)) + } + } + + private def createQueryStage(e: Exchange): QueryStage = { + val optimizedPlan = postStageCreationOptimize(e.child) + val queryStage = e match { + case s: ShuffleExchangeExec => + ShuffleQueryStage(currentStageId, s.copy(child = optimizedPlan)) + case b: BroadcastExchangeExec => + BroadcastQueryStage(currentStageId, b.copy(child = optimizedPlan)) + } + currentStageId += 1 + post(MaterializeStage(queryStage)) + queryStage + } + + override protected def onError(e: Throwable): Unit = callback.onError(e) +} + +case class CreateStageResult(newPlan: SparkPlan, allChildStagesReady: Boolean) + +object QueryStageCreator { + private val executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("QueryStageCreator", 16)) +} + +trait QueryStageTriggerCallback { + def onPlanUpdate(updatedPlan: SparkPlan): Unit + def onStageMaterializingFailed(stage: QueryStage, e: Throwable): Unit + def onError(e: Throwable): Unit +} + +sealed trait QueryStageCreatorEvent + +object StartCreation extends QueryStageCreatorEvent + +case class MaterializeStage(stage: QueryStage) extends QueryStageCreatorEvent + +case class StageReady(stage: QueryStage) extends QueryStageCreatorEvent diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageTrigger.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageTrigger.scala deleted file mode 100644 index 5dd65f42442b2..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageTrigger.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.adaptive - -import scala.collection.mutable.{HashMap, HashSet, ListBuffer} -import scala.concurrent.{ExecutionContext, ExecutionContextExecutorService, Future} - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.util.{EventLoop, ThreadUtils} - -/** - * This class triggers [[QueryStage]] bottom-up, apply planner rules for query stages and - * materialize them. It triggers as many query stages as possible at the same time, and triggers - * the parent query stage when all its child stages are materialized. - */ -class QueryStageTrigger(session: SparkSession, callback: QueryStageTriggerCallback) - extends EventLoop[QueryStageTriggerEvent]("QueryStageTrigger") { - - private val stageToParentStages = HashMap.empty[Int, ListBuffer[QueryStage]] - - private val idToUpdatedStage = HashMap.empty[Int, QueryStage] - - private val stageToNumPendingChildStages = HashMap.empty[Int, Int] - - private val submittedStages = HashSet.empty[Int] - - private val readyStages = HashSet.empty[Int] - - private val planner = new QueryStagePlanner(session.sessionState.conf) - - def trigger(stage: QueryStage): Unit = { - post(SubmitStage(stage)) - } - - private implicit def executionContext: ExecutionContextExecutorService = { - QueryStageTrigger.executionContext - } - - override protected def onReceive(event: QueryStageTriggerEvent): Unit = event match { - case SubmitStage(stage) => - // We may submit a query stage multiple times, because of stage reuse. Here we avoid - // re-submitting a query stage. - if (!submittedStages.contains(stage.id)) { - submittedStages += stage.id - val pendingChildStages = stage.plan.collect { - // The stage being submitted may have child stages that are already ready, if the child - // stage is a reused stage. - case stage: QueryStage if !readyStages.contains(stage.id) => stage - } - if (pendingChildStages.isEmpty) { - // This is a leaf stage, or all its child stages are ready, we can plan it now. - post(PlanStage(stage)) - } else { - // This stage has some pending child stages, we store the connection of this stage and - // its child stages, and submit all the child stages, so that we can plan this stage - // later when all its child stages are ready. - stageToNumPendingChildStages(stage.id) = pendingChildStages.length - pendingChildStages.foreach { child => - // a child may have multiple parents, because of query stage reuse. - val parentStages = stageToParentStages.getOrElseUpdate(child.id, new ListBuffer) - parentStages += stage - post(SubmitStage(child)) - } - } - } - - case PlanStage(stage) => - Future { - // planning needs active SparkSession in current thread. - SparkSession.setActiveSession(session) - planner.execute(stage.plan) - }.onComplete { res => - if (res.isSuccess) { - post(StagePlanned(stage, res.get)) - } else { - callback.onStagePlanningFailed(stage, res.failed.get) - stop() - } - } - submittedStages += stage.id - - case StagePlanned(stage, optimizedPlan) => - val newStage = stage.withNewPlan(optimizedPlan) - // We store the new stage with the new query plan after planning, so that later on we can - // update the query plan of its parent stage. - idToUpdatedStage(newStage.id) = newStage - // This stage has optimized its plan, notify the callback about this change. - callback.onStageUpdated(newStage) - - newStage.materialize().onComplete { res => - if (res.isSuccess) { - post(StageReady(stage)) - } else { - callback.onStageMaterializingFailed(newStage, res.failed.get) - stop() - } - } - - case StageReady(stage) => - readyStages += stage.id - stageToParentStages.remove(stage.id).foreach { parentStages => - parentStages.foreach { parent => - val numPendingChildStages = stageToNumPendingChildStages(parent.id) - if (numPendingChildStages == 1) { - stageToNumPendingChildStages.remove(parent.id) - // All its child stages are ready, here we update the query plan via replacing the old - // child stages with new ones that are planned. - val newPlan = parent.plan.transform { - case q: QueryStage => idToUpdatedStage(q.id) - } - // We can plan this stage now. - post(PlanStage(parent.withNewPlan(newPlan))) - } else { - assert(numPendingChildStages > 1) - stageToNumPendingChildStages(parent.id) = numPendingChildStages - 1 - } - } - } - } - - override protected def onError(e: Throwable): Unit = callback.onError(e) -} - -object QueryStageTrigger { - private val executionContext = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("QueryStageTrigger", 16)) -} - -trait QueryStageTriggerCallback { - def onStageUpdated(stage: QueryStage): Unit - def onStagePlanningFailed(stage: QueryStage, e: Throwable): Unit - def onStageMaterializingFailed(stage: QueryStage, e: Throwable): Unit - def onError(e: Throwable): Unit -} - -sealed trait QueryStageTriggerEvent - -case class SubmitStage(stage: QueryStage) extends QueryStageTriggerEvent - -case class PlanStage(stage: QueryStage) extends QueryStageTriggerEvent - -case class StagePlanned(stage: QueryStage, optimizedPlan: SparkPlan) extends QueryStageTriggerEvent - -case class StageReady(stage: QueryStage) extends QueryStageTriggerEvent diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStagePlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/rule/AssertChildStagesMaterialized.scala similarity index 62% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStagePlanner.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/rule/AssertChildStagesMaterialized.scala index 0319d146c30dc..33cd6f74ad29e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStagePlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/rule/AssertChildStagesMaterialized.scala @@ -15,25 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.adaptive +package org.apache.spark.sql.execution.adaptive.rule -import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} -import org.apache.spark.sql.execution.{CollapseCodegenStages, SparkPlan} -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.adaptive.QueryStage -class QueryStagePlanner(conf: SQLConf) extends RuleExecutor[SparkPlan] { - - override protected def batches: Seq[Batch] = Seq( - Batch("QueryStage Optimization", Once, - AssertChildStagesMaterialized, - ReduceNumShufflePartitions(conf), - CollapseCodegenStages(conf) - ) - ) -} - -// A sanity check rule to make sure we are running `QueryStagePlanner` on a sub-tree of query plan -// with all input stages materialized. +// A sanity check rule to make sure we are running query stage optimizer rules on a sub-tree of +// query plan with all input stages materialized. object AssertChildStagesMaterialized extends Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = plan.transform { case q: QueryStage if !q.materialize().isCompleted => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/rule/ReduceNumShufflePartitions.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/rule/ReduceNumShufflePartitions.scala index 1f3e85c153dad..9849ef9f4503b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/rule/ReduceNumShufflePartitions.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.adaptive +package org.apache.spark.sql.execution.adaptive.rule import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration.Duration @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.adaptive.ShuffleQueryStage import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.ThreadUtils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala index 70a7ed0e8d1d1..00bcee330afde 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.{MapOutputStatistics, SparkConf, SparkFunSuite} import org.apache.spark.internal.config.UI.UI_ENABLED import org.apache.spark.sql._ import org.apache.spark.sql.execution.adaptive._ +import org.apache.spark.sql.execution.adaptive.rule.{CoalescedShuffleReaderExec, ReduceNumShufflePartitions} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -314,7 +315,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val finalPlan = agg.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlan].resultStage.plan + .asInstanceOf[AdaptiveSparkPlan].finalPlan.plan val shuffleReaders = finalPlan.collect { case reader: CoalescedShuffleReaderExec => reader } @@ -361,7 +362,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val finalPlan = join.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlan].resultStage.plan + .asInstanceOf[AdaptiveSparkPlan].finalPlan.plan val shuffleReaders = finalPlan.collect { case reader: CoalescedShuffleReaderExec => reader } @@ -413,7 +414,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val finalPlan = join.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlan].resultStage.plan + .asInstanceOf[AdaptiveSparkPlan].finalPlan.plan val shuffleReaders = finalPlan.collect { case reader: CoalescedShuffleReaderExec => reader } @@ -465,7 +466,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val finalPlan = join.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlan].resultStage.plan + .asInstanceOf[AdaptiveSparkPlan].finalPlan.plan val shuffleReaders = finalPlan.collect { case reader: CoalescedShuffleReaderExec => reader } @@ -508,7 +509,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA // Then, let's make sure we do not reduce number of ppst shuffle partitions. val finalPlan = join.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlan].resultStage.plan + .asInstanceOf[AdaptiveSparkPlan].finalPlan.plan val shuffleReaders = finalPlan.collect { case reader: CoalescedShuffleReaderExec => reader } @@ -533,7 +534,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA // ReusedQueryStage 0 val resultDf = df.join(df, "key").join(df, "key") val finalPlan = resultDf.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlan].resultStage.plan + .asInstanceOf[AdaptiveSparkPlan].finalPlan.plan assert(finalPlan.collect { case p: ReusedQueryStage => p }.length == 2) assert(finalPlan.collect { case p: CoalescedShuffleReaderExec => p }.length == 3) checkAnswer(resultDf, Row(0, 0, 0, 0) :: Nil) @@ -549,7 +550,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA .union(grouped.groupBy(col("key") + 2).max("value")) val resultStage = resultDf2.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlan].resultStage + .asInstanceOf[AdaptiveSparkPlan].finalPlan // The result stage has 2 children val level1Stages = resultStage.plan.collect { case q: QueryStage => q }