@@ -156,10 +156,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
156156 assert(error.contains(" within 1 second(s)" ))
157157 }
158158
159-
160- def testBarrierTaskKilled (sc : SparkContext , interruptOnCancel : Boolean ): Unit = {
161- sc.setLocalProperty(SparkContext .SPARK_JOB_INTERRUPT_ON_CANCEL , interruptOnCancel.toString)
162-
159+ def testBarrierTaskKilled (interruptOnKill : Boolean ): Unit = {
163160 withTempDir { dir =>
164161 val killedFlagFile = " barrier.task.killed"
165162 val rdd = sc.makeRDD(Seq (0 , 1 ), 2 )
@@ -181,12 +178,15 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
181178
182179 val listener = new SparkListener {
183180 override def onTaskStart (taskStart : SparkListenerTaskStart ): Unit = {
184- new Thread {
185- override def run : Unit = {
186- Thread .sleep(1000 )
187- sc.killTaskAttempt(taskStart.taskInfo.taskId, interruptThread = false )
188- }
189- }.start()
181+ val partitionId = taskStart.taskInfo.index
182+ if (partitionId == 0 ) {
183+ new Thread {
184+ override def run : Unit = {
185+ Thread .sleep(1000 )
186+ sc.killTaskAttempt(taskStart.taskInfo.taskId, interruptThread = interruptOnKill)
187+ }
188+ }.start()
189+ }
190190 }
191191 }
192192 sc.addSparkListener(listener)
@@ -201,15 +201,25 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
201201 }
202202 }
203203
204- test(" barrier task killed" ) {
204+ test(" barrier task killed, no interrupt" ) {
205+ val conf = new SparkConf ()
206+ .set(" spark.barrier.sync.timeout" , " 1" )
207+ .set(TEST_NO_STAGE_RETRY , true )
208+ .setMaster(" local-cluster[4, 1, 1024]" )
209+ .setAppName(" test-cluster" )
210+ sc = new SparkContext (conf)
211+
212+ testBarrierTaskKilled(interruptOnKill = false )
213+ }
214+
215+ test(" barrier task killed, interrupt" ) {
205216 val conf = new SparkConf ()
206217 .set(" spark.barrier.sync.timeout" , " 1" )
207218 .set(TEST_NO_STAGE_RETRY , true )
208219 .setMaster(" local-cluster[4, 1, 1024]" )
209220 .setAppName(" test-cluster" )
210221 sc = new SparkContext (conf)
211222
212- testBarrierTaskKilled(sc, true )
213- testBarrierTaskKilled(sc, false )
223+ testBarrierTaskKilled(interruptOnKill = true )
214224 }
215225}
0 commit comments