Skip to content

Commit c23f5db

Browse files
mengxrpwendell
authored andcommitted
[SPARK-2251] fix concurrency issues in random sampler
The following code is very likely to throw an exception: ~~~ val rdd = sc.parallelize(0 until 111, 10).sample(false, 0.1) rdd.zip(rdd).count() ~~~ because the same random number generator is used in compute partitions. Author: Xiangrui Meng <[email protected]> Closes #1229 from mengxr/fix-sample and squashes the following commits: f1ee3d7 [Xiangrui Meng] fix concurrency issues in random sampler
1 parent d1636dd commit c23f5db

File tree

3 files changed

+38
-22
lines changed

3 files changed

+38
-22
lines changed

core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,25 +54,25 @@ trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable
5454
*/
5555
@DeveloperApi
5656
class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
57-
(implicit random: Random = new XORShiftRandom)
5857
extends RandomSampler[T, T] {
5958

60-
def this(ratio: Double)(implicit random: Random = new XORShiftRandom)
61-
= this(0.0d, ratio)(random)
59+
private[random] var rng: Random = new XORShiftRandom
6260

63-
override def setSeed(seed: Long) = random.setSeed(seed)
61+
def this(ratio: Double) = this(0.0d, ratio)
62+
63+
override def setSeed(seed: Long) = rng.setSeed(seed)
6464

6565
override def sample(items: Iterator[T]): Iterator[T] = {
6666
items.filter { item =>
67-
val x = random.nextDouble()
67+
val x = rng.nextDouble()
6868
(x >= lb && x < ub) ^ complement
6969
}
7070
}
7171

7272
/**
7373
* Return a sampler that is the complement of the range specified of the current sampler.
7474
*/
75-
def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)
75+
def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)
7676

7777
override def clone = new BernoulliSampler[T](lb, ub, complement)
7878
}
@@ -81,21 +81,21 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
8181
* :: DeveloperApi ::
8282
* A sampler based on values drawn from Poisson distribution.
8383
*
84-
* @param poisson a Poisson random number generator
84+
* @param mean Poisson mean
8585
* @tparam T item type
8686
*/
8787
@DeveloperApi
88-
class PoissonSampler[T](mean: Double)
89-
(implicit var poisson: Poisson = new Poisson(mean, new DRand))
90-
extends RandomSampler[T, T] {
88+
class PoissonSampler[T](mean: Double) extends RandomSampler[T, T] {
89+
90+
private[random] var rng = new Poisson(mean, new DRand)
9191

9292
override def setSeed(seed: Long) {
93-
poisson = new Poisson(mean, new DRand(seed.toInt))
93+
rng = new Poisson(mean, new DRand(seed.toInt))
9494
}
9595

9696
override def sample(items: Iterator[T]): Iterator[T] = {
9797
items.flatMap { item =>
98-
val count = poisson.nextInt()
98+
val count = rng.nextInt()
9999
if (count == 0) {
100100
Iterator.empty
101101
} else {

core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.rdd
2020
import org.scalatest.FunSuite
2121

2222
import org.apache.spark.SharedSparkContext
23-
import org.apache.spark.util.random.RandomSampler
23+
import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, RandomSampler}
2424

2525
/** a sampler that outputs its seed */
2626
class MockSampler extends RandomSampler[Long, Long] {
@@ -32,19 +32,29 @@ class MockSampler extends RandomSampler[Long, Long] {
3232
}
3333

3434
override def sample(items: Iterator[Long]): Iterator[Long] = {
35-
return Iterator(s)
35+
Iterator(s)
3636
}
3737

3838
override def clone = new MockSampler
3939
}
4040

4141
class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext {
4242

43-
test("seedDistribution") {
43+
test("seed distribution") {
4444
val rdd = sc.makeRDD(Array(1L, 2L, 3L, 4L), 2)
4545
val sampler = new MockSampler
4646
val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, 0L)
47-
assert(sample.distinct.count == 2, "Seeds must be different.")
47+
assert(sample.distinct().count == 2, "Seeds must be different.")
48+
}
49+
50+
test("concurrency") {
51+
// SPARK-2251: zip with self computes each partition twice.
52+
// We want to make sure there are no concurrency issues.
53+
val rdd = sc.parallelize(0 until 111, 10)
54+
for (sampler <- Seq(new BernoulliSampler[Int](0.5), new PoissonSampler[Int](0.5))) {
55+
val sampled = new PartitionwiseSampledRDD[Int, Int](rdd, sampler)
56+
sampled.zip(sampled).count()
57+
}
4858
}
4959
}
5060

core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
4242
}
4343
}
4444
whenExecuting(random) {
45-
val sampler = new BernoulliSampler[Int](0.25, 0.55)(random)
45+
val sampler = new BernoulliSampler[Int](0.25, 0.55)
46+
sampler.rng = random
4647
assert(sampler.sample(a.iterator).toList == List(3, 4, 5))
4748
}
4849
}
@@ -54,7 +55,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
5455
}
5556
}
5657
whenExecuting(random) {
57-
val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random)
58+
val sampler = new BernoulliSampler[Int](0.25, 0.55, true)
59+
sampler.rng = random
5860
assert(sampler.sample(a.iterator).toList === List(1, 2, 6, 7, 8, 9))
5961
}
6062
}
@@ -66,7 +68,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
6668
}
6769
}
6870
whenExecuting(random) {
69-
val sampler = new BernoulliSampler[Int](0.35)(random)
71+
val sampler = new BernoulliSampler[Int](0.35)
72+
sampler.rng = random
7073
assert(sampler.sample(a.iterator).toList == List(1, 2, 3))
7174
}
7275
}
@@ -78,7 +81,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
7881
}
7982
}
8083
whenExecuting(random) {
81-
val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random)
84+
val sampler = new BernoulliSampler[Int](0.25, 0.55, true)
85+
sampler.rng = random
8286
assert(sampler.sample(a.iterator).toList == List(1, 2, 6, 7, 8, 9))
8387
}
8488
}
@@ -88,7 +92,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
8892
random.setSeed(10L)
8993
}
9094
whenExecuting(random) {
91-
val sampler = new BernoulliSampler[Int](0.2)(random)
95+
val sampler = new BernoulliSampler[Int](0.2)
96+
sampler.rng = random
9297
sampler.setSeed(10L)
9398
}
9499
}
@@ -100,7 +105,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
100105
}
101106
}
102107
whenExecuting(poisson) {
103-
val sampler = new PoissonSampler[Int](0.2)(poisson)
108+
val sampler = new PoissonSampler[Int](0.2)
109+
sampler.rng = poisson
104110
assert(sampler.sample(a.iterator).toList == List(2, 3, 3, 5, 6))
105111
}
106112
}

0 commit comments

Comments
 (0)