diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 4f29f2f0be1e8..b80149afa2af4 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -25,6 +25,7 @@ from tempfile import NamedTemporaryFile from py4j.protocol import Py4JError +from py4j.java_gateway import is_instance_of from pyspark import accumulators from pyspark.accumulators import Accumulator @@ -864,10 +865,17 @@ def union(self, rdds): first_jrdd_deserializer = rdds[0]._jrdd_deserializer if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds): rdds = [x._reserialize() for x in rdds] + gw = SparkContext._gateway cls = SparkContext._jvm.org.apache.spark.api.java.JavaRDD - jrdds = SparkContext._gateway.new_array(cls, len(rdds)) + is_jrdd = is_instance_of(gw, rdds[0]._jrdd, cls) + jrdds = gw.new_array(cls, len(rdds)) for i in range(0, len(rdds)): - jrdds[i] = rdds[i]._jrdd + if is_jrdd: + jrdds[i] = rdds[i]._jrdd + else: + # zip could return JavaPairRDD hence we ensure `_jrdd` + # to be `JavaRDD` by wrapping it in a `map` + jrdds[i] = rdds[i].map(lambda x: x)._jrdd return RDD(self._jsc.union(jrdds), self, rdds[0]._jrdd_deserializer) def broadcast(self, value): diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index 62ad4221d7078..04dfe68e57a3a 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -168,6 +168,15 @@ def test_zip_chaining(self): set([(x, (x, x)) for x in 'abc']) ) + def test_union_pair_rdd(self): + # Regression test for SPARK-31788 + rdd = self.sc.parallelize([1, 2]) + pair_rdd = rdd.zip(rdd) + self.assertEqual( + self.sc.union([pair_rdd, pair_rdd]).collect(), + [((1, 1), (2, 2)), ((1, 1), (2, 2))] + ) + def test_deleting_input_files(self): # Regression test for SPARK-1025 tempFile = tempfile.NamedTemporaryFile(delete=False)