Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,18 @@ import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY

class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {

test("global sync by barrier() call") {
def initLocalClusterSparkContext(): Unit = {
val conf = new SparkConf()
// Init local cluster here so each barrier task runs in a separated process, thus `barrier()`
// call is actually useful.
.setMaster("local-cluster[4, 1, 1024]")
.setAppName("test-cluster")
.set(TEST_NO_STAGE_RETRY, true)
sc = new SparkContext(conf)
}

test("global sync by barrier() call") {
initLocalClusterSparkContext()
val rdd = sc.makeRDD(1 to 10, 4)
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
Expand All @@ -48,10 +53,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
}

test("support multiple barrier() call within a single task") {
val conf = new SparkConf()
.setMaster("local-cluster[4, 1, 1024]")
.setAppName("test-cluster")
sc = new SparkContext(conf)
initLocalClusterSparkContext()
val rdd = sc.makeRDD(1 to 10, 4)
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
Expand All @@ -77,12 +79,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
}

test("throw exception on barrier() call timeout") {
val conf = new SparkConf()
.set("spark.barrier.sync.timeout", "1")
.set(TEST_NO_STAGE_RETRY, true)
.setMaster("local-cluster[4, 1, 1024]")
.setAppName("test-cluster")
sc = new SparkContext(conf)
initLocalClusterSparkContext()
sc.conf.set("spark.barrier.sync.timeout", "1")
val rdd = sc.makeRDD(1 to 10, 4)
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
Expand All @@ -102,12 +100,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
}

test("throw exception if barrier() call doesn't happen on every task") {
val conf = new SparkConf()
.set("spark.barrier.sync.timeout", "1")
.set(TEST_NO_STAGE_RETRY, true)
.setMaster("local-cluster[4, 1, 1024]")
.setAppName("test-cluster")
sc = new SparkContext(conf)
initLocalClusterSparkContext()
sc.conf.set("spark.barrier.sync.timeout", "1")
val rdd = sc.makeRDD(1 to 10, 4)
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
Expand All @@ -125,12 +119,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
}

test("throw exception if the number of barrier() calls are not the same on every task") {
val conf = new SparkConf()
.set("spark.barrier.sync.timeout", "1")
.set(TEST_NO_STAGE_RETRY, true)
.setMaster("local-cluster[4, 1, 1024]")
.setAppName("test-cluster")
sc = new SparkContext(conf)
initLocalClusterSparkContext()
sc.conf.set("spark.barrier.sync.timeout", "1")
val rdd = sc.makeRDD(1 to 10, 4)
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
Expand All @@ -156,10 +146,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
assert(error.contains("within 1 second(s)"))
}


def testBarrierTaskKilled(sc: SparkContext, interruptOnCancel: Boolean): Unit = {
sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, interruptOnCancel.toString)

def testBarrierTaskKilled(interruptOnKill: Boolean): Unit = {
withTempDir { dir =>
val killedFlagFile = "barrier.task.killed"
val rdd = sc.makeRDD(Seq(0, 1), 2)
Expand All @@ -181,12 +168,15 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {

val listener = new SparkListener {
override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
new Thread {
override def run: Unit = {
Thread.sleep(1000)
sc.killTaskAttempt(taskStart.taskInfo.taskId, interruptThread = false)
}
}.start()
val partitionId = taskStart.taskInfo.index
if (partitionId == 0) {
new Thread {
override def run: Unit = {
Thread.sleep(1000)
sc.killTaskAttempt(taskStart.taskInfo.taskId, interruptThread = interruptOnKill)
}
}.start()
}
}
}
sc.addSparkListener(listener)
Expand All @@ -201,15 +191,13 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
}
}

test("barrier task killed") {
val conf = new SparkConf()
.set("spark.barrier.sync.timeout", "1")
.set(TEST_NO_STAGE_RETRY, true)
.setMaster("local-cluster[4, 1, 1024]")
.setAppName("test-cluster")
sc = new SparkContext(conf)
test("barrier task killed, no interrupt") {
initLocalClusterSparkContext()
testBarrierTaskKilled(interruptOnKill = false)
}

testBarrierTaskKilled(sc, true)
testBarrierTaskKilled(sc, false)
test("barrier task killed, interrupt") {
initLocalClusterSparkContext()
testBarrierTaskKilled(interruptOnKill = true)
}
}