@@ -22,6 +22,7 @@ import java.lang.Thread.UncaughtExceptionHandler
2222import java .nio .ByteBuffer
2323import java .util .Properties
2424import java .util .concurrent .{CountDownLatch , TimeUnit }
25+ import java .util .concurrent .atomic .AtomicBoolean
2526
2627import scala .collection .mutable .Map
2728import scala .concurrent .duration ._
@@ -139,7 +140,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
139140 // the fetch failure. The executor should still tell the driver that the task failed due to a
140141 // fetch failure, not a generic exception from user code.
141142 val inputRDD = new FetchFailureThrowingRDD (sc)
142- val secondRDD = new FetchFailureHidingRDD (sc, inputRDD, throwOOM = false )
143+ val secondRDD = new FetchFailureHidingRDD (sc, inputRDD, throwOOM = false , interrupt = false )
143144 val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
144145 val serializedTaskMetrics = serializer.serialize(TaskMetrics .registered).array()
145146 val task = new ResultTask (
@@ -173,17 +174,48 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
173174 }
174175
175176 test(" SPARK-19276: OOMs correctly handled with a FetchFailure" ) {
177+ val (failReason, uncaughtExceptionHandler) = testFetchFailureHandling(true )
178+ assert(failReason.isInstanceOf [ExceptionFailure ])
179+ val exceptionCaptor = ArgumentCaptor .forClass(classOf [Throwable ])
180+ verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())
181+ assert(exceptionCaptor.getAllValues.size === 1 )
182+ assert(exceptionCaptor.getAllValues().get(0 ).isInstanceOf [OutOfMemoryError ])
183+ }
184+
185+ test(" SPARK-23816: interrupts are not masked by a FetchFailure" ) {
186+ // If killing the task causes a fetch failure, we still treat it as a task that was killed,
187+ // as the fetch failure could easily be caused by interrupting the thread.
188+ val (failReason, _) = testFetchFailureHandling(false )
189+ assert(failReason.isInstanceOf [TaskKilled ])
190+ }
191+
192+ /**
193+ * Helper for testing some cases where a FetchFailure should *not* get sent back, because its
194+ * superceded by another error, either an OOM or intentionally killing a task.
195+ * @param oom if true, throw an OOM after the FetchFailure; else, interrupt the task after the
196+ * FetchFailure
197+ */
198+ private def testFetchFailureHandling (
199+ oom : Boolean ): (TaskFailedReason , UncaughtExceptionHandler ) = {
176200 // when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it
177201 // may be a false positive. And we should call the uncaught exception handler.
202+ // SPARK-23816 also handle interrupts the same way, as killing an obsolete speculative task
203+ // does not represent a real fetch failure.
178204 val conf = new SparkConf ().setMaster(" local" ).setAppName(" executor suite test" )
179205 sc = new SparkContext (conf)
180206 val serializer = SparkEnv .get.closureSerializer.newInstance()
181207 val resultFunc = (context : TaskContext , itr : Iterator [Int ]) => itr.size
182208
183- // Submit a job where a fetch failure is thrown, but then there is an OOM. We should treat
184- // the fetch failure as a false positive, and just do normal OOM handling.
209+ // Submit a job where a fetch failure is thrown, but then there is an OOM or interrupt. We
210+ // should treat the fetch failure as a false positive, and do normal OOM or interrupt handling.
185211 val inputRDD = new FetchFailureThrowingRDD (sc)
186- val secondRDD = new FetchFailureHidingRDD (sc, inputRDD, throwOOM = true )
212+ if (! oom) {
213+ // we are trying to setup a case where a task is killed after a fetch failure -- this
214+ // is just a helper to coordinate between the task thread and this thread that will
215+ // kill the task
216+ ExecutorSuiteHelper .latches = new ExecutorSuiteHelper ()
217+ }
218+ val secondRDD = new FetchFailureHidingRDD (sc, inputRDD, throwOOM = oom, interrupt = ! oom)
187219 val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
188220 val serializedTaskMetrics = serializer.serialize(TaskMetrics .registered).array()
189221 val task = new ResultTask (
@@ -200,15 +232,8 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
200232 val serTask = serializer.serialize(task)
201233 val taskDescription = createFakeTaskDescription(serTask)
202234
203- val (failReason, uncaughtExceptionHandler) =
204- runTaskGetFailReasonAndExceptionHandler(taskDescription)
205- // make sure the task failure just looks like a OOM, not a fetch failure
206- assert(failReason.isInstanceOf [ExceptionFailure ])
207- val exceptionCaptor = ArgumentCaptor .forClass(classOf [Throwable ])
208- verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())
209- assert(exceptionCaptor.getAllValues.size === 1 )
210- assert(exceptionCaptor.getAllValues.get(0 ).isInstanceOf [OutOfMemoryError ])
211- }
235+ runTaskGetFailReasonAndExceptionHandler(taskDescription, killTask = ! oom)
236+ }
212237
213238 test(" Gracefully handle error in task deserialization" ) {
214239 val conf = new SparkConf
@@ -257,22 +282,39 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
257282 }
258283
259284 private def runTaskAndGetFailReason (taskDescription : TaskDescription ): TaskFailedReason = {
260- runTaskGetFailReasonAndExceptionHandler(taskDescription)._1
285+ runTaskGetFailReasonAndExceptionHandler(taskDescription, false )._1
261286 }
262287
263288 private def runTaskGetFailReasonAndExceptionHandler (
264- taskDescription : TaskDescription ): (TaskFailedReason , UncaughtExceptionHandler ) = {
289+ taskDescription : TaskDescription ,
290+ killTask : Boolean ): (TaskFailedReason , UncaughtExceptionHandler ) = {
265291 val mockBackend = mock[ExecutorBackend ]
266292 val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler ]
267293 var executor : Executor = null
294+ val timedOut = new AtomicBoolean (false )
268295 try {
269296 executor = new Executor (" id" , " localhost" , SparkEnv .get, userClassPath = Nil , isLocal = true ,
270297 uncaughtExceptionHandler = mockUncaughtExceptionHandler)
271298 // the task will be launched in a dedicated worker thread
272299 executor.launchTask(mockBackend, taskDescription)
300+ if (killTask) {
301+ val killingThread = new Thread (" kill-task" ) {
302+ override def run (): Unit = {
303+ // wait to kill the task until it has thrown a fetch failure
304+ if (ExecutorSuiteHelper .latches.latch1.await(10 , TimeUnit .SECONDS )) {
305+ // now we can kill the task
306+ executor.killAllTasks(true , " Killed task, eg. because of speculative execution" )
307+ } else {
308+ timedOut.set(true )
309+ }
310+ }
311+ }
312+ killingThread.start()
313+ }
273314 eventually(timeout(5 .seconds), interval(10 .milliseconds)) {
274315 assert(executor.numRunningTasks === 0 )
275316 }
317+ assert(! timedOut.get(), " timed out waiting to be ready to kill tasks" )
276318 } finally {
277319 if (executor != null ) {
278320 executor.stop()
@@ -282,8 +324,9 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
282324 val statusCaptor = ArgumentCaptor .forClass(classOf [ByteBuffer ])
283325 orderedMock.verify(mockBackend)
284326 .statusUpdate(meq(0L ), meq(TaskState .RUNNING ), statusCaptor.capture())
327+ val finalState = if (killTask) TaskState .KILLED else TaskState .FAILED
285328 orderedMock.verify(mockBackend)
286- .statusUpdate(meq(0L ), meq(TaskState . FAILED ), statusCaptor.capture())
329+ .statusUpdate(meq(0L ), meq(finalState ), statusCaptor.capture())
287330 // first statusUpdate for RUNNING has empty data
288331 assert(statusCaptor.getAllValues().get(0 ).remaining() === 0 )
289332 // second update is more interesting
@@ -321,7 +364,8 @@ class SimplePartition extends Partition {
321364class FetchFailureHidingRDD (
322365 sc : SparkContext ,
323366 val input : FetchFailureThrowingRDD ,
324- throwOOM : Boolean ) extends RDD [Int ](input) {
367+ throwOOM : Boolean ,
368+ interrupt : Boolean ) extends RDD [Int ](input) {
325369 override def compute (split : Partition , context : TaskContext ): Iterator [Int ] = {
326370 val inItr = input.compute(split, context)
327371 try {
@@ -330,6 +374,15 @@ class FetchFailureHidingRDD(
330374 case t : Throwable =>
331375 if (throwOOM) {
332376 throw new OutOfMemoryError (" OOM while handling another exception" )
377+ } else if (interrupt) {
378+ // make sure our test is setup correctly
379+ assert(TaskContext .get().asInstanceOf [TaskContextImpl ].fetchFailed.isDefined)
380+ // signal our test is ready for the task to get killed
381+ ExecutorSuiteHelper .latches.latch1.countDown()
382+ // then wait for another thread in the test to kill the task -- this latch
383+ // is never actually decremented, we just wait to get killed.
384+ ExecutorSuiteHelper .latches.latch2.await(10 , TimeUnit .SECONDS )
385+ throw new IllegalStateException (" timed out waiting to be interrupted" )
333386 } else {
334387 throw new RuntimeException (" User Exception that hides the original exception" , t)
335388 }
@@ -352,6 +405,11 @@ private class ExecutorSuiteHelper {
352405 @ volatile var testFailedReason : TaskFailedReason = _
353406}
354407
408+ // helper for coordinating killing tasks
409+ private object ExecutorSuiteHelper {
410+ var latches : ExecutorSuiteHelper = null
411+ }
412+
355413private class NonDeserializableTask extends FakeTask (0 , 0 ) with Externalizable {
356414 def writeExternal (out : ObjectOutput ): Unit = {}
357415 def readExternal (in : ObjectInput ): Unit = {
0 commit comments