From 4156c034324632c59835066d2eae61532daa375f Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Sat, 28 Nov 2020 10:19:33 -0800 Subject: [PATCH 1/4] Kill the executor on nested fatal errors --- .../org/apache/spark/executor/Executor.scala | 29 +++++++- .../spark/internal/config/package.scala | 7 ++ .../apache/spark/executor/ExecutorSuite.scala | 72 ++++++++++++++++++- 3 files changed, 104 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index f7246448959e..a15a87d05c8a 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -150,6 +150,8 @@ private[spark] class Executor( // Whether to monitor killed / interrupted tasks private val taskReaperEnabled = conf.get(TASK_REAPER_ENABLED) + private val killOnNestedFatalError = conf.get(EXECUTOR_KILL_ON_NESTED_FATAL_ERROR) + // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager private val urlClassLoader = createClassLoader() @@ -648,7 +650,7 @@ private[spark] class Executor( plugins.foreach(_.onTaskFailed(reason)) execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason)) - case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) => + case t: Throwable if hasFetchFailure && !Executor.isFatalError(t, killOnNestedFatalError) => val reason = task.context.fetchFailed.get.toTaskFailedReason if (!t.isInstanceOf[FetchFailedException]) { // there was a fetch failure in the task, but some user code wrapped that exception @@ -711,7 +713,7 @@ private[spark] class Executor( // Don't forcibly exit unless the exception was inherently fatal, to avoid // stopping other tasks unnecessarily. - if (!t.isInstanceOf[SparkOutOfMemoryError] && Utils.isFatalError(t)) { + if (Executor.isFatalError(t, killOnNestedFatalError)) { uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t) } } finally { @@ -997,4 +999,27 @@ private[spark] object Executor { // Used to store executorSource, for local mode only var executorSourceLocalModeOnly: ExecutorSource = null + + /** + * Whether a `Throwable` thrown from a task is a fatal error. We use this to decide whether to + * kill the executor. + * + * @param shouldDetectNestedFatalError whether to go through the exception chain to check whether + * exists a fatal error. + * @param depth the current depth of the recursive call. Return `false` when it's greater than 5. + * This is to avoid `StackOverflowError` when hitting a cycle in the exception chain. + */ + def isFatalError(t: Throwable, shouldDetectNestedFatalError: Boolean, depth: Int = 0): Boolean = { + if (depth <= 5) { + t match { + case _: SparkOutOfMemoryError => false + case e if Utils.isFatalError(e) => true + case e if e.getCause != null && shouldDetectNestedFatalError => + isFatalError(e.getCause, shouldDetectNestedFatalError, depth + 1) + case _ => false + } + } else { + false + } + } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index b38d0e5c617b..e3067c1c3ff0 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1946,6 +1946,13 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val EXECUTOR_KILL_ON_NESTED_FATAL_ERROR = + ConfigBuilder("spark.executor.killOnNestedFatalError") + .doc("Whether to kill the executor when a nested fatal error is thrown from a task.") + .internal() + .booleanConf + .createWithDefault(true) + private[spark] val PUSH_BASED_SHUFFLE_ENABLED = ConfigBuilder("spark.shuffle.push.enabled") .doc("Set to 'true' to enable push-based shuffle on the client side and this works in " + diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 31049d104e63..6b5a4cc9d41e 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -28,6 +28,7 @@ import scala.collection.immutable import scala.collection.mutable.{ArrayBuffer, Map} import scala.concurrent.duration._ +import com.google.common.cache.{CacheBuilder, CacheLoader} import org.mockito.ArgumentCaptor import org.mockito.ArgumentMatchers.{any, eq => meq} import org.mockito.Mockito.{inOrder, verify, when} @@ -43,7 +44,7 @@ import org.apache.spark.TaskState.TaskState import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.config._ import org.apache.spark.internal.config.UI._ -import org.apache.spark.memory.TestMemoryManager +import org.apache.spark.memory.{SparkOutOfMemoryError, TestMemoryManager} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.rdd.RDD import org.apache.spark.resource.ResourceInformation @@ -52,7 +53,7 @@ import org.apache.spark.scheduler.{DirectTaskResult, FakeTask, ResultTask, Task, import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockManager, BlockManagerId} -import org.apache.spark.util.{LongAccumulator, UninterruptibleThread} +import org.apache.spark.util.{LongAccumulator, ThreadUtils, UninterruptibleThread} class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually with PrivateMethodTester { @@ -402,6 +403,73 @@ class ExecutorSuite extends SparkFunSuite assert(taskMetrics.getMetricValue("JVMHeapMemory") > 0) } + test("SPARK-33587: isFatalError") { + def errorInThreadPool(e: => Throwable): Throwable = { + intercept[Throwable] { + val taskPool = ThreadUtils.newDaemonFixedThreadPool(1, "test") + try { + val f = taskPool.submit(new java.util.concurrent.Callable[String] { + override def call(): String = throw e + }) + f.get() + } finally { + taskPool.shutdown() + } + } + } + + def errorInGuavaCache(e: => Throwable): Throwable = { + val cache = CacheBuilder.newBuilder() + .build(new CacheLoader[String, String] { + override def load(key: String): String = throw e + }) + intercept[Throwable] { + cache.get("test") + } + } + + def testThrowable( + e: => Throwable, + shouldDetectNestedFatalError: Boolean, + isFatal: Boolean): Unit = { + import Executor.isFatalError + assert(isFatalError(e, shouldDetectNestedFatalError) == isFatal) + // Now check nested exceptions. We get `true` only if we need to check nested exceptions + // (`shouldDetectNestedFatalError` is `true`) and `e` is fatal. + val expected = shouldDetectNestedFatalError && isFatal + assert(isFatalError(errorInThreadPool(e), shouldDetectNestedFatalError) == expected) + assert(isFatalError(errorInGuavaCache(e), shouldDetectNestedFatalError) == expected) + assert(isFatalError( + errorInThreadPool(errorInGuavaCache(e)), + shouldDetectNestedFatalError) == expected) + assert(isFatalError( + errorInGuavaCache(errorInThreadPool(e)), + shouldDetectNestedFatalError) == expected) + assert(isFatalError( + new SparkException("Task failed while writing rows.", e), + shouldDetectNestedFatalError) == expected) + } + + for (shouldDetectNestedFatalError <- true :: false :: Nil) { + testThrowable(new OutOfMemoryError(), shouldDetectNestedFatalError, isFatal = true) + testThrowable(new InterruptedException(), shouldDetectNestedFatalError, isFatal = false) + testThrowable(new RuntimeException("test"), shouldDetectNestedFatalError, isFatal = false) + testThrowable( + new SparkOutOfMemoryError("test"), + shouldDetectNestedFatalError, + isFatal = false) + } + + val e1 = new Exception("test1") + val e2 = new Exception("test2") + e1.initCause(e2) + e2.initCause(e1) + for (shouldDetectNestedFatalError <- true :: false :: Nil) { + testThrowable(e1, shouldDetectNestedFatalError, isFatal = false) + testThrowable(e2, shouldDetectNestedFatalError, isFatal = false) + } + } + private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = { val mockEnv = mock[SparkEnv] val mockRpcEnv = mock[RpcEnv] From 2720b602e1b263d7c0115eab4ecc0b23d18962fc Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Sat, 28 Nov 2020 15:22:31 -0800 Subject: [PATCH 2/4] add a config for depth --- .../org/apache/spark/executor/Executor.scala | 29 ++++++----- .../spark/internal/config/package.scala | 13 +++-- .../apache/spark/executor/ExecutorSuite.scala | 49 ++++++++++--------- 3 files changed, 47 insertions(+), 44 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index a15a87d05c8a..efb0b2c26d9a 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -150,7 +150,7 @@ private[spark] class Executor( // Whether to monitor killed / interrupted tasks private val taskReaperEnabled = conf.get(TASK_REAPER_ENABLED) - private val killOnNestedFatalError = conf.get(EXECUTOR_KILL_ON_NESTED_FATAL_ERROR) + private val killOnFatalErrorDepth = conf.get(EXECUTOR_KILL_ON_FATAL_ERROR_DEPTH) // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager @@ -650,7 +650,7 @@ private[spark] class Executor( plugins.foreach(_.onTaskFailed(reason)) execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason)) - case t: Throwable if hasFetchFailure && !Executor.isFatalError(t, killOnNestedFatalError) => + case t: Throwable if hasFetchFailure && !Executor.isFatalError(t, killOnFatalErrorDepth) => val reason = task.context.fetchFailed.get.toTaskFailedReason if (!t.isInstanceOf[FetchFailedException]) { // there was a fetch failure in the task, but some user code wrapped that exception @@ -713,7 +713,7 @@ private[spark] class Executor( // Don't forcibly exit unless the exception was inherently fatal, to avoid // stopping other tasks unnecessarily. - if (Executor.isFatalError(t, killOnNestedFatalError)) { + if (Executor.isFatalError(t, killOnFatalErrorDepth)) { uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t) } } finally { @@ -1001,25 +1001,24 @@ private[spark] object Executor { var executorSourceLocalModeOnly: ExecutorSource = null /** - * Whether a `Throwable` thrown from a task is a fatal error. We use this to decide whether to - * kill the executor. + * Whether a `Throwable` thrown from a task is a fatal error. We will use this to decide whether + * to kill the executor. * - * @param shouldDetectNestedFatalError whether to go through the exception chain to check whether - * exists a fatal error. - * @param depth the current depth of the recursive call. Return `false` when it's greater than 5. - * This is to avoid `StackOverflowError` when hitting a cycle in the exception chain. + * @param depthToCheck The max depth of the exception chain we should search for a fatal error. 0 + * means not checking any fatal error (in other words, return false), 1 means + * checking only the exception but not the cause, and so on. This is to avoid + * `StackOverflowError` when hitting a cycle in the exception chain. */ - def isFatalError(t: Throwable, shouldDetectNestedFatalError: Boolean, depth: Int = 0): Boolean = { - if (depth <= 5) { + def isFatalError(t: Throwable, depthToCheck: Int): Boolean = { + if (depthToCheck <= 0) { + false + } else { t match { case _: SparkOutOfMemoryError => false case e if Utils.isFatalError(e) => true - case e if e.getCause != null && shouldDetectNestedFatalError => - isFatalError(e.getCause, shouldDetectNestedFatalError, depth + 1) + case e if e.getCause != null => isFatalError(e.getCause, depthToCheck - 1) case _ => false } - } else { - false } } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index e3067c1c3ff0..5b9ee9e0f77a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1946,12 +1946,15 @@ package object config { .booleanConf .createWithDefault(false) - private[spark] val EXECUTOR_KILL_ON_NESTED_FATAL_ERROR = - ConfigBuilder("spark.executor.killOnNestedFatalError") - .doc("Whether to kill the executor when a nested fatal error is thrown from a task.") + private[spark] val EXECUTOR_KILL_ON_FATAL_ERROR_DEPTH = + ConfigBuilder("spark.executor.killOnFatalError.depth") + .doc("The max depth of the exception chain in a failed task Spark will search for a fatal " + + "error to check whether it should kill an executor. 0 means not checking any fatal " + + "error, 1 means checking only the exception but not the cause, and so on.") .internal() - .booleanConf - .createWithDefault(true) + .intConf + .checkValue(_ >= 0, "needs to be a non-negative value") + .createWithDefault(5) private[spark] val PUSH_BASED_SHUFFLE_ENABLED = ConfigBuilder("spark.shuffle.push.enabled") diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 6b5a4cc9d41e..555682ca6a37 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -429,44 +429,45 @@ class ExecutorSuite extends SparkFunSuite } def testThrowable( - e: => Throwable, - shouldDetectNestedFatalError: Boolean, - isFatal: Boolean): Unit = { + e: => Throwable, + depthToCheck: Int, + isFatal: Boolean): Unit = { import Executor.isFatalError - assert(isFatalError(e, shouldDetectNestedFatalError) == isFatal) - // Now check nested exceptions. We get `true` only if we need to check nested exceptions - // (`shouldDetectNestedFatalError` is `true`) and `e` is fatal. - val expected = shouldDetectNestedFatalError && isFatal - assert(isFatalError(errorInThreadPool(e), shouldDetectNestedFatalError) == expected) - assert(isFatalError(errorInGuavaCache(e), shouldDetectNestedFatalError) == expected) + // `e`'s depth is 1 so `depthToCheck` needs to be at least 3 to detect fatal errors. + assert(isFatalError(e, depthToCheck) == (depthToCheck >= 1)) + // `e`'s depth is 2 so `depthToCheck` needs to be at least 3 to detect fatal errors. + assert(isFatalError(errorInThreadPool(e), depthToCheck) == (depthToCheck >= 2 && isFatal)) + assert(isFatalError(errorInGuavaCache(e), depthToCheck) == (depthToCheck >= 2 && isFatal)) + assert(isFatalError( + new SparkException("foo", e), + depthToCheck) == (depthToCheck >= 2 && isFatal)) + // `e`'s depth is 3 so `depthToCheck` needs to be at least 3 to detect fatal errors. assert(isFatalError( errorInThreadPool(errorInGuavaCache(e)), - shouldDetectNestedFatalError) == expected) + depthToCheck) == (depthToCheck >= 3 && isFatal)) assert(isFatalError( errorInGuavaCache(errorInThreadPool(e)), - shouldDetectNestedFatalError) == expected) + depthToCheck) == (depthToCheck >= 3 && isFatal)) assert(isFatalError( - new SparkException("Task failed while writing rows.", e), - shouldDetectNestedFatalError) == expected) + new SparkException("foo", new SparkException("foo", e)), + depthToCheck) == (depthToCheck >= 3 && isFatal)) } - for (shouldDetectNestedFatalError <- true :: false :: Nil) { - testThrowable(new OutOfMemoryError(), shouldDetectNestedFatalError, isFatal = true) - testThrowable(new InterruptedException(), shouldDetectNestedFatalError, isFatal = false) - testThrowable(new RuntimeException("test"), shouldDetectNestedFatalError, isFatal = false) - testThrowable( - new SparkOutOfMemoryError("test"), - shouldDetectNestedFatalError, - isFatal = false) + for (depthToCheck <- 0 to 5) { + testThrowable(new OutOfMemoryError(), depthToCheck, isFatal = true) + testThrowable(new InterruptedException(), depthToCheck, isFatal = false) + testThrowable(new RuntimeException("test"), depthToCheck, isFatal = false) + testThrowable(new SparkOutOfMemoryError("test"), depthToCheck, isFatal = false) } + // Verify we can handle the cycle in the exception chain val e1 = new Exception("test1") val e2 = new Exception("test2") e1.initCause(e2) e2.initCause(e1) - for (shouldDetectNestedFatalError <- true :: false :: Nil) { - testThrowable(e1, shouldDetectNestedFatalError, isFatal = false) - testThrowable(e2, shouldDetectNestedFatalError, isFatal = false) + for (depthToCheck <- 0 to 5) { + testThrowable(e1, depthToCheck, isFatal = false) + testThrowable(e2, depthToCheck, isFatal = false) } } From 1ec1c1dc4baea6235d8a1e0e6d2b72790c950f51 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Sat, 28 Nov 2020 15:37:55 -0800 Subject: [PATCH 3/4] fix --- .../test/scala/org/apache/spark/executor/ExecutorSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 555682ca6a37..1326ae3c11a0 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -434,7 +434,7 @@ class ExecutorSuite extends SparkFunSuite isFatal: Boolean): Unit = { import Executor.isFatalError // `e`'s depth is 1 so `depthToCheck` needs to be at least 3 to detect fatal errors. - assert(isFatalError(e, depthToCheck) == (depthToCheck >= 1)) + assert(isFatalError(e, depthToCheck) == (depthToCheck >= 1 && isFatal)) // `e`'s depth is 2 so `depthToCheck` needs to be at least 3 to detect fatal errors. assert(isFatalError(errorInThreadPool(e), depthToCheck) == (depthToCheck >= 2 && isFatal)) assert(isFatalError(errorInGuavaCache(e), depthToCheck) == (depthToCheck >= 2 && isFatal)) From 312f0422f7c6379747762c0b2eaf523e76c96a9b Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Sat, 28 Nov 2020 16:59:35 -0800 Subject: [PATCH 4/4] version --- .../main/scala/org/apache/spark/internal/config/package.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 5b9ee9e0f77a..b8bcb374ef96 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1952,6 +1952,7 @@ package object config { "error to check whether it should kill an executor. 0 means not checking any fatal " + "error, 1 means checking only the exception but not the cause, and so on.") .internal() + .version("3.1.0") .intConf .checkValue(_ >= 0, "needs to be a non-negative value") .createWithDefault(5)