diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 6afe769662221..da3bcb7472843 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -684,10 +684,19 @@ def cartesian(self, other): >>> sorted(rdd.cartesian(rdd).collect()) [(1, 1), (1, 2), (2, 1), (2, 2)] """ + def reserialize_if_cartesian(rdd): + if isinstance(rdd._jrdd_deserializer, CartesianDeserializer): + return rdd._reserialize(self.ctx.serializer) + else: + return rdd + + this = reserialize_if_cartesian(self) + other = reserialize_if_cartesian(other) + # Due to batching, we can't use the Java cartesian method. - deserializer = CartesianDeserializer(self._jrdd_deserializer, + deserializer = CartesianDeserializer(this._jrdd_deserializer, other._jrdd_deserializer) - return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer) + return RDD(this._jrdd.cartesian(other._jrdd), self.ctx, deserializer) def groupBy(self, f, numPartitions=None, partitionFunc=portable_hash): """ diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 0a029b6e7441b..4b52cc241ba9f 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1036,6 +1036,19 @@ def test_pipe_functions(self): self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect) self.assertEqual([], rdd.pipe('grep 4').collect()) + def test_cartesian_chaining(self): + # Tests for SPARK-16589 + rdd = self.sc.parallelize(range(10), 2) + self.assertSetEqual( + set(rdd.cartesian(rdd).cartesian(rdd).collect()), + set([((x, y), z) for x in range(10) for y in range(10) for z in range(10)]) + ) + + self.assertSetEqual( + set(rdd.cartesian(rdd.cartesian(rdd)).collect()), + set([(x, (y, z)) for x in range(10) for y in range(10) for z in range(10)]) + ) + class ProfilerTests(PySparkTestCase):