Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 35 additions & 23 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
if sys.version < '3':
import cPickle as pickle
protocol = 2
from itertools import izip as zip
from itertools import izip as zip, imap as map
else:
import pickle
protocol = 3
Expand Down Expand Up @@ -96,7 +96,12 @@ def load_stream(self, stream):
raise NotImplementedError

def _load_stream_without_unbatching(self, stream):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even though this is internal it might make sense to have a docstring for this since were changing its behaviour.

return self.load_stream(stream)
"""
Return an iterator of deserialized batches (lists) of objects from the input stream.
if the serializer does not operate on batches the default implementation returns an
iterator of single element lists.
"""
return map(lambda x: [x], self.load_stream(stream))

# Note: our notion of "equality" is that output generated by
# equal serializers can be deserialized using the same serializer.
Expand Down Expand Up @@ -278,50 +283,57 @@ def __repr__(self):
return "AutoBatchedSerializer(%s)" % self.serializer


class CartesianDeserializer(FramedSerializer):
class CartesianDeserializer(Serializer):

"""
Deserializes the JavaRDD cartesian() of two PythonRDDs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should document this a bit given that we had problems with the implementation. (e.g. expand on the "Due to batching, we can't use the Java cartesian method." comment from rdd.py to explain how this is intended to function).

Due to pyspark batching we cannot simply use the result of the Java RDD cartesian,
we additionally need to do the cartesian within each pair of batches.
"""

def __init__(self, key_ser, val_ser):
FramedSerializer.__init__(self)
self.key_ser = key_ser
self.val_ser = val_ser

def prepare_keys_values(self, stream):
key_stream = self.key_ser._load_stream_without_unbatching(stream)
val_stream = self.val_ser._load_stream_without_unbatching(stream)
key_is_batched = isinstance(self.key_ser, BatchedSerializer)
val_is_batched = isinstance(self.val_ser, BatchedSerializer)
for (keys, vals) in zip(key_stream, val_stream):
keys = keys if key_is_batched else [keys]
vals = vals if val_is_batched else [vals]
yield (keys, vals)
def _load_stream_without_unbatching(self, stream):
key_batch_stream = self.key_ser._load_stream_without_unbatching(stream)
val_batch_stream = self.val_ser._load_stream_without_unbatching(stream)
for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream):
# for correctness with repeated cartesian/zip this must be returned as one batch
yield product(key_batch, val_batch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe consider adding a comment here explaining why the interaction of batching & product


def load_stream(self, stream):
for (keys, vals) in self.prepare_keys_values(stream):
for pair in product(keys, vals):
yield pair
return chain.from_iterable(self._load_stream_without_unbatching(stream))

def __repr__(self):
return "CartesianDeserializer(%s, %s)" % \
(str(self.key_ser), str(self.val_ser))


class PairDeserializer(CartesianDeserializer):
class PairDeserializer(Serializer):

"""
Deserializes the JavaRDD zip() of two PythonRDDs.
Due to pyspark batching we cannot simply use the result of the Java RDD zip,
we additionally need to do the zip within each pair of batches.
"""

def __init__(self, key_ser, val_ser):
self.key_ser = key_ser
self.val_ser = val_ser

def _load_stream_without_unbatching(self, stream):
key_batch_stream = self.key_ser._load_stream_without_unbatching(stream)
val_batch_stream = self.val_ser._load_stream_without_unbatching(stream)
for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream):
if len(key_batch) != len(val_batch):
raise ValueError("Can not deserialize PairRDD with different number of items"
" in batches: (%d, %d)" % (len(key_batch), len(val_batch)))
# for correctness with repeated cartesian/zip this must be returned as one batch
yield zip(key_batch, val_batch)

def load_stream(self, stream):
for (keys, vals) in self.prepare_keys_values(stream):
if len(keys) != len(vals):
raise ValueError("Can not deserialize RDD with different number of items"
" in pair: (%d, %d)" % (len(keys), len(vals)))
for pair in zip(keys, vals):
yield pair
return chain.from_iterable(self._load_stream_without_unbatching(stream))

def __repr__(self):
return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser))
Expand Down
18 changes: 18 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,24 @@ def test_cartesian_on_textfile(self):
self.assertEqual(u"Hello World!", x.strip())
self.assertEqual(u"Hello World!", y.strip())

def test_cartesian_chaining(self):
# Tests for SPARK-16589
rdd = self.sc.parallelize(range(10), 2)
self.assertSetEqual(
set(rdd.cartesian(rdd).cartesian(rdd).collect()),
set([((x, y), z) for x in range(10) for y in range(10) for z in range(10)])
)

self.assertSetEqual(
set(rdd.cartesian(rdd.cartesian(rdd)).collect()),
set([(x, (y, z)) for x in range(10) for y in range(10) for z in range(10)])
)

self.assertSetEqual(
set(rdd.cartesian(rdd.zip(rdd)).collect()),
set([(x, (y, y)) for x in range(10) for y in range(10)])
)

def test_deleting_input_files(self):
# Regression test for SPARK-1025
tempFile = tempfile.NamedTemporaryFile(delete=False)
Expand Down