@@ -96,6 +96,9 @@ private[spark] class TaskSchedulerImpl(
9696 private [scheduler] val taskIdToTaskSetManager = new ConcurrentHashMap [Long , TaskSetManager ]
9797 val taskIdToExecutorId = new HashMap [Long , String ]
9898
99+ // Protected by `this`
100+ private [scheduler] val stageIdToFinishedPartitions = new HashMap [Int , HashSet [Int ]]
101+
99102 @ volatile private var hasReceivedTask = false
100103 @ volatile private var hasLaunchedTask = false
101104 private val starvationTimer = new Timer (true )
@@ -208,7 +211,7 @@ private[spark] class TaskSchedulerImpl(
208211 taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap [Int , TaskSetManager ])
209212 stageTaskSets(taskSet.stageAttemptId) = manager
210213 val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
211- ts.taskSet != taskSet && ! ts.isZombie
214+ ts.taskSet != manager. taskSet && ! ts.isZombie
212215 }
213216 if (conflictingTaskSet) {
214217 throw new IllegalStateException (s " more than one active taskSet for stage $stage: " +
@@ -238,7 +241,13 @@ private[spark] class TaskSchedulerImpl(
238241 private [scheduler] def createTaskSetManager (
239242 taskSet : TaskSet ,
240243 maxTaskFailures : Int ): TaskSetManager = {
241- new TaskSetManager (this , taskSet, maxTaskFailures, blacklistTrackerOpt)
244+ val finishedPartitions =
245+ stageIdToFinishedPartitions.getOrElseUpdate(taskSet.stageId, new HashSet [Int ])
246+ // filter the task which has been finished by previous attempts
247+ val tasks = taskSet.tasks.filterNot{ t => finishedPartitions(t.partitionId) }
248+ val ts = new TaskSet (
249+ tasks, taskSet.stageId, taskSet.stageAttemptId, taskSet.priority, taskSet.properties)
250+ new TaskSetManager (this , ts, maxTaskFailures, blacklistTrackerOpt)
242251 }
243252
244253 override def cancelTasks (stageId : Int , interruptThread : Boolean ): Unit = synchronized {
@@ -297,6 +306,7 @@ private[spark] class TaskSchedulerImpl(
297306 taskSetsForStage -= manager.taskSet.stageAttemptId
298307 if (taskSetsForStage.isEmpty) {
299308 taskSetsByStageIdAndAttempt -= manager.taskSet.stageId
309+ stageIdToFinishedPartitions -= manager.taskSet.stageId
300310 }
301311 }
302312 manager.parent.removeSchedulable(manager)
@@ -837,17 +847,31 @@ private[spark] class TaskSchedulerImpl(
837847 }
838848
839849 /**
840- * Marks the task has completed in all TaskSetManagers for the given stage.
850+ * Marks the task has completed in all TaskSetManagers(active / zombie) for the given stage.
841851 *
842852 * After stage failure and retry, there may be multiple TaskSetManagers for the stage.
843853 * If an earlier attempt of a stage completes a task, we should ensure that the later attempts
844854 * do not also submit those same tasks. That also means that a task completion from an earlier
845855 * attempt can lead to the entire stage getting marked as successful.
856+ * And there's a situation that the active TaskSetManager corresponding to the stage may
857+ * haven't been created at the time we call this method. And it is possible since the behaviour
858+ * of calling on this method and creating active TaskSetManager is from two different threads,
859+ * which are "task-result-getter" and "dag-scheduler-event-loop" separately. Consequently, under
860+ * this situation, the active TaskSetManager which is created later could not learn about the
861+ * finished partitions and keep on launching duplicate tasks, which may lead to job fail for some
862+ * severe cases, see SPARK-25250 for details. So, to avoid the problem, we record the finished
863+ * partitions for that stage here and exclude the already finished tasks when we creating active
864+ * TaskSetManagers later by looking into stageIdToFinishedPartitions. Thus, active TaskSetManager
865+ * could be always notified about the finished partitions whether it has been created or not at
866+ * the time we call this method.
846867 */
847868 private [scheduler] def markPartitionCompletedInAllTaskSets (
848869 stageId : Int ,
849870 partitionId : Int ,
850871 taskInfo : TaskInfo ) = {
872+ val finishedPartitions =
873+ stageIdToFinishedPartitions.getOrElseUpdate(stageId, new HashSet [Int ])
874+ finishedPartitions += partitionId
851875 taskSetsByStageIdAndAttempt.getOrElse(stageId, Map ()).values.foreach { tsm =>
852876 tsm.markPartitionCompleted(partitionId, taskInfo)
853877 }
0 commit comments