Skip to content

Commit 9ee94ee

Browse files
committed
[SPARK-2082] stratified sampling in PairRDDFunctions that guarantees exact sample size
1 parent e3fd6a6 commit 9ee94ee

File tree

7 files changed

+379
-4
lines changed

7 files changed

+379
-4
lines changed

core/pom.xml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@
7070
<dependency>
7171
<groupId>org.apache.commons</groupId>
7272
<artifactId>commons-math3</artifactId>
73-
<scope>test</scope>
7473
</dependency>
7574
<dependency>
7675
<groupId>com.google.code.findbugs</groupId>

core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala

Lines changed: 185 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,12 @@ import scala.collection.Map
2727
import scala.collection.mutable
2828
import scala.collection.mutable.ArrayBuffer
2929
import scala.reflect.ClassTag
30+
import scala.util.control.Breaks._
3031

3132
import com.clearspring.analytics.stream.cardinality.HyperLogLog
33+
34+
import org.apache.commons.math3.random.RandomDataGenerator
35+
3236
import org.apache.hadoop.conf.{Configurable, Configuration}
3337
import org.apache.hadoop.fs.FileSystem
3438
import org.apache.hadoop.io.SequenceFile.CompressionType
@@ -46,7 +50,8 @@ import org.apache.spark.Partitioner.defaultPartitioner
4650
import org.apache.spark.SparkContext._
4751
import org.apache.spark.partial.{BoundedDouble, PartialResult}
4852
import org.apache.spark.serializer.Serializer
49-
import org.apache.spark.util.SerializableHyperLogLog
53+
import org.apache.spark.util.{Utils, SerializableHyperLogLog}
54+
import org.apache.spark.util.random.{PoissonBounds => PB}
5055

5156
/**
5257
* Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
@@ -155,6 +160,182 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
155160
foldByKey(zeroValue, defaultPartitioner(self))(func)
156161
}
157162

163+
/**
164+
* Return a subset of this RDD sampled by key (via stratified sampling).
165+
* We guarantee a sample size = math.ceil(fraction * S_i), where S_i is the size of the ith
166+
* stratum.
167+
*
168+
* @param withReplacement whether to sample with or without replacement
169+
* @param fraction sampling rate
170+
* @param seed seed for the random number generator
171+
* @return RDD containing the sampled subset
172+
*/
173+
def sampleByKey(withReplacement: Boolean,
174+
fraction: Double,
175+
seed: Long = Utils.random.nextLong): RDD[(K, V)]= {
176+
177+
class Stratum(var numItems: Long = 0L, var numAccepted: Long = 0L) extends Serializable {
178+
var waitList: ArrayBuffer[Double] = new ArrayBuffer[Double]
179+
var q1: Option[Double] = None
180+
var q2: Option[Double] = None
181+
182+
def incrNumItems(by: Long = 1L) = numItems += by
183+
184+
def incrNumAccepted(by: Long = 1L) = numAccepted += by
185+
186+
def addToWaitList(elem: Double) = waitList += elem
187+
188+
def addToWaitList(elems: ArrayBuffer[Double]) = waitList ++= elems
189+
190+
override def toString() = {
191+
"numItems: " + numItems + " numAccepted: " + numAccepted + " q1: " + q1 + " q2: " + q2 +
192+
" waitListSize:" + waitList.size
193+
}
194+
}
195+
196+
class Result(var resultMap: Map[K, Stratum], var cachedPartitionId: Option[Int] = None)
197+
extends Serializable {
198+
var rand: RandomDataGenerator = new RandomDataGenerator
199+
200+
def getEntry(key: K, numItems: Long = 0L): Stratum = {
201+
if (resultMap.get(key).isEmpty) {
202+
resultMap += (key -> new Stratum(numItems))
203+
}
204+
resultMap.get(key).get
205+
}
206+
207+
def getRand(partitionId: Int): RandomDataGenerator = {
208+
if (cachedPartitionId.isEmpty || cachedPartitionId.get != partitionId) {
209+
cachedPartitionId = Some(partitionId)
210+
rand.reSeed(seed + partitionId)
211+
}
212+
rand
213+
}
214+
}
215+
216+
// TODO implement the streaming version of sampling w/ replacement that doesn't require counts
217+
// in order to save one pass over the RDD
218+
val counts = if (withReplacement) Some(this.countByKey()) else None
219+
220+
val seqOp = (U: (TaskContext, Result), item: (K, V)) => {
221+
val delta = 5e-5
222+
val result = U._2
223+
val tc = U._1
224+
val rng = result.getRand(tc.partitionId)
225+
val stratum = result.getEntry(item._1)
226+
if (withReplacement) {
227+
// compute q1 and q2 only if they haven't been computed already
228+
// since they don't change from iteration to iteration.
229+
// TODO change this to the streaming version
230+
if (stratum.q1.isEmpty || stratum.q2.isEmpty) {
231+
val n = counts.get(item._1)
232+
val s = math.ceil(n * fraction).toLong
233+
val lmbd1 = PB.getLambda1(s)
234+
val minCount = PB.getMinCount(lmbd1)
235+
val lmbd2 = if (lmbd1 == 0) PB.getLambda2(s) else PB.getLambda2(s - minCount)
236+
val q1 = lmbd1 / n
237+
val q2 = lmbd2 / n
238+
stratum.q1 = Some(q1)
239+
stratum.q2 = Some(q2)
240+
}
241+
val x1 = if (stratum.q1.get == 0) 0L else rng.nextPoisson(stratum.q1.get)
242+
if (x1 > 0) {
243+
stratum.incrNumAccepted(x1)
244+
}
245+
val x2 = rng.nextPoisson(stratum.q2.get).toInt
246+
if (x2 > 0) {
247+
stratum.addToWaitList(ArrayBuffer.fill(x2)(rng.nextUniform(0.0, 1.0)))
248+
}
249+
} else {
250+
val g1 = - math.log(delta) / stratum.numItems
251+
val g2 = (2.0 / 3.0) * g1
252+
val q1 = math.max(0, fraction + g2 - math.sqrt((g2 * g2 + 3 * g2 * fraction)))
253+
val q2 = math.min(1, fraction + g1 + math.sqrt(g1 * g1 + 2 * g1 * fraction))
254+
255+
val x = rng.nextUniform(0.0, 1.0)
256+
if (x < q1) {
257+
stratum.incrNumAccepted()
258+
} else if ( x < q2) {
259+
stratum.addToWaitList(x)
260+
}
261+
stratum.q1 = Some(q1)
262+
stratum.q2 = Some(q2)
263+
}
264+
stratum.incrNumItems()
265+
result
266+
}
267+
268+
val combOp = (r1: Result, r2: Result) => {
269+
//take union of both key sets in case one partion doesn't contain all keys
270+
val keyUnion = r1.resultMap.keys.toSet.union(r2.resultMap.keys.toSet)
271+
272+
//Use r2 to keep the combined result since r1 is usual empty
273+
for (key <- keyUnion) {
274+
val entry1 = r1.resultMap.get(key)
275+
val entry2 = r2.resultMap.get(key)
276+
if (entry2.isEmpty && entry1.isDefined) {
277+
r2.resultMap += (key -> entry1.get)
278+
} else if (entry1.isDefined && entry2.isDefined) {
279+
entry2.get.addToWaitList(entry1.get.waitList)
280+
entry2.get.incrNumAccepted(entry1.get.numAccepted)
281+
entry2.get.incrNumItems(entry1.get.numItems)
282+
}
283+
}
284+
r2
285+
}
286+
287+
val zeroU = new Result(Map[K, Stratum]())
288+
289+
//determine threshold for each stratum and resample
290+
val finalResult = self.aggregateWithContext(zeroU)(seqOp, combOp).resultMap
291+
val thresholdByKey = new mutable.HashMap[K, Double]()
292+
for ((key, stratum) <- finalResult) {
293+
val s = math.ceil(stratum.numItems * fraction).toLong
294+
breakable {
295+
if (stratum.numAccepted > s) {
296+
logWarning("Pre-accepted too many")
297+
thresholdByKey += (key -> stratum.q1.get)
298+
break
299+
}
300+
val numWaitListAccepted = (s - stratum.numAccepted).toInt
301+
if (numWaitListAccepted >= stratum.waitList.size) {
302+
logWarning("WaitList too short")
303+
thresholdByKey += (key -> stratum.q2.get)
304+
} else {
305+
thresholdByKey += (key -> stratum.waitList.sorted.apply(numWaitListAccepted))
306+
}
307+
}
308+
}
309+
310+
if (withReplacement) {
311+
// Poisson sampler
312+
self.mapPartitionsWithIndex((idx: Int, iter: Iterator[(K, V)]) => {
313+
val random = new RandomDataGenerator()
314+
random.reSeed(seed + idx)
315+
iter.flatMap { t =>
316+
val q1 = finalResult.get(t._1).get.q1.get
317+
val q2 = finalResult.get(t._1).get.q2.get
318+
val x1 = if (q1 == 0) 0L else random.nextPoisson(q1)
319+
val x2 = random.nextPoisson(q2).toInt
320+
val x = x1 + (0 until x2).filter(i => random.nextUniform(0.0, 1.0) <
321+
thresholdByKey.get(t._1).get).size
322+
if (x > 0) {
323+
Iterator.fill(x.toInt)(t)
324+
} else {
325+
Iterator.empty
326+
}
327+
}
328+
}, preservesPartitioning = true)
329+
} else {
330+
// Bernoulli sampler
331+
self.mapPartitionsWithIndex((idx: Int, iter: Iterator[(K, V)]) => {
332+
val random = new RandomDataGenerator
333+
random.reSeed(seed+idx)
334+
iter.filter(t => random.nextUniform(0.0, 1.0) < thresholdByKey.get(t._1).get)
335+
}, preservesPartitioning = true)
336+
}
337+
}
338+
158339
/**
159340
* Merge the values for each key using an associative reduce function. This will also perform
160341
* the merging locally on each mapper before sending results to a reducer, similarly to a
@@ -442,6 +623,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
442623

443624
/**
444625
* Return the key-value pairs in this RDD to the master as a Map.
626+
*
627+
* Warning: this doesn't return a multimap (so if you have multiple values to the same key, only
628+
* one value per key is preserved in the map returned)
445629
*/
446630
def collectAsMap(): Map[K, V] = {
447631
val data = self.collect()

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,24 @@ abstract class RDD[T: ClassTag](
880880
jobResult
881881
}
882882

883+
/**
884+
* A version of {@link #aggregate()} that passes the TaskContext to the function that does
885+
* aggregation for each partition.
886+
*/
887+
def aggregateWithContext[U: ClassTag](zeroValue: U)(seqOp: ((TaskContext, U), T) => U, combOp: (U, U) => U): U = {
888+
// Clone the zero value since we will also be serializing it as part of tasks
889+
var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
890+
//pad seqOp and combOp with taskContext to conform to aggregate's signature in TraversableOnce
891+
val paddedSeqOp = (arg1: (TaskContext, U), item: T) => (arg1._1, seqOp(arg1, item))
892+
val paddedcombOp = (arg1 : (TaskContext, U), arg2: (TaskContext, U)) => (arg1._1, combOp(arg1._2, arg1._2))
893+
val cleanSeqOp = sc.clean(paddedSeqOp)
894+
val cleanCombOp = sc.clean(paddedcombOp)
895+
val aggregatePartition = (tc: TaskContext, it: Iterator[T]) => (it.aggregate(tc, zeroValue)(cleanSeqOp, cleanCombOp))._2
896+
val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)
897+
sc.runJob(this, aggregatePartition, mergeResult)
898+
jobResult
899+
}
900+
883901
/**
884902
* Return the number of elements in the RDD.
885903
*/
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util.random
19+
20+
import org.apache.commons.math3.distribution.{PoissonDistribution, NormalDistribution}
21+
22+
private[random] object PoissonBounds {
23+
24+
val delta = 1e-4 / 3.0
25+
val phi = new NormalDistribution().cumulativeProbability(1.0 - delta)
26+
27+
def getLambda1(s: Double): Double = {
28+
var lb = math.max(0.0, s - math.sqrt(s / delta)) // Chebyshev's inequality
29+
var ub = s
30+
while (lb < ub - 1.0) {
31+
val m = (lb + ub) / 2.0
32+
val poisson = new PoissonDistribution(m, 1e-15)
33+
val y = poisson.inverseCumulativeProbability(1 - delta)
34+
if (y > s) ub = m else lb = m
35+
}
36+
lb
37+
}
38+
39+
def getMinCount(lmbd: Double): Double = {
40+
if(lmbd == 0) return 0
41+
val poisson = new PoissonDistribution(lmbd, 1e-15)
42+
poisson.inverseCumulativeProbability(delta)
43+
}
44+
45+
def getLambda2(s: Double): Double = {
46+
var lb = s
47+
var ub = s + math.sqrt(s / delta) // Chebyshev's inequality
48+
while (lb < ub - 1.0) {
49+
val m = (lb + ub) / 2.0
50+
val poisson = new PoissonDistribution(m, 1e-15)
51+
val y = poisson.inverseCumulativeProbability(delta)
52+
if (y >= s) ub = m else lb = m
53+
}
54+
ub
55+
}
56+
}

0 commit comments

Comments
 (0)