diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 17c34f8a1c54c..dd924ef89868e 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -338,7 +338,7 @@ def transform(self, dstreams, transformFunc): jdstreams = [d._jdstream for d in dstreams] # change the final serializer to sc.serializer func = TransformFunction(self._sc, - lambda t, *rdds: transformFunc(rdds).map(lambda x: x), + lambda t, *rdds: transformFunc(rdds), *[d._jrdd_deserializer for d in dstreams]) jfunc = self._jvm.TransformFunction(func) jdstream = self._jssc.transform(jdstreams, jfunc) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 5b86c1cb2c390..71963b3c9d825 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -772,6 +772,12 @@ def func(rdds): self.assertEqual([2, 3, 1], self._take(dstream, 3)) + def test_transform_pairrdd(self): + # This regression test case is for SPARK-17756. + dstream = self.ssc.queueStream( + [[1], [2], [3]]).transform(lambda rdd: rdd.cartesian(rdd)) + self.assertEqual([(1, 1), (2, 2), (3, 3)], self._take(dstream, 3)) + def test_get_active(self): self.assertEqual(StreamingContext.getActive(), None) diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index abbbf6eb9394f..f34eea7b5c18f 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -19,6 +19,8 @@ from datetime import datetime import traceback +from py4j.java_gateway import is_instance_of + from pyspark import SparkContext, RDD @@ -64,7 +66,14 @@ def call(self, milliseconds, jrdds): t = datetime.fromtimestamp(milliseconds / 1000.0) r = self.func(t, *rdds) if r: - return r._jrdd + # Here, we work around to ensure `_jrdd` is `JavaRDD` by wrapping it by `map`. + # org.apache.spark.streaming.api.python.PythonTransformFunction requires to return + # `JavaRDD`; however, this could be `JavaPairRDD` by some APIs, for example, `zip`. + # See SPARK-17756. + if is_instance_of(self.ctx._gateway, r._jrdd, "org.apache.spark.api.java.JavaRDD"): + return r._jrdd + else: + return r.map(lambda x: x)._jrdd except: self.failure = traceback.format_exc()