Skip to content

Commit 555a3f9

Browse files
committed
separate out sampleByKeyExact as its own API
1 parent 616e55c commit 555a3f9

File tree

4 files changed

+216
-126
lines changed

4 files changed

+216
-126
lines changed

core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -133,68 +133,64 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
133133
* Return a subset of this RDD sampled by key (via stratified sampling).
134134
*
135135
* Create a sample of this RDD using variable sampling rates for different keys as specified by
136-
* `fractions`, a key to sampling rate map.
137-
*
138-
* If `exact` is set to false, create the sample via simple random sampling, with one pass
139-
* over the RDD, to produce a sample of size that's approximately equal to the sum of
140-
* math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over
141-
* the RDD to create a sample size that's exactly equal to the sum of
136+
* `fractions`, a key to sampling rate map, via simple random sampling with one pass over the
137+
* RDD, to produce a sample of size that's approximately equal to the sum of
142138
* math.ceil(numItems * samplingRate) over all key values.
143139
*/
144140
def sampleByKey(withReplacement: Boolean,
145141
fractions: JMap[K, Double],
146-
exact: Boolean,
147142
seed: Long): JavaPairRDD[K, V] =
148-
new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, exact, seed))
143+
new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, seed))
149144

150145
/**
151146
* Return a subset of this RDD sampled by key (via stratified sampling).
152147
*
153148
* Create a sample of this RDD using variable sampling rates for different keys as specified by
154-
* `fractions`, a key to sampling rate map.
155-
*
156-
* If `exact` is set to false, create the sample via simple random sampling, with one pass
157-
* over the RDD, to produce a sample of size that's approximately equal to the sum of
158-
* math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over
159-
* the RDD to create a sample size that's exactly equal to the sum of
149+
* `fractions`, a key to sampling rate map, via simple random sampling with one pass over the
150+
* RDD, to produce a sample of size that's approximately equal to the sum of
160151
* math.ceil(numItems * samplingRate) over all key values.
161152
*
162-
* Use Utils.random.nextLong as the default seed for the random number generator
153+
* Use Utils.random.nextLong as the default seed for the random number generator.
163154
*/
164155
def sampleByKey(withReplacement: Boolean,
165-
fractions: JMap[K, Double],
166-
exact: Boolean): JavaPairRDD[K, V] =
167-
sampleByKey(withReplacement, fractions, exact, Utils.random.nextLong)
156+
fractions: JMap[K, Double]): JavaPairRDD[K, V] =
157+
sampleByKey(withReplacement, fractions, Utils.random.nextLong)
168158

169159
/**
170-
* Return a subset of this RDD sampled by key (via stratified sampling).
160+
* ::Experimental::
171161
*
172-
* Create a sample of this RDD using variable sampling rates for different keys as specified by
173-
* `fractions`, a key to sampling rate map.
162+
* Return a subset of this RDD sampled by key (via stratified sampling) containing exactly
163+
* math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key).
174164
*
175-
* Produce a sample of size that's approximately equal to the sum of
176-
* math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via
177-
* simple random sampling.
165+
* This method differs from [[sampleByKey]] in that we make additional passes over the RDD to
166+
* create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate)
167+
* over all key values with a 99.99% confidence. When sampling without replacement, we need one
168+
* additional pass over the RDD to guarantee sample size; when sampling with replacement, we need
169+
* two additional passes.
178170
*/
179-
def sampleByKey(withReplacement: Boolean,
171+
@Experimental
172+
def sampleByKeyExact(withReplacement: Boolean,
180173
fractions: JMap[K, Double],
181174
seed: Long): JavaPairRDD[K, V] =
182-
sampleByKey(withReplacement, fractions, false, seed)
175+
new JavaPairRDD[K, V](rdd.sampleByKeyExact(withReplacement, fractions, seed))
183176

184177
/**
185-
* Return a subset of this RDD sampled by key (via stratified sampling).
178+
* ::Experimental::
186179
*
187-
* Create a sample of this RDD using variable sampling rates for different keys as specified by
188-
* `fractions`, a key to sampling rate map.
180+
* Return a subset of this RDD sampled by key (via stratified sampling) containing exactly
181+
* math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key).
189182
*
190-
* Produce a sample of size that's approximately equal to the sum of
191-
* math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via
192-
* simple random sampling.
183+
* This method differs from [[sampleByKey]] in that we make additional passes over the RDD to
184+
* create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate)
185+
* over all key values with a 99.99% confidence. When sampling without replacement, we need one
186+
* additional pass over the RDD to guarantee sample size; when sampling with replacement, we need
187+
* two additional passes.
193188
*
194-
* Use Utils.random.nextLong as the default seed for the random number generator
189+
* Use Utils.random.nextLong as the default seed for the random number generator.
195190
*/
196-
def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] =
197-
sampleByKey(withReplacement, fractions, false, Utils.random.nextLong)
191+
@Experimental
192+
def sampleByKeyExact(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] =
193+
sampleByKeyExact(withReplacement, fractions, Utils.random.nextLong)
198194

199195
/**
200196
* Return the union of this RDD and another one. Any identical elements will appear multiple

core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -197,33 +197,57 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
197197
* Return a subset of this RDD sampled by key (via stratified sampling).
198198
*
199199
* Create a sample of this RDD using variable sampling rates for different keys as specified by
200-
* `fractions`, a key to sampling rate map.
201-
*
202-
* If `exact` is set to false, create the sample via simple random sampling, with one pass
203-
* over the RDD, to produce a sample of size that's approximately equal to the sum of
204-
* math.ceil(numItems * samplingRate) over all key values; otherwise, use
205-
* additional passes over the RDD to create a sample size that's exactly equal to the sum of
206-
* math.ceil(numItems * samplingRate) over all key values with a 99.99% confidence. When sampling
207-
* without replacement, we need one additional pass over the RDD to guarantee sample size;
208-
* when sampling with replacement, we need two additional passes.
200+
* `fractions`, a key to sampling rate map, via simple random sampling with one pass over the
201+
* RDD, to produce a sample of size that's approximately equal to the sum of
202+
* math.ceil(numItems * samplingRate) over all key values.
209203
*
210204
* @param withReplacement whether to sample with or without replacement
211205
* @param fractions map of specific keys to sampling rates
212206
* @param seed seed for the random number generator
213-
* @param exact whether sample size needs to be exactly math.ceil(fraction * size) per key
214207
* @return RDD containing the sampled subset
215208
*/
216209
def sampleByKey(withReplacement: Boolean,
217210
fractions: Map[K, Double],
218-
exact: Boolean = false,
219-
seed: Long = Utils.random.nextLong): RDD[(K, V)]= {
211+
seed: Long = Utils.random.nextLong): RDD[(K, V)] = {
212+
213+
require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.")
214+
215+
val samplingFunc = if (withReplacement) {
216+
StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, false, seed)
217+
} else {
218+
StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, false, seed)
219+
}
220+
self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
221+
}
222+
223+
/**
224+
* ::Experimental::
225+
*
226+
* Return a subset of this RDD sampled by key (via stratified sampling) containing exactly
227+
* math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key).
228+
*
229+
* This method differs from [[sampleByKey]] in that we make additional passes over the RDD to
230+
* create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate)
231+
* over all key values with a 99.99% confidence. When sampling without replacement, we need one
232+
* additional pass over the RDD to guarantee sample size; when sampling with replacement, we need
233+
* two additional passes.
234+
*
235+
* @param withReplacement whether to sample with or without replacement
236+
* @param fractions map of specific keys to sampling rates
237+
* @param seed seed for the random number generator
238+
* @return RDD containing the sampled subset
239+
*/
240+
@Experimental
241+
def sampleByKeyExact(withReplacement: Boolean,
242+
fractions: Map[K, Double],
243+
seed: Long = Utils.random.nextLong): RDD[(K, V)] = {
220244

221245
require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.")
222246

223247
val samplingFunc = if (withReplacement) {
224-
StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, exact, seed)
248+
StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, true, seed)
225249
} else {
226-
StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, exact, seed)
250+
StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, true, seed)
227251
}
228252
self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
229253
}

core/src/test/java/org/apache/spark/JavaAPISuite.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,12 +1239,28 @@ public Tuple2<Integer, Integer> call(Integer i) {
12391239
Assert.assertTrue(worCounts.size() == 2);
12401240
Assert.assertTrue(worCounts.get(0) > 0);
12411241
Assert.assertTrue(worCounts.get(1) > 0);
1242-
JavaPairRDD<Integer, Integer> wrExact = rdd2.sampleByKey(true, fractions, true, 1L);
1242+
}
1243+
1244+
@Test
1245+
@SuppressWarnings("unchecked")
1246+
public void sampleByKeyExact() {
1247+
JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3);
1248+
JavaPairRDD<Integer, Integer> rdd2 = rdd1.mapToPair(
1249+
new PairFunction<Integer, Integer, Integer>() {
1250+
@Override
1251+
public Tuple2<Integer, Integer> call(Integer i) {
1252+
return new Tuple2<Integer, Integer>(i % 2, 1);
1253+
}
1254+
});
1255+
Map<Integer, Object> fractions = Maps.newHashMap();
1256+
fractions.put(0, 0.5);
1257+
fractions.put(1, 1.0);
1258+
JavaPairRDD<Integer, Integer> wrExact = rdd2.sampleByKeyExact(true, fractions, 1L);
12431259
Map<Integer, Long> wrExactCounts = (Map<Integer, Long>) (Object) wrExact.countByKey();
12441260
Assert.assertTrue(wrExactCounts.size() == 2);
12451261
Assert.assertTrue(wrExactCounts.get(0) == 2);
12461262
Assert.assertTrue(wrExactCounts.get(1) == 4);
1247-
JavaPairRDD<Integer, Integer> worExact = rdd2.sampleByKey(false, fractions, true, 1L);
1263+
JavaPairRDD<Integer, Integer> worExact = rdd2.sampleByKeyExact(false, fractions, 1L);
12481264
Map<Integer, Long> worExactCounts = (Map<Integer, Long>) (Object) worExact.countByKey();
12491265
Assert.assertTrue(worExactCounts.size() == 2);
12501266
Assert.assertTrue(worExactCounts.get(0) == 2);

0 commit comments

Comments
 (0)