Skip to content

Commit 04ce488

Browse files
committed
Reuse or clean-up SparkContext in streaming tests
1 parent 381ef4e commit 04ce488

24 files changed

+235
-151
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2350,6 +2350,16 @@ object SparkContext extends Logging {
23502350
}
23512351
}
23522352

2353+
private[spark] def getActiveContext(): Option[SparkContext] = {
2354+
SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
2355+
Option(activeContext.get())
2356+
}
2357+
}
2358+
2359+
private[spark] def stopActiveContext(): Unit = {
2360+
getActiveContext().foreach(_.stop())
2361+
}
2362+
23532363
/**
23542364
* Called at the end of the SparkContext constructor to ensure that no other SparkContext has
23552365
* raced with this constructor and started.

streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.streaming;
1919

2020
import org.apache.spark.SparkConf;
21+
import org.apache.spark.SparkContext$;
2122
import org.apache.spark.streaming.api.java.JavaStreamingContext;
2223
import org.junit.After;
2324
import org.junit.Before;
@@ -28,6 +29,7 @@ public abstract class LocalJavaStreamingContext {
2829

2930
@Before
3031
public void setUp() {
32+
SparkContext$.MODULE$.stopActiveContext();
3133
SparkConf conf = new SparkConf()
3234
.setMaster("local[2]")
3335
.setAppName("test")

streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ class BasicOperationsSuite extends TestSuiteBase {
620620
}
621621

622622
test("slice") {
623-
withStreamingContext(new StreamingContext(conf, Seconds(1))) { ssc =>
623+
withStreamingContext(Seconds(1)) { ssc =>
624624
val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4))
625625
val stream = new TestInputStream[Int](ssc, input, 2)
626626
stream.foreachRDD(_ => {}) // Dummy output stream
@@ -637,7 +637,7 @@ class BasicOperationsSuite extends TestSuiteBase {
637637
}
638638
}
639639
test("slice - has not been initialized") {
640-
withStreamingContext(new StreamingContext(conf, Seconds(1))) { ssc =>
640+
withStreamingContext(Seconds(1)) { ssc =>
641641
val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4))
642642
val stream = new TestInputStream[Int](ssc, input, 2)
643643
val thrown = intercept[SparkException] {
@@ -657,7 +657,7 @@ class BasicOperationsSuite extends TestSuiteBase {
657657
.window(Seconds(4), Seconds(2))
658658
}
659659

660-
val operatedStream = runCleanupTest(conf, operation _,
660+
val operatedStream = runCleanupTest(operation _,
661661
numExpectedOutput = cleanupTestInput.size / 2, rememberDuration = Seconds(3))
662662
val windowedStream2 = operatedStream.asInstanceOf[WindowedDStream[_]]
663663
val windowedStream1 = windowedStream2.dependencies.head.asInstanceOf[WindowedDStream[_]]
@@ -694,7 +694,7 @@ class BasicOperationsSuite extends TestSuiteBase {
694694
Some(values.sum + state.getOrElse(0))
695695
}
696696
val stateStream = runCleanupTest(
697-
conf, _.map(_ -> 1).updateStateByKey(updateFunc).checkpoint(Seconds(3)))
697+
_.map(_ -> 1).updateStateByKey(updateFunc).checkpoint(Seconds(3)))
698698

699699
assert(stateStream.rememberDuration === stateStream.checkpointDuration * 2)
700700
assert(stateStream.generatedRDDs.contains(Time(10000)))
@@ -705,7 +705,7 @@ class BasicOperationsSuite extends TestSuiteBase {
705705
// Actually receive data over through receiver to create BlockRDDs
706706

707707
withTestServer(new TestServer()) { testServer =>
708-
withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
708+
withStreamingContext { ssc =>
709709
testServer.start()
710710

711711
val batchCounter = new BatchCounter(ssc)
@@ -781,7 +781,6 @@ class BasicOperationsSuite extends TestSuiteBase {
781781

782782
/** Test cleanup of RDDs in DStream metadata */
783783
def runCleanupTest[T: ClassTag](
784-
conf2: SparkConf,
785784
operation: DStream[Int] => DStream[T],
786785
numExpectedOutput: Int = cleanupTestInput.size,
787786
rememberDuration: Duration = null

streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ trait DStreamCheckpointTester { self: SparkFunSuite =>
211211
class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester
212212
with ResetSystemProperties {
213213

214+
override val reuseContext: Boolean = false
215+
214216
var ssc: StreamingContext = null
215217

216218
override def batchDuration: Duration = Milliseconds(500)
@@ -238,8 +240,6 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester
238240

239241
assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second")
240242

241-
conf.set("spark.streaming.clock", "org.apache.spark.util.ManualClock")
242-
243243
val stateStreamCheckpointInterval = Seconds(1)
244244
val fs = FileSystem.getLocal(new Configuration())
245245
// this ensure checkpointing occurs at least once
@@ -571,7 +571,7 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester
571571
}
572572

573573
test("recovery maintains rate controller") {
574-
ssc = new StreamingContext(conf, batchDuration)
574+
ssc = new StreamingContext(sc, batchDuration)
575575
ssc.checkpoint(checkpointDir)
576576

577577
val dstream = new RateTestInputDStream(ssc) {
@@ -635,7 +635,7 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester
635635
try {
636636
// This is a var because it's re-assigned when we restart from a checkpoint
637637
var clock: ManualClock = null
638-
withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
638+
withStreamingContext(batchDuration) { ssc =>
639639
ssc.checkpoint(checkpointDir)
640640
clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
641641
val batchCounter = new BatchCounter(ssc)
@@ -760,7 +760,7 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester
760760
}
761761

762762
test("DStreamCheckpointData.restore invoking times") {
763-
withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
763+
withStreamingContext { ssc =>
764764
ssc.checkpoint(checkpointDir)
765765
val inputDStream = new CheckpointInputDStream(ssc)
766766
val checkpointData = inputDStream.checkpointData
@@ -822,7 +822,7 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester
822822
val jobGenerator = mock(classOf[JobGenerator])
823823
val checkpointDir = Utils.createTempDir().toString
824824
val checkpointWriter =
825-
new CheckpointWriter(jobGenerator, conf, checkpointDir, new Configuration())
825+
new CheckpointWriter(jobGenerator, sc.conf, checkpointDir, new Configuration())
826826
val bytes1 = Array.fill[Byte](10)(1)
827827
new checkpointWriter.CheckpointWriteHandler(
828828
Time(2000), bytes1, clearCheckpointDataLater = false).run()
@@ -869,6 +869,7 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester
869869
// Therefore SPARK-6847 introduces "spark.checkpoint.checkpointAllMarked" to force checkpointing
870870
// all marked RDDs in the DAG to resolve this issue. (For the previous example, it will break
871871
// connections between layer 2 and layer 3)
872+
stopActiveContext()
872873
ssc = new StreamingContext(master, framework, batchDuration)
873874
val batchCounter = new BatchCounter(ssc)
874875
ssc.checkpoint(checkpointDir)

streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,28 +19,25 @@ package org.apache.spark.streaming
1919

2020
import java.io.NotSerializableException
2121

22-
import org.scalatest.BeforeAndAfterAll
23-
24-
import org.apache.spark.{HashPartitioner, SparkContext, SparkException, SparkFunSuite}
22+
import org.apache.spark.{HashPartitioner, SparkException}
2523
import org.apache.spark.rdd.RDD
2624
import org.apache.spark.streaming.dstream.DStream
2725
import org.apache.spark.util.ReturnStatementInClosureException
2826

2927
/**
3028
* Test that closures passed to DStream operations are actually cleaned.
3129
*/
32-
class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll {
33-
private var ssc: StreamingContext = null
30+
class DStreamClosureSuite extends ReuseableSparkContext {
31+
private var ssc: StreamingContext = _
3432

3533
override def beforeAll(): Unit = {
3634
super.beforeAll()
37-
val sc = new SparkContext("local", "test")
3835
ssc = new StreamingContext(sc, Seconds(1))
3936
}
4037

4138
override def afterAll(): Unit = {
4239
try {
43-
ssc.stop(stopSparkContext = true)
40+
ssc.stop()
4441
ssc = null
4542
} finally {
4643
super.afterAll()

streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,23 @@ import org.apache.spark.util.ManualClock
3030
/**
3131
* Tests whether scope information is passed from DStream operations to RDDs correctly.
3232
*/
33-
class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll {
34-
private var ssc: StreamingContext = null
35-
private val batchDuration: Duration = Seconds(1)
33+
class DStreamScopeSuite extends ReuseableSparkContext {
34+
private var ssc: StreamingContext = _
35+
36+
// Configurations to add to a new or existing spark context.
37+
override def extraSparkConf: Map[String, String] = {
38+
// Use a manual clock
39+
super.extraSparkConf ++ Map("spark.streaming.clock" -> "org.apache.spark.util.ManualClock")
40+
}
3641

3742
override def beforeAll(): Unit = {
3843
super.beforeAll()
39-
val conf = new SparkConf().setMaster("local").setAppName("test")
40-
conf.set("spark.streaming.clock", classOf[ManualClock].getName())
41-
ssc = new StreamingContext(new SparkContext(conf), batchDuration)
44+
ssc = new StreamingContext(sc, Seconds(1))
4245
}
4346

4447
override def afterAll(): Unit = {
4548
try {
46-
ssc.stop(stopSparkContext = true)
49+
ssc.stop()
4750
} finally {
4851
super.afterAll()
4952
}

streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ class FailureSuite extends SparkFunSuite with BeforeAndAfter with Logging {
3535
private val numBatches = 30
3636
private var directory: File = null
3737

38+
override protected def beforeAll(): Unit = {
39+
super.beforeAll()
40+
SparkContext.getActiveContext().foreach(_.stop())
41+
}
42+
3843
before {
3944
directory = Utils.createTempDir()
4045
}
@@ -46,7 +51,7 @@ class FailureSuite extends SparkFunSuite with BeforeAndAfter with Logging {
4651
StreamingContext.getActive().foreach { _.stop() }
4752

4853
// Stop SparkContext if active
49-
SparkContext.getOrCreate(new SparkConf().setMaster("local").setAppName("bla")).stop()
54+
SparkContext.getActiveContext().foreach(_.stop())
5055
}
5156

5257
test("multiple failures with map") {

streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
4949
testServer.start()
5050

5151
// Set up the streaming context and input streams
52-
withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
52+
withStreamingContext { ssc =>
5353
ssc.addStreamingListener(ssc.progressListener)
5454

5555
val input = Seq(1, 2, 3, 4, 5)
@@ -112,7 +112,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
112112
withTestServer(new TestServer()) { testServer =>
113113
testServer.start()
114114

115-
withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
115+
withStreamingContext { ssc =>
116116
ssc.addStreamingListener(ssc.progressListener)
117117

118118
val batchCounter = new BatchCounter(ssc)
@@ -149,7 +149,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
149149
assert(existingFile.setLastModified(10000) && existingFile.lastModified === 10000)
150150

151151
// Set up the streaming context and input streams
152-
withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
152+
withStreamingContext(batchDuration) { ssc =>
153153
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
154154
// This `setTime` call ensures that the clock is past the creation time of `existingFile`
155155
clock.setTime(existingFile.lastModified + batchDuration.milliseconds)
@@ -213,7 +213,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
213213
val pathWithWildCard = testDir.toString + "/*/"
214214

215215
// Set up the streaming context and input streams
216-
withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
216+
withStreamingContext(batchDuration) { ssc =>
217217
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
218218
clock.setTime(existingFile.lastModified + batchDuration.milliseconds)
219219
val batchCounter = new BatchCounter(ssc)
@@ -270,7 +270,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
270270
def output: Iterable[Long] = outputQueue.asScala.flatMap(x => x)
271271

272272
// set up the network stream using the test receiver
273-
withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
273+
withStreamingContext { ssc =>
274274
val networkStream = ssc.receiverStream[Int](testReceiver)
275275
val countStream = networkStream.count
276276

@@ -305,7 +305,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
305305
def output: Iterable[Seq[String]] = outputQueue.asScala.filter(_.nonEmpty)
306306

307307
// Set up the streaming context and input streams
308-
withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
308+
withStreamingContext { ssc =>
309309
val queue = new mutable.Queue[RDD[String]]()
310310
val queueStream = ssc.queueStream(queue, oneAtATime = true)
311311
val outputStream = new TestOutputStream(queueStream, outputQueue)
@@ -350,7 +350,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
350350
val expectedOutput = Seq(Seq("1", "2", "3"), Seq("4", "5"))
351351

352352
// Set up the streaming context and input streams
353-
withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
353+
withStreamingContext { ssc =>
354354
val queue = new mutable.Queue[RDD[String]]()
355355
val queueStream = ssc.queueStream(queue, oneAtATime = false)
356356
val outputStream = new TestOutputStream(queueStream, outputQueue)
@@ -396,7 +396,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
396396
}
397397

398398
test("test track the number of input stream") {
399-
withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
399+
withStreamingContext { ssc =>
400400

401401
class TestInputDStream extends InputDStream[String](ssc) {
402402
def start() {}
@@ -434,7 +434,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
434434
assert(existingFile.setLastModified(10000) && existingFile.lastModified === 10000)
435435

436436
// Set up the streaming context and input streams
437-
withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
437+
withStreamingContext(batchDuration) { ssc =>
438438
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
439439
// This `setTime` call ensures that the clock is past the creation time of `existingFile`
440440
clock.setTime(existingFile.lastModified + batchDuration.milliseconds)

streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,21 @@ import java.util.concurrent.ConcurrentLinkedQueue
2323
import scala.collection.JavaConverters._
2424
import scala.reflect.ClassTag
2525

26-
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
2726
import org.scalatest.PrivateMethodTester._
2827

29-
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
3028
import org.apache.spark.streaming.dstream.{DStream, InternalMapWithStateDStream, MapWithStateDStream, MapWithStateDStreamImpl}
3129
import org.apache.spark.util.{ManualClock, Utils}
3230

33-
class MapWithStateSuite extends SparkFunSuite
34-
with DStreamCheckpointTester with BeforeAndAfterAll with BeforeAndAfter {
31+
class MapWithStateSuite extends ReuseableSparkContext with DStreamCheckpointTester {
3532

36-
private var sc: SparkContext = null
3733
protected var checkpointDir: File = null
3834
protected val batchDuration = Seconds(1)
3935

36+
override def extraSparkConf: Map[String, String] = {
37+
// Use a manual clock
38+
super.extraSparkConf ++ Map("spark.streaming.clock" -> "org.apache.spark.util.ManualClock")
39+
}
40+
4041
before {
4142
StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) }
4243
checkpointDir = Utils.createTempDir("checkpoint")
@@ -49,23 +50,6 @@ class MapWithStateSuite extends SparkFunSuite
4950
}
5051
}
5152

52-
override def beforeAll(): Unit = {
53-
super.beforeAll()
54-
val conf = new SparkConf().setMaster("local").setAppName("MapWithStateSuite")
55-
conf.set("spark.streaming.clock", classOf[ManualClock].getName())
56-
sc = new SparkContext(conf)
57-
}
58-
59-
override def afterAll(): Unit = {
60-
try {
61-
if (sc != null) {
62-
sc.stop()
63-
}
64-
} finally {
65-
super.afterAll()
66-
}
67-
}
68-
6953
test("state - get, exists, update, remove, ") {
7054
var state: StateImpl[Int] = null
7155

streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ class ReceivedBlockHandlerSuite
7575
var storageLevel: StorageLevel = null
7676
var tempDirectory: File = null
7777

78+
override def beforeAll(): Unit = {
79+
super.beforeAll()
80+
SparkContext.getActiveContext().foreach(_.stop())
81+
}
82+
7883
before {
7984
rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr)
8085
conf.set("spark.driver.port", rpcEnv.address.port.toString)
@@ -107,6 +112,8 @@ class ReceivedBlockHandlerSuite
107112
rpcEnv.awaitTermination()
108113
rpcEnv = null
109114

115+
sc.stop()
116+
110117
Utils.deleteRecursively(tempDirectory)
111118
}
112119

0 commit comments

Comments
 (0)