Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions core/src/main/scala/org/apache/spark/Partitioner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.{IOException, ObjectInputStream, ObjectOutputStream}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.math.log10
import scala.reflect.ClassTag
import scala.util.hashing.byteswap32

Expand All @@ -42,7 +43,9 @@ object Partitioner {
/**
* Choose a partitioner to use for a cogroup-like operation between a number of RDDs.
*
* If any of the RDDs already has a partitioner, choose that one.
* If any of the RDDs already has a partitioner, and the number of partitions of the
* partitioner is either greater than or is less than and within a single order of
* magnitude of the max number of upstream partitions, choose that one.
*
* Otherwise, we use a default HashPartitioner. For the number of partitions, if
* spark.default.parallelism is set, then we'll use the value from SparkContext
Expand All @@ -57,8 +60,15 @@ object Partitioner {
def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = {
val rdds = (Seq(rdd) ++ others)
val hasPartitioner = rdds.filter(_.partitioner.exists(_.numPartitions > 0))
if (hasPartitioner.nonEmpty) {
hasPartitioner.maxBy(_.partitions.length).partitioner.get

val hasMaxPartitioner: Option[RDD[_]] = if (hasPartitioner.nonEmpty) {
Some(hasPartitioner.maxBy(_.partitions.length))
} else {
None
}

if (isEligiblePartitioner(hasMaxPartitioner, rdds)) {
hasMaxPartitioner.get.partitioner.get
} else {
if (rdd.context.conf.contains("spark.default.parallelism")) {
new HashPartitioner(rdd.context.defaultParallelism)
Expand All @@ -67,6 +77,22 @@ object Partitioner {
}
}
}

/**
* Returns true if the number of partitions of the RDD is either greater
* than or is less than and within a single order of magnitude of the
* max number of upstream partitions;
* otherwise, returns false
*/
private def isEligiblePartitioner(
hasMaxPartitioner: Option[RDD[_]],
rdds: Seq[RDD[_]]): Boolean = {
if (hasMaxPartitioner.isEmpty) {
return false
}
val maxPartitions = rdds.map(_.partitions.length).max
log10(maxPartitions) - log10(hasMaxPartitioner.get.getNumPartitions) < 1
}
}

/**
Expand Down
27 changes: 27 additions & 0 deletions core/src/test/scala/org/apache/spark/PartitioningSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,33 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva
val partitioner = new RangePartitioner(22, rdd)
assert(partitioner.numPartitions === 3)
}

test("defaultPartitioner") {
val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 150)
val rdd2 = sc
.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4)))
.partitionBy(new HashPartitioner(10))
val rdd3 = sc
.parallelize(Array((1, 6), (7, 8), (3, 10), (5, 12), (13, 14)))
.partitionBy(new HashPartitioner(100))
val rdd4 = sc
.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4)))
.partitionBy(new HashPartitioner(9))
val rdd5 = sc.parallelize((1 to 10).map(x => (x, x)), 11)

val partitioner1 = Partitioner.defaultPartitioner(rdd1, rdd2)
val partitioner2 = Partitioner.defaultPartitioner(rdd2, rdd3)
val partitioner3 = Partitioner.defaultPartitioner(rdd3, rdd1)
val partitioner4 = Partitioner.defaultPartitioner(rdd1, rdd2, rdd3)
val partitioner5 = Partitioner.defaultPartitioner(rdd4, rdd5)

assert(partitioner1.numPartitions == rdd1.getNumPartitions)
assert(partitioner2.numPartitions == rdd3.getNumPartitions)
assert(partitioner3.numPartitions == rdd3.getNumPartitions)
assert(partitioner4.numPartitions == rdd3.getNumPartitions)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a testcase such that numPartitions 9 vs 11 is not treated as an order of magnitude jump (to prevent future changes which end up breaking this).

assert(partitioner5.numPartitions == rdd4.getNumPartitions)

}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,28 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
assert(joined.size > 0)
}

// See SPARK-22465
test("cogroup between multiple RDD " +
"with an order of magnitude difference in number of partitions") {
val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 1000)
val rdd2 = sc
.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
.partitionBy(new HashPartitioner(10))
val joined = rdd1.cogroup(rdd2)
assert(joined.getNumPartitions == rdd1.getNumPartitions)
}

// See SPARK-22465
test("cogroup between multiple RDD" +
" with number of partitions similar in order of magnitude") {
val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 20)
val rdd2 = sc
.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
.partitionBy(new HashPartitioner(10))
val joined = rdd1.cogroup(rdd2)
assert(joined.getNumPartitions == rdd2.getNumPartitions)
}

test("rightOuterJoin") {
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
Expand Down