Skip to content

Commit 4a9eb93

Browse files
committed
remove check and add tests
1 parent 8c7e19a commit 4a9eb93

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

python/pyspark/serializers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,9 @@ class PairDeserializer(Serializer):
333333
Deserializes the JavaRDD zip() of two PythonRDDs.
334334
Due to pyspark batching we cannot simply use the result of the Java RDD zip,
335335
we additionally need to do the zip within each pair of batches.
336+
337+
It is the responsibility of the user of this class to ensure the batch sizes of the key and
338+
value serializer are the same size. If they are not this will give incorrect results.
336339
"""
337340

338341
def __init__(self, key_ser, val_ser):
@@ -343,9 +346,6 @@ def _load_stream_without_unbatching(self, stream):
343346
key_batch_stream = self.key_ser._load_stream_without_unbatching(stream)
344347
val_batch_stream = self.val_ser._load_stream_without_unbatching(stream)
345348
for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream):
346-
if len(key_batch) != len(val_batch):
347-
raise ValueError("Can not deserialize PairRDD with different number of items"
348-
" in batches: (%d, %d)" % (len(key_batch), len(val_batch)))
349349
# for correctness with repeated cartesian/zip this must be returned as one batch
350350
yield zip(key_batch, val_batch)
351351

python/pyspark/tests.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,18 @@ def test_cartesian_chaining(self):
644644
set([(x, (y, y)) for x in range(10) for y in range(10)])
645645
)
646646

647+
def test_zip_chaining(self):
648+
# Tests for SPARK-21985
649+
rdd = self.sc.parallelize(range(10), 2)
650+
self.assertSetEqual(
651+
set(rdd.zip(rdd).zip(rdd).collect()),
652+
set([((x, x), x) for x in range(10)])
653+
)
654+
self.assertSetEqual(
655+
set(rdd.zip(rdd.zip(rdd)).collect()),
656+
set([((x, (x, x)) for x in range(10)])
657+
)
658+
647659
def test_deleting_input_files(self):
648660
# Regression test for SPARK-1025
649661
tempFile = tempfile.NamedTemporaryFile(delete=False)

0 commit comments

Comments
 (0)