Skip to content

Commit e0173f1

Browse files
araydavies
authored andcommitted
[SPARK-16589] [PYTHON] Chained cartesian produces incorrect number of records
## What changes were proposed in this pull request? Fixes a bug in the python implementation of rdd cartesian product related to batching that showed up in repeated cartesian products with seemingly random results. The root cause being multiple iterators pulling from the same stream in the wrong order because of logic that ignored batching. `CartesianDeserializer` and `PairDeserializer` were changed to implement `_load_stream_without_unbatching` and borrow the one line implementation of `load_stream` from `BatchedSerializer`. The default implementation of `_load_stream_without_unbatching` was changed to give consistent results (always an iterable) so that it could be used without additional checks. `PairDeserializer` no longer extends `CartesianDeserializer` as it was not really proper. If wanted a new common super class could be added. Both `CartesianDeserializer` and `PairDeserializer` now only extend `Serializer` (which has no `dump_stream` implementation) since they are only meant for *de*serialization. ## How was this patch tested? Additional unit tests (sourced from #14248) plus one for testing a cartesian with zip. Author: Andrew Ray <[email protected]> Closes #16121 from aray/fix-cartesian. (cherry picked from commit 3c68944) Signed-off-by: Davies Liu <[email protected]>
1 parent 726217e commit e0173f1

File tree

2 files changed

+53
-23
lines changed

2 files changed

+53
-23
lines changed

python/pyspark/serializers.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
if sys.version < '3':
6262
import cPickle as pickle
6363
protocol = 2
64-
from itertools import izip as zip
64+
from itertools import izip as zip, imap as map
6565
else:
6666
import pickle
6767
protocol = 3
@@ -96,7 +96,12 @@ def load_stream(self, stream):
9696
raise NotImplementedError
9797

9898
def _load_stream_without_unbatching(self, stream):
99-
return self.load_stream(stream)
99+
"""
100+
Return an iterator of deserialized batches (lists) of objects from the input stream.
101+
if the serializer does not operate on batches the default implementation returns an
102+
iterator of single element lists.
103+
"""
104+
return map(lambda x: [x], self.load_stream(stream))
100105

101106
# Note: our notion of "equality" is that output generated by
102107
# equal serializers can be deserialized using the same serializer.
@@ -278,50 +283,57 @@ def __repr__(self):
278283
return "AutoBatchedSerializer(%s)" % self.serializer
279284

280285

281-
class CartesianDeserializer(FramedSerializer):
286+
class CartesianDeserializer(Serializer):
282287

283288
"""
284289
Deserializes the JavaRDD cartesian() of two PythonRDDs.
290+
Due to pyspark batching we cannot simply use the result of the Java RDD cartesian,
291+
we additionally need to do the cartesian within each pair of batches.
285292
"""
286293

287294
def __init__(self, key_ser, val_ser):
288-
FramedSerializer.__init__(self)
289295
self.key_ser = key_ser
290296
self.val_ser = val_ser
291297

292-
def prepare_keys_values(self, stream):
293-
key_stream = self.key_ser._load_stream_without_unbatching(stream)
294-
val_stream = self.val_ser._load_stream_without_unbatching(stream)
295-
key_is_batched = isinstance(self.key_ser, BatchedSerializer)
296-
val_is_batched = isinstance(self.val_ser, BatchedSerializer)
297-
for (keys, vals) in zip(key_stream, val_stream):
298-
keys = keys if key_is_batched else [keys]
299-
vals = vals if val_is_batched else [vals]
300-
yield (keys, vals)
298+
def _load_stream_without_unbatching(self, stream):
299+
key_batch_stream = self.key_ser._load_stream_without_unbatching(stream)
300+
val_batch_stream = self.val_ser._load_stream_without_unbatching(stream)
301+
for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream):
302+
# for correctness with repeated cartesian/zip this must be returned as one batch
303+
yield product(key_batch, val_batch)
301304

302305
def load_stream(self, stream):
303-
for (keys, vals) in self.prepare_keys_values(stream):
304-
for pair in product(keys, vals):
305-
yield pair
306+
return chain.from_iterable(self._load_stream_without_unbatching(stream))
306307

307308
def __repr__(self):
308309
return "CartesianDeserializer(%s, %s)" % \
309310
(str(self.key_ser), str(self.val_ser))
310311

311312

312-
class PairDeserializer(CartesianDeserializer):
313+
class PairDeserializer(Serializer):
313314

314315
"""
315316
Deserializes the JavaRDD zip() of two PythonRDDs.
317+
Due to pyspark batching we cannot simply use the result of the Java RDD zip,
318+
we additionally need to do the zip within each pair of batches.
316319
"""
317320

321+
def __init__(self, key_ser, val_ser):
322+
self.key_ser = key_ser
323+
self.val_ser = val_ser
324+
325+
def _load_stream_without_unbatching(self, stream):
326+
key_batch_stream = self.key_ser._load_stream_without_unbatching(stream)
327+
val_batch_stream = self.val_ser._load_stream_without_unbatching(stream)
328+
for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream):
329+
if len(key_batch) != len(val_batch):
330+
raise ValueError("Can not deserialize PairRDD with different number of items"
331+
" in batches: (%d, %d)" % (len(key_batch), len(val_batch)))
332+
# for correctness with repeated cartesian/zip this must be returned as one batch
333+
yield zip(key_batch, val_batch)
334+
318335
def load_stream(self, stream):
319-
for (keys, vals) in self.prepare_keys_values(stream):
320-
if len(keys) != len(vals):
321-
raise ValueError("Can not deserialize RDD with different number of items"
322-
" in pair: (%d, %d)" % (len(keys), len(vals)))
323-
for pair in zip(keys, vals):
324-
yield pair
336+
return chain.from_iterable(self._load_stream_without_unbatching(stream))
325337

326338
def __repr__(self):
327339
return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser))

python/pyspark/tests.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,24 @@ def test_cartesian_on_textfile(self):
548548
self.assertEqual(u"Hello World!", x.strip())
549549
self.assertEqual(u"Hello World!", y.strip())
550550

551+
def test_cartesian_chaining(self):
552+
# Tests for SPARK-16589
553+
rdd = self.sc.parallelize(range(10), 2)
554+
self.assertSetEqual(
555+
set(rdd.cartesian(rdd).cartesian(rdd).collect()),
556+
set([((x, y), z) for x in range(10) for y in range(10) for z in range(10)])
557+
)
558+
559+
self.assertSetEqual(
560+
set(rdd.cartesian(rdd.cartesian(rdd)).collect()),
561+
set([(x, (y, z)) for x in range(10) for y in range(10) for z in range(10)])
562+
)
563+
564+
self.assertSetEqual(
565+
set(rdd.cartesian(rdd.zip(rdd)).collect()),
566+
set([(x, (y, y)) for x in range(10) for y in range(10)])
567+
)
568+
551569
def test_deleting_input_files(self):
552570
# Regression test for SPARK-1025
553571
tempFile = tempfile.NamedTemporaryFile(delete=False)

0 commit comments

Comments
 (0)