@@ -75,10 +75,9 @@ private[spark] class TaskSchedulerImpl(
7575
7676 // TaskSetManagers are not thread safe, so any access to one should be synchronized
7777 // on this class.
78- val activeTaskSets = new HashMap [String , TaskSetManager ]
79- val taskSetsByStage = new HashMap [Int , HashMap [Int , TaskSetManager ]]
78+ val stageIdToActiveTaskSet = new HashMap [Int , TaskSetManager ]
8079
81- val taskIdToTaskSetId = new HashMap [Long , String ]
80+ val taskIdToStageId = new HashMap [Long , Int ]
8281 val taskIdToExecutorId = new HashMap [Long , String ]
8382
8483 @ volatile private var hasReceivedTask = false
@@ -163,17 +162,13 @@ private[spark] class TaskSchedulerImpl(
163162 logInfo(" Adding task set " + taskSet.id + " with " + tasks.length + " tasks" )
164163 this .synchronized {
165164 val manager = createTaskSetManager(taskSet, maxTaskFailures)
166- activeTaskSets(taskSet.id) = manager
167- val stage = taskSet.stageId
168- val stageTaskSets = taskSetsByStage.getOrElseUpdate(stage, new HashMap [Int , TaskSetManager ])
169- stageTaskSets(taskSet.attempt) = manager
170- val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
171- ts.taskSet != taskSet && ! ts.isZombie
172- }
173- if (conflictingTaskSet) {
174- throw new IllegalStateException (s " more than one active taskSet for stage $stage: " +
175- s " ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(" ," )}" )
165+ stageIdToActiveTaskSet(taskSet.stageId) = manager
166+ val stageId = taskSet.stageId
167+ stageIdToActiveTaskSet.get(stageId).map { activeTaskSet =>
168+ throw new IllegalStateException (
169+ s " Active taskSet with id already exists for stage $stageId: ${activeTaskSet.taskSet.id}" )
176170 }
171+ stageIdToActiveTaskSet(stageId) = manager
177172 schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
178173
179174 if (! isLocal && ! hasReceivedTask) {
@@ -203,7 +198,7 @@ private[spark] class TaskSchedulerImpl(
203198
204199 override def cancelTasks (stageId : Int , interruptThread : Boolean ): Unit = synchronized {
205200 logInfo(" Cancelling stage " + stageId)
206- activeTaskSets.find(_._2. stageId == stageId).foreach { case (_, tsm) =>
201+ stageIdToActiveTaskSet.get( stageId).map { tsm =>
207202 // There are two possible cases here:
208203 // 1. The task set manager has been created and some tasks have been scheduled.
209204 // In this case, send a kill signal to the executors to kill the task and then abort
@@ -225,13 +220,7 @@ private[spark] class TaskSchedulerImpl(
225220 * cleaned up.
226221 */
227222 def taskSetFinished (manager : TaskSetManager ): Unit = synchronized {
228- activeTaskSets -= manager.taskSet.id
229- taskSetsByStage.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
230- taskSetsForStage -= manager.taskSet.attempt
231- if (taskSetsForStage.isEmpty) {
232- taskSetsByStage -= manager.taskSet.stageId
233- }
234- }
223+ stageIdToActiveTaskSet -= manager.stageId
235224 manager.parent.removeSchedulable(manager)
236225 logInfo(" Removed TaskSet %s, whose tasks have all completed, from pool %s"
237226 .format(manager.taskSet.id, manager.parent.name))
@@ -252,7 +241,7 @@ private[spark] class TaskSchedulerImpl(
252241 for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
253242 tasks(i) += task
254243 val tid = task.taskId
255- taskIdToTaskSetId (tid) = taskSet.taskSet.id
244+ taskIdToStageId (tid) = taskSet.taskSet.stageId
256245 taskIdToExecutorId(tid) = execId
257246 executorsByHost(host) += execId
258247 availableCpus(i) -= CPUS_PER_TASK
@@ -336,13 +325,13 @@ private[spark] class TaskSchedulerImpl(
336325 failedExecutor = Some (execId)
337326 }
338327 }
339- taskIdToTaskSetId .get(tid) match {
340- case Some (taskSetId ) =>
328+ taskIdToStageId .get(tid) match {
329+ case Some (stageId ) =>
341330 if (TaskState .isFinished(state)) {
342- taskIdToTaskSetId .remove(tid)
331+ taskIdToStageId .remove(tid)
343332 taskIdToExecutorId.remove(tid)
344333 }
345- activeTaskSets .get(taskSetId ).foreach { taskSet =>
334+ stageIdToActiveTaskSet .get(stageId ).foreach { taskSet =>
346335 if (state == TaskState .FINISHED ) {
347336 taskSet.removeRunningTask(tid)
348337 taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
@@ -380,8 +369,8 @@ private[spark] class TaskSchedulerImpl(
380369
381370 val metricsWithStageIds : Array [(Long , Int , Int , TaskMetrics )] = synchronized {
382371 taskMetrics.flatMap { case (id, metrics) =>
383- taskIdToTaskSetId .get(id)
384- .flatMap(activeTaskSets .get)
372+ taskIdToStageId .get(id)
373+ .flatMap(stageIdToActiveTaskSet .get)
385374 .map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics))
386375 }
387376 }
@@ -414,9 +403,9 @@ private[spark] class TaskSchedulerImpl(
414403
415404 def error (message : String ) {
416405 synchronized {
417- if (activeTaskSets .nonEmpty) {
406+ if (stageIdToActiveTaskSet .nonEmpty) {
418407 // Have each task set throw a SparkException with the error
419- for ((taskSetId , manager) <- activeTaskSets ) {
408+ for ((_ , manager) <- stageIdToActiveTaskSet ) {
420409 try {
421410 manager.abort(message)
422411 } catch {
0 commit comments