-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-9478] [ml] Add class weights to Random Forest #9008
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Test build #43318 has finished for PR 9008 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should specify return type here.
Is the reason that you can't just modify buildMetadata to accept and RDD[WeightedLabeledPoint] because you are trying not to change MLlib implementation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for your comment.
-
I will add the return type in my next push.
-
yes, you are right, I don't want to change the mllib impl yet. I will leave it as a TODO after we have a standard way to represent weighted label point.
|
Test build #43460 has finished for PR 9008 at commit
|
|
Test build #43482 timed out for PR 9008 at commit |
|
Test build #43503 has finished for PR 9008 at commit
|
|
Test build #43531 timed out for PR 9008 at commit |
|
Test build #43553 has finished for PR 9008 at commit
|
|
Jenkins failed tests unrelated to this patch. Let's try again. |
|
Test build #43572 has finished for PR 9008 at commit
|
|
Test build #43588 has finished for PR 9008 at commit
|
|
@sethah I have incorporated your comments in the latest patch. Thank you! @jkbradley Do you have any comments or suggestions? Much appreciated. |
|
Test build #45206 has finished for PR 9008 at commit
|
|
retest this please |
|
Test build #45315 has finished for PR 9008 at commit
|
| final class RandomForestClassifier(override val uid: String) | ||
| extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel] | ||
| with RandomForestParams with TreeClassifierParams { | ||
| with RandomForestParams with TreeClassifierParams with HasWeightCol{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Space after HasWeightCol
|
Thanks for working on this! Minor: PR title says |
| .setStages(Array(labelIndexer, featureIndexer, rf, labelConverter)) | ||
|
|
||
| val model1 = pipeline.fit(dataset) | ||
| val model2 = pipeline.fit(dataset, rf.weightCol->"weight") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: space around ->
| /** | ||
| * Inject the sample weight to sub-sample weights of the baggedPoints | ||
| */ | ||
| private[impl] def reweightSubSampleWeights( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a TODO in BaggedPoint.scala for accepting weighted instances. This might be a good time to address that. If not, we will have to implement this in this JIRA, fix Bagged Point in another JIRA, and then return to this, likely in a third JIRA. Thoughts?
|
@rotationsymmetry I made a pass on this, mostly minor comments. Thanks for working on this, it would be great to get it merged in! |
|
@sethah Thank you very much for your review. I will incorporate the changes in the next few days. Regarding the TODO in BaggedPoint.scala, I want to look into the details to find out the scope of the change. |
|
I noticed a problem with the current implementation regarding the I checked scikit-learn and they track the actual raw sample counts (unweighted) as well as the sample weights. They use |
|
Another issue is that the information gain for candidate splits is not computed correctly with fractional samples. This is because the information gain calculation here uses the sample counts which are converted to |
|
@rotationsymmetry: Will you have time to work on this? I am more than happy to send a PR to your PR if you do not have time. @jkbradley @dbtsai Would you mind chiming in on the issue mentioned above about minimum instances per node? |
|
cc @MLnick thoughts on the above comments? |
|
@sethah So to avoid adding any overhead from computing stats for both these params one option would be to selectively compute only the stats that are required (e.g. if they request |
|
@holdenk Thanks for the feedback. Upon some further thought, I think that a.) We need to compute the statistics needed for both I am going to have a PR for this ready soon, which will incorporate changes submitted in this PR. I created two JIRAs for issues that I encountered when preparing this PR and submitted patches for each. They are: |
|
@rotationsymmetry Could you please close this? |
Closes apache#11610 Closes apache#15411 Closes apache#15501 Closes apache#12613 Closes apache#12518 Closes apache#12026 Closes apache#15524 Closes apache#12693 Closes apache#12358 Closes apache#15588 Closes apache#15635 Closes apache#15678 Closes apache#14699 Closes apache#9008
This PR adds weight support to the following Predictors in ML.
DecisionTreeClassifier
DecisionTreeRegressor
RandomForestClassifier
RandomForestRegressor
cc @jkbradley