Skip to content

Commit 0a70951

Browse files
ConeyLiujiangxb1987
authored andcommitted
[SPARK-29499][CORE][PYSPARK] Add mapPartitionsWithIndex for RDDBarrier
### What changes were proposed in this pull request? Add mapPartitionsWithIndex for RDDBarrier. ### Why are the changes needed? There is only one method in `RDDBarrier`. We often use the partition index as a label for the current partition. We need to get the index from `TaskContext` index in the method of `mapPartitions` which is not convenient. ### Does this PR introduce any user-facing change? No ### How was this patch tested? New UT. Closes #26148 from ConeyLiu/barrier-index. Authored-by: Xianyang Liu <[email protected]> Signed-off-by: Xingbo Jiang <[email protected]>
1 parent 70dd9c0 commit 0a70951

File tree

5 files changed

+96
-0
lines changed

5 files changed

+96
-0
lines changed

core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,5 +54,27 @@ class RDDBarrier[T: ClassTag] private[spark] (rdd: RDD[T]) {
5454
)
5555
}
5656

57+
/**
58+
* :: Experimental ::
59+
* Returns a new RDD by applying a function to each partition of the wrapped RDD, while tracking
60+
* the index of the original partition. And all tasks are launched together in a barrier stage.
61+
* The interface is the same as [[org.apache.spark.rdd.RDD#mapPartitionsWithIndex]].
62+
* Please see the API doc there.
63+
* @see [[org.apache.spark.BarrierTaskContext]]
64+
*/
65+
@Experimental
66+
@Since("3.0.0")
67+
def mapPartitionsWithIndex[S: ClassTag](
68+
f: (Int, Iterator[T]) => Iterator[S],
69+
preservesPartitioning: Boolean = false): RDD[S] = rdd.withScope {
70+
val cleanedF = rdd.sparkContext.clean(f)
71+
new MapPartitionsRDD(
72+
rdd,
73+
(_: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(index, iter),
74+
preservesPartitioning,
75+
isFromBarrier = true
76+
)
77+
}
78+
5779
// TODO: [SPARK-25247] add extra conf to RDDBarrier, e.g., timeout.
5880
}

core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@ class RDDBarrierSuite extends SparkFunSuite with SharedSparkContext {
2929
assert(rdd2.isBarrier())
3030
}
3131

32+
test("RDDBarrier mapPartitionsWithIndex") {
33+
val rdd = sc.parallelize(1 to 12, 4)
34+
assert(rdd.isBarrier() === false)
35+
36+
val rdd2 = rdd.barrier().mapPartitionsWithIndex((index, iter) => Iterator(index))
37+
assert(rdd2.isBarrier())
38+
assert(rdd2.collect().toList === List(0, 1, 2, 3))
39+
}
40+
3241
test("create an RDDBarrier in the middle of a chain of RDDs") {
3342
val rdd = sc.parallelize(1 to 10, 4).map(x => x * 2)
3443
val rdd2 = rdd.barrier().mapPartitions(iter => iter).map(x => (x, x + 1))

dev/sparktestsupport/modules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ def __hash__(self):
329329
"pyspark.tests.test_join",
330330
"pyspark.tests.test_profiler",
331331
"pyspark.tests.test_rdd",
332+
"pyspark.tests.test_rddbarrier",
332333
"pyspark.tests.test_readwrite",
333334
"pyspark.tests.test_serializers",
334335
"pyspark.tests.test_shuffle",

python/pyspark/rdd.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2535,6 +2535,20 @@ def func(s, iterator):
25352535
return f(iterator)
25362536
return PipelinedRDD(self.rdd, func, preservesPartitioning, isFromBarrier=True)
25372537

2538+
def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
2539+
"""
2540+
.. note:: Experimental
2541+
2542+
Returns a new RDD by applying a function to each partition of the wrapped RDD, while
2543+
tracking the index of the original partition. And all tasks are launched together
2544+
in a barrier stage.
2545+
The interface is the same as :func:`RDD.mapPartitionsWithIndex`.
2546+
Please see the API doc there.
2547+
2548+
.. versionadded:: 3.0.0
2549+
"""
2550+
return PipelinedRDD(self.rdd, f, preservesPartitioning, isFromBarrier=True)
2551+
25382552

25392553
class PipelinedRDD(RDD):
25402554

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
from pyspark.testing.utils import ReusedPySparkTestCase
18+
19+
20+
class RDDBarrierTests(ReusedPySparkTestCase):
21+
def test_map_partitions(self):
22+
"""Test RDDBarrier.mapPartitions"""
23+
rdd = self.sc.parallelize(range(12), 4)
24+
self.assertFalse(rdd._is_barrier())
25+
26+
rdd1 = rdd.barrier().mapPartitions(lambda it: it)
27+
self.assertTrue(rdd1._is_barrier())
28+
29+
def test_map_partitions_with_index(self):
30+
"""Test RDDBarrier.mapPartitionsWithIndex"""
31+
rdd = self.sc.parallelize(range(12), 4)
32+
self.assertFalse(rdd._is_barrier())
33+
34+
def f(index, iterator):
35+
yield index
36+
rdd1 = rdd.barrier().mapPartitionsWithIndex(f)
37+
self.assertTrue(rdd1._is_barrier())
38+
self.assertEqual(rdd1.collect(), [0, 1, 2, 3])
39+
40+
41+
if __name__ == "__main__":
42+
import unittest
43+
from pyspark.tests.test_rddbarrier import *
44+
45+
try:
46+
import xmlrunner
47+
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
48+
except ImportError:
49+
testRunner = None
50+
unittest.main(testRunner=testRunner, verbosity=2)

0 commit comments

Comments
 (0)