diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index 8d5f04ac7651a..fc8ac38479932 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -26,13 +26,18 @@ import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { - test("global sync by barrier() call") { + def initLocalClusterSparkContext(): Unit = { val conf = new SparkConf() // Init local cluster here so each barrier task runs in a separated process, thus `barrier()` // call is actually useful. .setMaster("local-cluster[4, 1, 1024]") .setAppName("test-cluster") + .set(TEST_NO_STAGE_RETRY, true) sc = new SparkContext(conf) + } + + test("global sync by barrier() call") { + initLocalClusterSparkContext() val rdd = sc.makeRDD(1 to 10, 4) val rdd2 = rdd.barrier().mapPartitions { it => val context = BarrierTaskContext.get() @@ -48,10 +53,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { } test("support multiple barrier() call within a single task") { - val conf = new SparkConf() - .setMaster("local-cluster[4, 1, 1024]") - .setAppName("test-cluster") - sc = new SparkContext(conf) + initLocalClusterSparkContext() val rdd = sc.makeRDD(1 to 10, 4) val rdd2 = rdd.barrier().mapPartitions { it => val context = BarrierTaskContext.get() @@ -77,12 +79,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { } test("throw exception on barrier() call timeout") { - val conf = new SparkConf() - .set("spark.barrier.sync.timeout", "1") - .set(TEST_NO_STAGE_RETRY, true) - .setMaster("local-cluster[4, 1, 1024]") - .setAppName("test-cluster") - sc = new SparkContext(conf) + initLocalClusterSparkContext() + sc.conf.set("spark.barrier.sync.timeout", "1") val rdd = sc.makeRDD(1 to 10, 4) val rdd2 = rdd.barrier().mapPartitions { it => val context = BarrierTaskContext.get() @@ -102,12 +100,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { } test("throw exception if barrier() call doesn't happen on every task") { - val conf = new SparkConf() - .set("spark.barrier.sync.timeout", "1") - .set(TEST_NO_STAGE_RETRY, true) - .setMaster("local-cluster[4, 1, 1024]") - .setAppName("test-cluster") - sc = new SparkContext(conf) + initLocalClusterSparkContext() + sc.conf.set("spark.barrier.sync.timeout", "1") val rdd = sc.makeRDD(1 to 10, 4) val rdd2 = rdd.barrier().mapPartitions { it => val context = BarrierTaskContext.get() @@ -125,12 +119,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { } test("throw exception if the number of barrier() calls are not the same on every task") { - val conf = new SparkConf() - .set("spark.barrier.sync.timeout", "1") - .set(TEST_NO_STAGE_RETRY, true) - .setMaster("local-cluster[4, 1, 1024]") - .setAppName("test-cluster") - sc = new SparkContext(conf) + initLocalClusterSparkContext() + sc.conf.set("spark.barrier.sync.timeout", "1") val rdd = sc.makeRDD(1 to 10, 4) val rdd2 = rdd.barrier().mapPartitions { it => val context = BarrierTaskContext.get() @@ -156,10 +146,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { assert(error.contains("within 1 second(s)")) } - - def testBarrierTaskKilled(sc: SparkContext, interruptOnCancel: Boolean): Unit = { - sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, interruptOnCancel.toString) - + def testBarrierTaskKilled(interruptOnKill: Boolean): Unit = { withTempDir { dir => val killedFlagFile = "barrier.task.killed" val rdd = sc.makeRDD(Seq(0, 1), 2) @@ -181,12 +168,15 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { val listener = new SparkListener { override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { - new Thread { - override def run: Unit = { - Thread.sleep(1000) - sc.killTaskAttempt(taskStart.taskInfo.taskId, interruptThread = false) - } - }.start() + val partitionId = taskStart.taskInfo.index + if (partitionId == 0) { + new Thread { + override def run: Unit = { + Thread.sleep(1000) + sc.killTaskAttempt(taskStart.taskInfo.taskId, interruptThread = interruptOnKill) + } + }.start() + } } } sc.addSparkListener(listener) @@ -201,15 +191,13 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { } } - test("barrier task killed") { - val conf = new SparkConf() - .set("spark.barrier.sync.timeout", "1") - .set(TEST_NO_STAGE_RETRY, true) - .setMaster("local-cluster[4, 1, 1024]") - .setAppName("test-cluster") - sc = new SparkContext(conf) + test("barrier task killed, no interrupt") { + initLocalClusterSparkContext() + testBarrierTaskKilled(interruptOnKill = false) + } - testBarrierTaskKilled(sc, true) - testBarrierTaskKilled(sc, false) + test("barrier task killed, interrupt") { + initLocalClusterSparkContext() + testBarrierTaskKilled(interruptOnKill = true) } }