Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,25 +54,25 @@ trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable
*/
@DeveloperApi
class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could dropping this implicit break source and binary compatiblity? I think we'd like to avoid asking people to make code changes to upgrade to a bug-fix release, even if the API's are marked as developer. Can you just leave the existing argument and just ignore it?

(implicit random: Random = new XORShiftRandom)
extends RandomSampler[T, T] {

def this(ratio: Double)(implicit random: Random = new XORShiftRandom)
= this(0.0d, ratio)(random)
private[random] var rng: Random = new XORShiftRandom

override def setSeed(seed: Long) = random.setSeed(seed)
def this(ratio: Double) = this(0.0d, ratio)

override def setSeed(seed: Long) = rng.setSeed(seed)

override def sample(items: Iterator[T]): Iterator[T] = {
items.filter { item =>
val x = random.nextDouble()
val x = rng.nextDouble()
(x >= lb && x < ub) ^ complement
}
}

/**
* Return a sampler that is the complement of the range specified of the current sampler.
*/
def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)
def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)

override def clone = new BernoulliSampler[T](lb, ub, complement)
}
Expand All @@ -81,21 +81,21 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
* :: DeveloperApi ::
* A sampler based on values drawn from Poisson distribution.
*
* @param poisson a Poisson random number generator
* @param mean Poisson mean
* @tparam T item type
*/
@DeveloperApi
class PoissonSampler[T](mean: Double)
(implicit var poisson: Poisson = new Poisson(mean, new DRand))
extends RandomSampler[T, T] {
class PoissonSampler[T](mean: Double) extends RandomSampler[T, T] {

private[random] var rng = new Poisson(mean, new DRand)

override def setSeed(seed: Long) {
poisson = new Poisson(mean, new DRand(seed.toInt))
rng = new Poisson(mean, new DRand(seed.toInt))
}

override def sample(items: Iterator[T]): Iterator[T] = {
items.flatMap { item =>
val count = poisson.nextInt()
val count = rng.nextInt()
if (count == 0) {
Iterator.empty
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.rdd
import org.scalatest.FunSuite

import org.apache.spark.SharedSparkContext
import org.apache.spark.util.random.RandomSampler
import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, RandomSampler}

/** a sampler that outputs its seed */
class MockSampler extends RandomSampler[Long, Long] {
Expand All @@ -32,19 +32,29 @@ class MockSampler extends RandomSampler[Long, Long] {
}

override def sample(items: Iterator[Long]): Iterator[Long] = {
return Iterator(s)
Iterator(s)
}

override def clone = new MockSampler
}

class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext {

test("seedDistribution") {
test("seed distribution") {
val rdd = sc.makeRDD(Array(1L, 2L, 3L, 4L), 2)
val sampler = new MockSampler
val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, 0L)
assert(sample.distinct.count == 2, "Seeds must be different.")
assert(sample.distinct().count == 2, "Seeds must be different.")
}

test("concurrency") {
// SPARK-2251: zip with self computes each partition twice.
// We want to make sure there are no concurrency issues.
val rdd = sc.parallelize(0 until 111, 10)
for (sampler <- Seq(new BernoulliSampler[Int](0.5), new PoissonSampler[Int](0.5))) {
val sampled = new PartitionwiseSampledRDD[Int, Int](rdd, sampler)
sampled.zip(sampled).count()
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
}
whenExecuting(random) {
val sampler = new BernoulliSampler[Int](0.25, 0.55)(random)
val sampler = new BernoulliSampler[Int](0.25, 0.55)
sampler.rng = random
assert(sampler.sample(a.iterator).toList == List(3, 4, 5))
}
}
Expand All @@ -54,7 +55,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
}
whenExecuting(random) {
val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random)
val sampler = new BernoulliSampler[Int](0.25, 0.55, true)
sampler.rng = random
assert(sampler.sample(a.iterator).toList === List(1, 2, 6, 7, 8, 9))
}
}
Expand All @@ -66,7 +68,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
}
whenExecuting(random) {
val sampler = new BernoulliSampler[Int](0.35)(random)
val sampler = new BernoulliSampler[Int](0.35)
sampler.rng = random
assert(sampler.sample(a.iterator).toList == List(1, 2, 3))
}
}
Expand All @@ -78,7 +81,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
}
whenExecuting(random) {
val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random)
val sampler = new BernoulliSampler[Int](0.25, 0.55, true)
sampler.rng = random
assert(sampler.sample(a.iterator).toList == List(1, 2, 6, 7, 8, 9))
}
}
Expand All @@ -88,7 +92,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
random.setSeed(10L)
}
whenExecuting(random) {
val sampler = new BernoulliSampler[Int](0.2)(random)
val sampler = new BernoulliSampler[Int](0.2)
sampler.rng = random
sampler.setSeed(10L)
}
}
Expand All @@ -100,7 +105,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
}
whenExecuting(poisson) {
val sampler = new PoissonSampler[Int](0.2)(poisson)
val sampler = new PoissonSampler[Int](0.2)
sampler.rng = poisson
assert(sampler.sample(a.iterator).toList == List(2, 3, 3, 5, 6))
}
}
Expand Down