-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-9312][ML] Add RawPrediction, numClasses, and numFeatures for OneVsRestModel #21044
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
add numCLasses and numFeatures to OneVsRestModel
|
Test build #89216 has finished for PR 21044 at commit
|
|
Thanks for the PR! Quick high-level comment: We'll need to have rawPredictionCol be optional. If it's not set or is an empty string, then it should not be added to the output DataFrame. |
| predictions.maxBy(_._2)._1.toDouble | ||
| // output the RawPrediction as vector | ||
| val rawPredictionUDF = udf { (predictions: Map[Int, Double]) => | ||
| Vectors.sparse(numClasses, predictions.toList ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, let's output a dense Vector since it will almost surely be dense.
|
Test build #89308 has finished for PR 21044 at commit
|
|
Test build #89309 has finished for PR 21044 at commit
|
| val labelUDF = udf { (predictions: Map[Int, Double]) => | ||
| predictions.maxBy(_._2)._1.toDouble | ||
| } | ||
| // output confidence as rwa prediction, label and label metadata as prediction |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rwa -> raw
| } | ||
|
|
||
| // output the index of the classifier with highest confidence as prediction | ||
| val labelUDF = udf { (predictions: Vector) => predictions.argmax.toDouble } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
==> udf { (rawPredictions: Vector) => ... }
jkbradley
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Done with more in-depth review pass
| private[ml] val labelMetadata: Metadata, | ||
| @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]]) | ||
| extends Model[OneVsRestModel] with OneVsRestParams with MLWritable { | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a require() statement here which checks that models.nonEmpty is true (to throw an exception upon construction, rather than when numFeatures calls models.head below). Just to be safe...
|
|
||
| /** @group setParam */ | ||
| @Since("2.4.0") | ||
| def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll need to add this to the Estimator too.
| // output the RawPrediction as vector | ||
| if (getRawPredictionCol != "") { | ||
| val rawPredictionUDF = udf { (predictions: Map[Int, Double]) => | ||
| val predArray = Array.fill[Double](numClasses)(0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This causes a subtle ContextCleaner bug: numClasses refers to a field of the class OneVsRestModel, so when Spark's closure capture serializes this UDF to send to executors, it will end up sending the entire OneVsRestModel object, rather than just the value for numClasses. Make a local copy of the value numClasses within the transform() method to avoid this issue.
| .withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata) | ||
| .drop(accColName) | ||
| } | ||
| else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Scala style: This should go on the previous line: } else {
| val labelUDF = udf { (predictions: Map[Int, Double]) => | ||
| predictions.maxBy(_._2)._1.toDouble | ||
| } | ||
| // output confidence as rwa prediction, label and label metadata as prediction |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment seems to be in the wrong part of the code. Also there's a typo
|
Test build #89353 has finished for PR 21044 at commit
|
add setRawPredictionCol in OneVsRest create a local var numClass to resolve the issue
|
Test build #89361 has finished for PR 21044 at commit
|
|
LGTM |
add RawPrediction as output column
add numClasses and numFeatures to OneVsRestModel
What changes were proposed in this pull request?
Add two val numClasses and numFeatures in OneVsRestModel so that we can inherit from Classifier in the future
Add rawPrediction output column in transform, the prediction label in calculated by the rawPrediciton like raw2prediction
How was this patch tested?
(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)
(If this patch involves UI changes, please attach a screenshot; otherwise, remove this)
Please review http://spark.apache.org/contributing.html before opening a pull request.