diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a41b059fa7dec..a26546a5a7c3f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -110,7 +110,6 @@ private[spark] class TaskSetManager( // task set is aborted (for example, because it was killed). TaskSetManagers remain in the zombie // state until all tasks have finished running; we keep TaskSetManagers that are in the zombie // state in order to continue to track and account for the running tasks. - // TODO: We should kill any running task attempts when the task set manager becomes a zombie. private[scheduler] var isZombie = false // Set of pending tasks for each executor. These collections are actually @@ -768,6 +767,19 @@ private[spark] class TaskSetManager( s" executor ${info.executorId}): ${reason.toErrorString}" val failureException: Option[Throwable] = reason match { case fetchFailed: FetchFailed => + if (!isZombie) { + for (i <- 0 until numTasks if i != index) { + // Only for the first occurance of the fetch failure, kill all running + // tasks in the task set + for (attemptInfo <- taskAttempts(i) if attemptInfo.running) { + sched.backend.killTask( + attemptInfo.taskId, + attemptInfo.executorId, + interruptThread = true, + reason = "another attempt succeeded") + } + } + } logWarning(failureReason) if (!successful(index)) { successful(index) = true diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 9ca6b8b0fe635..782e8ec7313ae 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -417,6 +417,53 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg } } + test("Running tasks should be killed after first fetch failure") { + val rescheduleDelay = 300L + val conf = new SparkConf(). + set("spark.scheduler.executorTaskBlacklistTime", rescheduleDelay.toString). + // don't wait to jump locality levels in this test + set("spark.locality.wait", "0") + + val killedTasks = new ArrayBuffer[Long] + sc = new SparkContext("local", "test", conf) + // two executors on same host, one on different. + val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), + ("exec1.1", "host1"), ("exec2", "host2")) + sched.initialize(new FakeSchedulerBackend() { + override def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = { + killedTasks += taskId + } + }) + // affinity to exec1 on host1 - which we will fail. + val taskSet = FakeTask.createTaskSet(4) + val clock = new ManualClock + clock.advance(1) + val manager = new TaskSetManager(sched, taskSet, 4, None, clock) + + val offerResult1 = manager.resourceOffer("exec1", "host1", ANY) + assert(offerResult1.isDefined, "Expect resource offer to return a task") + + assert(offerResult1.get.index === 0) + assert(offerResult1.get.executorId === "exec1") + + val offerResult2 = manager.resourceOffer("exec2", "host2", ANY) + assert(offerResult2.isDefined, "Expect resource offer to return a task") + + assert(offerResult2.get.index === 1) + assert(offerResult2.get.executorId === "exec2") + // At this point, we have 2 tasks running and 2 pending. First fetch failure should + // abort all the pending tasks but the running tasks should not be aborted. + assert(killedTasks.isEmpty) + manager.handleFailedTask(offerResult1.get.taskId, TaskState.FINISHED, + FetchFailed(BlockManagerId("exec-host2", "host2", 12345), 0, 0, 0, "ignored")) + assert(killedTasks.size === 1) + assert(killedTasks(0) === offerResult2.get.taskId) + } + test("executors should be blacklisted after task failure, in spite of locality preferences") { val rescheduleDelay = 300L val conf = new SparkConf(). @@ -1107,6 +1154,13 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg set(config.BLACKLIST_ENABLED, true) sc = new SparkContext("local", "test", conf) sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + sched.initialize(new FakeSchedulerBackend() { + override def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = {} + }) val taskSet = FakeTask.createTaskSet(4) val tsm = new TaskSetManager(sched, taskSet, 4) // we need a spy so we can attach our mock blacklist