-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-14975][ML] Fixed GBTClassifier to predict probability per training instance and fixed interfaces #16441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Test build #70759 has finished for PR 16441 at commit
|
|
Jenkins, retest this please |
|
Test build #70760 has finished for PR 16441 at commit
|
|
Thanks for the PR; I do want to get this fixed. However, I don't think this is the right way to make predictions of probabilities for GBTs. I believe it should depend on the loss used. E.g., check out page 8 of Friedman (1999) "Greedy Function Approximation? A Gradient Boosting Machine" |
|
Test build #70935 has finished for PR 16441 at commit
|
|
Thanks, I've updated the PR based on your comment. The only disadvantage to the current code is that I do the probability computation within the classifier, but it seems like it should be moved to the LogLoss.scala class. However, it's not a problem right now because GBTClassifier only uses logistic loss, and other learners would have to be modified in a similar way as well probably. |
2b842e5 to
9def0ca
Compare
|
Test build #70938 has finished for PR 16441 at commit
|
|
Test build #70939 has finished for PR 16441 at commit
|
|
@jkbradley I've updated based on your comments, please take another look, thanks! |
sethah
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the patch. I made a first pass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
put this back on one line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is actually not correct since the constructor was private[ml] before. Since this has always been private, and we aren't actually using it anywhere, I think we can remove this constructor entirely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since tag not needed since it's private
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed since tag
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should import org.apache.spark.ml.linalg.BLAS and call BLAS.dot here and in predict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks like BLAS.dot is only for Vector, but these are both arrays. I'm worried that this may degrade performance. Is this specifically what you are looking for:
BLAS.dot(Vectors.dense(treePredictions), Vectors.dense(_treeWeights))
is the extra dense vector allocation worth it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I see it's not quite the same as in other places. We can leave it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh ok, thank you for confirming
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my concern is that this is hard coded to logistic loss. Maybe we can add a static method to GBTClassificationModel
private def classProbability(class: Int, loss: String, rawPrediction: Double): Double = {
loss match {
case "logistic" => ...
case _ => throw new Exception("Only logistic loss is supported ...")
}
}There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just use defaults here. And I'm in favor of only setting parameters that matter for the given test, otherwise it may give the impression that the test depends on a certain, say checkpoint interval.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you take a look at this test, and make it line up here? Specifically:
- compute probabilities manually from rawPrediction and ensure that it matches the probabilities column
- make sure that probabilities.argmax and rawPrediction.argmax equal the prediction
- make sure probabilities sum to one
- check the different code paths by unsetting some of the output columns
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In logistic regression we had previously overridden some of the methods in probabilistic classifier since we were only dealing with two classes, which makes those methods a bit faster (hard to say how much). We can do it here for now, but I'd be slightly in favor of not doing it since I'm not sure how much we gain from it and it makes the code harder to follow. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry I'm a bit confused, this classifier also only deals with two classes, it does not support multiclass data. Instead of overriding, what is the alternative? There is no default predictRaw or raw2probability implemented in probabilistic classifier, and it seems that this is the minimum required for GBTClassifier to use ProbabilisticClassifier. Can you please give more information on this point?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see how my comment was confusing now :) Since GBT only supports two classes right now, we could override methods like probability2prediction which are by default calling what is implemented in ProbabilisticClassifier. When thresholds are not defined, it calls probablity.argmax which for two classes we could simplify to
if (probability(1) > probablity(0)) 1 else 0Looking now, logistic regression also had a getThreshold method which allowed it to avoid loops in some cases, but we don't have it here. Let's leave things how they are.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, I'm still a little confused, should I override probability2prediction and simplify, or should I keep the argmax as is? The argmax seems better because it is more general anyway, but please let me know if you would prefer that I make any changes here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not change anything for now, it's fine as it is. Sorry for the confusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can use .foreach { case Row(raw: DenseVector, pred: Double, prob: DenseVector) => ... } here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
@sethah @jkbradley thank you for the review - could you please take another look since I've updated the code review based on your comments? |
|
Test build #70963 has finished for PR 16441 at commit
|
|
It looks like I am failing the binary compatibility tests despite this constructor being private: class GBTClassificationModel private[ml]( This is the same thing that happened in my original PR and then I had to add the additional this() overload to pass the tests. In the PR comment it was mentioned that I should be able to remove the unused constructor, does this mean that I need to change the binary compatibility test somehow as well? My guess is that the binary compat tests are java based and not scala based, in which case private[ml] doesn't matter, so the solution would be to keep the extra constructor I had before, just make sure that it is still private[ml], only so I can pass the binary compat tests. |
|
Test build #70982 has finished for PR 16441 at commit
|
|
Indeed re-adding the constructor seems to make the binary compatibility tests pass (see spark QA build above). I think in favor of making the binary compat tests pass, we can keep the extra private constructor, even though for most people it won't do anything. Please let me know if there are any outstanding comments that still need to be addressed. Thank you! |
|
I've removed the WIP from title to reflect the status of the pull request. |
|
ping @sethah @jkbradley could you please take another look since I've updated the code review based on your comments? Thank you! |
sethah
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made another pass, thanks for working on this, my apologies for the delayed review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DenseVector is unused
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
predFromRaw ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, can we leave a comment regarding the fact that we'd want to check other loss types here for classification if they are ever added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done and done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check that prob(0) + prob(1) ~== 1.0 absTol 1e-8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea! done. I added absEps for 1e-8 so that there won't be any magic constants floating around.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can save ourselves some computation here:
case dv: DenseVector =>
dv.values(0) = computeProb(dv.values(0))
dv.values(1) = 1.0 - dv.values(0)
dvThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, this would be better served embedded in the loss object. One solution would be to make a few changes to the loss:
trait ClassificationLoss extends Loss {
private[spark] def computeProbability(prediction: Double): Double
}
object LogLoss extends ClassificationLossThen we could add a class member to the model private val oldLoss: ClassificationLoss = getOldLossType, then we can just call oldLoss.computeProbability(pred) inside raw2ProbabilityInPlace. There might be a better solution too, but really I think it should be part of the loss.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adding "private val oldLoss: ClassificationLoss = getOldLossType" won't work because getOldLossType returns a Loss and not a LogLoss, which doesn't have computeProbability. However, I did add the ClassificationLoss trait and in ClassProbability I just call LogLoss.computeProbability. I'm not sure if it will pass the binary compat checks though, let's see...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can change getOldLossType to return a classification loss, can't you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, will update
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since tag not needed since it's private
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should just use numClasses = 2 for now, since getNumClasses can make an extra pass over the data, and >2 classes are not supported anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, the logistic regression gets the number of classes and throws in the binomial case, and getNumClasses should ideally get the number of classes from the metadata which shouldn't make an extra pass (ideally the label column is categorical?), but I think it's ok for now to make it 2 until we make GBT support the multiclass case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If getNumClasses doesn't find metadata, then it will make a pass over the data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right, I removed it for now, but ideally the user would preprocess the data and make the label column categorical. Either they would do that through the string indexer, or if they know it ahead of time, they would just add the metadata themselves (although unfortunately currently only advanced users would be able to do this, there is no transform that will allow they to pre-specify the labels if they know ahead of time what the labels are)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's still there...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops, I thought I changed it, sorry
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer to leave the handling of thresholds for another JIRA, but technically users will be able to set it. We can either do it here in this PR, or throw an error until we get it implemented in a follow up. Thoughts @jkbradley?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks like decision tree classifier has the same problem with thresholds
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, it looks like both this classifier and decision tree handle thresholds already in method probability2prediction under ProbabilisticClassifier.scala. Can you give more information on why GBTClassifier is not handling thresholds correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no setThresholds method, and there are no unit tests off the top of my head.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do see a setThresholds method both on the classifier and the model. It comes from ProbabilisticClassifier:
abstract class ProbabilisticClassifier[
FeaturesType,
E <: ProbabilisticClassifier[FeaturesType, E, M],
M <: ProbabilisticClassificationModel[FeaturesType, M]]
extends Classifier[FeaturesType, E, M] with ProbabilisticClassifierParams {
/** @group setParam */
def setProbabilityCol(value: String): E = set(probabilityCol, value).asInstanceOf[E]
/** @group setParam */
def setThresholds(value: Array[Double]): E = set(thresholds, value).asInstanceOf[E]
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, ok good catch. We should handle thresholds in this PR then. Can you look at other test suites and add those tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, I've added more tests in the latest commit. I've also fixed an issue where predict was not using thresholds - if they are defined we now use them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can avoid duplicating this code. Maybe, as in LogisticRegression, we can create a private function called score or margin and then use that in predict and predictRaw
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea, refactored to margin private method
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we add a "default params" test for parity with other suites like LogisticRegression?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea, added the extra test
|
Test build #71142 has finished for PR 16441 at commit
|
|
Test build #71144 has finished for PR 16441 at commit
|
|
Test build #71145 has finished for PR 16441 at commit
|
|
ping @sethah @jkbradley could you please take another look since I've updated the code review based on your comments? Thank you! |
|
Test build #71150 has finished for PR 16441 at commit
|
sethah
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! Thanks for all the updates
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: put each arg on one line, using 4 space indentation as is done with the constructor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thanks, also updated the other constructor (my default intellij settings don't seem to match the suggested ones)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: if (margin(features) > 0.0) 1.0 else 0.0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment should be removed since we made this function generic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved comment to LogLoss computeProbability method (kept for positive result only)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make a private class member private val loss = getOldLossType? Otherwise we call getOldLossType, (which calls getLossType) for every single instance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, this is a tricky point, because in the future if we have more than one loss when the user changes it the results should change as well, but since we only have one loss function I guess it is ok... I'll make the update but add a warning comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean that if someone takes a model and changes the loss type via set(lossType, "other") that the probability function should change? I don't think it makes sense to change the probability function for a model, since the probability is chosen to be optimal for a specific loss, but it's a good point. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's still there...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can be private[spark] I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: prefer explicit doubles like 1.0 instead of 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
ping @sethah @jkbradley could you please take another look since I've updated the code review based on your comments? Thank you! |
|
Test build #71169 has finished for PR 16441 at commit
|
|
Test build #71170 has finished for PR 16441 at commit
|
|
Test build #71171 has finished for PR 16441 at commit
|
…resholds if they are specified
0def50c to
1abfee0
Compare
|
Test build #71616 has finished for PR 16441 at commit
|
|
Test build #71617 has finished for PR 16441 at commit
|
|
LGTM |
|
@imatiach-msft thanks for this, really great to have GBT in the classification trait hierarchy, and now usable with binary evaluator metrics! |
…ning instance and fixed interfaces ## What changes were proposed in this pull request? For all of the classifiers in MLLib we can predict probabilities except for GBTClassifier. Also, all classifiers inherit from ProbabilisticClassifier but GBTClassifier strangely inherits from Predictor, which is a bug. This change corrects the interface and adds the ability for the classifier to give a probabilities vector. ## How was this patch tested? The basic ML tests were run after making the changes. I've marked this as WIP as I need to add more tests. Author: Ilya Matiach <[email protected]> Closes apache#16441 from imatiach-msft/ilmat/fix-GBT.
…ning instance and fixed interfaces ## What changes were proposed in this pull request? For all of the classifiers in MLLib we can predict probabilities except for GBTClassifier. Also, all classifiers inherit from ProbabilisticClassifier but GBTClassifier strangely inherits from Predictor, which is a bug. This change corrects the interface and adds the ability for the classifier to give a probabilities vector. ## How was this patch tested? The basic ML tests were run after making the changes. I've marked this as WIP as I need to add more tests. Author: Ilya Matiach <[email protected]> Closes apache#16441 from imatiach-msft/ilmat/fix-GBT.
|
In which release this fix is going to be available? Thanks! |
|
Should be in 2.2.0
…On Sat, 15 Jul 2017 at 07:54, yonglyhoo ***@***.***> wrote:
In which release this fix is going to be available? Thanks!
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#16441 (comment)>, or mute
the thread
<https://github.com/notifications/unsubscribe-auth/AA_SB6BKB35I1USk8AQg-f-YXh3f-e0Nks5sOFQOgaJpZM4LYYGD>
.
|
|
Great! Thanks Nick!Yong
Sent from Yahoo Mail for iPhone
On Friday, July 14, 2017, 10:59 PM, Nick Pentreath <[email protected]> wrote:
Should be in 2.2.0
On Sat, 15 Jul 2017 at 07:54, yonglyhoo ***@***.***> wrote:
In which release this fix is going to be available? Thanks!
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#16441 (comment)>, or mute
the thread
<https://github.com/notifications/unsubscribe-auth/AA_SB6BKB35I1USk8AQg-f-YXh3f-e0Nks5sOFQOgaJpZM4LYYGD>
.
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub, or mute the thread.
|
What changes were proposed in this pull request?
For all of the classifiers in MLLib we can predict probabilities except for GBTClassifier.
Also, all classifiers inherit from ProbabilisticClassifier but GBTClassifier strangely inherits from Predictor, which is a bug.
This change corrects the interface and adds the ability for the classifier to give a probabilities vector.
How was this patch tested?
The basic ML tests were run after making the changes. I've marked this as WIP as I need to add more tests.