Skip to content

Commit 9689763

Browse files
committed
Port of #15986 to master branch.
1 parent 8b1609b commit 9689763

File tree

3 files changed

+121
-36
lines changed

3 files changed

+121
-36
lines changed

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,12 @@ private[spark] class TaskSchedulerImpl(
9393
// Incrementing task IDs
9494
val nextTaskId = new AtomicLong(0)
9595

96-
// Number of tasks running on each executor
97-
private val executorIdToTaskCount = new HashMap[String, Int]
96+
// IDs of the tasks running on each executor
97+
private val executorIdToRunningTaskIds = new HashMap[String, HashSet[Long]]
9898

99-
def runningTasksByExecutors(): Map[String, Int] = executorIdToTaskCount.toMap
99+
def runningTasksByExecutors(): Map[String, Int] = {
100+
executorIdToRunningTaskIds.toMap.mapValues(_.size)
101+
}
100102

101103
// The set of executors we have on each host; this is used to compute hostsAlive, which
102104
// in turn is used to decide when we can attain data locality on a given host
@@ -264,7 +266,7 @@ private[spark] class TaskSchedulerImpl(
264266
val tid = task.taskId
265267
taskIdToTaskSetManager(tid) = taskSet
266268
taskIdToExecutorId(tid) = execId
267-
executorIdToTaskCount(execId) += 1
269+
executorIdToRunningTaskIds(execId).add(tid)
268270
availableCpus(i) -= CPUS_PER_TASK
269271
assert(availableCpus(i) >= 0)
270272
launchedTask = true
@@ -294,11 +296,11 @@ private[spark] class TaskSchedulerImpl(
294296
if (!hostToExecutors.contains(o.host)) {
295297
hostToExecutors(o.host) = new HashSet[String]()
296298
}
297-
if (!executorIdToTaskCount.contains(o.executorId)) {
299+
if (!executorIdToRunningTaskIds.contains(o.executorId)) {
298300
hostToExecutors(o.host) += o.executorId
299301
executorAdded(o.executorId, o.host)
300302
executorIdToHost(o.executorId) = o.host
301-
executorIdToTaskCount(o.executorId) = 0
303+
executorIdToRunningTaskIds(o.executorId) = HashSet[Long]()
302304
newExecAvail = true
303305
}
304306
for (rack <- getRackForHost(o.host)) {
@@ -349,38 +351,34 @@ private[spark] class TaskSchedulerImpl(
349351
var reason: Option[ExecutorLossReason] = None
350352
synchronized {
351353
try {
352-
if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) {
353-
// We lost this entire executor, so remember that it's gone
354-
val execId = taskIdToExecutorId(tid)
355-
356-
if (executorIdToTaskCount.contains(execId)) {
357-
reason = Some(
358-
SlaveLost(s"Task $tid was lost, so marking the executor as lost as well."))
359-
removeExecutor(execId, reason.get)
360-
failedExecutor = Some(execId)
361-
}
362-
}
363354
taskIdToTaskSetManager.get(tid) match {
364355
case Some(taskSet) =>
365-
if (TaskState.isFinished(state)) {
366-
taskIdToTaskSetManager.remove(tid)
367-
taskIdToExecutorId.remove(tid).foreach { execId =>
368-
if (executorIdToTaskCount.contains(execId)) {
369-
executorIdToTaskCount(execId) -= 1
370-
}
356+
if (state == TaskState.LOST) {
357+
// TaskState.LOST is only used by the deprecated Mesos fine-grained scheduling mode,
358+
// where each executor corresponds to a single task, so mark the executor as failed.
359+
val execId = taskIdToExecutorId.getOrElse(tid, throw new IllegalStateException(
360+
"taskIdToTaskSetManager.contains(tid) <=> taskIdToExecutorId.contains(tid)"))
361+
if (executorIdToRunningTaskIds.contains(execId)) {
362+
reason = Some(
363+
SlaveLost(s"Task $tid was lost, so marking the executor as lost as well."))
364+
removeExecutor(execId, reason.get)
365+
failedExecutor = Some(execId)
371366
}
372367
}
373-
if (state == TaskState.FINISHED) {
374-
taskSet.removeRunningTask(tid)
375-
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
376-
} else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
368+
if (TaskState.isFinished(state)) {
369+
cleanupTaskState(tid)
377370
taskSet.removeRunningTask(tid)
378-
taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
371+
if (state == TaskState.FINISHED) {
372+
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
373+
} else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
374+
taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
375+
}
379376
}
380377
case None =>
381378
logError(
382379
("Ignoring update with state %s for TID %s because its task set is gone (this is " +
383-
"likely the result of receiving duplicate task finished status updates)")
380+
"likely the result of receiving duplicate task finished status updates) or its " +
381+
"executor has been marked as failed.")
384382
.format(state, tid))
385383
}
386384
} catch {
@@ -491,7 +489,7 @@ private[spark] class TaskSchedulerImpl(
491489
var failedExecutor: Option[String] = None
492490

493491
synchronized {
494-
if (executorIdToTaskCount.contains(executorId)) {
492+
if (executorIdToRunningTaskIds.contains(executorId)) {
495493
val hostPort = executorIdToHost(executorId)
496494
logExecutorLoss(executorId, hostPort, reason)
497495
removeExecutor(executorId, reason)
@@ -533,13 +531,31 @@ private[spark] class TaskSchedulerImpl(
533531
logError(s"Lost executor $executorId on $hostPort: $reason")
534532
}
535533

534+
/**
535+
* Cleans up the TaskScheduler's state for tracking the given task.
536+
*/
537+
private def cleanupTaskState(tid: Long): Unit = {
538+
taskIdToTaskSetManager.remove(tid)
539+
taskIdToExecutorId.remove(tid).foreach { executorId =>
540+
executorIdToRunningTaskIds.get(executorId).foreach { _.remove(tid) }
541+
}
542+
}
543+
536544
/**
537545
* Remove an executor from all our data structures and mark it as lost. If the executor's loss
538546
* reason is not yet known, do not yet remove its association with its host nor update the status
539547
* of any running tasks, since the loss reason defines whether we'll fail those tasks.
540548
*/
541549
private def removeExecutor(executorId: String, reason: ExecutorLossReason) {
542-
executorIdToTaskCount -= executorId
550+
// The tasks on the lost executor may not send any more status updates (because the executor
551+
// has been lost), so they should be cleaned up here.
552+
executorIdToRunningTaskIds.remove(executorId).foreach { taskIds =>
553+
logDebug("Cleaning up TaskScheduler state for tasks " +
554+
s"${taskIds.mkString("[", ",", "]")} on failed executor $executorId")
555+
// We do not notify the TaskSetManager of the task failures because that will
556+
// happen below in the rootPool.executorLost() call.
557+
taskIds.foreach(cleanupTaskState)
558+
}
543559

544560
val host = executorIdToHost(executorId)
545561
val execs = hostToExecutors.getOrElse(host, new HashSet)
@@ -577,11 +593,11 @@ private[spark] class TaskSchedulerImpl(
577593
}
578594

579595
def isExecutorAlive(execId: String): Boolean = synchronized {
580-
executorIdToTaskCount.contains(execId)
596+
executorIdToRunningTaskIds.contains(execId)
581597
}
582598

583599
def isExecutorBusy(execId: String): Boolean = synchronized {
584-
executorIdToTaskCount.getOrElse(execId, -1) > 0
600+
executorIdToRunningTaskIds.get(execId).exists(_.nonEmpty)
585601
}
586602

587603
// By default, rack is unknown

core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -433,10 +433,11 @@ class StandaloneDynamicAllocationSuite
433433
assert(executors.size === 2)
434434

435435
// simulate running a task on the executor
436-
val getMap = PrivateMethod[mutable.HashMap[String, Int]]('executorIdToTaskCount)
436+
val getMap =
437+
PrivateMethod[mutable.HashMap[String, mutable.HashSet[Long]]]('executorIdToRunningTaskIds)
437438
val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl]
438-
val executorIdToTaskCount = taskScheduler invokePrivate getMap()
439-
executorIdToTaskCount(executors.head) = 1
439+
val executorIdToRunningTaskIds = taskScheduler invokePrivate getMap()
440+
executorIdToRunningTaskIds(executors.head) = mutable.HashSet(1L)
440441
// kill the busy executor without force; this should fail
441442
assert(killExecutor(sc, executors.head, force = false).isEmpty)
442443
apps = getApplications()

core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.scheduler
1919

20+
import java.nio.ByteBuffer
21+
2022
import scala.collection.mutable.HashMap
2123

2224
import org.mockito.Matchers.{anyInt, anyString, eq => meq}
@@ -648,4 +650,70 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
648650
assert(taskScheduler.getExecutorsAliveOnHost("host1") === Some(Set("executor1", "executor3")))
649651
}
650652

653+
test("if an executor is lost then the state for its running tasks is cleaned up (SPARK-18553)") {
654+
sc = new SparkContext("local", "TaskSchedulerImplSuite")
655+
val taskScheduler = new TaskSchedulerImpl(sc)
656+
taskScheduler.initialize(new FakeSchedulerBackend)
657+
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
658+
new DAGScheduler(sc, taskScheduler) {
659+
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
660+
override def executorAdded(execId: String, host: String) {}
661+
}
662+
663+
val e0Offers = IndexedSeq(WorkerOffer("executor0", "host0", 1))
664+
val attempt1 = FakeTask.createTaskSet(1)
665+
666+
// submit attempt 1, offer resources, task gets scheduled
667+
taskScheduler.submitTasks(attempt1)
668+
val taskDescriptions = taskScheduler.resourceOffers(e0Offers).flatten
669+
assert(1 === taskDescriptions.length)
670+
671+
// mark executor0 as dead
672+
taskScheduler.executorLost("executor0", SlaveLost())
673+
assert(!taskScheduler.isExecutorAlive("executor0"))
674+
assert(!taskScheduler.hasExecutorsAliveOnHost("host0"))
675+
assert(taskScheduler.getExecutorsAliveOnHost("host0").isEmpty)
676+
677+
678+
// Check that state associated with the lost task attempt is cleaned up:
679+
assert(taskScheduler.taskIdToExecutorId.isEmpty)
680+
assert(taskScheduler.taskIdToTaskSetManager.isEmpty)
681+
assert(taskScheduler.runningTasksByExecutors().get("executor0").isEmpty)
682+
}
683+
684+
test("if a task finishes with TaskState.LOST its executor is marked as dead") {
685+
sc = new SparkContext("local", "TaskSchedulerImplSuite")
686+
val taskScheduler = new TaskSchedulerImpl(sc)
687+
taskScheduler.initialize(new FakeSchedulerBackend)
688+
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
689+
new DAGScheduler(sc, taskScheduler) {
690+
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
691+
override def executorAdded(execId: String, host: String) {}
692+
}
693+
694+
val e0Offers = IndexedSeq(WorkerOffer("executor0", "host0", 1))
695+
val attempt1 = FakeTask.createTaskSet(1)
696+
697+
// submit attempt 1, offer resources, task gets scheduled
698+
taskScheduler.submitTasks(attempt1)
699+
val taskDescriptions = taskScheduler.resourceOffers(e0Offers).flatten
700+
assert(1 === taskDescriptions.length)
701+
702+
// Report the task as failed with TaskState.LOST
703+
taskScheduler.statusUpdate(
704+
tid = taskDescriptions.head.taskId,
705+
state = TaskState.LOST,
706+
serializedData = ByteBuffer.allocate(0)
707+
)
708+
709+
// Check that state associated with the lost task attempt is cleaned up:
710+
assert(taskScheduler.taskIdToExecutorId.isEmpty)
711+
assert(taskScheduler.taskIdToTaskSetManager.isEmpty)
712+
assert(taskScheduler.runningTasksByExecutors().get("executor0").isEmpty)
713+
714+
// Check that the executor has been marked as dead
715+
assert(!taskScheduler.isExecutorAlive("executor0"))
716+
assert(!taskScheduler.hasExecutorsAliveOnHost("host0"))
717+
assert(taskScheduler.getExecutorsAliveOnHost("host0").isEmpty)
718+
}
651719
}

0 commit comments

Comments
 (0)