@@ -50,6 +50,10 @@ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
5050 * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task
5151 * a small number of times before cancelling the whole stage.
5252 *
53+ * Here's a checklist to use when making or reviewing changes to this class:
54+ *
55+ * - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to
56+ * include the new structure. This will help to catch memory leaks.
5357 */
5458private [spark]
5559class DAGScheduler (
@@ -111,6 +115,8 @@ class DAGScheduler(
111115 // stray messages to detect.
112116 private val failedEpoch = new HashMap [String , Long ]
113117
118+ private [scheduler] val outputCommitCoordinator = env.outputCommitCoordinator
119+
114120 // A closure serializer that we reuse.
115121 // This is only safe because DAGScheduler runs in a single thread.
116122 private val closureSerializer = SparkEnv .get.closureSerializer.newInstance()
@@ -128,8 +134,6 @@ class DAGScheduler(
128134 private [scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop (this )
129135 taskScheduler.setDAGScheduler(this )
130136
131- private val outputCommitCoordinator = env.outputCommitCoordinator
132-
133137 // Called by TaskScheduler to report task's starting.
134138 def taskStarted (task : Task [_], taskInfo : TaskInfo ) {
135139 eventProcessLoop.post(BeginEvent (task, taskInfo))
@@ -641,13 +645,13 @@ class DAGScheduler(
641645 val split = rdd.partitions(job.partitions(0 ))
642646 val taskContext = new TaskContextImpl (job.finalStage.id, job.partitions(0 ), taskAttemptId = 0 ,
643647 attemptNumber = 0 , runningLocally = true )
644- TaskContextHelper .setTaskContext(taskContext)
648+ TaskContext .setTaskContext(taskContext)
645649 try {
646650 val result = job.func(taskContext, rdd.iterator(split, taskContext))
647651 job.listener.taskSucceeded(0 , result)
648652 } finally {
649653 taskContext.markTaskCompleted()
650- TaskContextHelper .unset()
654+ TaskContext .unset()
651655 }
652656 } catch {
653657 case e : Exception =>
@@ -710,9 +714,10 @@ class DAGScheduler(
710714 // cancelling the stages because if the DAG scheduler is stopped, the entire application
711715 // is in the process of getting stopped.
712716 val stageFailedMessage = " Stage cancelled because SparkContext was shut down"
713- runningStages.foreach { stage =>
714- stage.latestInfo.stageFailed(stageFailedMessage)
715- listenerBus.post(SparkListenerStageCompleted (stage.latestInfo))
717+ // The `toArray` here is necessary so that we don't iterate over `runningStages` while
718+ // mutating it.
719+ runningStages.toArray.foreach { stage =>
720+ markStageAsFinished(stage, Some (stageFailedMessage))
716721 }
717722 listenerBus.post(SparkListenerJobEnd (job.jobId, clock.getTimeMillis(), JobFailed (error)))
718723 }
@@ -887,10 +892,9 @@ class DAGScheduler(
887892 new TaskSet (tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
888893 stage.latestInfo.submissionTime = Some (clock.getTimeMillis())
889894 } else {
890- // Because we posted SparkListenerStageSubmitted earlier, we should post
891- // SparkListenerStageCompleted here in case there are no tasks to run.
892- outputCommitCoordinator.stageEnd(stage.id)
893- listenerBus.post(SparkListenerStageCompleted (stage.latestInfo))
895+ // Because we posted SparkListenerStageSubmitted earlier, we should mark
896+ // the stage as completed here in case there are no tasks to run
897+ markStageAsFinished(stage, None )
894898
895899 val debugString = stage match {
896900 case stage : ShuffleMapStage =>
@@ -902,7 +906,6 @@ class DAGScheduler(
902906 s " Stage ${stage} is actually done; (partitions: ${stage.numPartitions}) "
903907 }
904908 logDebug(debugString)
905- runningStages -= stage
906909 }
907910 }
908911
@@ -968,22 +971,6 @@ class DAGScheduler(
968971 }
969972
970973 val stage = stageIdToStage(task.stageId)
971-
972- def markStageAsFinished (stage : Stage , errorMessage : Option [String ] = None ): Unit = {
973- val serviceTime = stage.latestInfo.submissionTime match {
974- case Some (t) => " %.03f" .format((clock.getTimeMillis() - t) / 1000.0 )
975- case _ => " Unknown"
976- }
977- if (errorMessage.isEmpty) {
978- logInfo(" %s (%s) finished in %s s" .format(stage, stage.name, serviceTime))
979- stage.latestInfo.completionTime = Some (clock.getTimeMillis())
980- } else {
981- stage.latestInfo.stageFailed(errorMessage.get)
982- logInfo(" %s (%s) failed in %s s" .format(stage, stage.name, serviceTime))
983- }
984- listenerBus.post(SparkListenerStageCompleted (stage.latestInfo))
985- runningStages -= stage
986- }
987974 event.reason match {
988975 case Success =>
989976 listenerBus.post(SparkListenerTaskEnd (stageId, stage.latestInfo.attemptId, taskType,
@@ -1099,7 +1086,6 @@ class DAGScheduler(
10991086 logInfo(s " Marking $failedStage ( ${failedStage.name}) as failed " +
11001087 s " due to a fetch failure from $mapStage ( ${mapStage.name}) " )
11011088 markStageAsFinished(failedStage, Some (failureMessage))
1102- runningStages -= failedStage
11031089 }
11041090
11051091 if (disallowStageRetryForTest) {
@@ -1215,6 +1201,26 @@ class DAGScheduler(
12151201 submitWaitingStages()
12161202 }
12171203
1204+ /**
1205+ * Marks a stage as finished and removes it from the list of running stages.
1206+ */
1207+ private def markStageAsFinished (stage : Stage , errorMessage : Option [String ] = None ): Unit = {
1208+ val serviceTime = stage.latestInfo.submissionTime match {
1209+ case Some (t) => " %.03f" .format((clock.getTimeMillis() - t) / 1000.0 )
1210+ case _ => " Unknown"
1211+ }
1212+ if (errorMessage.isEmpty) {
1213+ logInfo(" %s (%s) finished in %s s" .format(stage, stage.name, serviceTime))
1214+ stage.latestInfo.completionTime = Some (clock.getTimeMillis())
1215+ } else {
1216+ stage.latestInfo.stageFailed(errorMessage.get)
1217+ logInfo(" %s (%s) failed in %s s" .format(stage, stage.name, serviceTime))
1218+ }
1219+ outputCommitCoordinator.stageEnd(stage.id)
1220+ listenerBus.post(SparkListenerStageCompleted (stage.latestInfo))
1221+ runningStages -= stage
1222+ }
1223+
12181224 /**
12191225 * Aborts all jobs depending on a particular Stage. This is called in response to a task set
12201226 * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
@@ -1264,8 +1270,7 @@ class DAGScheduler(
12641270 if (runningStages.contains(stage)) {
12651271 try { // cancelTasks will fail if a SchedulerBackend does not implement killTask
12661272 taskScheduler.cancelTasks(stageId, shouldInterruptThread)
1267- stage.latestInfo.stageFailed(failureReason)
1268- listenerBus.post(SparkListenerStageCompleted (stage.latestInfo))
1273+ markStageAsFinished(stage, Some (failureReason))
12691274 } catch {
12701275 case e : UnsupportedOperationException =>
12711276 logInfo(s " Could not cancel tasks for stage $stageId" , e)
0 commit comments