|
61 | 61 | if sys.version < '3': |
62 | 62 | import cPickle as pickle |
63 | 63 | protocol = 2 |
64 | | - from itertools import izip as zip |
| 64 | + from itertools import izip as zip, imap as map |
65 | 65 | else: |
66 | 66 | import pickle |
67 | 67 | protocol = 3 |
@@ -96,7 +96,12 @@ def load_stream(self, stream): |
96 | 96 | raise NotImplementedError |
97 | 97 |
|
98 | 98 | def _load_stream_without_unbatching(self, stream): |
99 | | - return self.load_stream(stream) |
| 99 | + """ |
| 100 | + Return an iterator of deserialized batches (lists) of objects from the input stream. |
| 101 | + if the serializer does not operate on batches the default implementation returns an |
| 102 | + iterator of single element lists. |
| 103 | + """ |
| 104 | + return map(lambda x: [x], self.load_stream(stream)) |
100 | 105 |
|
101 | 106 | # Note: our notion of "equality" is that output generated by |
102 | 107 | # equal serializers can be deserialized using the same serializer. |
@@ -278,50 +283,57 @@ def __repr__(self): |
278 | 283 | return "AutoBatchedSerializer(%s)" % self.serializer |
279 | 284 |
|
280 | 285 |
|
281 | | -class CartesianDeserializer(FramedSerializer): |
| 286 | +class CartesianDeserializer(Serializer): |
282 | 287 |
|
283 | 288 | """ |
284 | 289 | Deserializes the JavaRDD cartesian() of two PythonRDDs. |
| 290 | + Due to pyspark batching we cannot simply use the result of the Java RDD cartesian, |
| 291 | + we additionally need to do the cartesian within each pair of batches. |
285 | 292 | """ |
286 | 293 |
|
287 | 294 | def __init__(self, key_ser, val_ser): |
288 | | - FramedSerializer.__init__(self) |
289 | 295 | self.key_ser = key_ser |
290 | 296 | self.val_ser = val_ser |
291 | 297 |
|
292 | | - def prepare_keys_values(self, stream): |
293 | | - key_stream = self.key_ser._load_stream_without_unbatching(stream) |
294 | | - val_stream = self.val_ser._load_stream_without_unbatching(stream) |
295 | | - key_is_batched = isinstance(self.key_ser, BatchedSerializer) |
296 | | - val_is_batched = isinstance(self.val_ser, BatchedSerializer) |
297 | | - for (keys, vals) in zip(key_stream, val_stream): |
298 | | - keys = keys if key_is_batched else [keys] |
299 | | - vals = vals if val_is_batched else [vals] |
300 | | - yield (keys, vals) |
| 298 | + def _load_stream_without_unbatching(self, stream): |
| 299 | + key_batch_stream = self.key_ser._load_stream_without_unbatching(stream) |
| 300 | + val_batch_stream = self.val_ser._load_stream_without_unbatching(stream) |
| 301 | + for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream): |
| 302 | + # for correctness with repeated cartesian/zip this must be returned as one batch |
| 303 | + yield product(key_batch, val_batch) |
301 | 304 |
|
302 | 305 | def load_stream(self, stream): |
303 | | - for (keys, vals) in self.prepare_keys_values(stream): |
304 | | - for pair in product(keys, vals): |
305 | | - yield pair |
| 306 | + return chain.from_iterable(self._load_stream_without_unbatching(stream)) |
306 | 307 |
|
307 | 308 | def __repr__(self): |
308 | 309 | return "CartesianDeserializer(%s, %s)" % \ |
309 | 310 | (str(self.key_ser), str(self.val_ser)) |
310 | 311 |
|
311 | 312 |
|
312 | | -class PairDeserializer(CartesianDeserializer): |
| 313 | +class PairDeserializer(Serializer): |
313 | 314 |
|
314 | 315 | """ |
315 | 316 | Deserializes the JavaRDD zip() of two PythonRDDs. |
| 317 | + Due to pyspark batching we cannot simply use the result of the Java RDD zip, |
| 318 | + we additionally need to do the zip within each pair of batches. |
316 | 319 | """ |
317 | 320 |
|
| 321 | + def __init__(self, key_ser, val_ser): |
| 322 | + self.key_ser = key_ser |
| 323 | + self.val_ser = val_ser |
| 324 | + |
| 325 | + def _load_stream_without_unbatching(self, stream): |
| 326 | + key_batch_stream = self.key_ser._load_stream_without_unbatching(stream) |
| 327 | + val_batch_stream = self.val_ser._load_stream_without_unbatching(stream) |
| 328 | + for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream): |
| 329 | + if len(key_batch) != len(val_batch): |
| 330 | + raise ValueError("Can not deserialize PairRDD with different number of items" |
| 331 | + " in batches: (%d, %d)" % (len(key_batch), len(val_batch))) |
| 332 | + # for correctness with repeated cartesian/zip this must be returned as one batch |
| 333 | + yield zip(key_batch, val_batch) |
| 334 | + |
318 | 335 | def load_stream(self, stream): |
319 | | - for (keys, vals) in self.prepare_keys_values(stream): |
320 | | - if len(keys) != len(vals): |
321 | | - raise ValueError("Can not deserialize RDD with different number of items" |
322 | | - " in pair: (%d, %d)" % (len(keys), len(vals))) |
323 | | - for pair in zip(keys, vals): |
324 | | - yield pair |
| 336 | + return chain.from_iterable(self._load_stream_without_unbatching(stream)) |
325 | 337 |
|
326 | 338 | def __repr__(self): |
327 | 339 | return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser)) |
|
0 commit comments