diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index f7246448959e..efb0b2c26d9a 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -150,6 +150,8 @@ private[spark] class Executor( // Whether to monitor killed / interrupted tasks private val taskReaperEnabled = conf.get(TASK_REAPER_ENABLED) + private val killOnFatalErrorDepth = conf.get(EXECUTOR_KILL_ON_FATAL_ERROR_DEPTH) + // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager private val urlClassLoader = createClassLoader() @@ -648,7 +650,7 @@ private[spark] class Executor( plugins.foreach(_.onTaskFailed(reason)) execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason)) - case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) => + case t: Throwable if hasFetchFailure && !Executor.isFatalError(t, killOnFatalErrorDepth) => val reason = task.context.fetchFailed.get.toTaskFailedReason if (!t.isInstanceOf[FetchFailedException]) { // there was a fetch failure in the task, but some user code wrapped that exception @@ -711,7 +713,7 @@ private[spark] class Executor( // Don't forcibly exit unless the exception was inherently fatal, to avoid // stopping other tasks unnecessarily. - if (!t.isInstanceOf[SparkOutOfMemoryError] && Utils.isFatalError(t)) { + if (Executor.isFatalError(t, killOnFatalErrorDepth)) { uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t) } } finally { @@ -997,4 +999,26 @@ private[spark] object Executor { // Used to store executorSource, for local mode only var executorSourceLocalModeOnly: ExecutorSource = null + + /** + * Whether a `Throwable` thrown from a task is a fatal error. We will use this to decide whether + * to kill the executor. + * + * @param depthToCheck The max depth of the exception chain we should search for a fatal error. 0 + * means not checking any fatal error (in other words, return false), 1 means + * checking only the exception but not the cause, and so on. This is to avoid + * `StackOverflowError` when hitting a cycle in the exception chain. + */ + def isFatalError(t: Throwable, depthToCheck: Int): Boolean = { + if (depthToCheck <= 0) { + false + } else { + t match { + case _: SparkOutOfMemoryError => false + case e if Utils.isFatalError(e) => true + case e if e.getCause != null => isFatalError(e.getCause, depthToCheck - 1) + case _ => false + } + } + } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index b38d0e5c617b..b8bcb374ef96 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1946,6 +1946,17 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val EXECUTOR_KILL_ON_FATAL_ERROR_DEPTH = + ConfigBuilder("spark.executor.killOnFatalError.depth") + .doc("The max depth of the exception chain in a failed task Spark will search for a fatal " + + "error to check whether it should kill an executor. 0 means not checking any fatal " + + "error, 1 means checking only the exception but not the cause, and so on.") + .internal() + .version("3.1.0") + .intConf + .checkValue(_ >= 0, "needs to be a non-negative value") + .createWithDefault(5) + private[spark] val PUSH_BASED_SHUFFLE_ENABLED = ConfigBuilder("spark.shuffle.push.enabled") .doc("Set to 'true' to enable push-based shuffle on the client side and this works in " + diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 31049d104e63..1326ae3c11a0 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -28,6 +28,7 @@ import scala.collection.immutable import scala.collection.mutable.{ArrayBuffer, Map} import scala.concurrent.duration._ +import com.google.common.cache.{CacheBuilder, CacheLoader} import org.mockito.ArgumentCaptor import org.mockito.ArgumentMatchers.{any, eq => meq} import org.mockito.Mockito.{inOrder, verify, when} @@ -43,7 +44,7 @@ import org.apache.spark.TaskState.TaskState import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.config._ import org.apache.spark.internal.config.UI._ -import org.apache.spark.memory.TestMemoryManager +import org.apache.spark.memory.{SparkOutOfMemoryError, TestMemoryManager} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.rdd.RDD import org.apache.spark.resource.ResourceInformation @@ -52,7 +53,7 @@ import org.apache.spark.scheduler.{DirectTaskResult, FakeTask, ResultTask, Task, import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockManager, BlockManagerId} -import org.apache.spark.util.{LongAccumulator, UninterruptibleThread} +import org.apache.spark.util.{LongAccumulator, ThreadUtils, UninterruptibleThread} class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually with PrivateMethodTester { @@ -402,6 +403,74 @@ class ExecutorSuite extends SparkFunSuite assert(taskMetrics.getMetricValue("JVMHeapMemory") > 0) } + test("SPARK-33587: isFatalError") { + def errorInThreadPool(e: => Throwable): Throwable = { + intercept[Throwable] { + val taskPool = ThreadUtils.newDaemonFixedThreadPool(1, "test") + try { + val f = taskPool.submit(new java.util.concurrent.Callable[String] { + override def call(): String = throw e + }) + f.get() + } finally { + taskPool.shutdown() + } + } + } + + def errorInGuavaCache(e: => Throwable): Throwable = { + val cache = CacheBuilder.newBuilder() + .build(new CacheLoader[String, String] { + override def load(key: String): String = throw e + }) + intercept[Throwable] { + cache.get("test") + } + } + + def testThrowable( + e: => Throwable, + depthToCheck: Int, + isFatal: Boolean): Unit = { + import Executor.isFatalError + // `e`'s depth is 1 so `depthToCheck` needs to be at least 3 to detect fatal errors. + assert(isFatalError(e, depthToCheck) == (depthToCheck >= 1 && isFatal)) + // `e`'s depth is 2 so `depthToCheck` needs to be at least 3 to detect fatal errors. + assert(isFatalError(errorInThreadPool(e), depthToCheck) == (depthToCheck >= 2 && isFatal)) + assert(isFatalError(errorInGuavaCache(e), depthToCheck) == (depthToCheck >= 2 && isFatal)) + assert(isFatalError( + new SparkException("foo", e), + depthToCheck) == (depthToCheck >= 2 && isFatal)) + // `e`'s depth is 3 so `depthToCheck` needs to be at least 3 to detect fatal errors. + assert(isFatalError( + errorInThreadPool(errorInGuavaCache(e)), + depthToCheck) == (depthToCheck >= 3 && isFatal)) + assert(isFatalError( + errorInGuavaCache(errorInThreadPool(e)), + depthToCheck) == (depthToCheck >= 3 && isFatal)) + assert(isFatalError( + new SparkException("foo", new SparkException("foo", e)), + depthToCheck) == (depthToCheck >= 3 && isFatal)) + } + + for (depthToCheck <- 0 to 5) { + testThrowable(new OutOfMemoryError(), depthToCheck, isFatal = true) + testThrowable(new InterruptedException(), depthToCheck, isFatal = false) + testThrowable(new RuntimeException("test"), depthToCheck, isFatal = false) + testThrowable(new SparkOutOfMemoryError("test"), depthToCheck, isFatal = false) + } + + // Verify we can handle the cycle in the exception chain + val e1 = new Exception("test1") + val e2 = new Exception("test2") + e1.initCause(e2) + e2.initCause(e1) + for (depthToCheck <- 0 to 5) { + testThrowable(e1, depthToCheck, isFatal = false) + testThrowable(e2, depthToCheck, isFatal = false) + } + } + private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = { val mockEnv = mock[SparkEnv] val mockRpcEnv = mock[RpcEnv]