@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.adaptive
1919
2020import java .util .concurrent .CountDownLatch
2121
22+ import org .apache .spark .SparkException
2223import org .apache .spark .rdd .RDD
2324import org .apache .spark .sql .SparkSession
2425import org .apache .spark .sql .catalyst .InternalRow
@@ -31,44 +32,29 @@ import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
3132 * updates the query plan when a query stage is materialized and provides accurate runtime
3233 * statistics.
3334 */
34- case class AdaptiveSparkPlan (initialPlan : ResultQueryStage , session : SparkSession )
35+ case class AdaptiveSparkPlan (initialPlan : SparkPlan , session : SparkSession )
3536 extends LeafExecNode {
3637
3738 override def output : Seq [Attribute ] = initialPlan.output
3839
39- @ volatile private var currentQueryStage : QueryStage = initialPlan
40+ @ volatile private var currentPlan : SparkPlan = initialPlan
4041 @ volatile private var error : Throwable = null
41- private val readyLock = new CountDownLatch (1 )
4242
43- private def replaceStage (oldStage : QueryStage , newStage : QueryStage ): QueryStage = {
44- if (oldStage.id == newStage.id) {
45- newStage
46- } else {
47- val newPlanForOldStage = oldStage.plan.transform {
48- case q : QueryStage => replaceStage(q, newStage)
49- }
50- oldStage.withNewPlan(newPlanForOldStage)
51- }
52- }
43+ // We will release the lock when we finish planning query stages, or we fail to do the planning.
44+ // Getting `resultStage` will be blocked until the lock is release.
45+ // This is better than wait()/notify(), as we can easily check if the computation has completed,
46+ // by calling `readyLock.getCount()`.
47+ private val readyLock = new CountDownLatch (1 )
5348
5449 private def createCallback (executionId : Option [Long ]): QueryStageTriggerCallback = {
5550 new QueryStageTriggerCallback {
56- override def onStageUpdated (stage : QueryStage ): Unit = {
57- updateCurrentQueryStage(stage, executionId)
58- if (stage.isInstanceOf [ResultQueryStage ]) readyLock.countDown()
59- }
60-
61- override def onStagePlanningFailed (stage : QueryStage , e : Throwable ): Unit = {
62- error = new RuntimeException (
63- s """
64- |Fail to plan stage ${stage.id}:
65- | ${stage.plan.treeString}
66- """ .stripMargin, e)
67- readyLock.countDown()
51+ override def onPlanUpdate (updatedPlan : SparkPlan ): Unit = {
52+ updateCurrentPlan(updatedPlan, executionId)
53+ if (updatedPlan.isInstanceOf [ResultQueryStage ]) readyLock.countDown()
6854 }
6955
7056 override def onStageMaterializingFailed (stage : QueryStage , e : Throwable ): Unit = {
71- error = new RuntimeException (
57+ error = new SparkException (
7258 s """
7359 |Fail to materialize stage ${stage.id}:
7460 | ${stage.plan.treeString}
@@ -83,35 +69,34 @@ case class AdaptiveSparkPlan(initialPlan: ResultQueryStage, session: SparkSessio
8369 }
8470 }
8571
86- private def updateCurrentQueryStage ( newStage : QueryStage , executionId : Option [Long ]): Unit = {
87- currentQueryStage = replaceStage(currentQueryStage, newStage)
72+ private def updateCurrentPlan ( newPlan : SparkPlan , executionId : Option [Long ]): Unit = {
73+ currentPlan = newPlan
8874 executionId.foreach { id =>
8975 session.sparkContext.listenerBus.post(SparkListenerSQLAdaptiveExecutionUpdate (
9076 id,
9177 SQLExecution .getQueryExecution(id).toString,
92- SparkPlanInfo .fromSparkPlan(currentQueryStage )))
78+ SparkPlanInfo .fromSparkPlan(currentPlan )))
9379 }
9480 }
9581
96- def resultStage : ResultQueryStage = {
82+ def finalPlan : ResultQueryStage = {
9783 if (readyLock.getCount > 0 ) {
9884 val sc = session.sparkContext
9985 val executionId = Option (sc.getLocalProperty(SQLExecution .EXECUTION_ID_KEY )).map(_.toLong)
100- val trigger = new QueryStageTrigger (session, createCallback(executionId))
101- trigger.start()
102- trigger.trigger(initialPlan)
86+ val creator = new QueryStageCreator (initialPlan, session, createCallback(executionId))
87+ creator.start()
10388 readyLock.await()
104- trigger .stop()
89+ creator .stop()
10590 }
10691
10792 if (error != null ) throw error
108- currentQueryStage .asInstanceOf [ResultQueryStage ]
93+ currentPlan .asInstanceOf [ResultQueryStage ]
10994 }
11095
111- override def executeCollect (): Array [InternalRow ] = resultStage .executeCollect()
112- override def executeTake (n : Int ): Array [InternalRow ] = resultStage .executeTake(n)
113- override def executeToIterator (): Iterator [InternalRow ] = resultStage .executeToIterator()
114- override def doExecute (): RDD [InternalRow ] = resultStage .execute()
96+ override def executeCollect (): Array [InternalRow ] = finalPlan .executeCollect()
97+ override def executeTake (n : Int ): Array [InternalRow ] = finalPlan .executeTake(n)
98+ override def executeToIterator (): Iterator [InternalRow ] = finalPlan .executeToIterator()
99+ override def doExecute (): RDD [InternalRow ] = finalPlan .execute()
115100 override def generateTreeString (
116101 depth : Int ,
117102 lastChildren : Seq [Boolean ],
@@ -120,7 +105,7 @@ case class AdaptiveSparkPlan(initialPlan: ResultQueryStage, session: SparkSessio
120105 prefix : String = " " ,
121106 addSuffix : Boolean = false ,
122107 maxFields : Int ): Unit = {
123- currentQueryStage .generateTreeString(
108+ currentPlan .generateTreeString(
124109 depth, lastChildren, append, verbose, " " , false , maxFields)
125110 }
126111}
0 commit comments