Skip to content

Commit b715aa0

Browse files
dorxmengxr
authored andcommitted
[SPARK-2937] Separate out samplyByKeyExact as its own API in PairRDDFunction
To enable Python consistency and `Experimental` label of the `sampleByKeyExact` API. Author: Doris Xin <[email protected]> Author: Xiangrui Meng <[email protected]> Closes apache#1866 from dorx/stratified and squashes the following commits: 0ad97b2 [Doris Xin] reviewer comments. 2948aae [Doris Xin] remove unrelated changes e990325 [Doris Xin] Merge branch 'master' into stratified 555a3f9 [Doris Xin] separate out sampleByKeyExact as its own API 616e55c [Doris Xin] merge master 245439e [Doris Xin] moved minSamplingRate to getUpperBound eaf5771 [Doris Xin] bug fixes. 17a381b [Doris Xin] fixed a merge issue and a failed unit ea7d27f [Doris Xin] merge master b223529 [Xiangrui Meng] use approx bounds for poisson fix poisson mean for waitlisting add unit tests for Java b3013a4 [Xiangrui Meng] move math3 back to test scope eecee5f [Doris Xin] Merge branch 'master' into stratified f4c21f3 [Doris Xin] Reviewer comments a10e68d [Doris Xin] style fix a2bf756 [Doris Xin] Merge branch 'master' into stratified 680b677 [Doris Xin] use mapPartitionWithIndex instead 9884a9f [Doris Xin] style fix bbfb8c9 [Doris Xin] Merge branch 'master' into stratified ee9d260 [Doris Xin] addressed reviewer comments 6b5b10b [Doris Xin] Merge branch 'master' into stratified 254e03c [Doris Xin] minor fixes and Java API. 4ad516b [Doris Xin] remove unused imports from PairRDDFunctions bd9dc6e [Doris Xin] unit bug and style violation fixed 1fe1cff [Doris Xin] Changed fractionByKey to a map to enable arg check 944a10c [Doris Xin] [SPARK-2145] Add lower bound on sampling rate 0214a76 [Doris Xin] cleanUp 90d94c0 [Doris Xin] merge master 9e74ab5 [Doris Xin] Separated out most of the logic in sampleByKey 7327611 [Doris Xin] merge master 50581fc [Doris Xin] added a TODO for logging in python 46f6c8c [Doris Xin] fixed the NPE caused by closures being cleaned before being passed into the aggregate function 7e1a481 [Doris Xin] changed the permission on SamplingUtil 1d413ce [Doris Xin] fixed checkstyle issues 9ee94ee [Doris Xin] [SPARK-2082] stratified sampling in PairRDDFunctions that guarantees exact sample size e3fd6a6 [Doris Xin] Merge branch 'master' into takeSample 7cab53a [Doris Xin] fixed import bug in rdd.py ffea61a [Doris Xin] SPARK-1939: Refactor takeSample method in RDD 1441977 [Doris Xin] SPARK-1939 Refactor takeSample method in RDD to use ScaSRS
1 parent 28dcbb5 commit b715aa0

File tree

4 files changed

+216
-128
lines changed

4 files changed

+216
-128
lines changed

core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -133,68 +133,62 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
133133
* Return a subset of this RDD sampled by key (via stratified sampling).
134134
*
135135
* Create a sample of this RDD using variable sampling rates for different keys as specified by
136-
* `fractions`, a key to sampling rate map.
137-
*
138-
* If `exact` is set to false, create the sample via simple random sampling, with one pass
139-
* over the RDD, to produce a sample of size that's approximately equal to the sum of
140-
* math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over
141-
* the RDD to create a sample size that's exactly equal to the sum of
136+
* `fractions`, a key to sampling rate map, via simple random sampling with one pass over the
137+
* RDD, to produce a sample of size that's approximately equal to the sum of
142138
* math.ceil(numItems * samplingRate) over all key values.
143139
*/
144140
def sampleByKey(withReplacement: Boolean,
145141
fractions: JMap[K, Double],
146-
exact: Boolean,
147142
seed: Long): JavaPairRDD[K, V] =
148-
new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, exact, seed))
143+
new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, seed))
149144

150145
/**
151146
* Return a subset of this RDD sampled by key (via stratified sampling).
152147
*
153148
* Create a sample of this RDD using variable sampling rates for different keys as specified by
154-
* `fractions`, a key to sampling rate map.
155-
*
156-
* If `exact` is set to false, create the sample via simple random sampling, with one pass
157-
* over the RDD, to produce a sample of size that's approximately equal to the sum of
158-
* math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over
159-
* the RDD to create a sample size that's exactly equal to the sum of
149+
* `fractions`, a key to sampling rate map, via simple random sampling with one pass over the
150+
* RDD, to produce a sample of size that's approximately equal to the sum of
160151
* math.ceil(numItems * samplingRate) over all key values.
161152
*
162-
* Use Utils.random.nextLong as the default seed for the random number generator
153+
* Use Utils.random.nextLong as the default seed for the random number generator.
163154
*/
164155
def sampleByKey(withReplacement: Boolean,
165-
fractions: JMap[K, Double],
166-
exact: Boolean): JavaPairRDD[K, V] =
167-
sampleByKey(withReplacement, fractions, exact, Utils.random.nextLong)
156+
fractions: JMap[K, Double]): JavaPairRDD[K, V] =
157+
sampleByKey(withReplacement, fractions, Utils.random.nextLong)
168158

169159
/**
170-
* Return a subset of this RDD sampled by key (via stratified sampling).
171-
*
172-
* Create a sample of this RDD using variable sampling rates for different keys as specified by
173-
* `fractions`, a key to sampling rate map.
160+
* ::Experimental::
161+
* Return a subset of this RDD sampled by key (via stratified sampling) containing exactly
162+
* math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key).
174163
*
175-
* Produce a sample of size that's approximately equal to the sum of
176-
* math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via
177-
* simple random sampling.
164+
* This method differs from [[sampleByKey]] in that we make additional passes over the RDD to
165+
* create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate)
166+
* over all key values with a 99.99% confidence. When sampling without replacement, we need one
167+
* additional pass over the RDD to guarantee sample size; when sampling with replacement, we need
168+
* two additional passes.
178169
*/
179-
def sampleByKey(withReplacement: Boolean,
170+
@Experimental
171+
def sampleByKeyExact(withReplacement: Boolean,
180172
fractions: JMap[K, Double],
181173
seed: Long): JavaPairRDD[K, V] =
182-
sampleByKey(withReplacement, fractions, false, seed)
174+
new JavaPairRDD[K, V](rdd.sampleByKeyExact(withReplacement, fractions, seed))
183175

184176
/**
185-
* Return a subset of this RDD sampled by key (via stratified sampling).
177+
* ::Experimental::
178+
* Return a subset of this RDD sampled by key (via stratified sampling) containing exactly
179+
* math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key).
186180
*
187-
* Create a sample of this RDD using variable sampling rates for different keys as specified by
188-
* `fractions`, a key to sampling rate map.
189-
*
190-
* Produce a sample of size that's approximately equal to the sum of
191-
* math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via
192-
* simple random sampling.
181+
* This method differs from [[sampleByKey]] in that we make additional passes over the RDD to
182+
* create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate)
183+
* over all key values with a 99.99% confidence. When sampling without replacement, we need one
184+
* additional pass over the RDD to guarantee sample size; when sampling with replacement, we need
185+
* two additional passes.
193186
*
194-
* Use Utils.random.nextLong as the default seed for the random number generator
187+
* Use Utils.random.nextLong as the default seed for the random number generator.
195188
*/
196-
def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] =
197-
sampleByKey(withReplacement, fractions, false, Utils.random.nextLong)
189+
@Experimental
190+
def sampleByKeyExact(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] =
191+
sampleByKeyExact(withReplacement, fractions, Utils.random.nextLong)
198192

199193
/**
200194
* Return the union of this RDD and another one. Any identical elements will appear multiple

core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -197,33 +197,56 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
197197
* Return a subset of this RDD sampled by key (via stratified sampling).
198198
*
199199
* Create a sample of this RDD using variable sampling rates for different keys as specified by
200-
* `fractions`, a key to sampling rate map.
201-
*
202-
* If `exact` is set to false, create the sample via simple random sampling, with one pass
203-
* over the RDD, to produce a sample of size that's approximately equal to the sum of
204-
* math.ceil(numItems * samplingRate) over all key values; otherwise, use
205-
* additional passes over the RDD to create a sample size that's exactly equal to the sum of
206-
* math.ceil(numItems * samplingRate) over all key values with a 99.99% confidence. When sampling
207-
* without replacement, we need one additional pass over the RDD to guarantee sample size;
208-
* when sampling with replacement, we need two additional passes.
200+
* `fractions`, a key to sampling rate map, via simple random sampling with one pass over the
201+
* RDD, to produce a sample of size that's approximately equal to the sum of
202+
* math.ceil(numItems * samplingRate) over all key values.
209203
*
210204
* @param withReplacement whether to sample with or without replacement
211205
* @param fractions map of specific keys to sampling rates
212206
* @param seed seed for the random number generator
213-
* @param exact whether sample size needs to be exactly math.ceil(fraction * size) per key
214207
* @return RDD containing the sampled subset
215208
*/
216209
def sampleByKey(withReplacement: Boolean,
217210
fractions: Map[K, Double],
218-
exact: Boolean = false,
219-
seed: Long = Utils.random.nextLong): RDD[(K, V)]= {
211+
seed: Long = Utils.random.nextLong): RDD[(K, V)] = {
212+
213+
require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.")
214+
215+
val samplingFunc = if (withReplacement) {
216+
StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, false, seed)
217+
} else {
218+
StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, false, seed)
219+
}
220+
self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
221+
}
222+
223+
/**
224+
* ::Experimental::
225+
* Return a subset of this RDD sampled by key (via stratified sampling) containing exactly
226+
* math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key).
227+
*
228+
* This method differs from [[sampleByKey]] in that we make additional passes over the RDD to
229+
* create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate)
230+
* over all key values with a 99.99% confidence. When sampling without replacement, we need one
231+
* additional pass over the RDD to guarantee sample size; when sampling with replacement, we need
232+
* two additional passes.
233+
*
234+
* @param withReplacement whether to sample with or without replacement
235+
* @param fractions map of specific keys to sampling rates
236+
* @param seed seed for the random number generator
237+
* @return RDD containing the sampled subset
238+
*/
239+
@Experimental
240+
def sampleByKeyExact(withReplacement: Boolean,
241+
fractions: Map[K, Double],
242+
seed: Long = Utils.random.nextLong): RDD[(K, V)] = {
220243

221244
require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.")
222245

223246
val samplingFunc = if (withReplacement) {
224-
StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, exact, seed)
247+
StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, true, seed)
225248
} else {
226-
StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, exact, seed)
249+
StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, true, seed)
227250
}
228251
self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
229252
}

core/src/test/java/org/apache/spark/JavaAPISuite.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,12 +1239,28 @@ public Tuple2<Integer, Integer> call(Integer i) {
12391239
Assert.assertTrue(worCounts.size() == 2);
12401240
Assert.assertTrue(worCounts.get(0) > 0);
12411241
Assert.assertTrue(worCounts.get(1) > 0);
1242-
JavaPairRDD<Integer, Integer> wrExact = rdd2.sampleByKey(true, fractions, true, 1L);
1242+
}
1243+
1244+
@Test
1245+
@SuppressWarnings("unchecked")
1246+
public void sampleByKeyExact() {
1247+
JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3);
1248+
JavaPairRDD<Integer, Integer> rdd2 = rdd1.mapToPair(
1249+
new PairFunction<Integer, Integer, Integer>() {
1250+
@Override
1251+
public Tuple2<Integer, Integer> call(Integer i) {
1252+
return new Tuple2<Integer, Integer>(i % 2, 1);
1253+
}
1254+
});
1255+
Map<Integer, Object> fractions = Maps.newHashMap();
1256+
fractions.put(0, 0.5);
1257+
fractions.put(1, 1.0);
1258+
JavaPairRDD<Integer, Integer> wrExact = rdd2.sampleByKeyExact(true, fractions, 1L);
12431259
Map<Integer, Long> wrExactCounts = (Map<Integer, Long>) (Object) wrExact.countByKey();
12441260
Assert.assertTrue(wrExactCounts.size() == 2);
12451261
Assert.assertTrue(wrExactCounts.get(0) == 2);
12461262
Assert.assertTrue(wrExactCounts.get(1) == 4);
1247-
JavaPairRDD<Integer, Integer> worExact = rdd2.sampleByKey(false, fractions, true, 1L);
1263+
JavaPairRDD<Integer, Integer> worExact = rdd2.sampleByKeyExact(false, fractions, 1L);
12481264
Map<Integer, Long> worExactCounts = (Map<Integer, Long>) (Object) worExact.countByKey();
12491265
Assert.assertTrue(worExactCounts.size() == 2);
12501266
Assert.assertTrue(worExactCounts.get(0) == 2);

0 commit comments

Comments
 (0)