@@ -75,7 +75,7 @@ private[spark] class TaskSchedulerImpl(
7575
7676 // TaskSetManagers are not thread safe, so any access to one should be synchronized
7777 // on this class.
78- val taskSetsByStage = new HashMap [Int , HashMap [Int , TaskSetManager ]]
78+ val taskSetsByStageIdAndAttempt = new HashMap [Int , HashMap [Int , TaskSetManager ]]
7979
8080 val taskIdToStageIdAndAttempt = new HashMap [Long , (Int , Int )]
8181 val taskIdToExecutorId = new HashMap [Long , String ]
@@ -163,7 +163,8 @@ private[spark] class TaskSchedulerImpl(
163163 this .synchronized {
164164 val manager = createTaskSetManager(taskSet, maxTaskFailures)
165165 val stage = taskSet.stageId
166- val stageTaskSets = taskSetsByStage.getOrElseUpdate(stage, new HashMap [Int , TaskSetManager ])
166+ val stageTaskSets =
167+ taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap [Int , TaskSetManager ])
167168 stageTaskSets(taskSet.stageAttemptId) = manager
168169 val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
169170 ts.taskSet != taskSet && ! ts.isZombie
@@ -201,7 +202,7 @@ private[spark] class TaskSchedulerImpl(
201202
202203 override def cancelTasks (stageId : Int , interruptThread : Boolean ): Unit = synchronized {
203204 logInfo(" Cancelling stage " + stageId)
204- taskSetsByStage .get(stageId).foreach { attempts =>
205+ taskSetsByStageIdAndAttempt .get(stageId).foreach { attempts =>
205206 attempts.foreach { case (_, tsm) =>
206207 // There are two possible cases here:
207208 // 1. The task set manager has been created and some tasks have been scheduled.
@@ -225,10 +226,10 @@ private[spark] class TaskSchedulerImpl(
225226 * cleaned up.
226227 */
227228 def taskSetFinished (manager : TaskSetManager ): Unit = synchronized {
228- taskSetsByStage .get(manager.taskSet.stageId).foreach { taskSetsForStage =>
229+ taskSetsByStageIdAndAttempt .get(manager.taskSet.stageId).foreach { taskSetsForStage =>
229230 taskSetsForStage -= manager.taskSet.stageAttemptId
230231 if (taskSetsForStage.isEmpty) {
231- taskSetsByStage -= manager.taskSet.stageId
232+ taskSetsByStageIdAndAttempt -= manager.taskSet.stageId
232233 }
233234 }
234235 manager.parent.removeSchedulable(manager)
@@ -380,7 +381,7 @@ private[spark] class TaskSchedulerImpl(
380381 taskMetrics.flatMap { case (id, metrics) =>
381382 for {
382383 (stageId, stageAttemptId) <- taskIdToStageIdAndAttempt.get(id)
383- attempts <- taskSetsByStage .get(stageId)
384+ attempts <- taskSetsByStageIdAndAttempt .get(stageId)
384385 taskSetMgr <- attempts.get(stageAttemptId)
385386 } yield {
386387 (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
@@ -416,10 +417,10 @@ private[spark] class TaskSchedulerImpl(
416417
417418 def error (message : String ) {
418419 synchronized {
419- if (taskSetsByStage .nonEmpty) {
420+ if (taskSetsByStageIdAndAttempt .nonEmpty) {
420421 // Have each task set throw a SparkException with the error
421422 for {
422- attempts <- taskSetsByStage .values
423+ attempts <- taskSetsByStageIdAndAttempt .values
423424 manager <- attempts.values
424425 } {
425426 try {
@@ -552,7 +553,7 @@ private[spark] class TaskSchedulerImpl(
552553 stageId : Int ,
553554 stageAttemptId : Int ): Option [TaskSetManager ] = {
554555 for {
555- attempts <- taskSetsByStage .get(stageId)
556+ attempts <- taskSetsByStageIdAndAttempt .get(stageId)
556557 manager <- attempts.get(stageAttemptId)
557558 } yield {
558559 manager
0 commit comments