diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index a89419bbd10e7..53da0fced0a10 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.java -import java.util.{Comparator, List => JList} +import java.util.{Comparator, List => JList, Iterator => JIterator} import scala.collection.JavaConversions._ import scala.reflect.ClassTag @@ -72,11 +72,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to each partition of this RDD, while tracking the index * of the original partition. */ - def mapPartitionsWithIndex[R: ClassTag]( - f: JFunction2[java.lang.Integer, java.util.Iterator[T], java.util.Iterator[R]], - preservesPartitioning: Boolean = false): JavaRDD[R] = - new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))), - preservesPartitioning)) + def mapPartitionsWithIndex[R](f: JFunction2[Integer, JIterator[T], JIterator[R]], + preservesPartitioning: Boolean = false): JavaRDD[R] = { + import scala.collection.JavaConverters._ + def fn = (a: Int, b: Iterator[T]) => f.call(a, asJavaIterator(b)).asScala + val newRdd = rdd.mapPartitionsWithIndex(fn, preservesPartitioning)(fakeClassTag[R]) + new JavaRDD(newRdd)(fakeClassTag) + } /** * Return a new RDD by applying a function to all elements of this RDD. diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 40e853c39ca99..9547c678b71c5 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -414,6 +414,30 @@ public void javaDoubleRDDHistoGram() { Assert.assertArrayEquals(expected_counts, histogram); } + @Test + public void mapPartitionsWithIndex() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + JavaRDD rddByIndex = + rdd.mapPartitionsWithIndex(new Function2, + java.util.Iterator>() { + @Override + public Iterator call(Integer start, java.util.Iterator iter) { + List list = new ArrayList(); + int pos = start; + while (iter.hasNext()) { + list.add(iter.next() * pos); + pos += 1; + } + return list.iterator(); + } + }, false); + Assert.assertEquals(0, rddByIndex.first().intValue()); + Integer[] values = {0, 2, 6, 12, 20}; + Assert.assertEquals(Arrays.asList(values), rddByIndex.collect()); + } + + @Test public void map() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); diff --git a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java index f67251217ed4a..6c925427c9a49 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java @@ -245,6 +245,24 @@ public void mapPartitions() { Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); } + @Test + public void mapPartitionsWithIndex() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); + JavaRDD rddByIndex = rdd.mapPartitionsWithIndex((start, iter) -> { + List list = new ArrayList(); + int sum = 0; + int pos = start; + while (iter.hasNext()) { + sum += (pos * iter.next()); + pos += 1; + } + return list.iterator(); + }); + Assert.assertEquals(0, rddByIndex.first().intValue()); + Integer[] values = {0, 2, 6, 12, 20}; + Assert.assertEquals(Arrays.asList(values), rddByIndex.collect()); + } + @Test public void sequenceFile() { File tempDir = Files.createTempDir();