Skip to content

Commit 54bd92b

Browse files
committed
improve tests
1 parent c2b31cb commit 54bd92b

File tree

5 files changed

+151
-108
lines changed

5 files changed

+151
-108
lines changed

python/pyspark/streaming/context.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ def _ensure_initialized(cls):
118118
# it happens before creating SparkContext when loading from checkpointing
119119
cls._transformerSerializer = TransformFunctionSerializer(
120120
SparkContext._active_spark_context, CloudPickleSerializer(), gw)
121-
gw.jvm.PythonDStream.registerSerializer(cls._transformerSerializer)
122121

123122
@classmethod
124123
def getOrCreate(cls, path, setupFunc):

python/pyspark/streaming/dstream.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import time
2121
from datetime import datetime
2222

23+
from py4j.protocol import Py4JJavaError
24+
2325
from pyspark import RDD
2426
from pyspark.storagelevel import StorageLevel
2527
from pyspark.streaming.util import rddToFileName, TransformFunction
@@ -249,19 +251,31 @@ def saveAsTextFiles(self, prefix, suffix=None):
249251
Save each RDD in this DStream as at text file, using string
250252
representation of elements.
251253
"""
252-
def saveAsTextFile(time, rdd):
253-
path = rddToFileName(prefix, suffix, time)
254-
rdd.saveAsTextFile(path)
254+
def saveAsTextFile(t, rdd):
255+
path = rddToFileName(prefix, suffix, t)
256+
try:
257+
rdd.saveAsTextFile(path)
258+
except Py4JJavaError as e:
259+
# after recovered from checkpointing, the foreachRDD may
260+
# be called twice
261+
if 'FileAlreadyExistsException' not in str(e):
262+
raise
255263
return self.foreachRDD(saveAsTextFile)
256264

257265
def _saveAsPickleFiles(self, prefix, suffix=None):
258266
"""
259267
Save each RDD in this DStream as at binary file, the elements are
260268
serialized by pickle.
261269
"""
262-
def saveAsPickleFile(time, rdd):
263-
path = rddToFileName(prefix, suffix, time)
264-
rdd.saveAsPickleFile(path)
270+
def saveAsPickleFile(t, rdd):
271+
path = rddToFileName(prefix, suffix, t)
272+
try:
273+
rdd.saveAsPickleFile(path)
274+
except Py4JJavaError as e:
275+
# after recovered from checkpointing, the foreachRDD may
276+
# be called twice
277+
if 'FileAlreadyExistsException' not in str(e):
278+
raise
265279
return self.foreachRDD(saveAsPickleFile)
266280

267281
def transform(self, func):
@@ -608,8 +622,7 @@ def _jdstream(self):
608622
if self._jdstream_val is not None:
609623
return self._jdstream_val
610624

611-
func = self.func
612-
jfunc = TransformFunction(self.ctx, func, self.prev._jrdd_deserializer)
625+
jfunc = TransformFunction(self.ctx, self.func, self.prev._jrdd_deserializer)
613626
jdstream = self.ctx._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(),
614627
jfunc, self.reuse).asJavaDStream()
615628
self._jdstream_val = jdstream

python/pyspark/streaming/tests.py

Lines changed: 85 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ def setUp(self):
4242
def tearDown(self):
4343
self.ssc.stop()
4444

45+
def wait_for(self, result, n):
46+
start_time = time.time()
47+
while len(result) < n and time.time() - start_time < self.timeout:
48+
time.sleep(0.01)
49+
if len(result) < n:
50+
print "timeout after", self.timeout
51+
4552
def _take(self, dstream, n):
4653
"""
4754
Return the first `n` elements in the stream (will start and stop).
@@ -55,12 +62,10 @@ def take(_, rdd):
5562
dstream.foreachRDD(take)
5663

5764
self.ssc.start()
58-
while len(results) < n:
59-
time.sleep(0.01)
60-
self.ssc.stop(False, True)
65+
self.wait_for(results, n)
6166
return results
6267

63-
def _collect(self, dstream):
68+
def _collect(self, dstream, n, block=True):
6469
"""
6570
Collect each RDDs into the returned list.
6671
@@ -69,10 +74,18 @@ def _collect(self, dstream):
6974
result = []
7075

7176
def get_output(_, rdd):
72-
r = rdd.collect()
73-
if r:
74-
result.append(r)
77+
if rdd and len(result) < n:
78+
r = rdd.collect()
79+
if r:
80+
result.append(r)
81+
7582
dstream.foreachRDD(get_output)
83+
84+
if not block:
85+
return result
86+
87+
self.ssc.start()
88+
self.wait_for(result, n)
7689
return result
7790

7891
def _test_func(self, input, func, expected, sort=False, input2=None):
@@ -94,23 +107,7 @@ def _test_func(self, input, func, expected, sort=False, input2=None):
94107
else:
95108
stream = func(input_stream)
96109

97-
result = self._collect(stream)
98-
self.ssc.start()
99-
100-
start_time = time.time()
101-
# Loop until get the expected the number of the result from the stream.
102-
while True:
103-
current_time = time.time()
104-
# Check time out.
105-
if (current_time - start_time) > self.timeout:
106-
print "timeout after", self.timeout
107-
break
108-
# StreamingContext.awaitTermination is not used to wait because
109-
# if py4j server is called every 50 milliseconds, it gets an error.
110-
time.sleep(0.05)
111-
# Check if the output is the same length of expected output.
112-
if len(expected) == len(result):
113-
break
110+
result = self._collect(stream, len(expected))
114111
if sort:
115112
self._sort_result_based_on_key(result)
116113
self._sort_result_based_on_key(expected)
@@ -424,55 +421,50 @@ class TestStreamingContext(PySparkStreamingTestCase):
424421

425422
duration = 0.1
426423

424+
def _add_input_stream(self):
425+
inputs = map(lambda x: range(1, x), range(101))
426+
stream = self.ssc.queueStream(inputs)
427+
self._collect(stream, 1, block=False)
428+
427429
def test_stop_only_streaming_context(self):
428-
self._addInputStream()
430+
self._add_input_stream()
429431
self.ssc.start()
430432
self.ssc.stop(False)
431433
self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5)
432434

433435
def test_stop_multiple_times(self):
434-
self._addInputStream()
436+
self._add_input_stream()
435437
self.ssc.start()
436438
self.ssc.stop()
437439
self.ssc.stop()
438440

439-
def _addInputStream(self):
440-
# Make sure each length of input is over 3
441-
inputs = map(lambda x: range(1, x), range(5, 101))
442-
stream = self.ssc.queueStream(inputs)
443-
self._collect(stream)
444-
445-
def test_queueStream(self):
446-
input = [range(i) for i in range(3)]
441+
def test_queue_stream(self):
442+
input = [range(i + 1) for i in range(3)]
447443
dstream = self.ssc.queueStream(input)
448-
result = self._collect(dstream)
449-
self.ssc.start()
450-
time.sleep(1)
451-
self.assertEqual(input, result[:3])
444+
result = self._collect(dstream, 3)
445+
self.assertEqual(input, result)
452446

453-
def test_textFileStream(self):
447+
def test_text_file_stream(self):
454448
d = tempfile.mkdtemp()
455449
self.ssc = StreamingContext(self.sc, self.duration)
456450
dstream2 = self.ssc.textFileStream(d).map(int)
457-
result = self._collect(dstream2)
451+
result = self._collect(dstream2, 2, block=False)
458452
self.ssc.start()
459-
time.sleep(1)
460453
for name in ('a', 'b'):
454+
time.sleep(1)
461455
with open(os.path.join(d, name), "w") as f:
462456
f.writelines(["%d\n" % i for i in range(10)])
463-
time.sleep(2)
464-
self.assertEqual([range(10) * 2], result[:3])
457+
self.wait_for(result, 2)
458+
self.assertEqual([range(10), range(10)], result)
465459

466460
def test_union(self):
467-
input = [range(i) for i in range(3)]
461+
input = [range(i + 1) for i in range(3)]
468462
dstream = self.ssc.queueStream(input)
469463
dstream2 = self.ssc.queueStream(input)
470464
dstream3 = self.ssc.union(dstream, dstream2)
471-
result = self._collect(dstream3)
472-
self.ssc.start()
473-
time.sleep(1)
465+
result = self._collect(dstream3, 3)
474466
expected = [i * 2 for i in input]
475-
self.assertEqual(expected, result[:3])
467+
self.assertEqual(expected, result)
476468

477469
def test_transform(self):
478470
dstream1 = self.ssc.queueStream([[1]])
@@ -497,34 +489,62 @@ def tearDown(self):
497489
pass
498490

499491
def test_get_or_create(self):
500-
result = [0]
501492
inputd = tempfile.mkdtemp()
493+
outputd = tempfile.mkdtemp() + "/"
494+
495+
def updater(it):
496+
for k, vs, s in it:
497+
yield (k, sum(vs, s or 0))
502498

503499
def setup():
504500
conf = SparkConf().set("spark.default.parallelism", 1)
505501
sc = SparkContext(conf=conf)
506-
ssc = StreamingContext(sc, .2)
507-
dstream = ssc.textFileStream(inputd)
508-
result[0] = self._collect(dstream.count())
502+
ssc = StreamingContext(sc, 0.2)
503+
dstream = ssc.textFileStream(inputd).map(lambda x: (x, 1))
504+
wc = dstream.updateStateByKey(updater)
505+
wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test")
506+
wc.checkpoint(.2)
509507
return ssc
510508

511-
tmpd = tempfile.mkdtemp("test_streaming_cps")
512-
ssc = StreamingContext.getOrCreate(tmpd, setup)
509+
cpd = tempfile.mkdtemp("test_streaming_cps")
510+
ssc = StreamingContext.getOrCreate(cpd, setup)
513511
ssc.start()
514-
time.sleep(1)
515-
with open(os.path.join(inputd, "1"), 'w') as f:
516-
f.writelines(["%d\n" % i for i in range(10)])
517-
ssc.awaitTermination(4)
512+
513+
def check_output(n):
514+
while not os.listdir(outputd):
515+
time.sleep(0.1)
516+
time.sleep(1) # make sure mtime is larger than the previous one
517+
with open(os.path.join(inputd, str(n)), 'w') as f:
518+
f.writelines(["%d\n" % i for i in range(10)])
519+
520+
while True:
521+
p = os.path.join(outputd, max(os.listdir(outputd)))
522+
if '_SUCCESS' not in os.listdir(p):
523+
# not finished
524+
time.sleep(0.01)
525+
continue
526+
ordd = ssc.sparkContext.textFile(p).map(lambda line: line.split(","))
527+
d = ordd.values().map(int).collect()
528+
if not d:
529+
time.sleep(0.01)
530+
continue
531+
self.assertEqual(10, len(d))
532+
s = set(d)
533+
self.assertEqual(1, len(s))
534+
m = s.pop()
535+
if n > m:
536+
continue
537+
self.assertEqual(n, m)
538+
break
539+
540+
check_output(1)
541+
check_output(2)
518542
ssc.stop(True, True)
519-
expected = [[i * 1 + 1] for i in range(5)] + [[5]] * 5
520-
self.assertEqual([[10]], result[0][:1])
521543

522-
ssc = StreamingContext.getOrCreate(tmpd, setup)
523-
ssc.start()
524544
time.sleep(1)
525-
with open(os.path.join(inputd, "1"), 'w') as f:
526-
f.writelines(["%d\n" % i for i in range(10)])
527-
ssc.awaitTermination(2)
545+
ssc = StreamingContext.getOrCreate(cpd, setup)
546+
ssc.start()
547+
check_output(3)
528548
ssc.stop(True, True)
529549

530550

python/pyspark/streaming/util.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18+
import time
1819
from datetime import datetime
1920
import traceback
2021

@@ -32,23 +33,20 @@ def __init__(self, ctx, func, *deserializers):
3233
self.func = func
3334
self.deserializers = deserializers
3435

35-
@property
36-
def emptyRDD(self):
37-
if self._emptyRDD is None and self.ctx:
38-
self._emptyRDD = self.ctx.parallelize([]).cache()
39-
return self._emptyRDD
40-
4136
def call(self, milliseconds, jrdds):
4237
try:
4338
if self.ctx is None:
4439
self.ctx = SparkContext._active_spark_context
40+
if not self.ctx or not self.ctx._jsc:
41+
# stopped
42+
return
4543

4644
# extend deserializers with the first one
4745
sers = self.deserializers
4846
if len(sers) < len(jrdds):
4947
sers += (sers[0],) * (len(jrdds) - len(sers))
5048

51-
rdds = [RDD(jrdd, self.ctx, ser) if jrdd else self.emptyRDD
49+
rdds = [RDD(jrdd, self.ctx, ser) if jrdd else None
5250
for jrdd, ser in zip(jrdds, sers)]
5351
t = datetime.fromtimestamp(milliseconds / 1000.0)
5452
r = self.func(t, *rdds)
@@ -69,6 +67,7 @@ def __init__(self, ctx, serializer, gateway=None):
6967
self.ctx = ctx
7068
self.serializer = serializer
7169
self.gateway = gateway or self.ctx._gateway
70+
self.gateway.jvm.PythonDStream.registerSerializer(self)
7271

7372
def dumps(self, id):
7473
try:
@@ -91,20 +90,22 @@ class Java:
9190
implements = ['org.apache.spark.streaming.api.python.PythonTransformFunctionSerializer']
9291

9392

94-
def rddToFileName(prefix, suffix, time):
93+
def rddToFileName(prefix, suffix, timestamp):
9594
"""
9695
Return string prefix-time(.suffix)
9796
9897
>>> rddToFileName("spark", None, 12345678910)
9998
'spark-12345678910'
10099
>>> rddToFileName("spark", "tmp", 12345678910)
101100
'spark-12345678910.tmp'
102-
103101
"""
102+
if isinstance(timestamp, datetime):
103+
seconds = time.mktime(timestamp.timetuple())
104+
timestamp = long(seconds * 1000) + timestamp.microsecond / 1000
104105
if suffix is None:
105-
return prefix + "-" + str(time)
106+
return prefix + "-" + str(timestamp)
106107
else:
107-
return prefix + "-" + str(time) + "." + suffix
108+
return prefix + "-" + str(timestamp) + "." + suffix
108109

109110

110111
if __name__ == "__main__":

0 commit comments

Comments
 (0)