@@ -75,9 +75,10 @@ private[spark] class TaskSchedulerImpl(
7575
7676 // TaskSetManagers are not thread safe, so any access to one should be synchronized
7777 // on this class.
78- val stageIdToActiveTaskSet = new HashMap [Int , TaskSetManager ]
78+ val activeTaskSets = new HashMap [String , TaskSetManager ]
79+ val taskSetsByStage = new HashMap [Int , HashMap [Int , TaskSetManager ]]
7980
80- val taskIdToStageId = new HashMap [Long , Int ]
81+ val taskIdToTaskSetId = new HashMap [Long , String ]
8182 val taskIdToExecutorId = new HashMap [Long , String ]
8283
8384 @ volatile private var hasReceivedTask = false
@@ -162,13 +163,17 @@ private[spark] class TaskSchedulerImpl(
162163 logInfo(" Adding task set " + taskSet.id + " with " + tasks.length + " tasks" )
163164 this .synchronized {
164165 val manager = createTaskSetManager(taskSet, maxTaskFailures)
165- stageIdToActiveTaskSet(taskSet.stageId) = manager
166- val stageId = taskSet.stageId
167- stageIdToActiveTaskSet.get(stageId).map { activeTaskSet =>
168- throw new IllegalStateException (
169- s " Active taskSet with id already exists for stage $stageId: ${activeTaskSet.taskSet.id}" )
166+ activeTaskSets(taskSet.id) = manager
167+ val stage = taskSet.stageId
168+ val stageTaskSets = taskSetsByStage.getOrElseUpdate(stage, new HashMap [Int , TaskSetManager ])
169+ stageTaskSets(taskSet.attempt) = manager
170+ val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
171+ ts.taskSet != taskSet && ! ts.isZombie
172+ }
173+ if (conflictingTaskSet) {
174+ throw new IllegalStateException (s " more than one active taskSet for stage $stage: " +
175+ s " ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(" ," )}" )
170176 }
171- stageIdToActiveTaskSet(stageId) = manager
172177 schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
173178
174179 if (! isLocal && ! hasReceivedTask) {
@@ -198,7 +203,7 @@ private[spark] class TaskSchedulerImpl(
198203
199204 override def cancelTasks (stageId : Int , interruptThread : Boolean ): Unit = synchronized {
200205 logInfo(" Cancelling stage " + stageId)
201- stageIdToActiveTaskSet.get( stageId).map { tsm =>
206+ activeTaskSets.find(_._2. stageId == stageId).foreach { case (_, tsm) =>
202207 // There are two possible cases here:
203208 // 1. The task set manager has been created and some tasks have been scheduled.
204209 // In this case, send a kill signal to the executors to kill the task and then abort
@@ -220,7 +225,13 @@ private[spark] class TaskSchedulerImpl(
220225 * cleaned up.
221226 */
222227 def taskSetFinished (manager : TaskSetManager ): Unit = synchronized {
223- stageIdToActiveTaskSet -= manager.stageId
228+ activeTaskSets -= manager.taskSet.id
229+ taskSetsByStage.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
230+ taskSetsForStage -= manager.taskSet.attempt
231+ if (taskSetsForStage.isEmpty) {
232+ taskSetsByStage -= manager.taskSet.stageId
233+ }
234+ }
224235 manager.parent.removeSchedulable(manager)
225236 logInfo(" Removed TaskSet %s, whose tasks have all completed, from pool %s"
226237 .format(manager.taskSet.id, manager.parent.name))
@@ -241,7 +252,7 @@ private[spark] class TaskSchedulerImpl(
241252 for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
242253 tasks(i) += task
243254 val tid = task.taskId
244- taskIdToStageId (tid) = taskSet.taskSet.stageId
255+ taskIdToTaskSetId (tid) = taskSet.taskSet.id
245256 taskIdToExecutorId(tid) = execId
246257 executorsByHost(host) += execId
247258 availableCpus(i) -= CPUS_PER_TASK
@@ -325,13 +336,13 @@ private[spark] class TaskSchedulerImpl(
325336 failedExecutor = Some (execId)
326337 }
327338 }
328- taskIdToStageId .get(tid) match {
329- case Some (stageId ) =>
339+ taskIdToTaskSetId .get(tid) match {
340+ case Some (taskSetId ) =>
330341 if (TaskState .isFinished(state)) {
331- taskIdToStageId .remove(tid)
342+ taskIdToTaskSetId .remove(tid)
332343 taskIdToExecutorId.remove(tid)
333344 }
334- stageIdToActiveTaskSet .get(stageId ).foreach { taskSet =>
345+ activeTaskSets .get(taskSetId ).foreach { taskSet =>
335346 if (state == TaskState .FINISHED ) {
336347 taskSet.removeRunningTask(tid)
337348 taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
@@ -369,8 +380,8 @@ private[spark] class TaskSchedulerImpl(
369380
370381 val metricsWithStageIds : Array [(Long , Int , Int , TaskMetrics )] = synchronized {
371382 taskMetrics.flatMap { case (id, metrics) =>
372- taskIdToStageId .get(id)
373- .flatMap(stageIdToActiveTaskSet .get)
383+ taskIdToTaskSetId .get(id)
384+ .flatMap(activeTaskSets .get)
374385 .map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics))
375386 }
376387 }
@@ -403,9 +414,9 @@ private[spark] class TaskSchedulerImpl(
403414
404415 def error (message : String ) {
405416 synchronized {
406- if (stageIdToActiveTaskSet .nonEmpty) {
417+ if (activeTaskSets .nonEmpty) {
407418 // Have each task set throw a SparkException with the error
408- for ((_ , manager) <- stageIdToActiveTaskSet ) {
419+ for ((taskSetId , manager) <- activeTaskSets ) {
409420 try {
410421 manager.abort(message)
411422 } catch {
0 commit comments