Skip to content

Conversation

@sethah
Copy link
Contributor

@sethah sethah commented Jul 22, 2016

What changes were proposed in this pull request?

This patch adds the ability to do stratified sampling in cross validation for ML pipelines. This is accomplished by modifying some of the methods in StratifiedSamplingUtils to support multiple splits instead of a single subsample of the data. A method is added to PairRDDFunctions to support randomSplitByKey. Please see the detailed explanation below.

How was this patch tested?

Unit tests were added to PairRDDFunctionsSuite, MLUtilsSuite, CrossValidatorSuite, and TrainValidationSuite.

Algorithm changes

Currently, Spark implements a stratified sampling function on PairRDDs using the method sampleByKeyExact and sampleByKey. This method calls a stratified sampling routine that is implemented in StratifiedSamplingUtils. The underlying algorithm is described here in the paper by Xiangrui Meng. When exact samples stratified samples are required, the algorithm makes an extra pass through the data. Each sample is mapped on to the interval [0, 1](for sampling without replacement), and we expect that, say for a 50% sample, we will split the interval at 0.5 and accept the samples which fell below that threshold. Items near 0 are highly likely to be accepted, while items near 1 are highly unlikely to be accepted. Items near 0.5 are uncertain, and are added to a waitlist on the first pass. The items in the waitlist will be sorted and used to determine the exact split point which produces 50/50 sample.

image

This patch modifies the routine to produce multiple splits by generating multiple waitlists on the first pass. Each waitlist is sorted to determine the exact split points and then we can sample as normal.

image

One potential concern is that if this is used for a large number of splits, it may degrade to the point where sorting the entire dataset would be quicker, as the waitlists get closer and closer together. It could potentially cause OOM errors on the driver if there are too many waitlists collected. Still, before this patch there was not a way to actually take a single split of the data, as sampleByKey does not return the complement of the sample. This patch fixes this as well.

ML API

This patch also allows users to specify a stratified column in the CrossValidator and TrainValidationSplit estimators. This is done by converting the input dataframe to a PairRDD and calling the randomSplitByKey method. This is exposed via a setStratifiedCol parameter which, if set, will use exact stratified splits for cross validation.

Future considerations

This can be implemented as a function on dataframes in the future, if there is interest. It is somewhat inconvient to convert the dataframe to a pair rdd, perform sampling, and then convert back to a dataframe.

@sethah
Copy link
Contributor Author

sethah commented Jul 22, 2016

cc @MLnick @hhbyyh @mengxr I believe there is still interest in stratified sampling methods. Could you provide feedback/review on this patch? Thanks!

@SparkQA
Copy link

SparkQA commented Jul 22, 2016

Test build #62738 has finished for PR 14321 at commit 37be0b5.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

val keys = pairData.keys.distinct.collect()
val weights: Array[scala.collection.Map[Any, Double]] =
Array(keys.map((_, $(trainRatio))).toMap, keys.map((_, 1 - $(trainRatio))).toMap)
val splitsWithKeys = pairData.randomSplitByKey(weights, exact = true, $(seed))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it make sense perhaps to have a convenience version of randomSplitByKey that takes an Array[Double] for weights and applies the same sampling weight for each key? Since I would expect the vast majority of the time the use case is to split the dataset into folds with the same sampling ratio across keys?

n: Int,
exact: Boolean): Unit = {

def countByKey[K, V](xs: TraversableOnce[(K, V)]): Map[K, Int] = {
Copy link
Contributor

@MLnick MLnick Aug 11, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this? We should be able to use rdd.countByKey in L824 (totalCounts), L835 & L836 (sampleCounts and complementCounts) below? (You've basically done that in the test for kFoldStratified).

@pramitchoudhary
Copy link

Has the progress on this initiative stalled for any reason. May be I could be of help. @sethah

@HyukjinKwon
Copy link
Member

I just happened to look at this PR. Is this still WIP or waiting more review comments? If it is simply that the author is not currently able to proceed this further, then, maybe it'd be better to close this for now.

@sethah sethah closed this Feb 13, 2017
@idlecool
Copy link

Hi @sethah, any plans to work on it again?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants