@@ -27,8 +27,12 @@ import scala.collection.Map
2727import scala .collection .mutable
2828import scala .collection .mutable .ArrayBuffer
2929import scala .reflect .ClassTag
30+ import scala .util .control .Breaks ._
3031
3132import com .clearspring .analytics .stream .cardinality .HyperLogLog
33+
34+ import org .apache .commons .math3 .random .RandomDataGenerator
35+
3236import org .apache .hadoop .conf .{Configurable , Configuration }
3337import org .apache .hadoop .fs .FileSystem
3438import org .apache .hadoop .io .SequenceFile .CompressionType
@@ -46,7 +50,8 @@ import org.apache.spark.Partitioner.defaultPartitioner
4650import org .apache .spark .SparkContext ._
4751import org .apache .spark .partial .{BoundedDouble , PartialResult }
4852import org .apache .spark .serializer .Serializer
49- import org .apache .spark .util .SerializableHyperLogLog
53+ import org .apache .spark .util .{Utils , SerializableHyperLogLog }
54+ import org .apache .spark .util .random .{PoissonBounds => PB }
5055
5156/**
5257 * Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
@@ -155,6 +160,182 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
155160 foldByKey(zeroValue, defaultPartitioner(self))(func)
156161 }
157162
163+ /**
164+ * Return a subset of this RDD sampled by key (via stratified sampling).
165+ * We guarantee a sample size = math.ceil(fraction * S_i), where S_i is the size of the ith
166+ * stratum.
167+ *
168+ * @param withReplacement whether to sample with or without replacement
169+ * @param fraction sampling rate
170+ * @param seed seed for the random number generator
171+ * @return RDD containing the sampled subset
172+ */
173+ def sampleByKey (withReplacement : Boolean ,
174+ fraction : Double ,
175+ seed : Long = Utils .random.nextLong): RDD [(K , V )]= {
176+
177+ class Stratum (var numItems : Long = 0L , var numAccepted : Long = 0L ) extends Serializable {
178+ var waitList : ArrayBuffer [Double ] = new ArrayBuffer [Double ]
179+ var q1 : Option [Double ] = None
180+ var q2 : Option [Double ] = None
181+
182+ def incrNumItems (by : Long = 1L ) = numItems += by
183+
184+ def incrNumAccepted (by : Long = 1L ) = numAccepted += by
185+
186+ def addToWaitList (elem : Double ) = waitList += elem
187+
188+ def addToWaitList (elems : ArrayBuffer [Double ]) = waitList ++= elems
189+
190+ override def toString () = {
191+ " numItems: " + numItems + " numAccepted: " + numAccepted + " q1: " + q1 + " q2: " + q2 +
192+ " waitListSize:" + waitList.size
193+ }
194+ }
195+
196+ class Result (var resultMap : Map [K , Stratum ], var cachedPartitionId : Option [Int ] = None )
197+ extends Serializable {
198+ var rand : RandomDataGenerator = new RandomDataGenerator
199+
200+ def getEntry (key : K , numItems : Long = 0L ): Stratum = {
201+ if (resultMap.get(key).isEmpty) {
202+ resultMap += (key -> new Stratum (numItems))
203+ }
204+ resultMap.get(key).get
205+ }
206+
207+ def getRand (partitionId : Int ): RandomDataGenerator = {
208+ if (cachedPartitionId.isEmpty || cachedPartitionId.get != partitionId) {
209+ cachedPartitionId = Some (partitionId)
210+ rand.reSeed(seed + partitionId)
211+ }
212+ rand
213+ }
214+ }
215+
216+ // TODO implement the streaming version of sampling w/ replacement that doesn't require counts
217+ // in order to save one pass over the RDD
218+ val counts = if (withReplacement) Some (this .countByKey()) else None
219+
220+ val seqOp = (U : (TaskContext , Result ), item : (K , V )) => {
221+ val delta = 5e-5
222+ val result = U ._2
223+ val tc = U ._1
224+ val rng = result.getRand(tc.partitionId)
225+ val stratum = result.getEntry(item._1)
226+ if (withReplacement) {
227+ // compute q1 and q2 only if they haven't been computed already
228+ // since they don't change from iteration to iteration.
229+ // TODO change this to the streaming version
230+ if (stratum.q1.isEmpty || stratum.q2.isEmpty) {
231+ val n = counts.get(item._1)
232+ val s = math.ceil(n * fraction).toLong
233+ val lmbd1 = PB .getLambda1(s)
234+ val minCount = PB .getMinCount(lmbd1)
235+ val lmbd2 = if (lmbd1 == 0 ) PB .getLambda2(s) else PB .getLambda2(s - minCount)
236+ val q1 = lmbd1 / n
237+ val q2 = lmbd2 / n
238+ stratum.q1 = Some (q1)
239+ stratum.q2 = Some (q2)
240+ }
241+ val x1 = if (stratum.q1.get == 0 ) 0L else rng.nextPoisson(stratum.q1.get)
242+ if (x1 > 0 ) {
243+ stratum.incrNumAccepted(x1)
244+ }
245+ val x2 = rng.nextPoisson(stratum.q2.get).toInt
246+ if (x2 > 0 ) {
247+ stratum.addToWaitList(ArrayBuffer .fill(x2)(rng.nextUniform(0.0 , 1.0 )))
248+ }
249+ } else {
250+ val g1 = - math.log(delta) / stratum.numItems
251+ val g2 = (2.0 / 3.0 ) * g1
252+ val q1 = math.max(0 , fraction + g2 - math.sqrt((g2 * g2 + 3 * g2 * fraction)))
253+ val q2 = math.min(1 , fraction + g1 + math.sqrt(g1 * g1 + 2 * g1 * fraction))
254+
255+ val x = rng.nextUniform(0.0 , 1.0 )
256+ if (x < q1) {
257+ stratum.incrNumAccepted()
258+ } else if ( x < q2) {
259+ stratum.addToWaitList(x)
260+ }
261+ stratum.q1 = Some (q1)
262+ stratum.q2 = Some (q2)
263+ }
264+ stratum.incrNumItems()
265+ result
266+ }
267+
268+ val combOp = (r1 : Result , r2 : Result ) => {
269+ // take union of both key sets in case one partion doesn't contain all keys
270+ val keyUnion = r1.resultMap.keys.toSet.union(r2.resultMap.keys.toSet)
271+
272+ // Use r2 to keep the combined result since r1 is usual empty
273+ for (key <- keyUnion) {
274+ val entry1 = r1.resultMap.get(key)
275+ val entry2 = r2.resultMap.get(key)
276+ if (entry2.isEmpty && entry1.isDefined) {
277+ r2.resultMap += (key -> entry1.get)
278+ } else if (entry1.isDefined && entry2.isDefined) {
279+ entry2.get.addToWaitList(entry1.get.waitList)
280+ entry2.get.incrNumAccepted(entry1.get.numAccepted)
281+ entry2.get.incrNumItems(entry1.get.numItems)
282+ }
283+ }
284+ r2
285+ }
286+
287+ val zeroU = new Result (Map [K , Stratum ]())
288+
289+ // determine threshold for each stratum and resample
290+ val finalResult = self.aggregateWithContext(zeroU)(seqOp, combOp).resultMap
291+ val thresholdByKey = new mutable.HashMap [K , Double ]()
292+ for ((key, stratum) <- finalResult) {
293+ val s = math.ceil(stratum.numItems * fraction).toLong
294+ breakable {
295+ if (stratum.numAccepted > s) {
296+ logWarning(" Pre-accepted too many" )
297+ thresholdByKey += (key -> stratum.q1.get)
298+ break
299+ }
300+ val numWaitListAccepted = (s - stratum.numAccepted).toInt
301+ if (numWaitListAccepted >= stratum.waitList.size) {
302+ logWarning(" WaitList too short" )
303+ thresholdByKey += (key -> stratum.q2.get)
304+ } else {
305+ thresholdByKey += (key -> stratum.waitList.sorted.apply(numWaitListAccepted))
306+ }
307+ }
308+ }
309+
310+ if (withReplacement) {
311+ // Poisson sampler
312+ self.mapPartitionsWithIndex((idx : Int , iter : Iterator [(K , V )]) => {
313+ val random = new RandomDataGenerator ()
314+ random.reSeed(seed + idx)
315+ iter.flatMap { t =>
316+ val q1 = finalResult.get(t._1).get.q1.get
317+ val q2 = finalResult.get(t._1).get.q2.get
318+ val x1 = if (q1 == 0 ) 0L else random.nextPoisson(q1)
319+ val x2 = random.nextPoisson(q2).toInt
320+ val x = x1 + (0 until x2).filter(i => random.nextUniform(0.0 , 1.0 ) <
321+ thresholdByKey.get(t._1).get).size
322+ if (x > 0 ) {
323+ Iterator .fill(x.toInt)(t)
324+ } else {
325+ Iterator .empty
326+ }
327+ }
328+ }, preservesPartitioning = true )
329+ } else {
330+ // Bernoulli sampler
331+ self.mapPartitionsWithIndex((idx : Int , iter : Iterator [(K , V )]) => {
332+ val random = new RandomDataGenerator
333+ random.reSeed(seed+ idx)
334+ iter.filter(t => random.nextUniform(0.0 , 1.0 ) < thresholdByKey.get(t._1).get)
335+ }, preservesPartitioning = true )
336+ }
337+ }
338+
158339 /**
159340 * Merge the values for each key using an associative reduce function. This will also perform
160341 * the merging locally on each mapper before sending results to a reducer, similarly to a
@@ -442,6 +623,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
442623
443624 /**
444625 * Return the key-value pairs in this RDD to the master as a Map.
626+ *
627+ * Warning: this doesn't return a multimap (so if you have multiple values to the same key, only
628+ * one value per key is preserved in the map returned)
445629 */
446630 def collectAsMap (): Map [K , V ] = {
447631 val data = self.collect()
0 commit comments