@@ -268,9 +268,11 @@ case class AdaptiveSparkPlanExec(
268268
269269 def finalPhysicalPlan : SparkPlan = withFinalPlanUpdate(identity)
270270
271- private def getFinalPhysicalPlan (): SparkPlan = lock.synchronized {
272- if (isFinalPlan) return currentPhysicalPlan
273-
271+ /**
272+ * Run `fun` on finalized physical plan
273+ */
274+ def withFinalPlanUpdate [T ](fun : SparkPlan => T ): T = lock.synchronized {
275+ _isFinalPlan = false
274276 // In case of this adaptive plan being executed out of `withActive` scoped functions, e.g.,
275277 // `plan.queryExecution.rdd`, we need to set active session here as new plan nodes can be
276278 // created in the middle of the execution.
@@ -279,7 +281,7 @@ case class AdaptiveSparkPlanExec(
279281 // Use inputPlan logicalLink here in case some top level physical nodes may be removed
280282 // during `initialPlan`
281283 var currentLogicalPlan = inputPlan.logicalLink.get
282- var result = createQueryStages(currentPhysicalPlan)
284+ var result = createQueryStages(fun, currentPhysicalPlan, firstRun = true )
283285 val events = new LinkedBlockingQueue [StageMaterializationEvent ]()
284286 val errors = new mutable.ArrayBuffer [Throwable ]()
285287 var stagesToReplace = Seq .empty[QueryStageExec ]
@@ -344,56 +346,53 @@ case class AdaptiveSparkPlanExec(
344346 if (errors.nonEmpty) {
345347 cleanUpAndThrowException(errors.toSeq, None )
346348 }
347-
348- // Try re-optimizing and re-planning. Adopt the new plan if its cost is equal to or less
349- // than that of the current plan; otherwise keep the current physical plan together with
350- // the current logical plan since the physical plan's logical links point to the logical
351- // plan it has originated from.
352- // Meanwhile, we keep a list of the query stages that have been created since last plan
353- // update, which stands for the "semantic gap" between the current logical and physical
354- // plans. And each time before re-planning, we replace the corresponding nodes in the
355- // current logical plan with logical query stages to make it semantically in sync with
356- // the current physical plan. Once a new plan is adopted and both logical and physical
357- // plans are updated, we can clear the query stage list because at this point the two plans
358- // are semantically and physically in sync again.
359- val logicalPlan = replaceWithQueryStagesInLogicalPlan(currentLogicalPlan, stagesToReplace)
360- val afterReOptimize = reOptimize(logicalPlan)
361- if (afterReOptimize.isDefined) {
362- val (newPhysicalPlan, newLogicalPlan) = afterReOptimize.get
363- val origCost = costEvaluator.evaluateCost(currentPhysicalPlan)
364- val newCost = costEvaluator.evaluateCost(newPhysicalPlan)
365- if (newCost < origCost ||
366- (newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) {
367- lazy val plans =
368- sideBySide(currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString(" \n " )
369- logOnLevel(log " Plan changed: \n ${MDC (QUERY_PLAN , plans)}" )
370- cleanUpTempTags(newPhysicalPlan)
371- currentPhysicalPlan = newPhysicalPlan
372- currentLogicalPlan = newLogicalPlan
373- stagesToReplace = Seq .empty[QueryStageExec ]
349+ if (! currentPhysicalPlan.isInstanceOf [ResultQueryStageExec ]) {
350+ // Try re-optimizing and re-planning. Adopt the new plan if its cost is equal to or less
351+ // than that of the current plan; otherwise keep the current physical plan together with
352+ // the current logical plan since the physical plan's logical links point to the logical
353+ // plan it has originated from.
354+ // Meanwhile, we keep a list of the query stages that have been created since last plan
355+ // update, which stands for the "semantic gap" between the current logical and physical
356+ // plans. And each time before re-planning, we replace the corresponding nodes in the
357+ // current logical plan with logical query stages to make it semantically in sync with
358+ // the current physical plan. Once a new plan is adopted and both logical and physical
359+ // plans are updated, we can clear the query stage list because at this point the two
360+ // plans are semantically and physically in sync again.
361+ val logicalPlan = replaceWithQueryStagesInLogicalPlan(currentLogicalPlan, stagesToReplace)
362+ val afterReOptimize = reOptimize(logicalPlan)
363+ if (afterReOptimize.isDefined) {
364+ val (newPhysicalPlan, newLogicalPlan) = afterReOptimize.get
365+ val origCost = costEvaluator.evaluateCost(currentPhysicalPlan)
366+ val newCost = costEvaluator.evaluateCost(newPhysicalPlan)
367+ if (newCost < origCost ||
368+ (newCost == origCost && currentPhysicalPlan != newPhysicalPlan)) {
369+ lazy val plans = sideBySide(
370+ currentPhysicalPlan.treeString, newPhysicalPlan.treeString).mkString(" \n " )
371+ logOnLevel(log " Plan changed: \n ${MDC (QUERY_PLAN , plans)}" )
372+ cleanUpTempTags(newPhysicalPlan)
373+ currentPhysicalPlan = newPhysicalPlan
374+ currentLogicalPlan = newLogicalPlan
375+ stagesToReplace = Seq .empty[QueryStageExec ]
376+ }
374377 }
375378 }
376379 // Now that some stages have finished, we can try creating new stages.
377- result = createQueryStages(currentPhysicalPlan)
380+ result = createQueryStages(fun, currentPhysicalPlan, firstRun = false )
378381 }
379-
380- // Run the final plan when there's no more unfinished stages.
381- currentPhysicalPlan = applyPhysicalRules(
382- optimizeQueryStage(result.newPlan, isFinalStage = true ),
383- postStageCreationRules(supportsColumnar),
384- Some ((planChangeLogger, " AQE Post Stage Creation" )))
385- _isFinalPlan = true
386- executionId.foreach(onUpdatePlan(_, Seq (currentPhysicalPlan)))
387- currentPhysicalPlan
388382 }
383+ _isFinalPlan = true
384+ finalPlanUpdate
385+ // Dereference the result so it can be GCed. After this resultStage.isMaterialized will return
386+ // false, which is expected. If we want to collect result again, we should invoke
387+ // `withFinalPlanUpdate` and pass another result handler and we will create a new result stage.
388+ currentPhysicalPlan.asInstanceOf [ResultQueryStageExec ].resultOption.getAndUpdate(_ => None )
389+ .get.asInstanceOf [T ]
389390 }
390391
391392 // Use a lazy val to avoid this being called more than once.
392393 @ transient private lazy val finalPlanUpdate : Unit = {
393- // Subqueries that don't belong to any query stage of the main query will execute after the
394- // last UI update in `getFinalPhysicalPlan`, so we need to update UI here again to make sure
395- // the newly generated nodes of those subqueries are updated.
396- if (shouldUpdatePlan && currentPhysicalPlan.exists(_.subqueries.nonEmpty)) {
394+ // Do final plan update after result stage has materialized.
395+ if (shouldUpdatePlan) {
397396 getExecutionId.foreach(onUpdatePlan(_, Seq .empty))
398397 }
399398 logOnLevel(log " Final plan: \n ${MDC (QUERY_PLAN , currentPhysicalPlan)}" )
@@ -426,13 +425,6 @@ case class AdaptiveSparkPlanExec(
426425 }
427426 }
428427
429- private def withFinalPlanUpdate [T ](fun : SparkPlan => T ): T = {
430- val plan = getFinalPhysicalPlan()
431- val result = fun(plan)
432- finalPlanUpdate
433- result
434- }
435-
436428 protected override def stringArgs : Iterator [Any ] = Iterator (s " isFinalPlan= $isFinalPlan" )
437429
438430 override def generateTreeString (
@@ -521,6 +513,66 @@ case class AdaptiveSparkPlanExec(
521513 this .inputPlan == obj.asInstanceOf [AdaptiveSparkPlanExec ].inputPlan
522514 }
523515
516+ /**
517+ * We separate stage creation of result and non-result stages because there are several edge cases
518+ * of result stage creation:
519+ * - existing ResultQueryStage created in previous `withFinalPlanUpdate`.
520+ * - the root node is a non-result query stage and we have to create query result stage on top of
521+ * it.
522+ * - we create a non-result query stage as root node and the stage is immediately materialized
523+ * due to stage resue, therefore we have to create a result stage right after.
524+ *
525+ * This method wraps around `createNonResultQueryStages`, the general logic is:
526+ * - Early return if ResultQueryStageExec already created before.
527+ * - Create non result query stage if possible.
528+ * - Try to create result query stage when there is no new non-result query stage created and all
529+ * stages are materialized.
530+ */
531+ private def createQueryStages (
532+ resultHandler : SparkPlan => Any ,
533+ plan : SparkPlan ,
534+ firstRun : Boolean ): CreateStageResult = {
535+ plan match {
536+ // 1. ResultQueryStageExec is already created, no need to create non-result stages
537+ case resultStage @ ResultQueryStageExec (_, optimizedPlan, _) =>
538+ assertStageNotFailed(resultStage)
539+ if (firstRun) {
540+ // There is already an existing ResultQueryStage created in previous `withFinalPlanUpdate`
541+ // e.g, when we do `df.collect` multiple times. Here we create a new result stage to
542+ // execute it again, as the handler function can be different.
543+ val newResultStage = ResultQueryStageExec (currentStageId, optimizedPlan, resultHandler)
544+ currentStageId += 1
545+ setLogicalLinkForNewQueryStage(newResultStage, optimizedPlan)
546+ CreateStageResult (newPlan = newResultStage,
547+ allChildStagesMaterialized = false ,
548+ newStages = Seq (newResultStage))
549+ } else {
550+ // We will hit this branch after we've created result query stage in the AQE loop, we
551+ // should do nothing.
552+ CreateStageResult (newPlan = resultStage,
553+ allChildStagesMaterialized = resultStage.isMaterialized,
554+ newStages = Seq .empty)
555+ }
556+ case _ =>
557+ // 2. Create non result query stage
558+ val result = createNonResultQueryStages(plan)
559+ var allNewStages = result.newStages
560+ var newPlan = result.newPlan
561+ var allChildStagesMaterialized = result.allChildStagesMaterialized
562+ // 3. Create result stage
563+ if (allNewStages.isEmpty && allChildStagesMaterialized) {
564+ val resultStage = newResultQueryStage(resultHandler, newPlan)
565+ newPlan = resultStage
566+ allChildStagesMaterialized = false
567+ allNewStages :+= resultStage
568+ }
569+ CreateStageResult (
570+ newPlan = newPlan,
571+ allChildStagesMaterialized = allChildStagesMaterialized,
572+ newStages = allNewStages)
573+ }
574+ }
575+
524576 /**
525577 * This method is called recursively to traverse the plan tree bottom-up and create a new query
526578 * stage or try reusing an existing stage if the current node is an [[Exchange ]] node and all of
@@ -531,7 +583,7 @@ case class AdaptiveSparkPlanExec(
531583 * 2) Whether the child query stages (if any) of the current node have all been materialized.
532584 * 3) A list of the new query stages that have been created.
533585 */
534- private def createQueryStages (plan : SparkPlan ): CreateStageResult = plan match {
586+ private def createNonResultQueryStages (plan : SparkPlan ): CreateStageResult = plan match {
535587 case e : Exchange =>
536588 // First have a quick check in the `stageCache` without having to traverse down the node.
537589 context.stageCache.get(e.canonicalized) match {
@@ -544,7 +596,7 @@ case class AdaptiveSparkPlanExec(
544596 newStages = if (isMaterialized) Seq .empty else Seq (stage))
545597
546598 case _ =>
547- val result = createQueryStages (e.child)
599+ val result = createNonResultQueryStages (e.child)
548600 val newPlan = e.withNewChildren(Seq (result.newPlan)).asInstanceOf [Exchange ]
549601 // Create a query stage only when all the child query stages are ready.
550602 if (result.allChildStagesMaterialized) {
@@ -588,14 +640,28 @@ case class AdaptiveSparkPlanExec(
588640 if (plan.children.isEmpty) {
589641 CreateStageResult (newPlan = plan, allChildStagesMaterialized = true , newStages = Seq .empty)
590642 } else {
591- val results = plan.children.map(createQueryStages )
643+ val results = plan.children.map(createNonResultQueryStages )
592644 CreateStageResult (
593645 newPlan = plan.withNewChildren(results.map(_.newPlan)),
594646 allChildStagesMaterialized = results.forall(_.allChildStagesMaterialized),
595647 newStages = results.flatMap(_.newStages))
596648 }
597649 }
598650
651+ private def newResultQueryStage (
652+ resultHandler : SparkPlan => Any ,
653+ plan : SparkPlan ): ResultQueryStageExec = {
654+ // Run the final plan when there's no more unfinished stages.
655+ val optimizedRootPlan = applyPhysicalRules(
656+ optimizeQueryStage(plan, isFinalStage = true ),
657+ postStageCreationRules(supportsColumnar),
658+ Some ((planChangeLogger, " AQE Post Stage Creation" )))
659+ val resultStage = ResultQueryStageExec (currentStageId, optimizedRootPlan, resultHandler)
660+ currentStageId += 1
661+ setLogicalLinkForNewQueryStage(resultStage, plan)
662+ resultStage
663+ }
664+
599665 private def newQueryStage (plan : SparkPlan ): QueryStageExec = {
600666 val queryStage = plan match {
601667 case e : Exchange =>
0 commit comments