@@ -101,9 +101,15 @@ class DAGSchedulerSuite
101101 /** Length of time to wait while draining listener events. */
102102 val WAIT_TIMEOUT_MILLIS = 10000
103103 val sparkListener = new SparkListener () {
104+ val submittedStageInfos = new HashSet [StageInfo ]
104105 val successfulStages = new HashSet [Int ]
105106 val failedStages = new ArrayBuffer [Int ]
106107 val stageByOrderOfExecution = new ArrayBuffer [Int ]
108+
109+ override def onStageSubmitted (stageSubmitted : SparkListenerStageSubmitted ) {
110+ submittedStageInfos += stageSubmitted.stageInfo
111+ }
112+
107113 override def onStageCompleted (stageCompleted : SparkListenerStageCompleted ) {
108114 val stageInfo = stageCompleted.stageInfo
109115 stageByOrderOfExecution += stageInfo.stageId
@@ -150,6 +156,7 @@ class DAGSchedulerSuite
150156 // Enable local execution for this test
151157 val conf = new SparkConf ().set(" spark.localExecution.enabled" , " true" )
152158 sc = new SparkContext (" local" , " DAGSchedulerSuite" , conf)
159+ sparkListener.submittedStageInfos.clear()
153160 sparkListener.successfulStages.clear()
154161 sparkListener.failedStages.clear()
155162 failure = null
@@ -547,6 +554,133 @@ class DAGSchedulerSuite
547554 assert(sparkListener.failedStages.size == 1 )
548555 }
549556
557+ /** This tests the case where another FetchFailed comes in while the map stage is getting
558+ * re-run. */
559+ test(" late fetch failures don't cause multiple concurrent attempts for the same map stage" ) {
560+ val shuffleMapRdd = new MyRDD (sc, 2 , Nil )
561+ val shuffleDep = new ShuffleDependency (shuffleMapRdd, null )
562+ val shuffleId = shuffleDep.shuffleId
563+ val reduceRdd = new MyRDD (sc, 2 , List (shuffleDep))
564+ submit(reduceRdd, Array (0 , 1 ))
565+
566+ val mapStageId = 0
567+ def countSubmittedMapStageAttempts (): Int = {
568+ sparkListener.submittedStageInfos.count(_.stageId == mapStageId)
569+ }
570+
571+ // The map stage should have been submitted.
572+ assert(countSubmittedMapStageAttempts() === 1 )
573+
574+ complete(taskSets(0 ), Seq (
575+ (Success , makeMapStatus(" hostA" , 1 )),
576+ (Success , makeMapStatus(" hostB" , 1 ))))
577+ // The MapOutputTracker should know about both map output locations.
578+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0 ).map(_._1.host) ===
579+ Array (" hostA" , " hostB" ))
580+
581+ // The first result task fails, with a fetch failure for the output from the first mapper.
582+ runEvent(CompletionEvent (
583+ taskSets(1 ).tasks(0 ),
584+ FetchFailed (makeBlockManagerId(" hostA" ), shuffleId, 0 , 0 , " ignored" ),
585+ null ,
586+ Map [Long , Any ](),
587+ createFakeTaskInfo(),
588+ null ))
589+ assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS ))
590+ assert(sparkListener.failedStages.contains(1 ))
591+
592+ // Trigger resubmission of the failed map stage.
593+ runEvent(ResubmitFailedStages )
594+ assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS ))
595+
596+ // Another attempt for the map stage should have been submitted, resulting in 2 total attempts.
597+ assert(countSubmittedMapStageAttempts() === 2 )
598+
599+ // The second ResultTask fails, with a fetch failure for the output from the second mapper.
600+ runEvent(CompletionEvent (
601+ taskSets(1 ).tasks(1 ),
602+ FetchFailed (makeBlockManagerId(" hostB" ), shuffleId, 1 , 1 , " ignored" ),
603+ null ,
604+ Map [Long , Any ](),
605+ createFakeTaskInfo(),
606+ null ))
607+
608+ // Another ResubmitFailedStages event should not result result in another attempt for the map
609+ // stage being run concurrently.
610+ runEvent(ResubmitFailedStages )
611+ assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS ))
612+ assert(countSubmittedMapStageAttempts() === 2 )
613+
614+ // NOTE: the actual ResubmitFailedStages may get called at any time during this, shouldn't effect anything --
615+ // our calling it just makes *SURE* it gets called between the desired event and our check.
616+
617+ }
618+
619+ /** This tests the case where a late FetchFailed comes in after the map stage has finished getting
620+ * retried and a new reduce stage starts running.
621+ */
622+ test(" extremely late fetch failures don't cause multiple concurrent attempts for the same stage" ) {
623+ val shuffleMapRdd = new MyRDD (sc, 2 , Nil )
624+ val shuffleDep = new ShuffleDependency (shuffleMapRdd, null )
625+ val shuffleId = shuffleDep.shuffleId
626+ val reduceRdd = new MyRDD (sc, 2 , List (shuffleDep))
627+ submit(reduceRdd, Array (0 , 1 ))
628+
629+ def countSubmittedReduceStageAttempts (): Int = {
630+ sparkListener.submittedStageInfos.count(_.stageId == 1 )
631+ }
632+ def countSubmittedMapStageAttempts (): Int = {
633+ sparkListener.submittedStageInfos.count(_.stageId == 0 )
634+ }
635+
636+ // The map stage should have been submitted.
637+ assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS ))
638+ assert(countSubmittedMapStageAttempts() === 1 )
639+
640+ // Complete the map stage.
641+ complete(taskSets(0 ), Seq (
642+ (Success , makeMapStatus(" hostA" , 1 )),
643+ (Success , makeMapStatus(" hostB" , 1 ))))
644+
645+ // The reduce stage should have been submitted.
646+ assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS ))
647+ assert(countSubmittedReduceStageAttempts() === 1 )
648+
649+ // The first result task fails, with a fetch failure for the output from the first mapper.
650+ runEvent(CompletionEvent (
651+ taskSets(1 ).tasks(0 ),
652+ FetchFailed (makeBlockManagerId(" hostA" ), shuffleId, 0 , 0 , " ignored" ),
653+ null ,
654+ Map [Long , Any ](),
655+ createFakeTaskInfo(),
656+ null ))
657+
658+ // Trigger resubmission of the failed map stage and finish the re-started map task.
659+ runEvent(ResubmitFailedStages )
660+ complete(taskSets(2 ), Seq ((Success , makeMapStatus(" hostA" , 1 ))))
661+
662+ // Because the map stage finished, another attempt for the reduce stage should have been
663+ // submitted, resulting in 2 total attempts for each the map and the reduce stage.
664+ assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS ))
665+ assert(countSubmittedMapStageAttempts() === 2 )
666+ assert(countSubmittedReduceStageAttempts() === 2 )
667+
668+ // A late FetchFailed arrives from the second task in the original reduce stage.
669+ runEvent(CompletionEvent (
670+ taskSets(1 ).tasks(1 ),
671+ FetchFailed (makeBlockManagerId(" hostB" ), shuffleId, 1 , 1 , " ignored" ),
672+ null ,
673+ Map [Long , Any ](),
674+ createFakeTaskInfo(),
675+ null ))
676+
677+ // Trigger resubmission of the failed map stage and finish the re-started map task.
678+ runEvent(ResubmitFailedStages )
679+
680+ // The FetchFailed from the original reduce stage should be ignored.
681+ assert(countSubmittedMapStageAttempts() === 2 )
682+ }
683+
550684 test(" ignore late map task completions" ) {
551685 val shuffleMapRdd = new MyRDD (sc, 2 , Nil )
552686 val shuffleDep = new ShuffleDependency (shuffleMapRdd, null )
0 commit comments