diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index a14c584fdc6a6..6a7d13fb5421b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -210,6 +210,15 @@ object StaticSQLConf { .checkValue(thres => thres > 0 && thres <= 128, "The threshold must be in (0,128].") .createWithDefault(16) + val RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD = + buildStaticConf("spark.sql.resultQueryStage.maxThreadThreshold") + .internal() + .doc("The maximum degree of parallelism to execute ResultQueryStageExec in AQE") + .version("4.0.0") + .intConf + .checkValue(thres => thres > 0 && thres <= 1024, "The threshold must be in (0,1024].") + .createWithDefault(1024) + val SQL_EVENT_TRUNCATE_LENGTH = buildStaticConf("spark.sql.event.truncate.length") .doc("Threshold of SQL length beyond which it will be truncated before adding to " + "event. Defaults to no truncation. If set to 0, callsite will be logged instead.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index cb3f382914321..9dcb38f8ff10e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import java.util.concurrent.{ConcurrentHashMap, ExecutorService, Future => JFuture} +import java.util.concurrent.{CompletableFuture, ConcurrentHashMap, ExecutorService} import java.util.concurrent.atomic.AtomicLong import scala.jdk.CollectionConverters._ @@ -301,7 +301,7 @@ object SQLExecution extends Logging { * SparkContext local properties are forwarded to execution thread */ def withThreadLocalCaptured[T]( - sparkSession: SparkSession, exec: ExecutorService) (body: => T): JFuture[T] = { + sparkSession: SparkSession, exec: ExecutorService) (body: => T): CompletableFuture[T] = { val activeSession = sparkSession val sc = sparkSession.sparkContext val localProps = Utils.cloneProperties(sc.getLocalProperties) @@ -309,7 +309,7 @@ object SQLExecution extends Logging { // mode, we default back to the resources of the current Spark session. val artifactState = JobArtifactSet.getCurrentJobArtifactState.getOrElse( activeSession.artifactManager.state) - exec.submit(() => JobArtifactSet.withActiveJobArtifactState(artifactState) { + CompletableFuture.supplyAsync(() => JobArtifactSet.withActiveJobArtifactState(artifactState) { val originalSession = SparkSession.getActiveSession val originalLocalProps = sc.getLocalProperties SparkSession.setActiveSession(activeSession) @@ -326,6 +326,6 @@ object SQLExecution extends Logging { SparkSession.clearActiveSession() } res - }) + }, exec) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index b2c71bd996a2c..07d215f8a186f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -268,9 +268,11 @@ case class AdaptiveSparkPlanExec( def finalPhysicalPlan: SparkPlan = withFinalPlanUpdate(identity) - private def getFinalPhysicalPlan(): SparkPlan = lock.synchronized { - if (isFinalPlan) return currentPhysicalPlan - + /** + * Run `fun` on finalized physical plan + */ + def withFinalPlanUpdate[T](fun: SparkPlan => T): T = lock.synchronized { + _isFinalPlan = false // In case of this adaptive plan being executed out of `withActive` scoped functions, e.g., // `plan.queryExecution.rdd`, we need to set active session here as new plan nodes can be // created in the middle of the execution. @@ -279,7 +281,7 @@ case class AdaptiveSparkPlanExec( // Use inputPlan logicalLink here in case some top level physical nodes may be removed // during `initialPlan` var currentLogicalPlan = inputPlan.logicalLink.get - var result = createQueryStages(currentPhysicalPlan) + var result = createQueryStages(fun, currentPhysicalPlan, firstRun = true) val events = new LinkedBlockingQueue[StageMaterializationEvent]() val errors = new mutable.ArrayBuffer[Throwable]() var stagesToReplace = Seq.empty[QueryStageExec] @@ -344,56 +346,53 @@ case class AdaptiveSparkPlanExec( if (errors.nonEmpty) { cleanUpAndThrowException(errors.toSeq, None) } - - // Try re-optimizing and re-planning. Adopt the new plan if its cost is equal to or less - // than that of the current plan; otherwise keep the current physical plan together with - // the current logical plan since the physical plan's logical links point to the logical - // plan it has originated from. - // Meanwhile, we keep a list of the query stages that have been created since last plan - // update, which stands for the "semantic gap" between the current logical and physical - // plans. And each time before re-planning, we replace the corresponding nodes in the - // current logical plan with logical query stages to make it semantically in sync with - // the current physical plan. Once a new plan is adopted and both logical and physical - // plans are updated, we can clear the query stage list because at this point the two plans - // are semantically and physically in sync again. - val logicalPlan = replaceWithQueryStagesInLogicalPlan(currentLogicalPlan, stagesToReplace) - val afterReOptimize = reOptimize(logicalPlan) - if (afterReOptimize.isDefined) { - val (newPhysicalPlan, newLogicalPlan) = afterReOptimize.get - val origCost = costEvaluator.evaluateCost(currentPhysicalPlan) - val newCost = costEvaluator.evaluateCost(newPhysicalPlan) - if (newCost < origCost || - (newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) { - lazy val plans = - sideBySide(currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString("\n") - logOnLevel(log"Plan changed:\n${MDC(QUERY_PLAN, plans)}") - cleanUpTempTags(newPhysicalPlan) - currentPhysicalPlan = newPhysicalPlan - currentLogicalPlan = newLogicalPlan - stagesToReplace = Seq.empty[QueryStageExec] + if (!currentPhysicalPlan.isInstanceOf[ResultQueryStageExec]) { + // Try re-optimizing and re-planning. Adopt the new plan if its cost is equal to or less + // than that of the current plan; otherwise keep the current physical plan together with + // the current logical plan since the physical plan's logical links point to the logical + // plan it has originated from. + // Meanwhile, we keep a list of the query stages that have been created since last plan + // update, which stands for the "semantic gap" between the current logical and physical + // plans. And each time before re-planning, we replace the corresponding nodes in the + // current logical plan with logical query stages to make it semantically in sync with + // the current physical plan. Once a new plan is adopted and both logical and physical + // plans are updated, we can clear the query stage list because at this point the two + // plans are semantically and physically in sync again. + val logicalPlan = replaceWithQueryStagesInLogicalPlan(currentLogicalPlan, stagesToReplace) + val afterReOptimize = reOptimize(logicalPlan) + if (afterReOptimize.isDefined) { + val (newPhysicalPlan, newLogicalPlan) = afterReOptimize.get + val origCost = costEvaluator.evaluateCost(currentPhysicalPlan) + val newCost = costEvaluator.evaluateCost(newPhysicalPlan) + if (newCost < origCost || + (newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) { + lazy val plans = sideBySide( + currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString("\n") + logOnLevel(log"Plan changed:\n${MDC(QUERY_PLAN, plans)}") + cleanUpTempTags(newPhysicalPlan) + currentPhysicalPlan = newPhysicalPlan + currentLogicalPlan = newLogicalPlan + stagesToReplace = Seq.empty[QueryStageExec] + } } } // Now that some stages have finished, we can try creating new stages. - result = createQueryStages(currentPhysicalPlan) + result = createQueryStages(fun, currentPhysicalPlan, firstRun = false) } - - // Run the final plan when there's no more unfinished stages. - currentPhysicalPlan = applyPhysicalRules( - optimizeQueryStage(result.newPlan, isFinalStage = true), - postStageCreationRules(supportsColumnar), - Some((planChangeLogger, "AQE Post Stage Creation"))) - _isFinalPlan = true - executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan))) - currentPhysicalPlan } + _isFinalPlan = true + finalPlanUpdate + // Dereference the result so it can be GCed. After this resultStage.isMaterialized will return + // false, which is expected. If we want to collect result again, we should invoke + // `withFinalPlanUpdate` and pass another result handler and we will create a new result stage. + currentPhysicalPlan.asInstanceOf[ResultQueryStageExec].resultOption.getAndUpdate(_ => None) + .get.asInstanceOf[T] } // Use a lazy val to avoid this being called more than once. @transient private lazy val finalPlanUpdate: Unit = { - // Subqueries that don't belong to any query stage of the main query will execute after the - // last UI update in `getFinalPhysicalPlan`, so we need to update UI here again to make sure - // the newly generated nodes of those subqueries are updated. - if (shouldUpdatePlan && currentPhysicalPlan.exists(_.subqueries.nonEmpty)) { + // Do final plan update after result stage has materialized. + if (shouldUpdatePlan) { getExecutionId.foreach(onUpdatePlan(_, Seq.empty)) } logOnLevel(log"Final plan:\n${MDC(QUERY_PLAN, currentPhysicalPlan)}") @@ -426,13 +425,6 @@ case class AdaptiveSparkPlanExec( } } - private def withFinalPlanUpdate[T](fun: SparkPlan => T): T = { - val plan = getFinalPhysicalPlan() - val result = fun(plan) - finalPlanUpdate - result - } - protected override def stringArgs: Iterator[Any] = Iterator(s"isFinalPlan=$isFinalPlan") override def generateTreeString( @@ -521,6 +513,66 @@ case class AdaptiveSparkPlanExec( this.inputPlan == obj.asInstanceOf[AdaptiveSparkPlanExec].inputPlan } + /** + * We separate stage creation of result and non-result stages because there are several edge cases + * of result stage creation: + * - existing ResultQueryStage created in previous `withFinalPlanUpdate`. + * - the root node is a non-result query stage and we have to create query result stage on top of + * it. + * - we create a non-result query stage as root node and the stage is immediately materialized + * due to stage resue, therefore we have to create a result stage right after. + * + * This method wraps around `createNonResultQueryStages`, the general logic is: + * - Early return if ResultQueryStageExec already created before. + * - Create non result query stage if possible. + * - Try to create result query stage when there is no new non-result query stage created and all + * stages are materialized. + */ + private def createQueryStages( + resultHandler: SparkPlan => Any, + plan: SparkPlan, + firstRun: Boolean): CreateStageResult = { + plan match { + // 1. ResultQueryStageExec is already created, no need to create non-result stages + case resultStage @ ResultQueryStageExec(_, optimizedPlan, _) => + assertStageNotFailed(resultStage) + if (firstRun) { + // There is already an existing ResultQueryStage created in previous `withFinalPlanUpdate` + // e.g, when we do `df.collect` multiple times. Here we create a new result stage to + // execute it again, as the handler function can be different. + val newResultStage = ResultQueryStageExec(currentStageId, optimizedPlan, resultHandler) + currentStageId += 1 + setLogicalLinkForNewQueryStage(newResultStage, optimizedPlan) + CreateStageResult(newPlan = newResultStage, + allChildStagesMaterialized = false, + newStages = Seq(newResultStage)) + } else { + // We will hit this branch after we've created result query stage in the AQE loop, we + // should do nothing. + CreateStageResult(newPlan = resultStage, + allChildStagesMaterialized = resultStage.isMaterialized, + newStages = Seq.empty) + } + case _ => + // 2. Create non result query stage + val result = createNonResultQueryStages(plan) + var allNewStages = result.newStages + var newPlan = result.newPlan + var allChildStagesMaterialized = result.allChildStagesMaterialized + // 3. Create result stage + if (allNewStages.isEmpty && allChildStagesMaterialized) { + val resultStage = newResultQueryStage(resultHandler, newPlan) + newPlan = resultStage + allChildStagesMaterialized = false + allNewStages :+= resultStage + } + CreateStageResult( + newPlan = newPlan, + allChildStagesMaterialized = allChildStagesMaterialized, + newStages = allNewStages) + } + } + /** * This method is called recursively to traverse the plan tree bottom-up and create a new query * stage or try reusing an existing stage if the current node is an [[Exchange]] node and all of @@ -531,7 +583,7 @@ case class AdaptiveSparkPlanExec( * 2) Whether the child query stages (if any) of the current node have all been materialized. * 3) A list of the new query stages that have been created. */ - private def createQueryStages(plan: SparkPlan): CreateStageResult = plan match { + private def createNonResultQueryStages(plan: SparkPlan): CreateStageResult = plan match { case e: Exchange => // First have a quick check in the `stageCache` without having to traverse down the node. context.stageCache.get(e.canonicalized) match { @@ -544,7 +596,7 @@ case class AdaptiveSparkPlanExec( newStages = if (isMaterialized) Seq.empty else Seq(stage)) case _ => - val result = createQueryStages(e.child) + val result = createNonResultQueryStages(e.child) val newPlan = e.withNewChildren(Seq(result.newPlan)).asInstanceOf[Exchange] // Create a query stage only when all the child query stages are ready. if (result.allChildStagesMaterialized) { @@ -588,7 +640,7 @@ case class AdaptiveSparkPlanExec( if (plan.children.isEmpty) { CreateStageResult(newPlan = plan, allChildStagesMaterialized = true, newStages = Seq.empty) } else { - val results = plan.children.map(createQueryStages) + val results = plan.children.map(createNonResultQueryStages) CreateStageResult( newPlan = plan.withNewChildren(results.map(_.newPlan)), allChildStagesMaterialized = results.forall(_.allChildStagesMaterialized), @@ -596,6 +648,20 @@ case class AdaptiveSparkPlanExec( } } + private def newResultQueryStage( + resultHandler: SparkPlan => Any, + plan: SparkPlan): ResultQueryStageExec = { + // Run the final plan when there's no more unfinished stages. + val optimizedRootPlan = applyPhysicalRules( + optimizeQueryStage(plan, isFinalStage = true), + postStageCreationRules(supportsColumnar), + Some((planChangeLogger, "AQE Post Stage Creation"))) + val resultStage = ResultQueryStageExec(currentStageId, optimizedRootPlan, resultHandler) + currentStageId += 1 + setLogicalLinkForNewQueryStage(resultStage, plan) + resultStage + } + private def newQueryStage(plan: SparkPlan): QueryStageExec = { val queryStage = plan match { case e: Exchange => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala index afbb1f0e5a37e..2556edee8d02f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala @@ -129,10 +129,12 @@ trait AdaptiveSparkPlanHelper { } /** - * Strip the executePlan of AdaptiveSparkPlanExec leaf node. + * Strip the top [[AdaptiveSparkPlanExec]] and [[ResultQueryStageExec]] nodes off + * the [[SparkPlan]]. */ def stripAQEPlan(p: SparkPlan): SparkPlan = p match { - case a: AdaptiveSparkPlanExec => a.executedPlan + case a: AdaptiveSparkPlanExec => stripAQEPlan(a.executedPlan) + case ResultQueryStageExec(_, plan, _) => plan case other => other } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index 2391fe740118d..0a5bdefea7bc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.adaptive import java.util.concurrent.atomic.AtomicReference +import scala.concurrent.ExecutionContext import scala.concurrent.Future +import scala.concurrent.Promise import org.apache.spark.{MapOutputStatistics, SparkException} import org.apache.spark.broadcast.Broadcast @@ -32,7 +34,10 @@ import org.apache.spark.sql.columnar.CachedBatch import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.columnar.InMemoryTableScanLike import org.apache.spark.sql.execution.exchange._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.StaticSQLConf import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.ThreadUtils /** * A query stage is an independent subgraph of the query plan. AQE framework will materialize its @@ -303,3 +308,43 @@ case class TableCacheQueryStageExec( override def getRuntimeStatistics: Statistics = inMemoryTableScan.runtimeStatistics } + +case class ResultQueryStageExec( + override val id: Int, + override val plan: SparkPlan, + resultHandler: SparkPlan => Any) extends QueryStageExec { + + override def resetMetrics(): Unit = { + plan.resetMetrics() + } + + override protected def doMaterialize(): Future[Any] = { + val javaFuture = SQLExecution.withThreadLocalCaptured( + session, + ResultQueryStageExec.executionContext) { + resultHandler(plan) + } + val scalaPromise: Promise[Any] = Promise() + javaFuture.whenComplete { (result: Any, exception: Throwable) => + if (exception != null) { + scalaPromise.failure(exception match { + case completionException: java.util.concurrent.CompletionException => + completionException.getCause + case ex => ex + }) + } else { + scalaPromise.success(result) + } + } + scalaPromise.future + } + + // Result stage could be any SparkPlan, so we don't have a specific runtime statistics for it. + override def getRuntimeStatistics: Statistics = Statistics.DUMMY +} + +object ResultQueryStageExec { + private[execution] val executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("ResultQueryStageExecution", + SQLConf.get.getConf(StaticSQLConf.RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD))) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index f94d7dc7ab4c0..ced4b6224c884 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -106,7 +106,7 @@ object SparkPlanGraph { buildSparkPlanGraphNode( planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges) } - case "TableCacheQueryStage" => + case "TableCacheQueryStage" | "ResultQueryStage" => buildSparkPlanGraphNode( planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges) case "Subquery" if subgraph != null => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index ff455d28e8680..e0ad3feda3ac4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -1659,7 +1659,9 @@ class CachedTableSuite extends QueryTest with SQLTestUtils _.nodeName.contains("TableCacheQueryStage")) val aqeNode = findNodeInSparkPlanInfo(inMemoryScanNode.get, _.nodeName.contains("AdaptiveSparkPlan")) - aqeNode.get.children.head.nodeName == "AQEShuffleRead" + val aqePlanRoot = findNodeInSparkPlanInfo(inMemoryScanNode.get, + _.nodeName.contains("ResultQueryStage")) + aqePlanRoot.get.children.head.nodeName == "AQEShuffleRead" } withTempView("t0", "t1", "t2") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala index 22fdd96ce6bad..9c90e0105a424 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -557,31 +557,33 @@ class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuit val testDf = df1.join(df2, "k").groupBy("k").agg(count("v1"), sum("v1"), avg("v2")) // trigger the final plan for AQE testDf.collect() - // AdaptiveSparkPlan (21) + + // AdaptiveSparkPlan (22) // +- == Final Plan == - // * HashAggregate (12) - // +- AQEShuffleRead (11) - // +- ShuffleQueryStage (10) - // +- Exchange (9) - // +- * HashAggregate (8) - // +- * Project (7) - // +- * BroadcastHashJoin Inner BuildRight (6) - // :- * LocalTableScan (1) - // +- BroadcastQueryStage (5) - // +- BroadcastExchange (4) - // +- * Project (3) - // +- * LocalTableScan (2) + // ResultQueryStage (13) + // +- * HashAggregate (12) + // +- AQEShuffleRead (11) + // +- ShuffleQueryStage (10) + // +- Exchange (9) + // +- * HashAggregate (8) + // +- * Project (7) + // +- * BroadcastHashJoin Inner BuildRight (6) + // :- * LocalTableScan (1) + // +- BroadcastQueryStage (5) + // +- BroadcastExchange (4) + // +- * Project (3) + // +- * LocalTableScan (2) // +- == Initial Plan == - // HashAggregate (20) - // +- Exchange (19) - // +- HashAggregate (18) - // +- Project (17) - // +- BroadcastHashJoin Inner BuildRight (16) - // :- Project (14) - // : +- LocalTableScan (13) - // +- BroadcastExchange (15) - // +- Project (3) - // +- LocalTableScan (2) + // HashAggregate (21) + // +- Exchange (20) + // +- HashAggregate (19) + // +- Project (18) + // +- BroadcastHashJoin Inner BuildRight (17) + // :- Project (15) + // : +- LocalTableScan (14) + // +- BroadcastExchange (16) + // +- Project (3) + // +- LocalTableScan (2) checkKeywordsExistsInExplain( testDf, FormattedMode, @@ -599,18 +601,18 @@ class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuit |Arguments: coalesced |""".stripMargin, """ - |(16) BroadcastHashJoin + |(17) BroadcastHashJoin |Left keys [1]: [k#x] |Right keys [1]: [k#x] |Join type: Inner |Join condition: None |""".stripMargin, """ - |(19) Exchange + |(20) Exchange |Input [5]: [k#x, count#xL, sum#xL, sum#x, count#xL] |""".stripMargin, """ - |(21) AdaptiveSparkPlan + |(22) AdaptiveSparkPlan |Output [4]: [k#x, count(v1)#xL, sum(v1)#xL, avg(v2)#x] |Arguments: isFinalPlan=true |""".stripMargin @@ -656,7 +658,7 @@ class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuit |Output [1]: [id#xL] |Arguments: 0""".stripMargin, """ - |(12) AdaptiveSparkPlan + |(13) AdaptiveSparkPlan |Output [2]: [key#xL, value#xL] |Arguments: isFinalPlan=true |""".stripMargin, @@ -664,11 +666,11 @@ class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuit |Subquery:1 Hosting operator id = 2 Hosting Expression = Subquery subquery#x, [id=#x] |""".stripMargin, """ - |(16) ShuffleQueryStage + |(17) ShuffleQueryStage |Output [1]: [max#xL] |Arguments: 0""".stripMargin, """ - |(20) AdaptiveSparkPlan + |(22) AdaptiveSparkPlan |Output [1]: [max(id)#xL] |Arguments: isFinalPlan=true |""".stripMargin diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 920a0872ee4f9..44a0bf79fae9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -244,7 +244,7 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt columnStats: ListBuffer[AttributeMap[ColumnStat]]): Unit = { plan match { case a: AdaptiveSparkPlanExec => - findColumnStats(a.executedPlan, columnStats) + findColumnStats(stripAQEPlan(a), columnStats) case qs: ShuffleQueryStageExec => columnStats += qs.computeStats().get.attributeStats findColumnStats(qs.plan, columnStats) @@ -489,7 +489,7 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt test("SPARK-38697: Extend SparkSessionExtensions to inject rules into AQE Optimizer") { def executedPlan(df: Dataset[java.lang.Long]): SparkPlan = { assert(df.queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec]) - df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan + stripAQEPlan(df.queryExecution.executedPlan) } val extensions = create { extensions => extensions.injectRuntimeOptimizerRule(_ => AddLimit) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala index 9ed4f1a006b2b..da43b0cfc58be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala @@ -29,7 +29,8 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.ArrayImplicits._ -class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper { +class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper + with AdaptiveSparkPlanHelper{ private var originalActiveSparkSession: Option[SparkSession] = _ private var originalInstantiatedSparkSession: Option[SparkSession] = _ @@ -108,8 +109,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. - val finalPlan = agg.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlanExec].executedPlan + val finalPlan = stripAQEPlan(agg.queryExecution.executedPlan) val shuffleReads = finalPlan.collect { case r @ CoalescedShuffleRead() => r } @@ -154,8 +154,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. - val finalPlan = join.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlanExec].executedPlan + val finalPlan = stripAQEPlan(join.queryExecution.executedPlan) val shuffleReads = finalPlan.collect { case r @ CoalescedShuffleRead() => r } @@ -205,8 +204,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. - val finalPlan = join.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlanExec].executedPlan + val finalPlan = stripAQEPlan(join.queryExecution.executedPlan) val shuffleReads = finalPlan.collect { case r @ CoalescedShuffleRead() => r } @@ -256,8 +254,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. - val finalPlan = join.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlanExec].executedPlan + val finalPlan = stripAQEPlan(join.queryExecution.executedPlan) val shuffleReads = finalPlan.collect { case r @ CoalescedShuffleRead() => r } @@ -298,8 +295,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper { expectedAnswer.collect().toImmutableArraySeq) // Then, let's make sure we do not reduce number of post shuffle partitions. - val finalPlan = join.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlanExec].executedPlan + val finalPlan = stripAQEPlan(join.queryExecution.executedPlan) val shuffleReads = finalPlan.collect { case r @ CoalescedShuffleRead() => r } @@ -385,8 +381,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper { // ReusedQueryStage 0 val resultDf = df.join(df, "key").join(df, "key") QueryTest.checkAnswer(resultDf, (0 to 5).map(i => Row(i, i, i, i))) - val finalPlan = resultDf.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlanExec].executedPlan + val finalPlan = stripAQEPlan(resultDf.queryExecution.executedPlan) assert(finalPlan.collect { case ShuffleQueryStageExec(_, r: ReusedExchangeExec, _) => r }.length == 2) @@ -409,8 +404,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper { Row(4, 2) :: Row(5, 2) :: Row(5, 3) :: Row(6, 3) :: Row(6, 4) :: Row(7, 4) :: Row(7, 5) :: Row(8, 5) :: Nil) - val finalPlan2 = resultDf2.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlanExec].executedPlan + val finalPlan2 = stripAQEPlan(resultDf2.queryExecution.executedPlan) // The result stage has 2 children val level1Stages = finalPlan2.collect { case q: QueryStageExec => q } @@ -453,8 +447,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper { QueryTest.checkAnswer(resultDf, Seq(0, 1, 2).map(i => Row(i))) - val finalPlan = resultDf.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlanExec].executedPlan + val finalPlan = stripAQEPlan(resultDf.queryExecution.executedPlan) assert( finalPlan.collect { case r @ CoalescedShuffleRead() => r @@ -474,8 +467,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper { // Shuffle partition coalescing of the join is performed independent of the non-grouping // aggregate on the other side of the union. - val finalPlan = resultDf.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlanExec].executedPlan + val finalPlan = stripAQEPlan(resultDf.queryExecution.executedPlan) assert( finalPlan.collect { case r @ CoalescedShuffleRead() => r @@ -490,8 +482,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper { val resultDf = ds.repartition(ds.col("id")) resultDf.collect() - val finalPlan = resultDf.queryExecution.executedPlan - .asInstanceOf[AdaptiveSparkPlanExec].executedPlan + val finalPlan = stripAQEPlan(resultDf.queryExecution.executedPlan) assert( finalPlan.collect { case r @ CoalescedShuffleRead() => r diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index c461f41c9104c..4aa5716ab4836 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -521,7 +521,7 @@ class ExtendedInfo extends ExtendedExplainGenerator { def getActualPlan(plan: SparkPlan): SparkPlan = { plan match { - case p : AdaptiveSparkPlanExec => p.executedPlan + case p : AdaptiveSparkPlanExec => getActualPlan(p.executedPlan) case p : QueryStageExec => p.plan case p : WholeStageCodegenExec => p.child case p => p diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 8ddbd9af9d534..272be70f9fe5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -99,13 +99,10 @@ class AdaptiveQueryExecSuite } val planAfter = dfAdaptive.queryExecution.executedPlan assert(planAfter.toString.startsWith("AdaptiveSparkPlan isFinalPlan=true")) - val adaptivePlan = planAfter.asInstanceOf[AdaptiveSparkPlanExec].executedPlan + val adaptivePlan = stripAQEPlan(planAfter) spark.sparkContext.listenerBus.waitUntilEmpty() - // AQE will post `SparkListenerSQLAdaptiveExecutionUpdate` twice in case of subqueries that - // exist out of query stages. - val expectedFinalPlanCnt = adaptivePlan.find(_.subqueries.nonEmpty).map(_ => 2).getOrElse(1) - assert(finalPlanCnt == expectedFinalPlanCnt) + assert(finalPlanCnt == 1) spark.sparkContext.removeSparkListener(listener) val expectedMetrics = findInMemoryTable(planAfter).nonEmpty || @@ -1246,7 +1243,9 @@ class AdaptiveQueryExecSuite if (enabled) { assert(planInfo.nodeName == "AdaptiveSparkPlan") assert(planInfo.children.size == 1) - assert(planInfo.children.head.nodeName == + assert(planInfo.children.head.nodeName == "ResultQueryStage") + assert(planInfo.children.head.children.size == 1) + assert(planInfo.children.head.children.head.nodeName == "Execute InsertIntoHadoopFsRelationCommand") } else { assert(planInfo.nodeName == "Execute InsertIntoHadoopFsRelationCommand") @@ -3048,6 +3047,34 @@ class AdaptiveQueryExecSuite checkAnswer(unionDF.select("id").distinct(), Seq(Row(null))) } + test("Collect twice on the same dataframe with no AQE plan changes") { + val df = spark.sql("SELECT * FROM testData join testData2 ON key = a") + df.collect() + val plan1 = df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan + df.collect() + val plan2 = df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan + assert(plan1.isInstanceOf[ResultQueryStageExec]) + assert(plan2.isInstanceOf[ResultQueryStageExec]) + assert(plan1 ne plan2) + assert(plan1.asInstanceOf[ResultQueryStageExec].plan + .fastEquals(plan2.asInstanceOf[ResultQueryStageExec].plan)) + } + + test("Two different collect actions on same dataframe") { + val df = spark.sql("SELECT * FROM testData join testData2 ON key = a") + val adaptivePlan = df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec] + val res1 = adaptivePlan.execute().collect() + val plan1 = adaptivePlan.executedPlan + val res2 = adaptivePlan.executeTake(1) + val plan2 = adaptivePlan.executedPlan + assert (res1.length != res2.length) + assert(plan1.isInstanceOf[ResultQueryStageExec]) + assert(plan2.isInstanceOf[ResultQueryStageExec]) + assert(plan1 ne plan2) + assert(plan1.asInstanceOf[ResultQueryStageExec].plan + .fastEquals(plan2.asInstanceOf[ResultQueryStageExec].plan)) + } + test("SPARK-47247: coalesce differently for BNLJ") { Seq(true, false).foreach { expectCoalesce => val minPartitionSize = if (expectCoalesce) "64MB" else "1B" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala index 04a7b4834f4b8..80d771428d909 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala @@ -22,13 +22,14 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sort} import org.apache.spark.sql.execution.{QueryExecution, SortExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} import org.apache.spark.sql.types.{IntegerType, StringType} import org.apache.spark.sql.util.QueryExecutionListener import org.apache.spark.tags.SlowSQLTest -trait V1WriteCommandSuiteBase extends SQLTestUtils { +trait V1WriteCommandSuiteBase extends SQLTestUtils with AdaptiveSparkPlanHelper { import testImplicits._ @@ -218,7 +219,7 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write executedPlan.asInstanceOf[WriteFilesExecBase].child } else { executedPlan.transformDown { - case a: AdaptiveSparkPlanExec => a.executedPlan + case a: AdaptiveSparkPlanExec => stripAQEPlan(a) } } @@ -265,7 +266,7 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write executedPlan.asInstanceOf[WriteFilesExecBase].child } else { executedPlan.transformDown { - case a: AdaptiveSparkPlanExec => a.executedPlan + case a: AdaptiveSparkPlanExec => stripAQEPlan(a) } }