Skip to content

Commit bd9dc6e

Browse files
committed
unit bug and style violation fixed
1 parent 1fe1cff commit bd9dc6e

File tree

3 files changed

+5
-4
lines changed

3 files changed

+5
-4
lines changed

core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
227227
fractionByKey: Map[K, Double],
228228
seed: Long = Utils.random.nextLong,
229229
exact: Boolean = true): RDD[(K, V)]= {
230+
require(fractionByKey.forall({case(k, v) => v >= 0.0}), "Invalid sampling rates.")
230231
if (withReplacement) {
231-
require(fractionByKey.forall({case(k, v) => v >= 0.0}), "Invalid sampling rates.")
232232
val counts = if (exact) Some(this.countByKey()) else None
233233
val samplingFunc =
234234
StratifiedSampler.getPoissonSamplingFunction(self, fractionByKey, exact, counts, seed)

core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ private[spark] object SamplingUtils {
5454
} else {
5555
val delta = 1e-4
5656
val gamma = - math.log(delta) / total
57-
math.min(1, math.max(1e-10, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction)))
57+
math.min(1,
58+
math.max(1e-10, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction)))
5859
}
5960
}
6061
}

core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
106106
n: Long) = {
107107
val expectedSampleSize = stratifiedData.countByKey().mapValues(count =>
108108
math.ceil(count * samplingRate).toInt)
109-
val fractionByKey = (_:String) => samplingRate
109+
val fractionByKey = Map("1" -> samplingRate, "0" -> samplingRate)
110110
val sample = stratifiedData.sampleByKey(false, fractionByKey, seed, exact)
111111
val sampleCounts = sample.countByKey()
112112
val takeSample = sample.collect()
@@ -124,7 +124,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
124124
n: Long) = {
125125
val expectedSampleSize = stratifiedData.countByKey().mapValues(count =>
126126
math.ceil(count * samplingRate).toInt)
127-
val fractionByKey = (_:String) => samplingRate
127+
val fractionByKey = Map("1" -> samplingRate, "0" -> samplingRate)
128128
val sample = stratifiedData.sampleByKey(true, fractionByKey, seed, exact)
129129
val sampleCounts = sample.countByKey()
130130
val takeSample = sample.collect()

0 commit comments

Comments
 (0)