@@ -33,8 +33,12 @@ import org.apache.spark.storage.StorageLevel
3333import org .apache .spark .streaming .dstream .DStream
3434import org .apache .spark .streaming .receiver .Receiver
3535import org .apache .spark .util .Utils
36- import org .apache .spark .{Logging , SparkConf , SparkContext , SparkException , SparkFunSuite }
37-
36+ import org .apache .spark .{Logging , SparkConf , SparkContext , SparkFunSuite }
37+ import org .apache .spark .metrics .MetricsSystem
38+ import org .apache .spark .metrics .source .Source
39+ import org .scalatest .{PrivateMethodTester , Assertions , BeforeAndAfter }
40+ import org .apache .spark .{Logging , SparkConf , SparkContext , SparkFunSuite }
41+ import scala .collection .mutable .ArrayBuffer
3842
3943class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeouts with Logging {
4044
@@ -299,6 +303,26 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo
299303 Thread .sleep(100 )
300304 }
301305
306+ test (" registering and de-registering of streamingSource" ) {
307+ val conf = new SparkConf ().setMaster(master).setAppName(appName)
308+ ssc = new StreamingContext (conf, batchDuration)
309+ assert(ssc.getState() === StreamingContextState .INITIALIZED )
310+ addInputStream(ssc).register()
311+ ssc.start()
312+
313+ val sources = StreamingContextSuite .getSources(ssc.env.metricsSystem)
314+ val streamingSource = StreamingContextSuite .getStreamingSource(ssc)
315+ assert(sources.contains(streamingSource))
316+ assert(ssc.getState() === StreamingContextState .ACTIVE )
317+ Thread .sleep(100 )
318+
319+ ssc.stop()
320+ val sourcesAfterStop = StreamingContextSuite .getSources(ssc.env.metricsSystem)
321+ val streamingSourceAfterStop = StreamingContextSuite .getStreamingSource(ssc)
322+ assert(ssc.getState() === StreamingContextState .STOPPED )
323+ assert(! sourcesAfterStop.contains(streamingSourceAfterStop))
324+ }
325+
302326 test(" awaitTermination" ) {
303327 ssc = new StreamingContext (master, appName, batchDuration)
304328 val inputStream = addInputStream(ssc)
@@ -811,3 +835,19 @@ package object testPackage extends Assertions {
811835 }
812836 }
813837}
838+
839+ /**
840+ * Helper methods for testing StreamingContextSuite
841+ * This includes methods to access private methods and fields in StreamingContext and MetricsSystem
842+ */
843+
844+ private object StreamingContextSuite extends PrivateMethodTester {
845+ private val _sources = PrivateMethod [ArrayBuffer [Source ]](' sources )
846+ private def getSources (metricsSystem : MetricsSystem ): ArrayBuffer [Source ] = {
847+ metricsSystem.invokePrivate(_sources())
848+ }
849+ private val _streamingSource = PrivateMethod [StreamingSource ](' streamingSource )
850+ private def getStreamingSource (streamingContext : StreamingContext ): StreamingSource = {
851+ streamingContext.invokePrivate(_streamingSource())
852+ }
853+ }
0 commit comments