@@ -49,7 +49,7 @@ import org.apache.spark.SparkContext._
4949import org .apache .spark .partial .{BoundedDouble , PartialResult }
5050import org .apache .spark .serializer .Serializer
5151import org .apache .spark .util .Utils
52- import org .apache .spark .util .random .{PoissonBounds => PB }
52+ import org .apache .spark .util .random .{Stratum , Result , StratifiedSampler , PoissonBounds => PB }
5353
5454/**
5555 * Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
@@ -210,177 +210,32 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
210210
211211 /**
212212 * Return a subset of this RDD sampled by key (via stratified sampling).
213- * We guarantee a sample size = math.ceil(fraction * S_i), where S_i is the size of the ith
214- * stratum.
213+ *
214+ * If exact set to true, we guarantee, with high probability, a sample size =
215+ * math.ceil(fraction * S_i), where S_i is the size of the ith stratum (collection of entries
216+ * that share the same key). When sampling without replacement, we need one additional pass over
217+ * the RDD to guarantee sample size with a 99.99% confidence; when sampling with replacement, we
218+ * need two additional passes over the RDD to guarantee sample size with a 99.99% confidence.
215219 *
216220 * @param withReplacement whether to sample with or without replacement
217- * @param fraction sampling rate
221+ * @param fractionByKey function mapping key to sampling rate
218222 * @param seed seed for the random number generator
223+ * @param exact whether sample size needs to be exactly math.ceil(fraction * size) per stratum
219224 * @return RDD containing the sampled subset
220225 */
221226 def sampleByKey (withReplacement : Boolean ,
222- fraction : Double ,
223- seed : Long = Utils .random.nextLong): RDD [(K , V )]= {
224-
225- class Stratum (var numItems : Long = 0L , var numAccepted : Long = 0L ) extends Serializable {
226- var waitList : ArrayBuffer [Double ] = new ArrayBuffer [Double ]
227- var q1 : Option [Double ] = None
228- var q2 : Option [Double ] = None
229-
230- def incrNumItems (by : Long = 1L ) = numItems += by
231-
232- def incrNumAccepted (by : Long = 1L ) = numAccepted += by
233-
234- def addToWaitList (elem : Double ) = waitList += elem
235-
236- def addToWaitList (elems : ArrayBuffer [Double ]) = waitList ++= elems
237-
238- override def toString () = {
239- " numItems: " + numItems + " numAccepted: " + numAccepted + " q1: " + q1 + " q2: " + q2 +
240- " waitListSize:" + waitList.size
241- }
242- }
243-
244- class Result (var resultMap : Map [K , Stratum ], var cachedPartitionId : Option [Int ] = None )
245- extends Serializable {
246- var rand : RandomDataGenerator = new RandomDataGenerator
247-
248- def getEntry (key : K , numItems : Long = 0L ): Stratum = {
249- if (resultMap.get(key).isEmpty) {
250- resultMap += (key -> new Stratum (numItems))
251- }
252- resultMap.get(key).get
253- }
254-
255- def getRand (partitionId : Int ): RandomDataGenerator = {
256- if (cachedPartitionId.isEmpty || cachedPartitionId.get != partitionId) {
257- cachedPartitionId = Some (partitionId)
258- rand.reSeed(seed + partitionId)
259- }
260- rand
261- }
262- }
263-
264- // TODO implement the streaming version of sampling w/ replacement that doesn't require counts
265- // in order to save one pass over the RDD
266- val counts = if (withReplacement) Some (this .countByKey()) else None
267-
268- val seqOp = (U : (TaskContext , Result ), item : (K , V )) => {
269- val delta = 5e-5
270- val result = U ._2
271- val tc = U ._1
272- val rng = result.getRand(tc.partitionId)
273- val stratum = result.getEntry(item._1)
274- if (withReplacement) {
275- // compute q1 and q2 only if they haven't been computed already
276- // since they don't change from iteration to iteration.
277- // TODO change this to the streaming version
278- if (stratum.q1.isEmpty || stratum.q2.isEmpty) {
279- val n = counts.get(item._1)
280- val s = math.ceil(n * fraction).toLong
281- val lmbd1 = PB .getLambda1(s)
282- val minCount = PB .getMinCount(lmbd1)
283- val lmbd2 = if (lmbd1 == 0 ) PB .getLambda2(s) else PB .getLambda2(s - minCount)
284- val q1 = lmbd1 / n
285- val q2 = lmbd2 / n
286- stratum.q1 = Some (q1)
287- stratum.q2 = Some (q2)
288- }
289- val x1 = if (stratum.q1.get == 0 ) 0L else rng.nextPoisson(stratum.q1.get)
290- if (x1 > 0 ) {
291- stratum.incrNumAccepted(x1)
292- }
293- val x2 = rng.nextPoisson(stratum.q2.get).toInt
294- if (x2 > 0 ) {
295- stratum.addToWaitList(ArrayBuffer .fill(x2)(rng.nextUniform(0.0 , 1.0 )))
296- }
297- } else {
298- val g1 = - math.log(delta) / stratum.numItems
299- val g2 = (2.0 / 3.0 ) * g1
300- val q1 = math.max(0 , fraction + g2 - math.sqrt((g2 * g2 + 3 * g2 * fraction)))
301- val q2 = math.min(1 , fraction + g1 + math.sqrt(g1 * g1 + 2 * g1 * fraction))
302-
303- val x = rng.nextUniform(0.0 , 1.0 )
304- if (x < q1) {
305- stratum.incrNumAccepted()
306- } else if ( x < q2) {
307- stratum.addToWaitList(x)
308- }
309- stratum.q1 = Some (q1)
310- stratum.q2 = Some (q2)
311- }
312- stratum.incrNumItems()
313- result
314- }
315-
316- val combOp = (r1 : Result , r2 : Result ) => {
317- // take union of both key sets in case one partion doesn't contain all keys
318- val keyUnion = r1.resultMap.keys.toSet.union(r2.resultMap.keys.toSet)
319-
320- // Use r2 to keep the combined result since r1 is usual empty
321- for (key <- keyUnion) {
322- val entry1 = r1.resultMap.get(key)
323- val entry2 = r2.resultMap.get(key)
324- if (entry2.isEmpty && entry1.isDefined) {
325- r2.resultMap += (key -> entry1.get)
326- } else if (entry1.isDefined && entry2.isDefined) {
327- entry2.get.addToWaitList(entry1.get.waitList)
328- entry2.get.incrNumAccepted(entry1.get.numAccepted)
329- entry2.get.incrNumItems(entry1.get.numItems)
330- }
331- }
332- r2
333- }
334-
335- val zeroU = new Result (Map [K , Stratum ]())
336-
337- // determine threshold for each stratum and resample
338- val finalResult = self.aggregateWithContext(zeroU)(seqOp, combOp).resultMap
339- val thresholdByKey = new mutable.HashMap [K , Double ]()
340- for ((key, stratum) <- finalResult) {
341- val s = math.ceil(stratum.numItems * fraction).toLong
342- breakable {
343- if (stratum.numAccepted > s) {
344- logWarning(" Pre-accepted too many" )
345- thresholdByKey += (key -> stratum.q1.get)
346- break
347- }
348- val numWaitListAccepted = (s - stratum.numAccepted).toInt
349- if (numWaitListAccepted >= stratum.waitList.size) {
350- logWarning(" WaitList too short" )
351- thresholdByKey += (key -> stratum.q2.get)
352- } else {
353- thresholdByKey += (key -> stratum.waitList.sorted.apply(numWaitListAccepted))
354- }
355- }
356- }
357-
227+ fractionByKey : K => Double ,
228+ seed : Long = Utils .random.nextLong,
229+ exact : Boolean = true ): RDD [(K , V )]= {
358230 if (withReplacement) {
359- // Poisson sampler
360- self.mapPartitionsWithIndex((idx : Int , iter : Iterator [(K , V )]) => {
361- val random = new RandomDataGenerator ()
362- random.reSeed(seed + idx)
363- iter.flatMap { t =>
364- val q1 = finalResult.get(t._1).get.q1.get
365- val q2 = finalResult.get(t._1).get.q2.get
366- val x1 = if (q1 == 0 ) 0L else random.nextPoisson(q1)
367- val x2 = random.nextPoisson(q2).toInt
368- val x = x1 + (0 until x2).filter(i => random.nextUniform(0.0 , 1.0 ) <
369- thresholdByKey.get(t._1).get).size
370- if (x > 0 ) {
371- Iterator .fill(x.toInt)(t)
372- } else {
373- Iterator .empty
374- }
375- }
376- }, preservesPartitioning = true )
231+ val counts = if (exact) Some (this .countByKey()) else None
232+ val samplingFunc =
233+ StratifiedSampler .getPoissonSamplingFunction(self, fractionByKey, exact, counts, seed)
234+ self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true )
377235 } else {
378- // Bernoulli sampler
379- self.mapPartitionsWithIndex((idx : Int , iter : Iterator [(K , V )]) => {
380- val random = new RandomDataGenerator
381- random.reSeed(seed + idx)
382- iter.filter(t => random.nextUniform(0.0 , 1.0 ) < thresholdByKey.get(t._1).get)
383- }, preservesPartitioning = true )
236+ val samplingFunc =
237+ StratifiedSampler .getBernoulliSamplingFunction(self, fractionByKey, exact, seed)
238+ self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true )
384239 }
385240 }
386241
0 commit comments