Skip to content

Conversation

@lu-wang-dl
Copy link
Contributor

add RawPrediction as output column
add numClasses and numFeatures to OneVsRestModel

What changes were proposed in this pull request?

  • Add two val numClasses and numFeatures in OneVsRestModel so that we can inherit from Classifier in the future

  • Add rawPrediction output column in transform, the prediction label in calculated by the rawPrediciton like raw2prediction

How was this patch tested?

(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)
(If this patch involves UI changes, please attach a screenshot; otherwise, remove this)
Please review http://spark.apache.org/contributing.html before opening a pull request.

add numCLasses and numFeatures to OneVsRestModel
@lu-wang-dl lu-wang-dl changed the title Add RawPrediction, numClasses, and numFeatures for OneVsRestModel [SPARK-9312][ML] Add RawPrediction, numClasses, and numFeatures for OneVsRestModel Apr 11, 2018
@SparkQA
Copy link

SparkQA commented Apr 11, 2018

Test build #89216 has finished for PR 21044 at commit 0cfc20a.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@jkbradley
Copy link
Member

Thanks for the PR! Quick high-level comment: We'll need to have rawPredictionCol be optional. If it's not set or is an empty string, then it should not be added to the output DataFrame.

predictions.maxBy(_._2)._1.toDouble
// output the RawPrediction as vector
val rawPredictionUDF = udf { (predictions: Map[Int, Double]) =>
Vectors.sparse(numClasses, predictions.toList )
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, let's output a dense Vector since it will almost surely be dense.

@SparkQA
Copy link

SparkQA commented Apr 13, 2018

Test build #89308 has finished for PR 21044 at commit 2a47e2b.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Apr 13, 2018

Test build #89309 has finished for PR 21044 at commit 0c32fca.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

val labelUDF = udf { (predictions: Map[Int, Double]) =>
predictions.maxBy(_._2)._1.toDouble
}
// output confidence as rwa prediction, label and label metadata as prediction
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rwa -> raw

}

// output the index of the classifier with highest confidence as prediction
val labelUDF = udf { (predictions: Vector) => predictions.argmax.toDouble }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

==> udf { (rawPredictions: Vector) => ... }

Copy link
Member

@jkbradley jkbradley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Done with more in-depth review pass

private[ml] val labelMetadata: Metadata,
@Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]])
extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a require() statement here which checks that models.nonEmpty is true (to throw an exception upon construction, rather than when numFeatures calls models.head below). Just to be safe...


/** @group setParam */
@Since("2.4.0")
def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll need to add this to the Estimator too.

// output the RawPrediction as vector
if (getRawPredictionCol != "") {
val rawPredictionUDF = udf { (predictions: Map[Int, Double]) =>
val predArray = Array.fill[Double](numClasses)(0.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes a subtle ContextCleaner bug: numClasses refers to a field of the class OneVsRestModel, so when Spark's closure capture serializes this UDF to send to executors, it will end up sending the entire OneVsRestModel object, rather than just the value for numClasses. Make a local copy of the value numClasses within the transform() method to avoid this issue.

.withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata)
.drop(accColName)
}
else {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scala style: This should go on the previous line: } else {

val labelUDF = udf { (predictions: Map[Int, Double]) =>
predictions.maxBy(_._2)._1.toDouble
}
// output confidence as rwa prediction, label and label metadata as prediction
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment seems to be in the wrong part of the code. Also there's a typo

@SparkQA
Copy link

SparkQA commented Apr 13, 2018

Test build #89353 has finished for PR 21044 at commit ebf4a6c.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

add setRawPredictionCol in OneVsRest
create a local var numClass to resolve the issue
@SparkQA
Copy link

SparkQA commented Apr 13, 2018

Test build #89361 has finished for PR 21044 at commit b3c7fec.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@jkbradley
Copy link
Member

LGTM
Merging with master
Thanks!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants