From 28dce3a72fe3d44664ebec360712a210fc2df192 Mon Sep 17 00:00:00 2001 From: schintap Date: Thu, 21 May 2020 19:35:03 -0400 Subject: [PATCH 1/5] Fix UnionRDD of PairRDDs --- python/pyspark/context.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 4f29f2f0be1e8..373b7b830e1e6 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -25,6 +25,7 @@ from tempfile import NamedTemporaryFile from py4j.protocol import Py4JError +from py4j.java_gateway import is_instance_of from pyspark import accumulators from pyspark.accumulators import Accumulator @@ -864,8 +865,20 @@ def union(self, rdds): first_jrdd_deserializer = rdds[0]._jrdd_deserializer if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds): rdds = [x._reserialize() for x in rdds] - cls = SparkContext._jvm.org.apache.spark.api.java.JavaRDD - jrdds = SparkContext._gateway.new_array(cls, len(rdds)) + gw = SparkContext._gateway + jvm = SparkContext._jvm + jrdd_cls = jvm.org.apache.spark.api.java.JavaRDD + pair_jrdd_cls = jvm.org.apache.spark.api.java.JavaPairRDD + double_jrdd_cls = jvm.org.apache.spark.api.java.JavaDoubleRDD + if is_instance_of(gw, rdds[0]._jrdd, jrdd_cls): + cls = jrdd_cls + elif is_instance_of(gw, rdds[0]._jrdd, pair_jrdd_cls): + cls = pair_jrdd_cls + elif is_instance_of(gw, rdds[0]._jrdd, double_jrdd_cls): + cls = double_jrdd_cls + else: + raise TypeError("Unsupported java rdd class %s", rdds[0]._jrdd) + jrdds = gw.new_array(cls, len(rdds)) for i in range(0, len(rdds)): jrdds[i] = rdds[i]._jrdd return RDD(self._jsc.union(jrdds), self, rdds[0]._jrdd_deserializer) From 1d8d3082018288ed342743877728066789130aaa Mon Sep 17 00:00:00 2001 From: schintap Date: Fri, 22 May 2020 12:04:49 -0400 Subject: [PATCH 2/5] Address review comments --- python/pyspark/context.py | 21 ++++++++------------- python/pyspark/tests/test_rdd.py | 9 +++++++++ 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 373b7b830e1e6..a4cb07c2a5d8e 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -866,21 +866,16 @@ def union(self, rdds): if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds): rdds = [x._reserialize() for x in rdds] gw = SparkContext._gateway - jvm = SparkContext._jvm - jrdd_cls = jvm.org.apache.spark.api.java.JavaRDD - pair_jrdd_cls = jvm.org.apache.spark.api.java.JavaPairRDD - double_jrdd_cls = jvm.org.apache.spark.api.java.JavaDoubleRDD - if is_instance_of(gw, rdds[0]._jrdd, jrdd_cls): - cls = jrdd_cls - elif is_instance_of(gw, rdds[0]._jrdd, pair_jrdd_cls): - cls = pair_jrdd_cls - elif is_instance_of(gw, rdds[0]._jrdd, double_jrdd_cls): - cls = double_jrdd_cls - else: - raise TypeError("Unsupported java rdd class %s", rdds[0]._jrdd) + cls = SparkContext._jvm.org.apache.spark.api.java.JavaRDD + is_jrdd = is_instance_of(gw, rdds[0]._jrdd, cls) jrdds = gw.new_array(cls, len(rdds)) for i in range(0, len(rdds)): - jrdds[i] = rdds[i]._jrdd + if is_jrdd: + jrdds[i] = rdds[i]._jrdd + else: + # zip could return JavaPairRDD hence we ensure `_jrdd` + # to be `JavaRDD` by wrapping it in a `map` + rdds[i] = rdds[i].map(lambda x: x)._jrdd return RDD(self._jsc.union(jrdds), self, rdds[0]._jrdd_deserializer) def broadcast(self, value): diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index 62ad4221d7078..61cb3d3194edd 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -168,6 +168,15 @@ def test_zip_chaining(self): set([(x, (x, x)) for x in 'abc']) ) + def test_union_pair_rdd(self): + # Regression test for SPARK-31788 + rdd1 = self.sc.parallelize([1, 2]) + rdd2 = self.sc.parallelize([3, 4]) + pair_rdd = rdd1.zip(rdd2) + expected = [(1, 3), (2, 4), (1, 3), (2, 4)] + actual = self.sc.union([pair_rdd, pair_rdd]).collect() + self.assertEqual(expected, actual) + def test_deleting_input_files(self): # Regression test for SPARK-1025 tempFile = tempfile.NamedTemporaryFile(delete=False) From 5da0ad64583665050a2af573fdb3ca4d62f931cc Mon Sep 17 00:00:00 2001 From: schintap Date: Fri, 22 May 2020 12:54:39 -0400 Subject: [PATCH 3/5] Fix nit --- python/pyspark/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a4cb07c2a5d8e..b80149afa2af4 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -875,7 +875,7 @@ def union(self, rdds): else: # zip could return JavaPairRDD hence we ensure `_jrdd` # to be `JavaRDD` by wrapping it in a `map` - rdds[i] = rdds[i].map(lambda x: x)._jrdd + jrdds[i] = rdds[i].map(lambda x: x)._jrdd return RDD(self._jsc.union(jrdds), self, rdds[0]._jrdd_deserializer) def broadcast(self, value): From 66fb35392a08fc05d3bc424ad3c65be8bc987151 Mon Sep 17 00:00:00 2001 From: schintap Date: Fri, 22 May 2020 13:17:57 -0400 Subject: [PATCH 4/5] Fix indentation --- python/pyspark/tests/test_rdd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index 61cb3d3194edd..c33e678613a76 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -169,7 +169,7 @@ def test_zip_chaining(self): ) def test_union_pair_rdd(self): - # Regression test for SPARK-31788 + # Regression test for SPARK-31788 rdd1 = self.sc.parallelize([1, 2]) rdd2 = self.sc.parallelize([3, 4]) pair_rdd = rdd1.zip(rdd2) From c65be0d6ac8e78f56ca8cd063d9db1012b65f0ab Mon Sep 17 00:00:00 2001 From: schintap Date: Fri, 22 May 2020 16:07:19 -0400 Subject: [PATCH 5/5] Condense test --- python/pyspark/tests/test_rdd.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index c33e678613a76..04dfe68e57a3a 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -170,12 +170,12 @@ def test_zip_chaining(self): def test_union_pair_rdd(self): # Regression test for SPARK-31788 - rdd1 = self.sc.parallelize([1, 2]) - rdd2 = self.sc.parallelize([3, 4]) - pair_rdd = rdd1.zip(rdd2) - expected = [(1, 3), (2, 4), (1, 3), (2, 4)] - actual = self.sc.union([pair_rdd, pair_rdd]).collect() - self.assertEqual(expected, actual) + rdd = self.sc.parallelize([1, 2]) + pair_rdd = rdd.zip(rdd) + self.assertEqual( + self.sc.union([pair_rdd, pair_rdd]).collect(), + [((1, 1), (2, 2)), ((1, 1), (2, 2))] + ) def test_deleting_input_files(self): # Regression test for SPARK-1025