@@ -42,6 +42,13 @@ def setUp(self):
4242 def tearDown (self ):
4343 self .ssc .stop ()
4444
45+ def wait_for (self , result , n ):
46+ start_time = time .time ()
47+ while len (result ) < n and time .time () - start_time < self .timeout :
48+ time .sleep (0.01 )
49+ if len (result ) < n :
50+ print "timeout after" , self .timeout
51+
4552 def _take (self , dstream , n ):
4653 """
4754 Return the first `n` elements in the stream (will start and stop).
@@ -55,12 +62,10 @@ def take(_, rdd):
5562 dstream .foreachRDD (take )
5663
5764 self .ssc .start ()
58- while len (results ) < n :
59- time .sleep (0.01 )
60- self .ssc .stop (False , True )
65+ self .wait_for (results , n )
6166 return results
6267
63- def _collect (self , dstream ):
68+ def _collect (self , dstream , n , block = True ):
6469 """
6570 Collect each RDDs into the returned list.
6671
@@ -69,10 +74,18 @@ def _collect(self, dstream):
6974 result = []
7075
7176 def get_output (_ , rdd ):
72- r = rdd .collect ()
73- if r :
74- result .append (r )
77+ if rdd and len (result ) < n :
78+ r = rdd .collect ()
79+ if r :
80+ result .append (r )
81+
7582 dstream .foreachRDD (get_output )
83+
84+ if not block :
85+ return result
86+
87+ self .ssc .start ()
88+ self .wait_for (result , n )
7689 return result
7790
7891 def _test_func (self , input , func , expected , sort = False , input2 = None ):
@@ -94,23 +107,7 @@ def _test_func(self, input, func, expected, sort=False, input2=None):
94107 else :
95108 stream = func (input_stream )
96109
97- result = self ._collect (stream )
98- self .ssc .start ()
99-
100- start_time = time .time ()
101- # Loop until get the expected the number of the result from the stream.
102- while True :
103- current_time = time .time ()
104- # Check time out.
105- if (current_time - start_time ) > self .timeout :
106- print "timeout after" , self .timeout
107- break
108- # StreamingContext.awaitTermination is not used to wait because
109- # if py4j server is called every 50 milliseconds, it gets an error.
110- time .sleep (0.05 )
111- # Check if the output is the same length of expected output.
112- if len (expected ) == len (result ):
113- break
110+ result = self ._collect (stream , len (expected ))
114111 if sort :
115112 self ._sort_result_based_on_key (result )
116113 self ._sort_result_based_on_key (expected )
@@ -424,55 +421,50 @@ class TestStreamingContext(PySparkStreamingTestCase):
424421
425422 duration = 0.1
426423
424+ def _add_input_stream (self ):
425+ inputs = map (lambda x : range (1 , x ), range (101 ))
426+ stream = self .ssc .queueStream (inputs )
427+ self ._collect (stream , 1 , block = False )
428+
427429 def test_stop_only_streaming_context (self ):
428- self ._addInputStream ()
430+ self ._add_input_stream ()
429431 self .ssc .start ()
430432 self .ssc .stop (False )
431433 self .assertEqual (len (self .sc .parallelize (range (5 ), 5 ).glom ().collect ()), 5 )
432434
433435 def test_stop_multiple_times (self ):
434- self ._addInputStream ()
436+ self ._add_input_stream ()
435437 self .ssc .start ()
436438 self .ssc .stop ()
437439 self .ssc .stop ()
438440
439- def _addInputStream (self ):
440- # Make sure each length of input is over 3
441- inputs = map (lambda x : range (1 , x ), range (5 , 101 ))
442- stream = self .ssc .queueStream (inputs )
443- self ._collect (stream )
444-
445- def test_queueStream (self ):
446- input = [range (i ) for i in range (3 )]
441+ def test_queue_stream (self ):
442+ input = [range (i + 1 ) for i in range (3 )]
447443 dstream = self .ssc .queueStream (input )
448- result = self ._collect (dstream )
449- self .ssc .start ()
450- time .sleep (1 )
451- self .assertEqual (input , result [:3 ])
444+ result = self ._collect (dstream , 3 )
445+ self .assertEqual (input , result )
452446
453- def test_textFileStream (self ):
447+ def test_text_file_stream (self ):
454448 d = tempfile .mkdtemp ()
455449 self .ssc = StreamingContext (self .sc , self .duration )
456450 dstream2 = self .ssc .textFileStream (d ).map (int )
457- result = self ._collect (dstream2 )
451+ result = self ._collect (dstream2 , 2 , block = False )
458452 self .ssc .start ()
459- time .sleep (1 )
460453 for name in ('a' , 'b' ):
454+ time .sleep (1 )
461455 with open (os .path .join (d , name ), "w" ) as f :
462456 f .writelines (["%d\n " % i for i in range (10 )])
463- time . sleep ( 2 )
464- self .assertEqual ([range (10 ) * 2 ], result [: 3 ] )
457+ self . wait_for ( result , 2 )
458+ self .assertEqual ([range (10 ), range ( 10 ) ], result )
465459
466460 def test_union (self ):
467- input = [range (i ) for i in range (3 )]
461+ input = [range (i + 1 ) for i in range (3 )]
468462 dstream = self .ssc .queueStream (input )
469463 dstream2 = self .ssc .queueStream (input )
470464 dstream3 = self .ssc .union (dstream , dstream2 )
471- result = self ._collect (dstream3 )
472- self .ssc .start ()
473- time .sleep (1 )
465+ result = self ._collect (dstream3 , 3 )
474466 expected = [i * 2 for i in input ]
475- self .assertEqual (expected , result [: 3 ] )
467+ self .assertEqual (expected , result )
476468
477469 def test_transform (self ):
478470 dstream1 = self .ssc .queueStream ([[1 ]])
@@ -497,34 +489,62 @@ def tearDown(self):
497489 pass
498490
499491 def test_get_or_create (self ):
500- result = [0 ]
501492 inputd = tempfile .mkdtemp ()
493+ outputd = tempfile .mkdtemp () + "/"
494+
495+ def updater (it ):
496+ for k , vs , s in it :
497+ yield (k , sum (vs , s or 0 ))
502498
503499 def setup ():
504500 conf = SparkConf ().set ("spark.default.parallelism" , 1 )
505501 sc = SparkContext (conf = conf )
506- ssc = StreamingContext (sc , .2 )
507- dstream = ssc .textFileStream (inputd )
508- result [0 ] = self ._collect (dstream .count ())
502+ ssc = StreamingContext (sc , 0.2 )
503+ dstream = ssc .textFileStream (inputd ).map (lambda x : (x , 1 ))
504+ wc = dstream .updateStateByKey (updater )
505+ wc .map (lambda x : "%s,%d" % x ).saveAsTextFiles (outputd + "test" )
506+ wc .checkpoint (.2 )
509507 return ssc
510508
511- tmpd = tempfile .mkdtemp ("test_streaming_cps" )
512- ssc = StreamingContext .getOrCreate (tmpd , setup )
509+ cpd = tempfile .mkdtemp ("test_streaming_cps" )
510+ ssc = StreamingContext .getOrCreate (cpd , setup )
513511 ssc .start ()
514- time .sleep (1 )
515- with open (os .path .join (inputd , "1" ), 'w' ) as f :
516- f .writelines (["%d\n " % i for i in range (10 )])
517- ssc .awaitTermination (4 )
512+
513+ def check_output (n ):
514+ while not os .listdir (outputd ):
515+ time .sleep (0.1 )
516+ time .sleep (1 ) # make sure mtime is larger than the previous one
517+ with open (os .path .join (inputd , str (n )), 'w' ) as f :
518+ f .writelines (["%d\n " % i for i in range (10 )])
519+
520+ while True :
521+ p = os .path .join (outputd , max (os .listdir (outputd )))
522+ if '_SUCCESS' not in os .listdir (p ):
523+ # not finished
524+ time .sleep (0.01 )
525+ continue
526+ ordd = ssc .sparkContext .textFile (p ).map (lambda line : line .split ("," ))
527+ d = ordd .values ().map (int ).collect ()
528+ if not d :
529+ time .sleep (0.01 )
530+ continue
531+ self .assertEqual (10 , len (d ))
532+ s = set (d )
533+ self .assertEqual (1 , len (s ))
534+ m = s .pop ()
535+ if n > m :
536+ continue
537+ self .assertEqual (n , m )
538+ break
539+
540+ check_output (1 )
541+ check_output (2 )
518542 ssc .stop (True , True )
519- expected = [[i * 1 + 1 ] for i in range (5 )] + [[5 ]] * 5
520- self .assertEqual ([[10 ]], result [0 ][:1 ])
521543
522- ssc = StreamingContext .getOrCreate (tmpd , setup )
523- ssc .start ()
524544 time .sleep (1 )
525- with open ( os . path . join ( inputd , "1" ), 'w' ) as f :
526- f . writelines ([ "%d \n " % i for i in range ( 10 )] )
527- ssc . awaitTermination ( 2 )
545+ ssc = StreamingContext . getOrCreate ( cpd , setup )
546+ ssc . start ( )
547+ check_output ( 3 )
528548 ssc .stop (True , True )
529549
530550
0 commit comments