Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 71 additions & 67 deletions core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -174,31 +174,66 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
/** Length of time to wait while draining listener events. */
val WAIT_TIMEOUT_MILLIS = 10000

val submittedStageInfos = new HashSet[StageInfo]
val successfulStages = new HashSet[Int]
val failedStages = new ArrayBuffer[Int]
val stageByOrderOfExecution = new ArrayBuffer[Int]
val endedTasks = new HashSet[Long]
val sparkListener = new SparkListener() {
/**
* Listeners which records some information to verify in UTs. Getter-kind methods in this class
* ensures the value is returned after ensuring there's no event to process, as well as the
* value is immutable: prevent showing odd result by race condition.
*/
class EventInfoRecordingListener extends SparkListener {
private val _submittedStageInfos = new HashSet[StageInfo]
private val _successfulStages = new HashSet[Int]
private val _failedStages = new ArrayBuffer[Int]
private val _stageByOrderOfExecution = new ArrayBuffer[Int]
private val _endedTasks = new HashSet[Long]

override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
submittedStageInfos += stageSubmitted.stageInfo
_submittedStageInfos += stageSubmitted.stageInfo
}

override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
val stageInfo = stageCompleted.stageInfo
stageByOrderOfExecution += stageInfo.stageId
_stageByOrderOfExecution += stageInfo.stageId
if (stageInfo.failureReason.isEmpty) {
successfulStages += stageInfo.stageId
_successfulStages += stageInfo.stageId
} else {
failedStages += stageInfo.stageId
_failedStages += stageInfo.stageId
}
}

override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
endedTasks += taskEnd.taskInfo.taskId
_endedTasks += taskEnd.taskInfo.taskId
}

def submittedStageInfos: Set[StageInfo] = {
waitForListeners()
_submittedStageInfos.toSet
}

def successfulStages: Set[Int] = {
waitForListeners()
_successfulStages.toSet
}

def failedStages: List[Int] = {
waitForListeners()
_failedStages.toList
}

def stageByOrderOfExecution: List[Int] = {
waitForListeners()
_stageByOrderOfExecution.toList
}

def endedTasks: Set[Long] = {
waitForListeners()
_endedTasks.toSet
}

private def waitForListeners(): Unit = sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
}

var sparkListener: EventInfoRecordingListener = null

var mapOutputTracker: MapOutputTrackerMaster = null
var broadcastManager: BroadcastManager = null
var securityMgr: SecurityManager = null
Expand Down Expand Up @@ -247,10 +282,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi

private def init(testConf: SparkConf): Unit = {
sc = new SparkContext("local[2]", "DAGSchedulerSuite", testConf)
submittedStageInfos.clear()
successfulStages.clear()
failedStages.clear()
endedTasks.clear()
sparkListener = new EventInfoRecordingListener
failure = null
sc.addSparkListener(sparkListener)
taskSets.clear()
Expand Down Expand Up @@ -373,9 +405,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
}

test("[SPARK-3353] parent stage should have lower stage id") {
stageByOrderOfExecution.clear()
sc.parallelize(1 to 10).map(x => (x, x)).reduceByKey(_ + _, 4).count()
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
val stageByOrderOfExecution = sparkListener.stageByOrderOfExecution
assert(stageByOrderOfExecution.length === 2)
assert(stageByOrderOfExecution(0) < stageByOrderOfExecution(1))
}
Expand Down Expand Up @@ -618,19 +649,15 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
submit(unserializableRdd, Array(0))
assert(failure.getMessage.startsWith(
"Job aborted due to stage failure: Task not serializable:"))
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(failedStages.contains(0))
assert(failedStages.size === 1)
assert(sparkListener.failedStages === Seq(0))
assertDataStructuresEmpty()
}

test("trivial job failure") {
submit(new MyRDD(sc, 1, Nil), Array(0))
failed(taskSets(0), "some failure")
assert(failure.getMessage === "Job aborted due to stage failure: some failure")
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(failedStages.contains(0))
assert(failedStages.size === 1)
assert(sparkListener.failedStages === Seq(0))
assertDataStructuresEmpty()
}

Expand All @@ -639,9 +666,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
val jobId = submit(rdd, Array(0))
cancel(jobId)
assert(failure.getMessage === s"Job $jobId cancelled ")
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(failedStages.contains(0))
assert(failedStages.size === 1)
assert(sparkListener.failedStages === Seq(0))
assertDataStructuresEmpty()
}

Expand Down Expand Up @@ -699,9 +724,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
assert(results === Map(0 -> 42))
assertDataStructuresEmpty()

sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(failedStages.isEmpty)
assert(successfulStages.contains(0))
assert(sparkListener.failedStages.isEmpty)
assert(sparkListener.successfulStages.contains(0))
}

test("run trivial shuffle") {
Expand Down Expand Up @@ -1084,17 +1108,15 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
taskSets(1).tasks(0),
FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
null))
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(failedStages.contains(1))
assert(sparkListener.failedStages.contains(1))

// The second ResultTask fails, with a fetch failure for the output from the second mapper.
runEvent(makeCompletionEvent(
taskSets(1).tasks(0),
FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1, "ignored"),
null))
// The SparkListener should not receive redundant failure events.
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(failedStages.size == 1)
assert(sparkListener.failedStages.size === 1)
}

test("Retry all the tasks on a resubmitted attempt of a barrier stage caused by FetchFailure") {
Expand Down Expand Up @@ -1141,7 +1163,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
taskSets(0).tasks(1),
TaskKilled("test"),
null))
assert(failedStages === Seq(0))
assert(sparkListener.failedStages === Seq(0))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actual place of bug. The code is newly added (for "newly" I meant after the pattern of sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) is applied) and doesn't follow the existing pattern. That is easily missed and here we can enforce it by disallowing access variables directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There're some other places accessing failedStages without waiting. Search with assert(failedStages === Seq(0))

assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq(0, 1)))

scheduler.resubmitFailedStages()
Expand Down Expand Up @@ -1195,11 +1217,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi

val mapStageId = 0
def countSubmittedMapStageAttempts(): Int = {
submittedStageInfos.count(_.stageId == mapStageId)
sparkListener.submittedStageInfos.count(_.stageId == mapStageId)
}

// The map stage should have been submitted.
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedMapStageAttempts() === 1)

complete(taskSets(0), Seq(
Expand All @@ -1216,12 +1237,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
taskSets(1).tasks(0),
FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
null))
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(failedStages.contains(1))
assert(sparkListener.failedStages.contains(1))

// Trigger resubmission of the failed map stage.
runEvent(ResubmitFailedStages)
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)

// Another attempt for the map stage should have been submitted, resulting in 2 total attempts.
assert(countSubmittedMapStageAttempts() === 2)
Expand All @@ -1238,7 +1257,6 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
// shouldn't effect anything -- our calling it just makes *SURE* it gets called between the
// desired event and our check.
runEvent(ResubmitFailedStages)
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedMapStageAttempts() === 2)

}
Expand All @@ -1256,14 +1274,13 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
submit(reduceRdd, Array(0, 1))

def countSubmittedReduceStageAttempts(): Int = {
submittedStageInfos.count(_.stageId == 1)
sparkListener.submittedStageInfos.count(_.stageId == 1)
}
def countSubmittedMapStageAttempts(): Int = {
submittedStageInfos.count(_.stageId == 0)
sparkListener.submittedStageInfos.count(_.stageId == 0)
}

// The map stage should have been submitted.
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedMapStageAttempts() === 1)

// Complete the map stage.
Expand All @@ -1272,7 +1289,6 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
(Success, makeMapStatus("hostB", 2))))

// The reduce stage should have been submitted.
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedReduceStageAttempts() === 1)

// The first result task fails, with a fetch failure for the output from the first mapper.
Expand All @@ -1287,7 +1303,6 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi

// Because the map stage finished, another attempt for the reduce stage should have been
// submitted, resulting in 2 total attempts for each the map and the reduce stage.
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedMapStageAttempts() === 2)
assert(countSubmittedReduceStageAttempts() === 2)

Expand Down Expand Up @@ -1317,10 +1332,9 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
runEvent(makeCompletionEvent(
taskSets(0).tasks(1), Success, 42,
Seq.empty, Array.empty, createFakeTaskInfoWithId(1)))
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
// verify stage exists
assert(scheduler.stageIdToStage.contains(0))
assert(endedTasks.size == 2)
assert(sparkListener.endedTasks.size === 2)

// finish other 2 tasks
runEvent(makeCompletionEvent(
Expand All @@ -1329,8 +1343,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
runEvent(makeCompletionEvent(
taskSets(0).tasks(3), Success, 42,
Seq.empty, Array.empty, createFakeTaskInfoWithId(3)))
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(endedTasks.size == 4)
assert(sparkListener.endedTasks.size === 4)

// verify the stage is done
assert(!scheduler.stageIdToStage.contains(0))
Expand All @@ -1340,15 +1353,13 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
runEvent(makeCompletionEvent(
taskSets(0).tasks(3), Success, 42,
Seq.empty, Array.empty, createFakeTaskInfoWithId(5)))
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(endedTasks.size == 5)
assert(sparkListener.endedTasks.size === 5)

// make sure non successful tasks also send out event
runEvent(makeCompletionEvent(
taskSets(0).tasks(3), UnknownReason, 42,
Seq.empty, Array.empty, createFakeTaskInfoWithId(6)))
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(endedTasks.size == 6)
assert(sparkListener.endedTasks.size === 6)
}

test("ignore late map task completions") {
Expand Down Expand Up @@ -1421,8 +1432,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi

// Listener bus should get told about the map stage failing, but not the reduce stage
// (since the reduce stage hasn't been started yet).
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(failedStages.toSet === Set(0))
assert(sparkListener.failedStages.toSet === Set(0))

assertDataStructuresEmpty()
}
Expand Down Expand Up @@ -1665,9 +1675,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
assert(cancelledStages.toSet === Set(0, 2))

// Make sure the listeners got told about both failed stages.
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(successfulStages.isEmpty)
assert(failedStages.toSet === Set(0, 2))
assert(sparkListener.successfulStages.isEmpty)
assert(sparkListener.failedStages.toSet === Set(0, 2))

assert(listener1.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage")
assert(listener2.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage")
Expand Down Expand Up @@ -2641,19 +2650,18 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi

val mapStageId = 0
def countSubmittedMapStageAttempts(): Int = {
submittedStageInfos.count(_.stageId == mapStageId)
sparkListener.submittedStageInfos.count(_.stageId == mapStageId)
}

// The map stage should have been submitted.
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedMapStageAttempts() === 1)

// The first map task fails with TaskKilled.
runEvent(makeCompletionEvent(
taskSets(0).tasks(0),
TaskKilled("test"),
null))
assert(failedStages === Seq(0))
assert(sparkListener.failedStages === Seq(0))

// The second map task fails with TaskKilled.
runEvent(makeCompletionEvent(
Expand All @@ -2663,7 +2671,6 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi

// Trigger resubmission of the failed map stage.
runEvent(ResubmitFailedStages)
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)

// Another attempt for the map stage should have been submitted, resulting in 2 total attempts.
assert(countSubmittedMapStageAttempts() === 2)
Expand All @@ -2677,23 +2684,21 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi

val mapStageId = 0
def countSubmittedMapStageAttempts(): Int = {
submittedStageInfos.count(_.stageId == mapStageId)
sparkListener.submittedStageInfos.count(_.stageId == mapStageId)
}

// The map stage should have been submitted.
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedMapStageAttempts() === 1)

// The first map task fails with TaskKilled.
runEvent(makeCompletionEvent(
taskSets(0).tasks(0),
TaskKilled("test"),
null))
assert(failedStages === Seq(0))
assert(sparkListener.failedStages === Seq(0))

// Trigger resubmission of the failed map stage.
runEvent(ResubmitFailedStages)
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)

// Another attempt for the map stage should have been submitted, resulting in 2 total attempts.
assert(countSubmittedMapStageAttempts() === 2)
Expand All @@ -2706,7 +2711,6 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi

// The second map task failure doesn't trigger stage retry.
runEvent(ResubmitFailedStages)
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedMapStageAttempts() === 2)
}

Expand Down