@@ -469,13 +469,18 @@ def setUp(self):
469469 self .batachDuration = Milliseconds (500 )
470470 self .sparkHome = "SomeDir"
471471 self .envPair = {"key" : "value" }
472+ self .ssc = None
473+ self .sc = None
472474
473475 def tearDown (self ):
474476 # Do not call pyspark.streaming.context.StreamingContext.stop directly because
475477 # we do not wait to shutdown py4j client.
476478 # We need change this simply calll streamingConxt.Stop
477- self .ssc ._jssc .stop ()
478- self .ssc ._sc .stop ()
479+ #self.ssc._jssc.stop()
480+ if self .ssc is not None :
481+ self .ssc .stop ()
482+ if self .sc is not None :
483+ self .sc .stop ()
479484 # Why does it long time to terminate StremaingContext and SparkContext?
480485 # Should we change the sleep time if this depends on machine spec?
481486 time .sleep (1 )
@@ -486,48 +491,67 @@ def tearDownClass(cls):
486491 SparkContext ._gateway ._shutdown_callback_server ()
487492
488493 def test_from_no_conf_constructor (self ):
489- ssc = StreamingContext (master = self .master , appName = self .appName , duration = batachDuration )
494+ self .ssc = StreamingContext (master = self .master , appName = self .appName ,
495+ duration = self .batachDuration )
490496 # Alternative call master: ssc.sparkContext.master
491497 # I try to make code close to Scala.
492- self .assertEqual (ssc .sparkContext ._conf .get ("spark.master" ), self .master )
493- self .assertEqual (ssc .sparkContext ._conf .get ("spark.app.name" ), self .appName )
498+ self .assertEqual (self . ssc .sparkContext ._conf .get ("spark.master" ), self .master )
499+ self .assertEqual (self . ssc .sparkContext ._conf .get ("spark.app.name" ), self .appName )
494500
495501 def test_from_no_conf_plus_spark_home (self ):
496- ssc = StreamingContext (master = self .master , appName = self .appName ,
497- sparkHome = self .sparkHome , duration = batachDuration )
498- self .assertEqual (ssc .sparkContext ._conf .get ("spark.home" ), self .sparkHome )
502+ self .ssc = StreamingContext (master = self .master , appName = self .appName ,
503+ sparkHome = self .sparkHome , duration = self .batachDuration )
504+ self .assertEqual (self .ssc .sparkContext ._conf .get ("spark.home" ), self .sparkHome )
505+
506+ def test_from_no_conf_plus_spark_home_plus_env (self ):
507+ self .ssc = StreamingContext (master = self .master , appName = self .appName ,
508+ sparkHome = self .sparkHome , environment = self .envPair ,
509+ duration = self .batachDuration )
510+ self .assertEqual (self .ssc .sparkContext ._conf .get ("spark.executorEnv.key" ), self .envPair ["key" ])
499511
500512 def test_from_existing_spark_context (self ):
501- sc = SparkContext (master = self .master , appName = self .appName )
502- ssc = StreamingContext (sparkContext = sc )
513+ self . sc = SparkContext (master = self .master , appName = self .appName )
514+ self . ssc = StreamingContext (sparkContext = self . sc , duration = self . batachDuration )
503515
504516 def test_existing_spark_context_with_settings (self ):
505517 conf = SparkConf ()
506518 conf .set ("spark.cleaner.ttl" , "10" )
507- sc = SparkContext (master = self .master , appName = self .appName , conf = conf )
508- ssc = StreamingContext (context = sc )
509- self .assertEqual (int (ssc .sparkContext ._conf .get ("spark.cleaner.ttl" )), 10 )
510-
511- def _addInputStream (self , s ):
512- test_inputs = map (lambda x : range (1 , x ), range (5 , 101 ))
513- # make sure numSlice is 2 due to deserializer proglem in pyspark
514- s ._testInputStream (test_inputs , 2 )
515-
516- def test_from_no_conf_plus_spark_home_plus_env (self ):
517- pass
519+ self .sc = SparkContext (master = self .master , appName = self .appName , conf = conf )
520+ self .ssc = StreamingContext (sparkContext = self .sc , duration = self .batachDuration )
521+ self .assertEqual (int (self .ssc .sparkContext ._conf .get ("spark.cleaner.ttl" )), 10 )
518522
519523 def test_from_conf_with_settings (self ):
520- pass
524+ conf = SparkConf ()
525+ conf .set ("spark.cleaner.ttl" , "10" )
526+ conf .setMaster (self .master )
527+ conf .setAppName (self .appName )
528+ self .ssc = StreamingContext (conf = conf , duration = self .batachDuration )
529+ self .assertEqual (int (self .ssc .sparkContext ._conf .get ("spark.cleaner.ttl" )), 10 )
521530
522531 def test_stop_only_streaming_context (self ):
523- pass
524-
525- def test_await_termination (self ):
526- pass
527-
528-
532+ self . sc = SparkContext ( master = self . master , appName = self . appName )
533+ self . ssc = StreamingContext ( sparkContext = self . sc , duration = self . batachDuration )
534+ self . _addInputStream (self . ssc )
535+ self . ssc . start ()
536+ self . ssc . stop ( False )
537+ self . assertEqual ( len ( self . sc . parallelize ( range ( 5 ), 5 ). glom (). collect ()), 5 )
529538
539+ def test_stop_multiple_times (self ):
540+ self .ssc = StreamingContext (master = self .master , appName = self .appName ,
541+ duration = self .batachDuration )
542+ self ._addInputStream (self .ssc )
543+ self .ssc .start ()
544+ self .ssc .stop ()
545+ self .ssc .stop ()
530546
547+ def _addInputStream (self , s ):
548+ # Make sure each length of input is over 3 and
549+ # numSlice is 2 due to deserializer problem in pyspark.streaming
550+ test_inputs = map (lambda x : range (1 , x ), range (5 , 101 ))
551+ test_stream = s ._testInputStream (test_inputs , 2 )
552+ # Register fake output operation
553+ result = list ()
554+ test_stream ._test_output (result )
531555
532556if __name__ == "__main__" :
533557 unittest .main ()
0 commit comments