Skip to content

Commit 0703b67

Browse files
committed
Returns an iterator of lists
1 parent 17edfec commit 0703b67

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

python/pyspark/serializers.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -316,8 +316,9 @@ def _load_stream_without_unbatching(self, stream):
316316
key_batch_stream = self.key_ser._load_stream_without_unbatching(stream)
317317
val_batch_stream = self.val_ser._load_stream_without_unbatching(stream)
318318
for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream):
319-
# for correctness with repeated cartesian/zip this must be returned as one batch
320-
yield product(key_batch, val_batch)
319+
# for correctness with repeated cartesian/zip this must be returned as
320+
# one batch (a list)
321+
yield list(product(key_batch, val_batch))
321322

322323
def load_stream(self, stream):
323324
return chain.from_iterable(self._load_stream_without_unbatching(stream))
@@ -346,8 +347,9 @@ def _load_stream_without_unbatching(self, stream):
346347
if len(key_batch) != len(val_batch):
347348
raise ValueError("Can not deserialize PairRDD with different number of items"
348349
" in batches: (%d, %d)" % (len(key_batch), len(val_batch)))
349-
# for correctness with repeated cartesian/zip this must be returned as one batch
350-
yield zip(key_batch, val_batch)
350+
# for correctness with repeated cartesian/zip this must be returned as
351+
# one batch (a list)
352+
yield list(zip(key_batch, val_batch))
351353

352354
def load_stream(self, stream):
353355
return chain.from_iterable(self._load_stream_without_unbatching(stream))

python/pyspark/tests.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,19 @@ def test_cartesian_chaining(self):
644644
set([(x, (y, y)) for x in range(10) for y in range(10)])
645645
)
646646

647+
def test_zip_chaining(self):
648+
# Tests for SPARK-21985
649+
rdd = self.sc.parallelize(range(3), 2)
650+
self.assertSetEqual(
651+
set(rdd.zip(rdd).zip(rdd).collect()),
652+
set(zip(zip(range(3), range(3)), range(3)))
653+
)
654+
655+
self.assertSetEqual(
656+
set(rdd.zip(rdd.zip(rdd)).collect()),
657+
set(zip(range(3), zip(range(3), range(3))))
658+
)
659+
647660
def test_deleting_input_files(self):
648661
# Regression test for SPARK-1025
649662
tempFile = tempfile.NamedTemporaryFile(delete=False)

0 commit comments

Comments
 (0)