@@ -26,6 +26,9 @@ import scala.Some
2626import org .apache .spark .rdd .RDD
2727
2828private [spark] object StratifiedSampler extends Logging {
29+ /**
30+ * Returns the function used by aggregate to collect sampling statistics for each partition.
31+ */
2932 def getSeqOp [K , V ](withReplacement : Boolean ,
3033 fractionByKey : (K => Double ),
3134 counts : Option [Map [K , Long ]]): ((TaskContext , Result [K ]),(K , V )) => Result [K ] = {
@@ -43,9 +46,9 @@ private[spark] object StratifiedSampler extends Logging {
4346 if (stratum.q1.isEmpty || stratum.q2.isEmpty) {
4447 val n = counts.get(item._1)
4548 val s = math.ceil(n * fraction).toLong
46- val lmbd1 = PB .getLambda1 (s)
49+ val lmbd1 = PB .getLowerBound (s)
4750 val minCount = PB .getMinCount(lmbd1)
48- val lmbd2 = if (lmbd1 == 0 ) PB .getLambda2 (s) else PB .getLambda2 (s - minCount)
51+ val lmbd2 = if (lmbd1 == 0 ) PB .getUpperBound (s) else PB .getUpperBound (s - minCount)
4952 val q1 = lmbd1 / n
5053 val q2 = lmbd2 / n
5154 stratum.q1 = Some (q1)
@@ -60,6 +63,8 @@ private[spark] object StratifiedSampler extends Logging {
6063 stratum.addToWaitList(ArrayBuffer .fill(x2)(rng.nextUniform(0.0 , 1.0 )))
6164 }
6265 } else {
66+ // We use the streaming version of the algorithm for sampling without replacement.
67+ // Hence, q1 and q2 change on every iteration.
6368 val g1 = - math.log(delta) / stratum.numItems
6469 val g2 = (2.0 / 3.0 ) * g1
6570 val q1 = math.max(0 , fraction + g2 - math.sqrt((g2 * g2 + 3 * g2 * fraction)))
@@ -79,7 +84,11 @@ private[spark] object StratifiedSampler extends Logging {
7984 }
8085 }
8186
82- def getCombOp [K ](): (Result [K ], Result [K ]) => Result [K ] = {
87+ /**
88+ * Returns the function used by aggregate to combine results from different partitions, as
89+ * returned by seqOp.
90+ */
91+ def getCombOp [K ](): (Result [K ], Result [K ]) => Result [K ] = {
8392 (r1 : Result [K ], r2 : Result [K ]) => {
8493 // take union of both key sets in case one partition doesn't contain all keys
8594 val keyUnion = r1.resultMap.keys.toSet.union(r2.resultMap.keys.toSet)
@@ -100,6 +109,10 @@ private[spark] object StratifiedSampler extends Logging {
100109 }
101110 }
102111
112+ /**
113+ * Given the result returned by the aggregate function, we need to determine the threshold used
114+ * to accept items to generate the exact sample size.
115+ */
103116 def computeThresholdByKey [K ](finalResult : Map [K , Stratum ], fractionByKey : (K => Double )):
104117 (K => Double ) = {
105118 val thresholdByKey = new mutable.HashMap [K , Double ]()
@@ -122,11 +135,15 @@ private[spark] object StratifiedSampler extends Logging {
122135 thresholdByKey
123136 }
124137
125- def computeThresholdByKey [K ](finalResult : Map [K , String ]): (K => String ) = {
126- finalResult
127- }
128-
129- def getBernoulliSamplingFunction [K , V ](rdd: RDD [(K , V )],
138+ /**
139+ * Return the per partition sampling function used for sampling without replacement.
140+ *
141+ * When exact sample size is required, we make an additional pass over the RDD to determine the
142+ * exact sampling rate that guarantees sample size with high confidence.
143+ *
144+ * The sampling function has a unique seed per partition.
145+ */
146+ def getBernoulliSamplingFunction [K , V ](rdd : RDD [(K , V )],
130147 fractionByKey : K => Double ,
131148 exact : Boolean ,
132149 seed : Long ): (Int , Iterator [(K , V )]) => Iterator [(K , V )] = {
@@ -146,6 +163,16 @@ private[spark] object StratifiedSampler extends Logging {
146163 }
147164 }
148165
166+ /**
167+ * Return the per partition sampling function used for sampling with replacement.
168+ *
169+ * When exact sample size is required, we make two additional passed over the RDD to determine
170+ * the exact sampling rate that guarantees sample size with high confidence. The first pass
171+ * counts the number of items in each stratum (group of items with the same key) in the RDD, and
172+ * the second pass uses the counts to determine exact sampling rates.
173+ *
174+ * The sampling function has a unique seed per partition.
175+ */
149176 def getPoissonSamplingFunction [K , V ](rdd: RDD [(K , V )],
150177 fractionByKey : K => Double ,
151178 exact : Boolean ,
@@ -191,6 +218,10 @@ private[spark] object StratifiedSampler extends Logging {
191218 }
192219}
193220
221+ /**
222+ * Object used by seqOp to keep track of the number of items accepted and items waitlisted per
223+ * stratum, as well as the bounds for accepting and waitlisting items.
224+ */
194225private [random] class Stratum (var numItems : Long = 0L , var numAccepted : Long = 0L )
195226 extends Serializable {
196227
@@ -205,13 +236,14 @@ private[random] class Stratum(var numItems: Long = 0L, var numAccepted: Long = 0
205236 def addToWaitList (elem : Double ) = waitList += elem
206237
207238 def addToWaitList (elems : ArrayBuffer [Double ]) = waitList ++= elems
208-
209- override def toString () = {
210- " numItems: " + numItems + " numAccepted: " + numAccepted + " q1: " + q1 + " q2: " + q2 +
211- " waitListSize:" + waitList.size
212- }
213239}
214240
241+ /**
242+ * Object used by seqOp and combOp to keep track of the sampling statistics for all strata.
243+ *
244+ * When used by seqOp for each partition, we also keep track of the partition ID in this object
245+ * to make sure a single random number generator with a unique seed is used for each partition.
246+ */
215247private [random] class Result [K ](var resultMap : Map [K , Stratum ],
216248 var cachedPartitionId : Option [Int ] = None ,
217249 val seed : Long )
0 commit comments