@@ -89,9 +89,11 @@ private[spark] class TaskSchedulerImpl(
8989 val nextTaskId = new AtomicLong (0 )
9090
9191 // Number of tasks running on each executor
92- private val executorIdToTaskCount = new HashMap [String , Int ]
92+ private val executorIdToRunningTaskIds = new HashMap [String , HashSet [ Long ] ]
9393
94- def runningTasksByExecutors (): Map [String , Int ] = executorIdToTaskCount.toMap
94+ def runningTasksByExecutors (): Map [String , Int ] = synchronized {
95+ executorIdToRunningTaskIds.toMap.mapValues(_.size)
96+ }
9597
9698 // The set of executors we have on each host; this is used to compute hostsAlive, which
9799 // in turn is used to decide when we can attain data locality on a given host
@@ -259,7 +261,7 @@ private[spark] class TaskSchedulerImpl(
259261 val tid = task.taskId
260262 taskIdToTaskSetManager(tid) = taskSet
261263 taskIdToExecutorId(tid) = execId
262- executorIdToTaskCount (execId) += 1
264+ executorIdToRunningTaskIds (execId).add(tid)
263265 executorsByHost(host) += execId
264266 availableCpus(i) -= CPUS_PER_TASK
265267 assert(availableCpus(i) >= 0 )
@@ -288,7 +290,7 @@ private[spark] class TaskSchedulerImpl(
288290 var newExecAvail = false
289291 for (o <- offers) {
290292 executorIdToHost(o.executorId) = o.host
291- executorIdToTaskCount .getOrElseUpdate(o.executorId, 0 )
293+ executorIdToRunningTaskIds .getOrElseUpdate(o.executorId, HashSet [ Long ]() )
292294 if (! executorsByHost.contains(o.host)) {
293295 executorsByHost(o.host) = new HashSet [String ]()
294296 executorAdded(o.executorId, o.host)
@@ -339,7 +341,7 @@ private[spark] class TaskSchedulerImpl(
339341 // We lost this entire executor, so remember that it's gone
340342 val execId = taskIdToExecutorId(tid)
341343
342- if (executorIdToTaskCount .contains(execId)) {
344+ if (executorIdToRunningTaskIds .contains(execId)) {
343345 reason = Some (
344346 SlaveLost (s " Task $tid was lost, so marking the executor as lost as well. " ))
345347 removeExecutor(execId, reason.get)
@@ -351,9 +353,7 @@ private[spark] class TaskSchedulerImpl(
351353 if (TaskState .isFinished(state)) {
352354 taskIdToTaskSetManager.remove(tid)
353355 taskIdToExecutorId.remove(tid).foreach { execId =>
354- if (executorIdToTaskCount.contains(execId)) {
355- executorIdToTaskCount(execId) -= 1
356- }
356+ executorIdToRunningTaskIds.remove(execId)
357357 }
358358 }
359359 if (state == TaskState .FINISHED ) {
@@ -477,7 +477,7 @@ private[spark] class TaskSchedulerImpl(
477477 var failedExecutor : Option [String ] = None
478478
479479 synchronized {
480- if (executorIdToTaskCount .contains(executorId)) {
480+ if (executorIdToRunningTaskIds .contains(executorId)) {
481481 val hostPort = executorIdToHost(executorId)
482482 logExecutorLoss(executorId, hostPort, reason)
483483 removeExecutor(executorId, reason)
@@ -525,7 +525,12 @@ private[spark] class TaskSchedulerImpl(
525525 * of any running tasks, since the loss reason defines whether we'll fail those tasks.
526526 */
527527 private def removeExecutor (executorId : String , reason : ExecutorLossReason ) {
528- executorIdToTaskCount -= executorId
528+ executorIdToRunningTaskIds.remove(executorId).foreach { taskIds =>
529+ taskIds.foreach { tid =>
530+ taskIdToExecutorId.remove(tid)
531+ taskIdToTaskSetManager.remove(tid)
532+ }
533+ }
529534
530535 val host = executorIdToHost(executorId)
531536 val execs = executorsByHost.getOrElse(host, new HashSet )
@@ -563,11 +568,11 @@ private[spark] class TaskSchedulerImpl(
563568 }
564569
565570 def isExecutorAlive (execId : String ): Boolean = synchronized {
566- executorIdToTaskCount .contains(execId)
571+ executorIdToRunningTaskIds .contains(execId)
567572 }
568573
569574 def isExecutorBusy (execId : String ): Boolean = synchronized {
570- executorIdToTaskCount.getOrElse (execId, - 1 ) > 0
575+ executorIdToRunningTaskIds.get (execId).exists(_.nonEmpty)
571576 }
572577
573578 // By default, rack is unknown
0 commit comments