Skip to content

Commit 3ae7ab8

Browse files
arayHyukjinKwon
authored andcommitted
[SPARK-21985][PYSPARK] PairDeserializer is broken for double-zipped RDDs
## What changes were proposed in this pull request? (edited) Fixes a bug introduced in #16121 In PairDeserializer convert each batch of keys and values to lists (if they do not have `__len__` already) so that we can check that they are the same size. Normally they already are lists so this should not have a performance impact, but this is needed when repeated `zip`'s are done. ## How was this patch tested? Additional unit test Author: Andrew Ray <[email protected]> Closes #19226 from aray/SPARK-21985. (cherry picked from commit 6adf67d) Signed-off-by: hyukjinkwon <[email protected]>
1 parent e49c997 commit 3ae7ab8

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

python/pyspark/serializers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def load_stream(self, stream):
9797

9898
def _load_stream_without_unbatching(self, stream):
9999
"""
100-
Return an iterator of deserialized batches (lists) of objects from the input stream.
100+
Return an iterator of deserialized batches (iterable) of objects from the input stream.
101101
if the serializer does not operate on batches the default implementation returns an
102102
iterator of single element lists.
103103
"""
@@ -326,6 +326,10 @@ def _load_stream_without_unbatching(self, stream):
326326
key_batch_stream = self.key_ser._load_stream_without_unbatching(stream)
327327
val_batch_stream = self.val_ser._load_stream_without_unbatching(stream)
328328
for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream):
329+
# For double-zipped RDDs, the batches can be iterators from other PairDeserializer,
330+
# instead of lists. We need to convert them to lists if needed.
331+
key_batch = key_batch if hasattr(key_batch, '__len__') else list(key_batch)
332+
val_batch = val_batch if hasattr(val_batch, '__len__') else list(val_batch)
329333
if len(key_batch) != len(val_batch):
330334
raise ValueError("Can not deserialize PairRDD with different number of items"
331335
" in batches: (%d, %d)" % (len(key_batch), len(val_batch)))

python/pyspark/tests.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,18 @@ def test_cartesian_chaining(self):
579579
set([(x, (y, y)) for x in range(10) for y in range(10)])
580580
)
581581

582+
def test_zip_chaining(self):
583+
# Tests for SPARK-21985
584+
rdd = self.sc.parallelize('abc', 2)
585+
self.assertSetEqual(
586+
set(rdd.zip(rdd).zip(rdd).collect()),
587+
set([((x, x), x) for x in 'abc'])
588+
)
589+
self.assertSetEqual(
590+
set(rdd.zip(rdd.zip(rdd)).collect()),
591+
set([(x, (x, x)) for x in 'abc'])
592+
)
593+
582594
def test_deleting_input_files(self):
583595
# Regression test for SPARK-1025
584596
tempFile = tempfile.NamedTemporaryFile(delete=False)

0 commit comments

Comments
 (0)