Skip to content

Commit 1fe1cff

Browse files
committed
Changed fractionByKey to a map to enable arg check
1 parent 944a10c commit 1fe1cff

File tree

3 files changed

+9
-13
lines changed

3 files changed

+9
-13
lines changed

core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
215215
* math.ceil(fraction * S_i), where S_i is the size of the ith stratum (collection of entries
216216
* that share the same key). When sampling without replacement, we need one additional pass over
217217
* the RDD to guarantee sample size with a 99.99% confidence; when sampling with replacement, we
218-
* need two additional passes over the RDD to guarantee sample size with a 99.99% confidence.
219-
*
220-
* Note that if the sampling rate for any stratum is < 1e-10, we will throw an exception to
221-
* avoid not being able to ever create the sample as an artifact of the RNG's quality.
218+
* need two additional passes.
222219
*
223220
* @param withReplacement whether to sample with or without replacement
224221
* @param fractionByKey function mapping key to sampling rate
@@ -227,14 +224,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
227224
* @return RDD containing the sampled subset
228225
*/
229226
def sampleByKey(withReplacement: Boolean,
230-
fractionByKey: K => Double,
227+
fractionByKey: Map[K, Double],
231228
seed: Long = Utils.random.nextLong,
232229
exact: Boolean = true): RDD[(K, V)]= {
233-
234-
require(fractionByKey.asInstanceOf[Map[K, Double]].forall({case(k, v) => v >= 1e-10}),
235-
"Unable to support sampling rates < 1e-10.")
236-
237230
if (withReplacement) {
231+
require(fractionByKey.forall({case(k, v) => v >= 0.0}), "Invalid sampling rates.")
238232
val counts = if (exact) Some(this.countByKey()) else None
239233
val samplingFunc =
240234
StratifiedSampler.getPoissonSamplingFunction(self, fractionByKey, exact, counts, seed)

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,12 +351,11 @@ abstract class RDD[T: ClassTag](
351351
/**
352352
* Return a sampled subset of this RDD.
353353
*
354-
* fraction < 1e-10 not supported.
355354
*/
356355
def sample(withReplacement: Boolean,
357356
fraction: Double,
358357
seed: Long = Utils.random.nextLong): RDD[T] = {
359-
require(fraction >= 1e-10, "Invalid fraction value: " + fraction)
358+
require(fraction >= 0.0, "Invalid fraction value: " + fraction)
360359
if (withReplacement) {
361360
new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed)
362361
} else {

core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ private[spark] object SamplingUtils {
3737
* ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success
3838
* rate, where success rate is defined the same as in sampling with replacement.
3939
*
40+
* The smallest sampling rate supported is 1e-10 (in order to avoid running into the limit of the
41+
* RNG's resolution).
42+
*
4043
* @param sampleSizeLowerBound sample size
4144
* @param total size of RDD
4245
* @param withReplacement whether sampling with replacement
@@ -47,11 +50,11 @@ private[spark] object SamplingUtils {
4750
val fraction = sampleSizeLowerBound.toDouble / total
4851
if (withReplacement) {
4952
val numStDev = if (sampleSizeLowerBound < 12) 9 else 5
50-
fraction + numStDev * math.sqrt(fraction / total)
53+
math.max(1e-10, fraction + numStDev * math.sqrt(fraction / total))
5154
} else {
5255
val delta = 1e-4
5356
val gamma = - math.log(delta) / total
54-
math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))
57+
math.min(1, math.max(1e-10, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction)))
5558
}
5659
}
5760
}

0 commit comments

Comments
 (0)