From d886ba3840ab06cd3a5d9dea7d47a8e156d5eb72 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 5 Apr 2018 11:29:01 -0500 Subject: [PATCH 1/4] [SPARK-23816][CORE] Killed tasks should ignore FetchFailures. SPARK-19276 ensured that FetchFailures do not get swallowed by other layers of exception handling, but it also meant that a killed task could look like a fetch failure. This is particularly a problem with speculative execution, where we expect to kill tasks as they are reading shuffle data. The fix is to ensure that we always check for killed tasks first. Added a new unit test which fails before the fix, ran it 1k times to check for flakiness. Full suite of tests on jenkins. --- .../org/apache/spark/executor/Executor.scala | 26 +++---- .../apache/spark/executor/ExecutorSuite.scala | 76 +++++++++++++++---- 2 files changed, 74 insertions(+), 28 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index dcec3ec21b546..c325222b764b8 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -480,6 +480,19 @@ private[spark] class Executor( execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) } catch { + case t: TaskKilledException => + logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}") + setTaskFinishedAndClearInterruptStatus() + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) + + case _: InterruptedException | NonFatal(_) if + task != null && task.reasonIfKilled.isDefined => + val killReason = task.reasonIfKilled.getOrElse("unknown reason") + logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") + setTaskFinishedAndClearInterruptStatus() + execBackend.statusUpdate( + taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason))) + case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) => val reason = task.context.fetchFailed.get.toTaskFailedReason if (!t.isInstanceOf[FetchFailedException]) { @@ -494,19 +507,6 @@ private[spark] class Executor( setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) - case t: TaskKilledException => - logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}") - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) - - case _: InterruptedException | NonFatal(_) if - task != null && task.reasonIfKilled.isDefined => - val killReason = task.reasonIfKilled.getOrElse("unknown reason") - logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate( - taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason))) - case CausedBy(cDE: CommitDeniedException) => val reason = cDE.toTaskCommitDeniedReason setTaskFinishedAndClearInterruptStatus() diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 105a178f2d94e..627f044ab4166 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -139,7 +139,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug // the fetch failure. The executor should still tell the driver that the task failed due to a // fetch failure, not a generic exception from user code. val inputRDD = new FetchFailureThrowingRDD(sc) - val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false) + val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false, interrupt = false) val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() val task = new ResultTask( @@ -173,8 +173,26 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug } test("SPARK-19276: OOMs correctly handled with a FetchFailure") { + val (failReason, uncaughtExceptionHandler) = testFetchFailureHandling(true) + assert(failReason.isInstanceOf[ExceptionFailure]) + val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]) + verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture()) + assert(exceptionCaptor.getAllValues.size === 1) + assert(exceptionCaptor.getAllValues().get(0).isInstanceOf[OutOfMemoryError]) + } + + test(s"SPARK-23816: interrupts are not masked by a FetchFailure") { + // If killing the task causes a fetch failure, we still treat it as a task that was killed, + // as the fetch failure could easily be caused by interrupting the thread. + val (failReason, _) = testFetchFailureHandling(false) + assert(failReason.isInstanceOf[TaskKilled]) + } + + def testFetchFailureHandling(oom: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = { // when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it // may be a false positive. And we should call the uncaught exception handler. + // SPARK-23816 also handle interrupts the same way, as killing an obsolete speculative task + // does not represent a real fetch failure. val conf = new SparkConf().setMaster("local").setAppName("executor suite test") sc = new SparkContext(conf) val serializer = SparkEnv.get.closureSerializer.newInstance() @@ -183,7 +201,13 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug // Submit a job where a fetch failure is thrown, but then there is an OOM. We should treat // the fetch failure as a false positive, and just do normal OOM handling. val inputRDD = new FetchFailureThrowingRDD(sc) - val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = true) + if (!oom) { + // we are trying to setup a case where a task is killed after a fetch failure -- this + // is just a helper to coordinate between the task thread and this thread that will + // kill the task + ExecutorSuiteHelper.latches = new ExecutorSuiteHelper() + } + val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = oom, interrupt = !oom) val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() val task = new ResultTask( @@ -200,15 +224,8 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug val serTask = serializer.serialize(task) val taskDescription = createFakeTaskDescription(serTask) - val (failReason, uncaughtExceptionHandler) = - runTaskGetFailReasonAndExceptionHandler(taskDescription) - // make sure the task failure just looks like a OOM, not a fetch failure - assert(failReason.isInstanceOf[ExceptionFailure]) - val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]) - verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture()) - assert(exceptionCaptor.getAllValues.size === 1) - assert(exceptionCaptor.getAllValues.get(0).isInstanceOf[OutOfMemoryError]) - } + runTaskGetFailReasonAndExceptionHandler(taskDescription, killTask = !oom) + } test("Gracefully handle error in task deserialization") { val conf = new SparkConf @@ -257,19 +274,32 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug } private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = { - runTaskGetFailReasonAndExceptionHandler(taskDescription)._1 + runTaskGetFailReasonAndExceptionHandler(taskDescription, false)._1 } private def runTaskGetFailReasonAndExceptionHandler( - taskDescription: TaskDescription): (TaskFailedReason, UncaughtExceptionHandler) = { + taskDescription: TaskDescription, + killTask: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = { val mockBackend = mock[ExecutorBackend] val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler] var executor: Executor = null + var killingThread: Thread = null try { executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true, uncaughtExceptionHandler = mockUncaughtExceptionHandler) // the task will be launched in a dedicated worker thread executor.launchTask(mockBackend, taskDescription) + if (killTask) { + killingThread = new Thread("kill-task") { + override def run(): Unit = { + // wait to kill the task until it has thrown a fetch failure + ExecutorSuiteHelper.latches.latch1.await() + // now we can kill the task + executor.killAllTasks(true, "Killed task, eg. because of speculative execution") + } + } + killingThread.start() + } eventually(timeout(5.seconds), interval(10.milliseconds)) { assert(executor.numRunningTasks === 0) } @@ -282,8 +312,9 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) orderedMock.verify(mockBackend) .statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture()) + val finalState = if (killTask) TaskState.KILLED else TaskState.FAILED orderedMock.verify(mockBackend) - .statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture()) + .statusUpdate(meq(0L), meq(finalState), statusCaptor.capture()) // first statusUpdate for RUNNING has empty data assert(statusCaptor.getAllValues().get(0).remaining() === 0) // second update is more interesting @@ -321,7 +352,8 @@ class SimplePartition extends Partition { class FetchFailureHidingRDD( sc: SparkContext, val input: FetchFailureThrowingRDD, - throwOOM: Boolean) extends RDD[Int](input) { + throwOOM: Boolean, + interrupt: Boolean) extends RDD[Int](input) { override def compute(split: Partition, context: TaskContext): Iterator[Int] = { val inItr = input.compute(split, context) try { @@ -330,6 +362,15 @@ class FetchFailureHidingRDD( case t: Throwable => if (throwOOM) { throw new OutOfMemoryError("OOM while handling another exception") + } else if (interrupt) { + // make sure our test is setup correctly + assert(TaskContext.get().asInstanceOf[TaskContextImpl].fetchFailed.isDefined) + // signal our test is ready for the task to get killed + ExecutorSuiteHelper.latches.latch1.countDown() + // then wait for another thread in the test to kill the task -- this latch + // is never actually decremented, we just wait to get killed. + ExecutorSuiteHelper.latches.latch2.await() + throw new IllegalStateException("impossible") } else { throw new RuntimeException("User Exception that hides the original exception", t) } @@ -352,6 +393,11 @@ private class ExecutorSuiteHelper { @volatile var testFailedReason: TaskFailedReason = _ } +// helper for coordinating killing tasks +private object ExecutorSuiteHelper { + var latches: ExecutorSuiteHelper = null +} + private class NonDeserializableTask extends FakeTask(0, 0) with Externalizable { def writeExternal(out: ObjectOutput): Unit = {} def readExternal(in: ObjectInput): Unit = { From b387552f7c2a546ac7290be6da007678875814d7 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 6 Apr 2018 09:32:52 -0500 Subject: [PATCH 2/4] cleanup --- .../org/apache/spark/executor/ExecutorSuite.scala | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 627f044ab4166..c0b71d456bef9 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -188,7 +188,14 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug assert(failReason.isInstanceOf[TaskKilled]) } - def testFetchFailureHandling(oom: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = { + /** + * Helper for testing some cases where a FetchFailure should *not* get sent back, because its + * superceded by another error, either an OOM or intentionally killing a task. + * @param oom if true, throw an OOM after the FetchFailure; else, interrupt the task after the + * FetchFailure + */ + private def testFetchFailureHandling( + oom: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = { // when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it // may be a false positive. And we should call the uncaught exception handler. // SPARK-23816 also handle interrupts the same way, as killing an obsolete speculative task @@ -198,8 +205,8 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug val serializer = SparkEnv.get.closureSerializer.newInstance() val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size - // Submit a job where a fetch failure is thrown, but then there is an OOM. We should treat - // the fetch failure as a false positive, and just do normal OOM handling. + // Submit a job where a fetch failure is thrown, but then there is an OOM or interrupt. We + // should treat the fetch failure as a false positive, and do normal OOM or interrupt handling. val inputRDD = new FetchFailureThrowingRDD(sc) if (!oom) { // we are trying to setup a case where a task is killed after a fetch failure -- this From b1997a7e9df56d48c28f825cfb30c02fe61de21d Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Sat, 7 Apr 2018 19:25:02 -0500 Subject: [PATCH 3/4] dont wait on latches indefinitely --- .../apache/spark/executor/ExecutorSuite.scala | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index c0b71d456bef9..bd1b93c925691 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -181,7 +181,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug assert(exceptionCaptor.getAllValues().get(0).isInstanceOf[OutOfMemoryError]) } - test(s"SPARK-23816: interrupts are not masked by a FetchFailure") { + test("SPARK-23816: interrupts are not masked by a FetchFailure") { // If killing the task causes a fetch failure, we still treat it as a task that was killed, // as the fetch failure could easily be caused by interrupting the thread. val (failReason, _) = testFetchFailureHandling(false) @@ -290,19 +290,22 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug val mockBackend = mock[ExecutorBackend] val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler] var executor: Executor = null - var killingThread: Thread = null + var timedOut = false try { executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true, uncaughtExceptionHandler = mockUncaughtExceptionHandler) // the task will be launched in a dedicated worker thread executor.launchTask(mockBackend, taskDescription) if (killTask) { - killingThread = new Thread("kill-task") { + val killingThread = new Thread("kill-task") { override def run(): Unit = { // wait to kill the task until it has thrown a fetch failure - ExecutorSuiteHelper.latches.latch1.await() - // now we can kill the task - executor.killAllTasks(true, "Killed task, eg. because of speculative execution") + if (ExecutorSuiteHelper.latches.latch1.await(10, TimeUnit.SECONDS)) { + // now we can kill the task + executor.killAllTasks(true, "Killed task, eg. because of speculative execution") + } else { + timedOut = true + } } } killingThread.start() @@ -310,6 +313,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug eventually(timeout(5.seconds), interval(10.milliseconds)) { assert(executor.numRunningTasks === 0) } + assert(!timedOut, "timed out waiting to be ready to kill tasks") } finally { if (executor != null) { executor.stop() @@ -376,8 +380,8 @@ class FetchFailureHidingRDD( ExecutorSuiteHelper.latches.latch1.countDown() // then wait for another thread in the test to kill the task -- this latch // is never actually decremented, we just wait to get killed. - ExecutorSuiteHelper.latches.latch2.await() - throw new IllegalStateException("impossible") + ExecutorSuiteHelper.latches.latch2.await(10, TimeUnit.SECONDS) + throw new IllegalStateException("timed out waiting to be interrupted") } else { throw new RuntimeException("User Exception that hides the original exception", t) } From 3860aeda33be3cf1cad5999bfd6fbb274ac0bc24 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Sun, 8 Apr 2018 10:14:03 -0500 Subject: [PATCH 4/4] atomic --- .../scala/org/apache/spark/executor/ExecutorSuite.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index bd1b93c925691..1a7bebe2c53cd 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -22,6 +22,7 @@ import java.lang.Thread.UncaughtExceptionHandler import java.nio.ByteBuffer import java.util.Properties import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable.Map import scala.concurrent.duration._ @@ -290,7 +291,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug val mockBackend = mock[ExecutorBackend] val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler] var executor: Executor = null - var timedOut = false + val timedOut = new AtomicBoolean(false) try { executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true, uncaughtExceptionHandler = mockUncaughtExceptionHandler) @@ -304,7 +305,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug // now we can kill the task executor.killAllTasks(true, "Killed task, eg. because of speculative execution") } else { - timedOut = true + timedOut.set(true) } } } @@ -313,7 +314,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug eventually(timeout(5.seconds), interval(10.milliseconds)) { assert(executor.numRunningTasks === 0) } - assert(!timedOut, "timed out waiting to be ready to kill tasks") + assert(!timedOut.get(), "timed out waiting to be ready to kill tasks") } finally { if (executor != null) { executor.stop()