@@ -156,3 +156,64 @@ class RangePartitioner[K : Ordering : ClassTag, V](
156156 false
157157 }
158158}
159+
160+ /**
161+ * A [[org.apache.spark.Partitioner ]] that partitions records into specified bounds
162+ * Default value is 1000. Once all partitions have bounds elements, the partitioner
163+ * allocates 1 element per partition so eventually the smaller partitions are at most
164+ * off by 1 key compared to the larger partitions.
165+ */
166+ class BoundaryPartitioner [K : Ordering : ClassTag , V ](
167+ partitions : Int ,
168+ @ transient rdd : RDD [_ <: Product2 [K ,V ]],
169+ private val boundary : Int = 1000 )
170+ extends Partitioner {
171+
172+ // this array keeps track of keys assigned to a partition
173+ // counts[0] refers to # of keys in partition 0 and so on
174+ private val counts : Array [Int ] = {
175+ new Array [Int ](numPartitions)
176+ }
177+
178+ def numPartitions = math.abs(partitions)
179+
180+ /*
181+ * Ideally, this should've been calculated based on # partitions and total keys
182+ * But we are not calling count on RDD here to avoid calling an action.
183+ * User has the flexibility of calling count and passing in any appropriate boundary
184+ */
185+ def keysPerPartition = boundary
186+
187+ var currPartition = 0
188+
189+ /*
190+ * Pick current partition for the key until we hit the bound for keys / partition,
191+ * start allocating to next partition at that time.
192+ *
193+ * NOTE: In case where we have lets say 2000 keys and user says 3 partitions with 500
194+ * passed in as boundary, the first 500 will goto P1, 501-1000 go to P2, 1001-1500 go to P3,
195+ * after that, next keys go to one partition at a time. So 1501 goes to P1, 1502 goes to P2,
196+ * 1503 goes to P3 and so on.
197+ */
198+ def getPartition (key : Any ): Int = {
199+ val partition = currPartition
200+ counts(partition) = counts(partition) + 1
201+ /*
202+ * Since we are filling up a partition before moving to next one (this helps in maintaining
203+ * order of keys, in certain cases, it is possible to end up with empty partitions, like
204+ * 3 partitions, 500 keys / partition and if rdd has 700 keys, 1 partition will be entirely
205+ * empty.
206+ */
207+ if (counts(currPartition) >= keysPerPartition)
208+ currPartition = (currPartition + 1 ) % numPartitions
209+ partition
210+ }
211+
212+ override def equals (other : Any ): Boolean = other match {
213+ case r : BoundaryPartitioner [_,_] =>
214+ (r.counts.sameElements(counts) && r.boundary == boundary
215+ && r.currPartition == currPartition)
216+ case _ =>
217+ false
218+ }
219+ }
0 commit comments