-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-8971][ML] Add stratified sampling to ML CrossValidator and TrainValidationSplit #14321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Test build #62738 has finished for PR 14321 at commit
|
| val keys = pairData.keys.distinct.collect() | ||
| val weights: Array[scala.collection.Map[Any, Double]] = | ||
| Array(keys.map((_, $(trainRatio))).toMap, keys.map((_, 1 - $(trainRatio))).toMap) | ||
| val splitsWithKeys = pairData.randomSplitByKey(weights, exact = true, $(seed)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does it make sense perhaps to have a convenience version of randomSplitByKey that takes an Array[Double] for weights and applies the same sampling weight for each key? Since I would expect the vast majority of the time the use case is to split the dataset into folds with the same sampling ratio across keys?
| n: Int, | ||
| exact: Boolean): Unit = { | ||
|
|
||
| def countByKey[K, V](xs: TraversableOnce[(K, V)]): Map[K, Int] = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this? We should be able to use rdd.countByKey in L824 (totalCounts), L835 & L836 (sampleCounts and complementCounts) below? (You've basically done that in the test for kFoldStratified).
|
Has the progress on this initiative stalled for any reason. May be I could be of help. @sethah |
|
I just happened to look at this PR. Is this still WIP or waiting more review comments? If it is simply that the author is not currently able to proceed this further, then, maybe it'd be better to close this for now. |
|
Hi @sethah, any plans to work on it again? |
What changes were proposed in this pull request?
This patch adds the ability to do stratified sampling in cross validation for ML pipelines. This is accomplished by modifying some of the methods in
StratifiedSamplingUtilsto support multiple splits instead of a single subsample of the data. A method is added toPairRDDFunctionsto supportrandomSplitByKey. Please see the detailed explanation below.How was this patch tested?
Unit tests were added to
PairRDDFunctionsSuite,MLUtilsSuite,CrossValidatorSuite, andTrainValidationSuite.Algorithm changes
Currently, Spark implements a stratified sampling function on PairRDDs using the method
sampleByKeyExactandsampleByKey. This method calls a stratified sampling routine that is implemented inStratifiedSamplingUtils. The underlying algorithm is described here in the paper by Xiangrui Meng. When exact samples stratified samples are required, the algorithm makes an extra pass through the data. Each sample is mapped on to the interval [0, 1](for sampling without replacement), and we expect that, say for a 50% sample, we will split the interval at 0.5 and accept the samples which fell below that threshold. Items near 0 are highly likely to be accepted, while items near 1 are highly unlikely to be accepted. Items near 0.5 are uncertain, and are added to a waitlist on the first pass. The items in the waitlist will be sorted and used to determine the exact split point which produces 50/50 sample.This patch modifies the routine to produce multiple splits by generating multiple waitlists on the first pass. Each waitlist is sorted to determine the exact split points and then we can sample as normal.
One potential concern is that if this is used for a large number of splits, it may degrade to the point where sorting the entire dataset would be quicker, as the waitlists get closer and closer together. It could potentially cause OOM errors on the driver if there are too many waitlists collected. Still, before this patch there was not a way to actually take a single split of the data, as
sampleByKeydoes not return the complement of the sample. This patch fixes this as well.ML API
This patch also allows users to specify a stratified column in the
CrossValidatorandTrainValidationSplitestimators. This is done by converting the input dataframe to a PairRDD and calling therandomSplitByKeymethod. This is exposed via asetStratifiedColparameter which, if set, will use exact stratified splits for cross validation.Future considerations
This can be implemented as a function on dataframes in the future, if there is interest. It is somewhat inconvient to convert the dataframe to a pair rdd, perform sampling, and then convert back to a dataframe.