Skip to content

Commit 9e74ab5

Browse files
committed
Separated out most of the logic in sampleByKey
into StratifiedSampler in util.random
1 parent 7327611 commit 9e74ab5

File tree

3 files changed

+309
-190
lines changed

3 files changed

+309
-190
lines changed

core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala

Lines changed: 19 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ import org.apache.spark.SparkContext._
4949
import org.apache.spark.partial.{BoundedDouble, PartialResult}
5050
import org.apache.spark.serializer.Serializer
5151
import org.apache.spark.util.Utils
52-
import org.apache.spark.util.random.{PoissonBounds => PB}
52+
import org.apache.spark.util.random.{Stratum, Result, StratifiedSampler, PoissonBounds => PB}
5353

5454
/**
5555
* Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
@@ -210,177 +210,32 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
210210

211211
/**
212212
* Return a subset of this RDD sampled by key (via stratified sampling).
213-
* We guarantee a sample size = math.ceil(fraction * S_i), where S_i is the size of the ith
214-
* stratum.
213+
*
214+
* If exact set to true, we guarantee, with high probability, a sample size =
215+
* math.ceil(fraction * S_i), where S_i is the size of the ith stratum (collection of entries
216+
* that share the same key). When sampling without replacement, we need one additional pass over
217+
* the RDD to guarantee sample size with a 99.99% confidence; when sampling with replacement, we
218+
* need two additional passes over the RDD to guarantee sample size with a 99.99% confidence.
215219
*
216220
* @param withReplacement whether to sample with or without replacement
217-
* @param fraction sampling rate
221+
* @param fractionByKey function mapping key to sampling rate
218222
* @param seed seed for the random number generator
223+
* @param exact whether sample size needs to be exactly math.ceil(fraction * size) per stratum
219224
* @return RDD containing the sampled subset
220225
*/
221226
def sampleByKey(withReplacement: Boolean,
222-
fraction: Double,
223-
seed: Long = Utils.random.nextLong): RDD[(K, V)]= {
224-
225-
class Stratum(var numItems: Long = 0L, var numAccepted: Long = 0L) extends Serializable {
226-
var waitList: ArrayBuffer[Double] = new ArrayBuffer[Double]
227-
var q1: Option[Double] = None
228-
var q2: Option[Double] = None
229-
230-
def incrNumItems(by: Long = 1L) = numItems += by
231-
232-
def incrNumAccepted(by: Long = 1L) = numAccepted += by
233-
234-
def addToWaitList(elem: Double) = waitList += elem
235-
236-
def addToWaitList(elems: ArrayBuffer[Double]) = waitList ++= elems
237-
238-
override def toString() = {
239-
"numItems: " + numItems + " numAccepted: " + numAccepted + " q1: " + q1 + " q2: " + q2 +
240-
" waitListSize:" + waitList.size
241-
}
242-
}
243-
244-
class Result(var resultMap: Map[K, Stratum], var cachedPartitionId: Option[Int] = None)
245-
extends Serializable {
246-
var rand: RandomDataGenerator = new RandomDataGenerator
247-
248-
def getEntry(key: K, numItems: Long = 0L): Stratum = {
249-
if (resultMap.get(key).isEmpty) {
250-
resultMap += (key -> new Stratum(numItems))
251-
}
252-
resultMap.get(key).get
253-
}
254-
255-
def getRand(partitionId: Int): RandomDataGenerator = {
256-
if (cachedPartitionId.isEmpty || cachedPartitionId.get != partitionId) {
257-
cachedPartitionId = Some(partitionId)
258-
rand.reSeed(seed + partitionId)
259-
}
260-
rand
261-
}
262-
}
263-
264-
// TODO implement the streaming version of sampling w/ replacement that doesn't require counts
265-
// in order to save one pass over the RDD
266-
val counts = if (withReplacement) Some(this.countByKey()) else None
267-
268-
val seqOp = (U: (TaskContext, Result), item: (K, V)) => {
269-
val delta = 5e-5
270-
val result = U._2
271-
val tc = U._1
272-
val rng = result.getRand(tc.partitionId)
273-
val stratum = result.getEntry(item._1)
274-
if (withReplacement) {
275-
// compute q1 and q2 only if they haven't been computed already
276-
// since they don't change from iteration to iteration.
277-
// TODO change this to the streaming version
278-
if (stratum.q1.isEmpty || stratum.q2.isEmpty) {
279-
val n = counts.get(item._1)
280-
val s = math.ceil(n * fraction).toLong
281-
val lmbd1 = PB.getLambda1(s)
282-
val minCount = PB.getMinCount(lmbd1)
283-
val lmbd2 = if (lmbd1 == 0) PB.getLambda2(s) else PB.getLambda2(s - minCount)
284-
val q1 = lmbd1 / n
285-
val q2 = lmbd2 / n
286-
stratum.q1 = Some(q1)
287-
stratum.q2 = Some(q2)
288-
}
289-
val x1 = if (stratum.q1.get == 0) 0L else rng.nextPoisson(stratum.q1.get)
290-
if (x1 > 0) {
291-
stratum.incrNumAccepted(x1)
292-
}
293-
val x2 = rng.nextPoisson(stratum.q2.get).toInt
294-
if (x2 > 0) {
295-
stratum.addToWaitList(ArrayBuffer.fill(x2)(rng.nextUniform(0.0, 1.0)))
296-
}
297-
} else {
298-
val g1 = - math.log(delta) / stratum.numItems
299-
val g2 = (2.0 / 3.0) * g1
300-
val q1 = math.max(0, fraction + g2 - math.sqrt((g2 * g2 + 3 * g2 * fraction)))
301-
val q2 = math.min(1, fraction + g1 + math.sqrt(g1 * g1 + 2 * g1 * fraction))
302-
303-
val x = rng.nextUniform(0.0, 1.0)
304-
if (x < q1) {
305-
stratum.incrNumAccepted()
306-
} else if ( x < q2) {
307-
stratum.addToWaitList(x)
308-
}
309-
stratum.q1 = Some(q1)
310-
stratum.q2 = Some(q2)
311-
}
312-
stratum.incrNumItems()
313-
result
314-
}
315-
316-
val combOp = (r1: Result, r2: Result) => {
317-
// take union of both key sets in case one partion doesn't contain all keys
318-
val keyUnion = r1.resultMap.keys.toSet.union(r2.resultMap.keys.toSet)
319-
320-
// Use r2 to keep the combined result since r1 is usual empty
321-
for (key <- keyUnion) {
322-
val entry1 = r1.resultMap.get(key)
323-
val entry2 = r2.resultMap.get(key)
324-
if (entry2.isEmpty && entry1.isDefined) {
325-
r2.resultMap += (key -> entry1.get)
326-
} else if (entry1.isDefined && entry2.isDefined) {
327-
entry2.get.addToWaitList(entry1.get.waitList)
328-
entry2.get.incrNumAccepted(entry1.get.numAccepted)
329-
entry2.get.incrNumItems(entry1.get.numItems)
330-
}
331-
}
332-
r2
333-
}
334-
335-
val zeroU = new Result(Map[K, Stratum]())
336-
337-
// determine threshold for each stratum and resample
338-
val finalResult = self.aggregateWithContext(zeroU)(seqOp, combOp).resultMap
339-
val thresholdByKey = new mutable.HashMap[K, Double]()
340-
for ((key, stratum) <- finalResult) {
341-
val s = math.ceil(stratum.numItems * fraction).toLong
342-
breakable {
343-
if (stratum.numAccepted > s) {
344-
logWarning("Pre-accepted too many")
345-
thresholdByKey += (key -> stratum.q1.get)
346-
break
347-
}
348-
val numWaitListAccepted = (s - stratum.numAccepted).toInt
349-
if (numWaitListAccepted >= stratum.waitList.size) {
350-
logWarning("WaitList too short")
351-
thresholdByKey += (key -> stratum.q2.get)
352-
} else {
353-
thresholdByKey += (key -> stratum.waitList.sorted.apply(numWaitListAccepted))
354-
}
355-
}
356-
}
357-
227+
fractionByKey: K => Double,
228+
seed: Long = Utils.random.nextLong,
229+
exact: Boolean = true): RDD[(K, V)]= {
358230
if (withReplacement) {
359-
// Poisson sampler
360-
self.mapPartitionsWithIndex((idx: Int, iter: Iterator[(K, V)]) => {
361-
val random = new RandomDataGenerator()
362-
random.reSeed(seed + idx)
363-
iter.flatMap { t =>
364-
val q1 = finalResult.get(t._1).get.q1.get
365-
val q2 = finalResult.get(t._1).get.q2.get
366-
val x1 = if (q1 == 0) 0L else random.nextPoisson(q1)
367-
val x2 = random.nextPoisson(q2).toInt
368-
val x = x1 + (0 until x2).filter(i => random.nextUniform(0.0, 1.0) <
369-
thresholdByKey.get(t._1).get).size
370-
if (x > 0) {
371-
Iterator.fill(x.toInt)(t)
372-
} else {
373-
Iterator.empty
374-
}
375-
}
376-
}, preservesPartitioning = true)
231+
val counts = if (exact) Some(this.countByKey()) else None
232+
val samplingFunc =
233+
StratifiedSampler.getPoissonSamplingFunction(self, fractionByKey, exact, counts, seed)
234+
self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
377235
} else {
378-
// Bernoulli sampler
379-
self.mapPartitionsWithIndex((idx: Int, iter: Iterator[(K, V)]) => {
380-
val random = new RandomDataGenerator
381-
random.reSeed(seed + idx)
382-
iter.filter(t => random.nextUniform(0.0, 1.0) < thresholdByKey.get(t._1).get)
383-
}, preservesPartitioning = true)
236+
val samplingFunc =
237+
StratifiedSampler.getBernoulliSamplingFunction(self, fractionByKey, exact, seed)
238+
self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
384239
}
385240
}
386241

0 commit comments

Comments
 (0)