diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 2b3423f9a4d40..44a4eadef630c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -174,31 +174,66 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi /** Length of time to wait while draining listener events. */ val WAIT_TIMEOUT_MILLIS = 10000 - val submittedStageInfos = new HashSet[StageInfo] - val successfulStages = new HashSet[Int] - val failedStages = new ArrayBuffer[Int] - val stageByOrderOfExecution = new ArrayBuffer[Int] - val endedTasks = new HashSet[Long] - val sparkListener = new SparkListener() { + /** + * Listeners which records some information to verify in UTs. Getter-kind methods in this class + * ensures the value is returned after ensuring there's no event to process, as well as the + * value is immutable: prevent showing odd result by race condition. + */ + class EventInfoRecordingListener extends SparkListener { + private val _submittedStageInfos = new HashSet[StageInfo] + private val _successfulStages = new HashSet[Int] + private val _failedStages = new ArrayBuffer[Int] + private val _stageByOrderOfExecution = new ArrayBuffer[Int] + private val _endedTasks = new HashSet[Long] + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { - submittedStageInfos += stageSubmitted.stageInfo + _submittedStageInfos += stageSubmitted.stageInfo } override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { val stageInfo = stageCompleted.stageInfo - stageByOrderOfExecution += stageInfo.stageId + _stageByOrderOfExecution += stageInfo.stageId if (stageInfo.failureReason.isEmpty) { - successfulStages += stageInfo.stageId + _successfulStages += stageInfo.stageId } else { - failedStages += stageInfo.stageId + _failedStages += stageInfo.stageId } } override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { - endedTasks += taskEnd.taskInfo.taskId + _endedTasks += taskEnd.taskInfo.taskId + } + + def submittedStageInfos: Set[StageInfo] = { + waitForListeners() + _submittedStageInfos.toSet + } + + def successfulStages: Set[Int] = { + waitForListeners() + _successfulStages.toSet + } + + def failedStages: List[Int] = { + waitForListeners() + _failedStages.toList + } + + def stageByOrderOfExecution: List[Int] = { + waitForListeners() + _stageByOrderOfExecution.toList + } + + def endedTasks: Set[Long] = { + waitForListeners() + _endedTasks.toSet } + + private def waitForListeners(): Unit = sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) } + var sparkListener: EventInfoRecordingListener = null + var mapOutputTracker: MapOutputTrackerMaster = null var broadcastManager: BroadcastManager = null var securityMgr: SecurityManager = null @@ -247,10 +282,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi private def init(testConf: SparkConf): Unit = { sc = new SparkContext("local[2]", "DAGSchedulerSuite", testConf) - submittedStageInfos.clear() - successfulStages.clear() - failedStages.clear() - endedTasks.clear() + sparkListener = new EventInfoRecordingListener failure = null sc.addSparkListener(sparkListener) taskSets.clear() @@ -373,9 +405,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } test("[SPARK-3353] parent stage should have lower stage id") { - stageByOrderOfExecution.clear() sc.parallelize(1 to 10).map(x => (x, x)).reduceByKey(_ + _, 4).count() - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + val stageByOrderOfExecution = sparkListener.stageByOrderOfExecution assert(stageByOrderOfExecution.length === 2) assert(stageByOrderOfExecution(0) < stageByOrderOfExecution(1)) } @@ -618,9 +649,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi submit(unserializableRdd, Array(0)) assert(failure.getMessage.startsWith( "Job aborted due to stage failure: Task not serializable:")) - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - assert(failedStages.contains(0)) - assert(failedStages.size === 1) + assert(sparkListener.failedStages === Seq(0)) assertDataStructuresEmpty() } @@ -628,9 +657,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi submit(new MyRDD(sc, 1, Nil), Array(0)) failed(taskSets(0), "some failure") assert(failure.getMessage === "Job aborted due to stage failure: some failure") - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - assert(failedStages.contains(0)) - assert(failedStages.size === 1) + assert(sparkListener.failedStages === Seq(0)) assertDataStructuresEmpty() } @@ -639,9 +666,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi val jobId = submit(rdd, Array(0)) cancel(jobId) assert(failure.getMessage === s"Job $jobId cancelled ") - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - assert(failedStages.contains(0)) - assert(failedStages.size === 1) + assert(sparkListener.failedStages === Seq(0)) assertDataStructuresEmpty() } @@ -699,9 +724,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(results === Map(0 -> 42)) assertDataStructuresEmpty() - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - assert(failedStages.isEmpty) - assert(successfulStages.contains(0)) + assert(sparkListener.failedStages.isEmpty) + assert(sparkListener.successfulStages.contains(0)) } test("run trivial shuffle") { @@ -1084,8 +1108,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi taskSets(1).tasks(0), FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null)) - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - assert(failedStages.contains(1)) + assert(sparkListener.failedStages.contains(1)) // The second ResultTask fails, with a fetch failure for the output from the second mapper. runEvent(makeCompletionEvent( @@ -1093,8 +1116,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1, "ignored"), null)) // The SparkListener should not receive redundant failure events. - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - assert(failedStages.size == 1) + assert(sparkListener.failedStages.size === 1) } test("Retry all the tasks on a resubmitted attempt of a barrier stage caused by FetchFailure") { @@ -1141,7 +1163,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi taskSets(0).tasks(1), TaskKilled("test"), null)) - assert(failedStages === Seq(0)) + assert(sparkListener.failedStages === Seq(0)) assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq(0, 1))) scheduler.resubmitFailedStages() @@ -1195,11 +1217,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi val mapStageId = 0 def countSubmittedMapStageAttempts(): Int = { - submittedStageInfos.count(_.stageId == mapStageId) + sparkListener.submittedStageInfos.count(_.stageId == mapStageId) } // The map stage should have been submitted. - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(countSubmittedMapStageAttempts() === 1) complete(taskSets(0), Seq( @@ -1216,12 +1237,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi taskSets(1).tasks(0), FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null)) - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - assert(failedStages.contains(1)) + assert(sparkListener.failedStages.contains(1)) // Trigger resubmission of the failed map stage. runEvent(ResubmitFailedStages) - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) // Another attempt for the map stage should have been submitted, resulting in 2 total attempts. assert(countSubmittedMapStageAttempts() === 2) @@ -1238,7 +1257,6 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // shouldn't effect anything -- our calling it just makes *SURE* it gets called between the // desired event and our check. runEvent(ResubmitFailedStages) - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(countSubmittedMapStageAttempts() === 2) } @@ -1256,14 +1274,13 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi submit(reduceRdd, Array(0, 1)) def countSubmittedReduceStageAttempts(): Int = { - submittedStageInfos.count(_.stageId == 1) + sparkListener.submittedStageInfos.count(_.stageId == 1) } def countSubmittedMapStageAttempts(): Int = { - submittedStageInfos.count(_.stageId == 0) + sparkListener.submittedStageInfos.count(_.stageId == 0) } // The map stage should have been submitted. - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(countSubmittedMapStageAttempts() === 1) // Complete the map stage. @@ -1272,7 +1289,6 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostB", 2)))) // The reduce stage should have been submitted. - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(countSubmittedReduceStageAttempts() === 1) // The first result task fails, with a fetch failure for the output from the first mapper. @@ -1287,7 +1303,6 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // Because the map stage finished, another attempt for the reduce stage should have been // submitted, resulting in 2 total attempts for each the map and the reduce stage. - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(countSubmittedMapStageAttempts() === 2) assert(countSubmittedReduceStageAttempts() === 2) @@ -1317,10 +1332,9 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi runEvent(makeCompletionEvent( taskSets(0).tasks(1), Success, 42, Seq.empty, Array.empty, createFakeTaskInfoWithId(1))) - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) // verify stage exists assert(scheduler.stageIdToStage.contains(0)) - assert(endedTasks.size == 2) + assert(sparkListener.endedTasks.size === 2) // finish other 2 tasks runEvent(makeCompletionEvent( @@ -1329,8 +1343,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi runEvent(makeCompletionEvent( taskSets(0).tasks(3), Success, 42, Seq.empty, Array.empty, createFakeTaskInfoWithId(3))) - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - assert(endedTasks.size == 4) + assert(sparkListener.endedTasks.size === 4) // verify the stage is done assert(!scheduler.stageIdToStage.contains(0)) @@ -1340,15 +1353,13 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi runEvent(makeCompletionEvent( taskSets(0).tasks(3), Success, 42, Seq.empty, Array.empty, createFakeTaskInfoWithId(5))) - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - assert(endedTasks.size == 5) + assert(sparkListener.endedTasks.size === 5) // make sure non successful tasks also send out event runEvent(makeCompletionEvent( taskSets(0).tasks(3), UnknownReason, 42, Seq.empty, Array.empty, createFakeTaskInfoWithId(6))) - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - assert(endedTasks.size == 6) + assert(sparkListener.endedTasks.size === 6) } test("ignore late map task completions") { @@ -1421,8 +1432,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // Listener bus should get told about the map stage failing, but not the reduce stage // (since the reduce stage hasn't been started yet). - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - assert(failedStages.toSet === Set(0)) + assert(sparkListener.failedStages.toSet === Set(0)) assertDataStructuresEmpty() } @@ -1665,9 +1675,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(cancelledStages.toSet === Set(0, 2)) // Make sure the listeners got told about both failed stages. - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) - assert(successfulStages.isEmpty) - assert(failedStages.toSet === Set(0, 2)) + assert(sparkListener.successfulStages.isEmpty) + assert(sparkListener.failedStages.toSet === Set(0, 2)) assert(listener1.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage") assert(listener2.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage") @@ -2641,11 +2650,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi val mapStageId = 0 def countSubmittedMapStageAttempts(): Int = { - submittedStageInfos.count(_.stageId == mapStageId) + sparkListener.submittedStageInfos.count(_.stageId == mapStageId) } // The map stage should have been submitted. - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(countSubmittedMapStageAttempts() === 1) // The first map task fails with TaskKilled. @@ -2653,7 +2661,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi taskSets(0).tasks(0), TaskKilled("test"), null)) - assert(failedStages === Seq(0)) + assert(sparkListener.failedStages === Seq(0)) // The second map task fails with TaskKilled. runEvent(makeCompletionEvent( @@ -2663,7 +2671,6 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // Trigger resubmission of the failed map stage. runEvent(ResubmitFailedStages) - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) // Another attempt for the map stage should have been submitted, resulting in 2 total attempts. assert(countSubmittedMapStageAttempts() === 2) @@ -2677,11 +2684,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi val mapStageId = 0 def countSubmittedMapStageAttempts(): Int = { - submittedStageInfos.count(_.stageId == mapStageId) + sparkListener.submittedStageInfos.count(_.stageId == mapStageId) } // The map stage should have been submitted. - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(countSubmittedMapStageAttempts() === 1) // The first map task fails with TaskKilled. @@ -2689,11 +2695,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi taskSets(0).tasks(0), TaskKilled("test"), null)) - assert(failedStages === Seq(0)) + assert(sparkListener.failedStages === Seq(0)) // Trigger resubmission of the failed map stage. runEvent(ResubmitFailedStages) - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) // Another attempt for the map stage should have been submitted, resulting in 2 total attempts. assert(countSubmittedMapStageAttempts() === 2) @@ -2706,7 +2711,6 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // The second map task failure doesn't trigger stage retry. runEvent(ResubmitFailedStages) - sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(countSubmittedMapStageAttempts() === 2) }