Skip to content

Commit 92cebad

Browse files
Syed Hashmirxin
authored andcommitted
[SPARK-1784] Add a new partitioner to allow specifying # of keys per partition
This change adds a new partitioner which allows users to specify # of keys per partition. Author: Syed Hashmi <[email protected]> Closes apache#721 from syedhashmi/master and squashes the following commits: 4ca94cc [Syed Hashmi] [SPARK-1784] Add a new partitioner
1 parent 4423386 commit 92cebad

File tree

2 files changed

+95
-0
lines changed

2 files changed

+95
-0
lines changed

core/src/main/scala/org/apache/spark/Partitioner.scala

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,64 @@ class RangePartitioner[K : Ordering : ClassTag, V](
156156
false
157157
}
158158
}
159+
160+
/**
161+
* A [[org.apache.spark.Partitioner]] that partitions records into specified bounds
162+
* Default value is 1000. Once all partitions have bounds elements, the partitioner
163+
* allocates 1 element per partition so eventually the smaller partitions are at most
164+
* off by 1 key compared to the larger partitions.
165+
*/
166+
class BoundaryPartitioner[K : Ordering : ClassTag, V](
167+
partitions: Int,
168+
@transient rdd: RDD[_ <: Product2[K,V]],
169+
private val boundary: Int = 1000)
170+
extends Partitioner {
171+
172+
// this array keeps track of keys assigned to a partition
173+
// counts[0] refers to # of keys in partition 0 and so on
174+
private val counts: Array[Int] = {
175+
new Array[Int](numPartitions)
176+
}
177+
178+
def numPartitions = math.abs(partitions)
179+
180+
/*
181+
* Ideally, this should've been calculated based on # partitions and total keys
182+
* But we are not calling count on RDD here to avoid calling an action.
183+
* User has the flexibility of calling count and passing in any appropriate boundary
184+
*/
185+
def keysPerPartition = boundary
186+
187+
var currPartition = 0
188+
189+
/*
190+
* Pick current partition for the key until we hit the bound for keys / partition,
191+
* start allocating to next partition at that time.
192+
*
193+
* NOTE: In case where we have lets say 2000 keys and user says 3 partitions with 500
194+
* passed in as boundary, the first 500 will goto P1, 501-1000 go to P2, 1001-1500 go to P3,
195+
* after that, next keys go to one partition at a time. So 1501 goes to P1, 1502 goes to P2,
196+
* 1503 goes to P3 and so on.
197+
*/
198+
def getPartition(key: Any): Int = {
199+
val partition = currPartition
200+
counts(partition) = counts(partition) + 1
201+
/*
202+
* Since we are filling up a partition before moving to next one (this helps in maintaining
203+
* order of keys, in certain cases, it is possible to end up with empty partitions, like
204+
* 3 partitions, 500 keys / partition and if rdd has 700 keys, 1 partition will be entirely
205+
* empty.
206+
*/
207+
if(counts(currPartition) >= keysPerPartition)
208+
currPartition = (currPartition + 1) % numPartitions
209+
partition
210+
}
211+
212+
override def equals(other: Any): Boolean = other match {
213+
case r: BoundaryPartitioner[_,_] =>
214+
(r.counts.sameElements(counts) && r.boundary == boundary
215+
&& r.currPartition == currPartition)
216+
case _ =>
217+
false
218+
}
219+
}

core/src/test/scala/org/apache/spark/PartitioningSuite.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,40 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet
6666
assert(descendingP4 != p4)
6767
}
6868

69+
test("BoundaryPartitioner equality") {
70+
// Make an RDD where all the elements are the same so that the partition range bounds
71+
// are deterministically all the same.
72+
val rdd = sc.parallelize(1.to(4000)).map(x => (x, x))
73+
74+
val p2 = new BoundaryPartitioner(2, rdd, 1000)
75+
val p4 = new BoundaryPartitioner(4, rdd, 1000)
76+
val anotherP4 = new BoundaryPartitioner(4, rdd)
77+
78+
assert(p2 === p2)
79+
assert(p4 === p4)
80+
assert(p2 != p4)
81+
assert(p4 != p2)
82+
assert(p4 === anotherP4)
83+
assert(anotherP4 === p4)
84+
}
85+
86+
test("BoundaryPartitioner getPartition") {
87+
val rdd = sc.parallelize(1.to(2000)).map(x => (x, x))
88+
val partitioner = new BoundaryPartitioner(4, rdd, 500)
89+
1.to(2000).map { element => {
90+
val partition = partitioner.getPartition(element)
91+
if (element <= 500) {
92+
assert(partition === 0)
93+
} else if (element > 501 && element <= 1000) {
94+
assert(partition === 1)
95+
} else if (element > 1001 && element <= 1500) {
96+
assert(partition === 2)
97+
} else if (element > 1501 && element <= 2000) {
98+
assert(partition === 3)
99+
}
100+
}}
101+
}
102+
69103
test("RangePartitioner getPartition") {
70104
val rdd = sc.parallelize(1.to(2000)).map(x => (x, x))
71105
// We have different behaviour of getPartition for partitions with less than 1000 and more than

0 commit comments

Comments
 (0)