Skip to content

Commit 4156c03

Browse files
committed
Kill the executor on nested fatal errors
1 parent cf98a76 commit 4156c03

File tree

3 files changed

+104
-4
lines changed

3 files changed

+104
-4
lines changed

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ private[spark] class Executor(
150150
// Whether to monitor killed / interrupted tasks
151151
private val taskReaperEnabled = conf.get(TASK_REAPER_ENABLED)
152152

153+
private val killOnNestedFatalError = conf.get(EXECUTOR_KILL_ON_NESTED_FATAL_ERROR)
154+
153155
// Create our ClassLoader
154156
// do this after SparkEnv creation so can access the SecurityManager
155157
private val urlClassLoader = createClassLoader()
@@ -648,7 +650,7 @@ private[spark] class Executor(
648650
plugins.foreach(_.onTaskFailed(reason))
649651
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
650652

651-
case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
653+
case t: Throwable if hasFetchFailure && !Executor.isFatalError(t, killOnNestedFatalError) =>
652654
val reason = task.context.fetchFailed.get.toTaskFailedReason
653655
if (!t.isInstanceOf[FetchFailedException]) {
654656
// there was a fetch failure in the task, but some user code wrapped that exception
@@ -711,7 +713,7 @@ private[spark] class Executor(
711713

712714
// Don't forcibly exit unless the exception was inherently fatal, to avoid
713715
// stopping other tasks unnecessarily.
714-
if (!t.isInstanceOf[SparkOutOfMemoryError] && Utils.isFatalError(t)) {
716+
if (Executor.isFatalError(t, killOnNestedFatalError)) {
715717
uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t)
716718
}
717719
} finally {
@@ -997,4 +999,27 @@ private[spark] object Executor {
997999

9981000
// Used to store executorSource, for local mode only
9991001
var executorSourceLocalModeOnly: ExecutorSource = null
1002+
1003+
/**
1004+
* Whether a `Throwable` thrown from a task is a fatal error. We use this to decide whether to
1005+
* kill the executor.
1006+
*
1007+
* @param shouldDetectNestedFatalError whether to go through the exception chain to check whether
1008+
* exists a fatal error.
1009+
* @param depth the current depth of the recursive call. Return `false` when it's greater than 5.
1010+
* This is to avoid `StackOverflowError` when hitting a cycle in the exception chain.
1011+
*/
1012+
def isFatalError(t: Throwable, shouldDetectNestedFatalError: Boolean, depth: Int = 0): Boolean = {
1013+
if (depth <= 5) {
1014+
t match {
1015+
case _: SparkOutOfMemoryError => false
1016+
case e if Utils.isFatalError(e) => true
1017+
case e if e.getCause != null && shouldDetectNestedFatalError =>
1018+
isFatalError(e.getCause, shouldDetectNestedFatalError, depth + 1)
1019+
case _ => false
1020+
}
1021+
} else {
1022+
false
1023+
}
1024+
}
10001025
}

core/src/main/scala/org/apache/spark/internal/config/package.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1946,6 +1946,13 @@ package object config {
19461946
.booleanConf
19471947
.createWithDefault(false)
19481948

1949+
private[spark] val EXECUTOR_KILL_ON_NESTED_FATAL_ERROR =
1950+
ConfigBuilder("spark.executor.killOnNestedFatalError")
1951+
.doc("Whether to kill the executor when a nested fatal error is thrown from a task.")
1952+
.internal()
1953+
.booleanConf
1954+
.createWithDefault(true)
1955+
19491956
private[spark] val PUSH_BASED_SHUFFLE_ENABLED =
19501957
ConfigBuilder("spark.shuffle.push.enabled")
19511958
.doc("Set to 'true' to enable push-based shuffle on the client side and this works in " +

core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import scala.collection.immutable
2828
import scala.collection.mutable.{ArrayBuffer, Map}
2929
import scala.concurrent.duration._
3030

31+
import com.google.common.cache.{CacheBuilder, CacheLoader}
3132
import org.mockito.ArgumentCaptor
3233
import org.mockito.ArgumentMatchers.{any, eq => meq}
3334
import org.mockito.Mockito.{inOrder, verify, when}
@@ -43,7 +44,7 @@ import org.apache.spark.TaskState.TaskState
4344
import org.apache.spark.broadcast.Broadcast
4445
import org.apache.spark.internal.config._
4546
import org.apache.spark.internal.config.UI._
46-
import org.apache.spark.memory.TestMemoryManager
47+
import org.apache.spark.memory.{SparkOutOfMemoryError, TestMemoryManager}
4748
import org.apache.spark.metrics.MetricsSystem
4849
import org.apache.spark.rdd.RDD
4950
import org.apache.spark.resource.ResourceInformation
@@ -52,7 +53,7 @@ import org.apache.spark.scheduler.{DirectTaskResult, FakeTask, ResultTask, Task,
5253
import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager}
5354
import org.apache.spark.shuffle.FetchFailedException
5455
import org.apache.spark.storage.{BlockManager, BlockManagerId}
55-
import org.apache.spark.util.{LongAccumulator, UninterruptibleThread}
56+
import org.apache.spark.util.{LongAccumulator, ThreadUtils, UninterruptibleThread}
5657

5758
class ExecutorSuite extends SparkFunSuite
5859
with LocalSparkContext with MockitoSugar with Eventually with PrivateMethodTester {
@@ -402,6 +403,73 @@ class ExecutorSuite extends SparkFunSuite
402403
assert(taskMetrics.getMetricValue("JVMHeapMemory") > 0)
403404
}
404405

406+
test("SPARK-33587: isFatalError") {
407+
def errorInThreadPool(e: => Throwable): Throwable = {
408+
intercept[Throwable] {
409+
val taskPool = ThreadUtils.newDaemonFixedThreadPool(1, "test")
410+
try {
411+
val f = taskPool.submit(new java.util.concurrent.Callable[String] {
412+
override def call(): String = throw e
413+
})
414+
f.get()
415+
} finally {
416+
taskPool.shutdown()
417+
}
418+
}
419+
}
420+
421+
def errorInGuavaCache(e: => Throwable): Throwable = {
422+
val cache = CacheBuilder.newBuilder()
423+
.build(new CacheLoader[String, String] {
424+
override def load(key: String): String = throw e
425+
})
426+
intercept[Throwable] {
427+
cache.get("test")
428+
}
429+
}
430+
431+
def testThrowable(
432+
e: => Throwable,
433+
shouldDetectNestedFatalError: Boolean,
434+
isFatal: Boolean): Unit = {
435+
import Executor.isFatalError
436+
assert(isFatalError(e, shouldDetectNestedFatalError) == isFatal)
437+
// Now check nested exceptions. We get `true` only if we need to check nested exceptions
438+
// (`shouldDetectNestedFatalError` is `true`) and `e` is fatal.
439+
val expected = shouldDetectNestedFatalError && isFatal
440+
assert(isFatalError(errorInThreadPool(e), shouldDetectNestedFatalError) == expected)
441+
assert(isFatalError(errorInGuavaCache(e), shouldDetectNestedFatalError) == expected)
442+
assert(isFatalError(
443+
errorInThreadPool(errorInGuavaCache(e)),
444+
shouldDetectNestedFatalError) == expected)
445+
assert(isFatalError(
446+
errorInGuavaCache(errorInThreadPool(e)),
447+
shouldDetectNestedFatalError) == expected)
448+
assert(isFatalError(
449+
new SparkException("Task failed while writing rows.", e),
450+
shouldDetectNestedFatalError) == expected)
451+
}
452+
453+
for (shouldDetectNestedFatalError <- true :: false :: Nil) {
454+
testThrowable(new OutOfMemoryError(), shouldDetectNestedFatalError, isFatal = true)
455+
testThrowable(new InterruptedException(), shouldDetectNestedFatalError, isFatal = false)
456+
testThrowable(new RuntimeException("test"), shouldDetectNestedFatalError, isFatal = false)
457+
testThrowable(
458+
new SparkOutOfMemoryError("test"),
459+
shouldDetectNestedFatalError,
460+
isFatal = false)
461+
}
462+
463+
val e1 = new Exception("test1")
464+
val e2 = new Exception("test2")
465+
e1.initCause(e2)
466+
e2.initCause(e1)
467+
for (shouldDetectNestedFatalError <- true :: false :: Nil) {
468+
testThrowable(e1, shouldDetectNestedFatalError, isFatal = false)
469+
testThrowable(e2, shouldDetectNestedFatalError, isFatal = false)
470+
}
471+
}
472+
405473
private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = {
406474
val mockEnv = mock[SparkEnv]
407475
val mockRpcEnv = mock[RpcEnv]

0 commit comments

Comments
 (0)