Skip to content

Commit d79dd97

Browse files
Jeroen Schotsrowen
authored andcommitted
[SPARK-3580][CORE] Add Consistent Method To Get Number of RDD Partitions Across Different Languages
I have tried to address all the comments in pull request #2447. Note that the second commit (using the new method in all internal code of all components) is quite intrusive and could be omitted. Author: Jeroen Schot <[email protected]> Closes #9767 from schot/master. (cherry picked from commit 128c290) Signed-off-by: Sean Owen <[email protected]>
1 parent c47a737 commit d79dd97

File tree

5 files changed

+30
-1
lines changed

5 files changed

+30
-1
lines changed

core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import com.google.common.base.Optional
2828
import org.apache.hadoop.io.compress.CompressionCodec
2929

3030
import org.apache.spark._
31+
import org.apache.spark.annotation.Since
3132
import org.apache.spark.api.java.JavaPairRDD._
3233
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
3334
import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap
@@ -62,6 +63,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
6263
/** Set of partitions in this RDD. */
6364
def partitions: JList[Partition] = rdd.partitions.toSeq.asJava
6465

66+
/** Return the number of partitions in this RDD. */
67+
@Since("1.6.0")
68+
def getNumPartitions: Int = rdd.getNumPartitions
69+
6570
/** The partitioner of this RDD. */
6671
def partitioner: Optional[Partitioner] = JavaUtils.optionToOptional(rdd.partitioner)
6772

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.hadoop.mapred.TextOutputFormat
3131

3232
import org.apache.spark._
3333
import org.apache.spark.Partitioner._
34-
import org.apache.spark.annotation.DeveloperApi
34+
import org.apache.spark.annotation.{Since, DeveloperApi}
3535
import org.apache.spark.api.java.JavaRDD
3636
import org.apache.spark.partial.BoundedDouble
3737
import org.apache.spark.partial.CountEvaluator
@@ -242,6 +242,12 @@ abstract class RDD[T: ClassTag](
242242
}
243243
}
244244

245+
/**
246+
* Returns the number of partitions of this RDD.
247+
*/
248+
@Since("1.6.0")
249+
final def getNumPartitions: Int = partitions.length
250+
245251
/**
246252
* Get the preferred locations of a partition, taking into account whether the
247253
* RDD is checkpointed.

core/src/test/java/org/apache/spark/JavaAPISuite.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,19 @@ public Iterator<Integer> call(Integer index, Iterator<Integer> iter) {
973973
Assert.assertEquals("[3, 7]", partitionSums.collect().toString());
974974
}
975975

976+
@Test
977+
public void getNumPartitions(){
978+
JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3);
979+
JavaDoubleRDD rdd2 = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0), 2);
980+
JavaPairRDD<String, Integer> rdd3 = sc.parallelizePairs(Arrays.asList(
981+
new Tuple2<>("a", 1),
982+
new Tuple2<>("aa", 2),
983+
new Tuple2<>("aaa", 3)
984+
), 2);
985+
Assert.assertEquals(3, rdd1.getNumPartitions());
986+
Assert.assertEquals(2, rdd2.getNumPartitions());
987+
Assert.assertEquals(2, rdd3.getNumPartitions());
988+
}
976989

977990
@Test
978991
public void repartition() {

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
3434

3535
test("basic operations") {
3636
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
37+
assert(nums.getNumPartitions === 2)
3738
assert(nums.collect().toList === List(1, 2, 3, 4))
3839
assert(nums.toLocalIterator.toList === List(1, 2, 3, 4))
3940
val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2)

project/MimaExcludes.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,10 @@ object MimaExcludes {
155155
"org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor"),
156156
ProblemFilters.exclude[MissingClassProblem](
157157
"org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor$")
158+
) ++ Seq(
159+
// SPARK-3580 Add getNumPartitions method to JavaRDD
160+
ProblemFilters.exclude[MissingMethodProblem](
161+
"org.apache.spark.api.java.JavaRDDLike.getNumPartitions")
158162
)
159163
case v if v.startsWith("1.5") =>
160164
Seq(

0 commit comments

Comments
 (0)