@@ -28,6 +28,7 @@ import scala.collection.immutable
2828import scala .collection .mutable .{ArrayBuffer , Map }
2929import scala .concurrent .duration ._
3030
31+ import com .google .common .cache .{CacheBuilder , CacheLoader }
3132import org .mockito .ArgumentCaptor
3233import org .mockito .ArgumentMatchers .{any , eq => meq }
3334import org .mockito .Mockito .{inOrder , verify , when }
@@ -43,7 +44,7 @@ import org.apache.spark.TaskState.TaskState
4344import org .apache .spark .broadcast .Broadcast
4445import org .apache .spark .internal .config ._
4546import org .apache .spark .internal .config .UI ._
46- import org .apache .spark .memory .TestMemoryManager
47+ import org .apache .spark .memory .{ SparkOutOfMemoryError , TestMemoryManager }
4748import org .apache .spark .metrics .MetricsSystem
4849import org .apache .spark .rdd .RDD
4950import org .apache .spark .resource .ResourceInformation
@@ -52,7 +53,7 @@ import org.apache.spark.scheduler.{DirectTaskResult, FakeTask, ResultTask, Task,
5253import org .apache .spark .serializer .{JavaSerializer , SerializerInstance , SerializerManager }
5354import org .apache .spark .shuffle .FetchFailedException
5455import org .apache .spark .storage .{BlockManager , BlockManagerId }
55- import org .apache .spark .util .{LongAccumulator , UninterruptibleThread }
56+ import org .apache .spark .util .{LongAccumulator , ThreadUtils , UninterruptibleThread }
5657
5758class ExecutorSuite extends SparkFunSuite
5859 with LocalSparkContext with MockitoSugar with Eventually with PrivateMethodTester {
@@ -402,6 +403,73 @@ class ExecutorSuite extends SparkFunSuite
402403 assert(taskMetrics.getMetricValue(" JVMHeapMemory" ) > 0 )
403404 }
404405
406+ test(" SPARK-33587: isFatalError" ) {
407+ def errorInThreadPool (e : => Throwable ): Throwable = {
408+ intercept[Throwable ] {
409+ val taskPool = ThreadUtils .newDaemonFixedThreadPool(1 , " test" )
410+ try {
411+ val f = taskPool.submit(new java.util.concurrent.Callable [String ] {
412+ override def call (): String = throw e
413+ })
414+ f.get()
415+ } finally {
416+ taskPool.shutdown()
417+ }
418+ }
419+ }
420+
421+ def errorInGuavaCache (e : => Throwable ): Throwable = {
422+ val cache = CacheBuilder .newBuilder()
423+ .build(new CacheLoader [String , String ] {
424+ override def load (key : String ): String = throw e
425+ })
426+ intercept[Throwable ] {
427+ cache.get(" test" )
428+ }
429+ }
430+
431+ def testThrowable (
432+ e : => Throwable ,
433+ shouldDetectNestedFatalError : Boolean ,
434+ isFatal : Boolean ): Unit = {
435+ import Executor .isFatalError
436+ assert(isFatalError(e, shouldDetectNestedFatalError) == isFatal)
437+ // Now check nested exceptions. We get `true` only if we need to check nested exceptions
438+ // (`shouldDetectNestedFatalError` is `true`) and `e` is fatal.
439+ val expected = shouldDetectNestedFatalError && isFatal
440+ assert(isFatalError(errorInThreadPool(e), shouldDetectNestedFatalError) == expected)
441+ assert(isFatalError(errorInGuavaCache(e), shouldDetectNestedFatalError) == expected)
442+ assert(isFatalError(
443+ errorInThreadPool(errorInGuavaCache(e)),
444+ shouldDetectNestedFatalError) == expected)
445+ assert(isFatalError(
446+ errorInGuavaCache(errorInThreadPool(e)),
447+ shouldDetectNestedFatalError) == expected)
448+ assert(isFatalError(
449+ new SparkException (" Task failed while writing rows." , e),
450+ shouldDetectNestedFatalError) == expected)
451+ }
452+
453+ for (shouldDetectNestedFatalError <- true :: false :: Nil ) {
454+ testThrowable(new OutOfMemoryError (), shouldDetectNestedFatalError, isFatal = true )
455+ testThrowable(new InterruptedException (), shouldDetectNestedFatalError, isFatal = false )
456+ testThrowable(new RuntimeException (" test" ), shouldDetectNestedFatalError, isFatal = false )
457+ testThrowable(
458+ new SparkOutOfMemoryError (" test" ),
459+ shouldDetectNestedFatalError,
460+ isFatal = false )
461+ }
462+
463+ val e1 = new Exception (" test1" )
464+ val e2 = new Exception (" test2" )
465+ e1.initCause(e2)
466+ e2.initCause(e1)
467+ for (shouldDetectNestedFatalError <- true :: false :: Nil ) {
468+ testThrowable(e1, shouldDetectNestedFatalError, isFatal = false )
469+ testThrowable(e2, shouldDetectNestedFatalError, isFatal = false )
470+ }
471+ }
472+
405473 private def createMockEnv (conf : SparkConf , serializer : JavaSerializer ): SparkEnv = {
406474 val mockEnv = mock[SparkEnv ]
407475 val mockRpcEnv = mock[RpcEnv ]
0 commit comments