Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyspark/streaming/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def transform(self, dstreams, transformFunc):
jdstreams = [d._jdstream for d in dstreams]
# change the final serializer to sc.serializer
func = TransformFunction(self._sc,
lambda t, *rdds: transformFunc(rdds).map(lambda x: x),
lambda t, *rdds: transformFunc(rdds),
*[d._jrdd_deserializer for d in dstreams])
jfunc = self._jvm.TransformFunction(func)
jdstream = self._jssc.transform(jdstreams, jfunc)
Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/streaming/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,12 @@ def func(rdds):

self.assertEqual([2, 3, 1], self._take(dstream, 3))

def test_transform_pairrdd(self):
# This regression test case is for SPARK-17756.
dstream = self.ssc.queueStream(
[[1], [2], [3]]).transform(lambda rdd: rdd.cartesian(rdd))
self.assertEqual([(1, 1), (2, 2), (3, 3)], self._take(dstream, 3))

def test_get_active(self):
self.assertEqual(StreamingContext.getActive(), None)

Expand Down
11 changes: 10 additions & 1 deletion python/pyspark/streaming/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from datetime import datetime
import traceback

from py4j.java_gateway import is_instance_of

from pyspark import SparkContext, RDD


Expand Down Expand Up @@ -64,7 +66,14 @@ def call(self, milliseconds, jrdds):
t = datetime.fromtimestamp(milliseconds / 1000.0)
r = self.func(t, *rdds)
if r:
return r._jrdd
# Here, we work around to ensure `_jrdd` is `JavaRDD` by wrapping it by `map`.
# org.apache.spark.streaming.api.python.PythonTransformFunction requires to return
# `JavaRDD`; however, this could be `JavaPairRDD` by some APIs, for example, `zip`.
# See SPARK-17756.
if is_instance_of(self.ctx._gateway, r._jrdd, "org.apache.spark.api.java.JavaRDD"):
return r._jrdd
else:
return r.map(lambda x: x)._jrdd
except:
self.failure = traceback.format_exc()

Expand Down