Skip to content

Commit a113216

Browse files
uncleGensrowen
authored andcommitted
[SPARK-12031][CORE][BUG] Integer overflow when do sampling
Author: uncleGen <[email protected]> Closes #10023 from uncleGen/1.6-bugfix.
1 parent f6883bb commit a113216

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

core/src/main/scala/org/apache/spark/Partitioner.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ private[spark] object RangePartitioner {
253253
*/
254254
def sketch[K : ClassTag](
255255
rdd: RDD[K],
256-
sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = {
256+
sampleSizePerPartition: Int): (Long, Array[(Int, Long, Array[K])]) = {
257257
val shift = rdd.id
258258
// val classTagK = classTag[K] // to avoid serializing the entire partitioner object
259259
val sketched = rdd.mapPartitionsWithIndex { (idx, iter) =>
@@ -262,7 +262,7 @@ private[spark] object RangePartitioner {
262262
iter, sampleSizePerPartition, seed)
263263
Iterator((idx, n, sample))
264264
}.collect()
265-
val numItems = sketched.map(_._2.toLong).sum
265+
val numItems = sketched.map(_._2).sum
266266
(numItems, sketched)
267267
}
268268

core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ private[spark] object SamplingUtils {
3434
input: Iterator[T],
3535
k: Int,
3636
seed: Long = Random.nextLong())
37-
: (Array[T], Int) = {
37+
: (Array[T], Long) = {
3838
val reservoir = new Array[T](k)
3939
// Put the first k elements in the reservoir.
4040
var i = 0
@@ -52,16 +52,17 @@ private[spark] object SamplingUtils {
5252
(trimReservoir, i)
5353
} else {
5454
// If input size > k, continue the sampling process.
55+
var l = i.toLong
5556
val rand = new XORShiftRandom(seed)
5657
while (input.hasNext) {
5758
val item = input.next()
58-
val replacementIndex = rand.nextInt(i)
59+
val replacementIndex = (rand.nextDouble() * l).toLong
5960
if (replacementIndex < k) {
60-
reservoir(replacementIndex) = item
61+
reservoir(replacementIndex.toInt) = item
6162
}
62-
i += 1
63+
l += 1
6364
}
64-
(reservoir, i)
65+
(reservoir, l)
6566
}
6667
}
6768

0 commit comments

Comments
 (0)