Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ private[spark] class Executor(
// Whether to monitor killed / interrupted tasks
private val taskReaperEnabled = conf.get(TASK_REAPER_ENABLED)

private val killOnFatalErrorDepth = conf.get(EXECUTOR_KILL_ON_FATAL_ERROR_DEPTH)

// Create our ClassLoader
// do this after SparkEnv creation so can access the SecurityManager
private val urlClassLoader = createClassLoader()
Expand Down Expand Up @@ -648,7 +650,7 @@ private[spark] class Executor(
plugins.foreach(_.onTaskFailed(reason))
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))

case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
case t: Throwable if hasFetchFailure && !Executor.isFatalError(t, killOnFatalErrorDepth) =>
val reason = task.context.fetchFailed.get.toTaskFailedReason
if (!t.isInstanceOf[FetchFailedException]) {
// there was a fetch failure in the task, but some user code wrapped that exception
Expand Down Expand Up @@ -711,7 +713,7 @@ private[spark] class Executor(

// Don't forcibly exit unless the exception was inherently fatal, to avoid
// stopping other tasks unnecessarily.
if (!t.isInstanceOf[SparkOutOfMemoryError] && Utils.isFatalError(t)) {
if (Executor.isFatalError(t, killOnFatalErrorDepth)) {
uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t)
}
} finally {
Expand Down Expand Up @@ -997,4 +999,26 @@ private[spark] object Executor {

// Used to store executorSource, for local mode only
var executorSourceLocalModeOnly: ExecutorSource = null

/**
* Whether a `Throwable` thrown from a task is a fatal error. We will use this to decide whether
* to kill the executor.
*
* @param depthToCheck The max depth of the exception chain we should search for a fatal error. 0
* means not checking any fatal error (in other words, return false), 1 means
* checking only the exception but not the cause, and so on. This is to avoid
* `StackOverflowError` when hitting a cycle in the exception chain.
*/
def isFatalError(t: Throwable, depthToCheck: Int): Boolean = {
if (depthToCheck <= 0) {
false
} else {
t match {
case _: SparkOutOfMemoryError => false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just in case, we are sure that OOM cannot be caused by a fatal error, and it cannot present somewhere in the chain?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an existing behavior. #20014 added SparkOutOfMemoryError to avoid killing the executor when it's not thrown by JVM.

case e if Utils.isFatalError(e) => true
case e if e.getCause != null => isFatalError(e.getCause, depthToCheck - 1)
case _ => false
}
}
}
}
11 changes: 11 additions & 0 deletions core/src/main/scala/org/apache/spark/internal/config/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1946,6 +1946,17 @@ package object config {
.booleanConf
.createWithDefault(false)

private[spark] val EXECUTOR_KILL_ON_FATAL_ERROR_DEPTH =
ConfigBuilder("spark.executor.killOnFatalError.depth")
.doc("The max depth of the exception chain in a failed task Spark will search for a fatal " +
"error to check whether it should kill an executor. 0 means not checking any fatal " +
"error, 1 means checking only the exception but not the cause, and so on.")
.internal()
.version("3.1.0")
.intConf
.checkValue(_ >= 0, "needs to be a non-negative value")
.createWithDefault(5)

private[spark] val PUSH_BASED_SHUFFLE_ENABLED =
ConfigBuilder("spark.shuffle.push.enabled")
.doc("Set to 'true' to enable push-based shuffle on the client side and this works in " +
Expand Down
73 changes: 71 additions & 2 deletions core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import scala.collection.immutable
import scala.collection.mutable.{ArrayBuffer, Map}
import scala.concurrent.duration._

import com.google.common.cache.{CacheBuilder, CacheLoader}
import org.mockito.ArgumentCaptor
import org.mockito.ArgumentMatchers.{any, eq => meq}
import org.mockito.Mockito.{inOrder, verify, when}
Expand All @@ -43,7 +44,7 @@ import org.apache.spark.TaskState.TaskState
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.config._
import org.apache.spark.internal.config.UI._
import org.apache.spark.memory.TestMemoryManager
import org.apache.spark.memory.{SparkOutOfMemoryError, TestMemoryManager}
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.rdd.RDD
import org.apache.spark.resource.ResourceInformation
Expand All @@ -52,7 +53,7 @@ import org.apache.spark.scheduler.{DirectTaskResult, FakeTask, ResultTask, Task,
import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{BlockManager, BlockManagerId}
import org.apache.spark.util.{LongAccumulator, UninterruptibleThread}
import org.apache.spark.util.{LongAccumulator, ThreadUtils, UninterruptibleThread}

class ExecutorSuite extends SparkFunSuite
with LocalSparkContext with MockitoSugar with Eventually with PrivateMethodTester {
Expand Down Expand Up @@ -402,6 +403,74 @@ class ExecutorSuite extends SparkFunSuite
assert(taskMetrics.getMetricValue("JVMHeapMemory") > 0)
}

test("SPARK-33587: isFatalError") {
def errorInThreadPool(e: => Throwable): Throwable = {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to make this test cover the cases I mentioned in the description.

intercept[Throwable] {
val taskPool = ThreadUtils.newDaemonFixedThreadPool(1, "test")
try {
val f = taskPool.submit(new java.util.concurrent.Callable[String] {
override def call(): String = throw e
})
f.get()
} finally {
taskPool.shutdown()
}
}
}

def errorInGuavaCache(e: => Throwable): Throwable = {
val cache = CacheBuilder.newBuilder()
.build(new CacheLoader[String, String] {
override def load(key: String): String = throw e
})
intercept[Throwable] {
cache.get("test")
}
}

def testThrowable(
e: => Throwable,
depthToCheck: Int,
isFatal: Boolean): Unit = {
import Executor.isFatalError
// `e`'s depth is 1 so `depthToCheck` needs to be at least 3 to detect fatal errors.
assert(isFatalError(e, depthToCheck) == (depthToCheck >= 1 && isFatal))
// `e`'s depth is 2 so `depthToCheck` needs to be at least 3 to detect fatal errors.
assert(isFatalError(errorInThreadPool(e), depthToCheck) == (depthToCheck >= 2 && isFatal))
assert(isFatalError(errorInGuavaCache(e), depthToCheck) == (depthToCheck >= 2 && isFatal))
assert(isFatalError(
new SparkException("foo", e),
depthToCheck) == (depthToCheck >= 2 && isFatal))
// `e`'s depth is 3 so `depthToCheck` needs to be at least 3 to detect fatal errors.
assert(isFatalError(
errorInThreadPool(errorInGuavaCache(e)),
depthToCheck) == (depthToCheck >= 3 && isFatal))
assert(isFatalError(
errorInGuavaCache(errorInThreadPool(e)),
depthToCheck) == (depthToCheck >= 3 && isFatal))
assert(isFatalError(
new SparkException("foo", new SparkException("foo", e)),
depthToCheck) == (depthToCheck >= 3 && isFatal))
}

for (depthToCheck <- 0 to 5) {
testThrowable(new OutOfMemoryError(), depthToCheck, isFatal = true)
testThrowable(new InterruptedException(), depthToCheck, isFatal = false)
testThrowable(new RuntimeException("test"), depthToCheck, isFatal = false)
testThrowable(new SparkOutOfMemoryError("test"), depthToCheck, isFatal = false)
}

// Verify we can handle the cycle in the exception chain
val e1 = new Exception("test1")
val e2 = new Exception("test2")
e1.initCause(e2)
e2.initCause(e1)
for (depthToCheck <- 0 to 5) {
testThrowable(e1, depthToCheck, isFatal = false)
testThrowable(e2, depthToCheck, isFatal = false)
}
}

private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = {
val mockEnv = mock[SparkEnv]
val mockRpcEnv = mock[RpcEnv]
Expand Down