@@ -174,31 +174,72 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
174174 /** Length of time to wait while draining listener events. */
175175 val WAIT_TIMEOUT_MILLIS = 10000
176176
177- val submittedStageInfos = new HashSet [StageInfo ]
178- val successfulStages = new HashSet [Int ]
179- val failedStages = new ArrayBuffer [Int ]
180- val stageByOrderOfExecution = new ArrayBuffer [Int ]
181- val endedTasks = new HashSet [Long ]
182- val sparkListener = new SparkListener () {
177+ /**
178+ * Listeners which records some information to verify in UTs. Getter-kind methods in this class
179+ * ensures the value is returned after ensuring there's no event to process, as well as the
180+ * value is immutable: prevent showing odd result by race condition.
181+ */
182+ class EventInfoRecordingListener extends SparkListener {
183+ private val _submittedStageInfos = new HashSet [StageInfo ]
184+ private val _successfulStages = new HashSet [Int ]
185+ private val _failedStages = new ArrayBuffer [Int ]
186+ private val _stageByOrderOfExecution = new ArrayBuffer [Int ]
187+ private val _endedTasks = new HashSet [Long ]
188+
183189 override def onStageSubmitted (stageSubmitted : SparkListenerStageSubmitted ) {
184- submittedStageInfos += stageSubmitted.stageInfo
190+ _submittedStageInfos += stageSubmitted.stageInfo
185191 }
186192
187193 override def onStageCompleted (stageCompleted : SparkListenerStageCompleted ) {
188194 val stageInfo = stageCompleted.stageInfo
189- stageByOrderOfExecution += stageInfo.stageId
195+ _stageByOrderOfExecution += stageInfo.stageId
190196 if (stageInfo.failureReason.isEmpty) {
191- successfulStages += stageInfo.stageId
197+ _successfulStages += stageInfo.stageId
192198 } else {
193- failedStages += stageInfo.stageId
199+ _failedStages += stageInfo.stageId
194200 }
195201 }
196202
197203 override def onTaskEnd (taskEnd : SparkListenerTaskEnd ): Unit = {
198- endedTasks += taskEnd.taskInfo.taskId
204+ _endedTasks += taskEnd.taskInfo.taskId
205+ }
206+
207+ def submittedStageInfos : Set [StageInfo ] = withWaitingListenerUntilEmpty {
208+ _submittedStageInfos.toSet
209+ }
210+
211+ def successfulStages : Set [Int ] = withWaitingListenerUntilEmpty {
212+ _successfulStages.toSet
213+ }
214+
215+ def failedStages : List [Int ] = withWaitingListenerUntilEmpty {
216+ _failedStages.toList
217+ }
218+
219+ def stageByOrderOfExecution : List [Int ] = withWaitingListenerUntilEmpty {
220+ _stageByOrderOfExecution.toList
221+ }
222+
223+ def endedTask : Set [Long ] = withWaitingListenerUntilEmpty {
224+ _endedTasks.toSet
225+ }
226+
227+ def clear (): Unit = {
228+ _submittedStageInfos.clear()
229+ _successfulStages.clear()
230+ _failedStages.clear()
231+ _stageByOrderOfExecution.clear()
232+ _endedTasks.clear()
233+ }
234+
235+ private def withWaitingListenerUntilEmpty [T ](fn : => T ): T = {
236+ sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
237+ fn
199238 }
200239 }
201240
241+ val sparkListener = new EventInfoRecordingListener ()
242+
202243 var mapOutputTracker : MapOutputTrackerMaster = null
203244 var broadcastManager : BroadcastManager = null
204245 var securityMgr : SecurityManager = null
@@ -247,10 +288,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
247288
248289 private def init (testConf : SparkConf ): Unit = {
249290 sc = new SparkContext (" local[2]" , " DAGSchedulerSuite" , testConf)
250- submittedStageInfos.clear()
251- successfulStages.clear()
252- failedStages.clear()
253- endedTasks.clear()
291+ sparkListener.clear()
254292 failure = null
255293 sc.addSparkListener(sparkListener)
256294 taskSets.clear()
@@ -373,9 +411,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
373411 }
374412
375413 test(" [SPARK-3353] parent stage should have lower stage id" ) {
376- stageByOrderOfExecution.clear()
377414 sc.parallelize(1 to 10 ).map(x => (x, x)).reduceByKey(_ + _, 4 ).count()
378- sc.listenerBus.waitUntilEmpty( WAIT_TIMEOUT_MILLIS )
415+ val stageByOrderOfExecution = sparkListener.stageByOrderOfExecution
379416 assert(stageByOrderOfExecution.length === 2 )
380417 assert(stageByOrderOfExecution(0 ) < stageByOrderOfExecution(1 ))
381418 }
@@ -618,19 +655,15 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
618655 submit(unserializableRdd, Array (0 ))
619656 assert(failure.getMessage.startsWith(
620657 " Job aborted due to stage failure: Task not serializable:" ))
621- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
622- assert(failedStages.contains(0 ))
623- assert(failedStages.size === 1 )
658+ assert(sparkListener.failedStages === Seq (0 ))
624659 assertDataStructuresEmpty()
625660 }
626661
627662 test(" trivial job failure" ) {
628663 submit(new MyRDD (sc, 1 , Nil ), Array (0 ))
629664 failed(taskSets(0 ), " some failure" )
630665 assert(failure.getMessage === " Job aborted due to stage failure: some failure" )
631- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
632- assert(failedStages.contains(0 ))
633- assert(failedStages.size === 1 )
666+ assert(sparkListener.failedStages === Seq (0 ))
634667 assertDataStructuresEmpty()
635668 }
636669
@@ -639,9 +672,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
639672 val jobId = submit(rdd, Array (0 ))
640673 cancel(jobId)
641674 assert(failure.getMessage === s " Job $jobId cancelled " )
642- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
643- assert(failedStages.contains(0 ))
644- assert(failedStages.size === 1 )
675+ assert(sparkListener.failedStages === Seq (0 ))
645676 assertDataStructuresEmpty()
646677 }
647678
@@ -699,9 +730,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
699730 assert(results === Map (0 -> 42 ))
700731 assertDataStructuresEmpty()
701732
702- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
703- assert(failedStages.isEmpty)
704- assert(successfulStages.contains(0 ))
733+ assert(sparkListener.failedStages.isEmpty)
734+ assert(sparkListener.successfulStages.contains(0 ))
705735 }
706736
707737 test(" run trivial shuffle" ) {
@@ -1084,17 +1114,15 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
10841114 taskSets(1 ).tasks(0 ),
10851115 FetchFailed (makeBlockManagerId(" hostA" ), shuffleId, 0 , 0 , " ignored" ),
10861116 null ))
1087- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
1088- assert(failedStages.contains(1 ))
1117+ assert(sparkListener.failedStages.contains(1 ))
10891118
10901119 // The second ResultTask fails, with a fetch failure for the output from the second mapper.
10911120 runEvent(makeCompletionEvent(
10921121 taskSets(1 ).tasks(0 ),
10931122 FetchFailed (makeBlockManagerId(" hostA" ), shuffleId, 1 , 1 , " ignored" ),
10941123 null ))
10951124 // The SparkListener should not receive redundant failure events.
1096- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
1097- assert(failedStages.size == 1 )
1125+ assert(sparkListener.failedStages.size === 1 )
10981126 }
10991127
11001128 test(" Retry all the tasks on a resubmitted attempt of a barrier stage caused by FetchFailure" ) {
@@ -1141,8 +1169,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
11411169 taskSets(0 ).tasks(1 ),
11421170 TaskKilled (" test" ),
11431171 null ))
1144- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
1145- assert(failedStages === Seq (0 ))
1172+ assert(sparkListener.failedStages === Seq (0 ))
11461173 assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some (Seq (0 , 1 )))
11471174
11481175 scheduler.resubmitFailedStages()
@@ -1196,11 +1223,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
11961223
11971224 val mapStageId = 0
11981225 def countSubmittedMapStageAttempts (): Int = {
1199- submittedStageInfos.count(_.stageId == mapStageId)
1226+ sparkListener. submittedStageInfos.count(_.stageId == mapStageId)
12001227 }
12011228
12021229 // The map stage should have been submitted.
1203- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
12041230 assert(countSubmittedMapStageAttempts() === 1 )
12051231
12061232 complete(taskSets(0 ), Seq (
@@ -1217,12 +1243,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
12171243 taskSets(1 ).tasks(0 ),
12181244 FetchFailed (makeBlockManagerId(" hostA" ), shuffleId, 0 , 0 , " ignored" ),
12191245 null ))
1220- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
1221- assert(failedStages.contains(1 ))
1246+ assert(sparkListener.failedStages.contains(1 ))
12221247
12231248 // Trigger resubmission of the failed map stage.
12241249 runEvent(ResubmitFailedStages )
1225- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
12261250
12271251 // Another attempt for the map stage should have been submitted, resulting in 2 total attempts.
12281252 assert(countSubmittedMapStageAttempts() === 2 )
@@ -1239,7 +1263,6 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
12391263 // shouldn't effect anything -- our calling it just makes *SURE* it gets called between the
12401264 // desired event and our check.
12411265 runEvent(ResubmitFailedStages )
1242- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
12431266 assert(countSubmittedMapStageAttempts() === 2 )
12441267
12451268 }
@@ -1257,14 +1280,13 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
12571280 submit(reduceRdd, Array (0 , 1 ))
12581281
12591282 def countSubmittedReduceStageAttempts (): Int = {
1260- submittedStageInfos.count(_.stageId == 1 )
1283+ sparkListener. submittedStageInfos.count(_.stageId == 1 )
12611284 }
12621285 def countSubmittedMapStageAttempts (): Int = {
1263- submittedStageInfos.count(_.stageId == 0 )
1286+ sparkListener. submittedStageInfos.count(_.stageId == 0 )
12641287 }
12651288
12661289 // The map stage should have been submitted.
1267- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
12681290 assert(countSubmittedMapStageAttempts() === 1 )
12691291
12701292 // Complete the map stage.
@@ -1273,7 +1295,6 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
12731295 (Success , makeMapStatus(" hostB" , 2 ))))
12741296
12751297 // The reduce stage should have been submitted.
1276- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
12771298 assert(countSubmittedReduceStageAttempts() === 1 )
12781299
12791300 // The first result task fails, with a fetch failure for the output from the first mapper.
@@ -1288,7 +1309,6 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
12881309
12891310 // Because the map stage finished, another attempt for the reduce stage should have been
12901311 // submitted, resulting in 2 total attempts for each the map and the reduce stage.
1291- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
12921312 assert(countSubmittedMapStageAttempts() === 2 )
12931313 assert(countSubmittedReduceStageAttempts() === 2 )
12941314
@@ -1318,10 +1338,9 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
13181338 runEvent(makeCompletionEvent(
13191339 taskSets(0 ).tasks(1 ), Success , 42 ,
13201340 Seq .empty, Array .empty, createFakeTaskInfoWithId(1 )))
1321- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
13221341 // verify stage exists
13231342 assert(scheduler.stageIdToStage.contains(0 ))
1324- assert(endedTasks. size == 2 )
1343+ assert(sparkListener.endedTask. size = == 2 )
13251344
13261345 // finish other 2 tasks
13271346 runEvent(makeCompletionEvent(
@@ -1330,8 +1349,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
13301349 runEvent(makeCompletionEvent(
13311350 taskSets(0 ).tasks(3 ), Success , 42 ,
13321351 Seq .empty, Array .empty, createFakeTaskInfoWithId(3 )))
1333- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
1334- assert(endedTasks.size == 4 )
1352+ assert(sparkListener.endedTask.size === 4 )
13351353
13361354 // verify the stage is done
13371355 assert(! scheduler.stageIdToStage.contains(0 ))
@@ -1341,15 +1359,13 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
13411359 runEvent(makeCompletionEvent(
13421360 taskSets(0 ).tasks(3 ), Success , 42 ,
13431361 Seq .empty, Array .empty, createFakeTaskInfoWithId(5 )))
1344- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
1345- assert(endedTasks.size == 5 )
1362+ assert(sparkListener.endedTask.size === 5 )
13461363
13471364 // make sure non successful tasks also send out event
13481365 runEvent(makeCompletionEvent(
13491366 taskSets(0 ).tasks(3 ), UnknownReason , 42 ,
13501367 Seq .empty, Array .empty, createFakeTaskInfoWithId(6 )))
1351- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
1352- assert(endedTasks.size == 6 )
1368+ assert(sparkListener.endedTask.size === 6 )
13531369 }
13541370
13551371 test(" ignore late map task completions" ) {
@@ -1422,8 +1438,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
14221438
14231439 // Listener bus should get told about the map stage failing, but not the reduce stage
14241440 // (since the reduce stage hasn't been started yet).
1425- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
1426- assert(failedStages.toSet === Set (0 ))
1441+ assert(sparkListener.failedStages.toSet === Set (0 ))
14271442
14281443 assertDataStructuresEmpty()
14291444 }
@@ -1666,9 +1681,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
16661681 assert(cancelledStages.toSet === Set (0 , 2 ))
16671682
16681683 // Make sure the listeners got told about both failed stages.
1669- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
1670- assert(successfulStages.isEmpty)
1671- assert(failedStages.toSet === Set (0 , 2 ))
1684+ assert(sparkListener.successfulStages.isEmpty)
1685+ assert(sparkListener.failedStages.toSet === Set (0 , 2 ))
16721686
16731687 assert(listener1.failureMessage === s " Job aborted due to stage failure: $stageFailureMessage" )
16741688 assert(listener2.failureMessage === s " Job aborted due to stage failure: $stageFailureMessage" )
@@ -2642,20 +2656,18 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
26422656
26432657 val mapStageId = 0
26442658 def countSubmittedMapStageAttempts (): Int = {
2645- submittedStageInfos.count(_.stageId == mapStageId)
2659+ sparkListener. submittedStageInfos.count(_.stageId == mapStageId)
26462660 }
26472661
26482662 // The map stage should have been submitted.
2649- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
26502663 assert(countSubmittedMapStageAttempts() === 1 )
26512664
26522665 // The first map task fails with TaskKilled.
26532666 runEvent(makeCompletionEvent(
26542667 taskSets(0 ).tasks(0 ),
26552668 TaskKilled (" test" ),
26562669 null ))
2657- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
2658- assert(failedStages === Seq (0 ))
2670+ assert(sparkListener.failedStages === Seq (0 ))
26592671
26602672 // The second map task fails with TaskKilled.
26612673 runEvent(makeCompletionEvent(
@@ -2665,7 +2677,6 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
26652677
26662678 // Trigger resubmission of the failed map stage.
26672679 runEvent(ResubmitFailedStages )
2668- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
26692680
26702681 // Another attempt for the map stage should have been submitted, resulting in 2 total attempts.
26712682 assert(countSubmittedMapStageAttempts() === 2 )
@@ -2679,24 +2690,21 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
26792690
26802691 val mapStageId = 0
26812692 def countSubmittedMapStageAttempts (): Int = {
2682- submittedStageInfos.count(_.stageId == mapStageId)
2693+ sparkListener. submittedStageInfos.count(_.stageId == mapStageId)
26832694 }
26842695
26852696 // The map stage should have been submitted.
2686- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
26872697 assert(countSubmittedMapStageAttempts() === 1 )
26882698
26892699 // The first map task fails with TaskKilled.
26902700 runEvent(makeCompletionEvent(
26912701 taskSets(0 ).tasks(0 ),
26922702 TaskKilled (" test" ),
26932703 null ))
2694- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
2695- assert(failedStages === Seq (0 ))
2704+ assert(sparkListener.failedStages === Seq (0 ))
26962705
26972706 // Trigger resubmission of the failed map stage.
26982707 runEvent(ResubmitFailedStages )
2699- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
27002708
27012709 // Another attempt for the map stage should have been submitted, resulting in 2 total attempts.
27022710 assert(countSubmittedMapStageAttempts() === 2 )
@@ -2709,7 +2717,6 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
27092717
27102718 // The second map task failure doesn't trigger stage retry.
27112719 runEvent(ResubmitFailedStages )
2712- sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS )
27132720 assert(countSubmittedMapStageAttempts() === 2 )
27142721 }
27152722
0 commit comments