diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala index 42802f7113a19..b70ea0073c9a0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala @@ -54,5 +54,27 @@ class RDDBarrier[T: ClassTag] private[spark] (rdd: RDD[T]) { ) } + /** + * :: Experimental :: + * Returns a new RDD by applying a function to each partition of the wrapped RDD, while tracking + * the index of the original partition. And all tasks are launched together in a barrier stage. + * The interface is the same as [[org.apache.spark.rdd.RDD#mapPartitionsWithIndex]]. + * Please see the API doc there. + * @see [[org.apache.spark.BarrierTaskContext]] + */ + @Experimental + @Since("3.0.0") + def mapPartitionsWithIndex[S: ClassTag]( + f: (Int, Iterator[T]) => Iterator[S], + preservesPartitioning: Boolean = false): RDD[S] = rdd.withScope { + val cleanedF = rdd.sparkContext.clean(f) + new MapPartitionsRDD( + rdd, + (_: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(index, iter), + preservesPartitioning, + isFromBarrier = true + ) + } + // TODO: [SPARK-25247] add extra conf to RDDBarrier, e.g., timeout. } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala index 2f6c4d6a42ea3..f048f95430138 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala @@ -29,6 +29,15 @@ class RDDBarrierSuite extends SparkFunSuite with SharedSparkContext { assert(rdd2.isBarrier()) } + test("RDDBarrier mapPartitionsWithIndex") { + val rdd = sc.parallelize(1 to 12, 4) + assert(rdd.isBarrier() === false) + + val rdd2 = rdd.barrier().mapPartitionsWithIndex((index, iter) => Iterator(index)) + assert(rdd2.isBarrier()) + assert(rdd2.collect().toList === List(0, 1, 2, 3)) + } + test("create an RDDBarrier in the middle of a chain of RDDs") { val rdd = sc.parallelize(1 to 10, 4).map(x => x * 2) val rdd2 = rdd.barrier().mapPartitions(iter => iter).map(x => (x, x + 1)) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index c7ea065b28ed8..1443584ccbcb8 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -329,6 +329,7 @@ def __hash__(self): "pyspark.tests.test_join", "pyspark.tests.test_profiler", "pyspark.tests.test_rdd", + "pyspark.tests.test_rddbarrier", "pyspark.tests.test_readwrite", "pyspark.tests.test_serializers", "pyspark.tests.test_shuffle", diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 1edffaa4ca168..52ab86c0d88ee 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2535,6 +2535,20 @@ def func(s, iterator): return f(iterator) return PipelinedRDD(self.rdd, func, preservesPartitioning, isFromBarrier=True) + def mapPartitionsWithIndex(self, f, preservesPartitioning=False): + """ + .. note:: Experimental + + Returns a new RDD by applying a function to each partition of the wrapped RDD, while + tracking the index of the original partition. And all tasks are launched together + in a barrier stage. + The interface is the same as :func:`RDD.mapPartitionsWithIndex`. + Please see the API doc there. + + .. versionadded:: 3.0.0 + """ + return PipelinedRDD(self.rdd, f, preservesPartitioning, isFromBarrier=True) + class PipelinedRDD(RDD): diff --git a/python/pyspark/tests/test_rddbarrier.py b/python/pyspark/tests/test_rddbarrier.py new file mode 100644 index 0000000000000..8534fb4abb876 --- /dev/null +++ b/python/pyspark/tests/test_rddbarrier.py @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from pyspark.testing.utils import ReusedPySparkTestCase + + +class RDDBarrierTests(ReusedPySparkTestCase): + def test_map_partitions(self): + """Test RDDBarrier.mapPartitions""" + rdd = self.sc.parallelize(range(12), 4) + self.assertFalse(rdd._is_barrier()) + + rdd1 = rdd.barrier().mapPartitions(lambda it: it) + self.assertTrue(rdd1._is_barrier()) + + def test_map_partitions_with_index(self): + """Test RDDBarrier.mapPartitionsWithIndex""" + rdd = self.sc.parallelize(range(12), 4) + self.assertFalse(rdd._is_barrier()) + + def f(index, iterator): + yield index + rdd1 = rdd.barrier().mapPartitionsWithIndex(f) + self.assertTrue(rdd1._is_barrier()) + self.assertEqual(rdd1.collect(), [0, 1, 2, 3]) + + +if __name__ == "__main__": + import unittest + from pyspark.tests.test_rddbarrier import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2)