From f4e2089a0ac4a25d0690bef627dd1e00e0bcb959 Mon Sep 17 00:00:00 2001 From: "badriub@gmail.com" Date: Fri, 24 Jul 2015 16:39:00 -0400 Subject: [PATCH 1/7] Added predictive probability to OneVsRestModel and LogisticRegressionModel --- .../spark/ml/classification/OneVsRest.scala | 10 ++++- .../classification/LogisticRegression.scala | 45 +++++++++++++++++++ .../spark/mllib/classification/SVM.scala | 7 +++ .../GeneralizedLinearAlgorithm.scala | 19 ++++++++ .../apache/spark/mllib/regression/Lasso.scala | 17 +++++++ .../mllib/regression/LinearRegression.scala | 7 +++ .../mllib/regression/RidgeRegression.scala | 7 +++ .../ml/classification/OneVsRestSuite.scala | 3 +- 8 files changed, 113 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index ea757c5e40c76..ccfe513e57a7f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -19,6 +19,8 @@ package org.apache.spark.ml.classification import java.util.UUID +import org.apache.spark.ml.param.shared.HasRawPredictionCol + import scala.language.existentials import org.apache.spark.annotation.Experimental @@ -35,7 +37,7 @@ import org.apache.spark.storage.StorageLevel /** * Params for [[OneVsRest]]. */ -private[ml] trait OneVsRestParams extends PredictorParams { +private[ml] trait OneVsRestParams extends PredictorParams with HasRawPredictionCol { // scalastyle:off structural.type type ClassifierType = Classifier[F, E, M] forSome { @@ -127,9 +129,15 @@ final class OneVsRestModel private[ml] ( predictions.maxBy(_._2)._1.toDouble } + // output the index of the classifier with highest confidence as prediction + val probabilityUDF = udf { (predictions: Map[Int, Double]) => + predictions.maxBy(_._2)._2.toDouble + } + // output label and label metadata as prediction aggregatedDataset .withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol), labelMetadata)) + .withColumn($(rawPredictionCol), probabilityUDF(col(accColName)).as($(rawPredictionCol), labelMetadata)) .drop(accColName) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 268642ac6a2f6..ede11f76d8199 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -114,6 +114,51 @@ class LogisticRegressionModel ( this } + override protected def predictPointWithProbability( + dataMatrix: Vector, + weightMatrix: Vector, + intercept: Double) = { + require(dataMatrix.size == numFeatures) + + // If dataMatrix and weightMatrix have the same dimension, it's binary logistic regression. + if (numClasses == 2) { + val margin = dot(weightMatrix, dataMatrix) + intercept + val score = 1.0 / (1.0 + math.exp(-margin)) + (threshold match { + case Some(t) => if (score > t) 1.0 else 0.0 + case None => score + }, score) + } else { + /** + * Compute and find the one with maximum margins. If the maxMargin is negative, then the + * prediction result will be the first class. + * + * PS, if you want to compute the probabilities for each outcome instead of the outcome + * with maximum probability, remember to subtract the maxMargin from margins if maxMargin + * is positive to prevent overflow. + */ + var bestClass = 0 + var maxMargin = 0.0 + val withBias = dataMatrix.size + 1 == dataWithBiasSize + (0 until numClasses - 1).foreach { i => + var margin = 0.0 + dataMatrix.foreachActive { (index, value) => + if (value != 0.0) margin += value * weightsArray((i * dataWithBiasSize) + index) + } + // Intercept is required to be added into margin. + if (withBias) { + margin += weightsArray((i * dataWithBiasSize) + dataMatrix.size) + } + if (margin > maxMargin) { + maxMargin = margin + bestClass = i + 1 + } + } + val score = 1.0 / (1.0 + math.exp(-maxMargin)) + (bestClass.toDouble, score) + } + } + override protected def predictPoint( dataMatrix: Vector, weightMatrix: Vector, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 5b54feeb10467..85cc3f0c01aec 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -73,6 +73,13 @@ class SVMModel ( this } + override protected def predictPointWithProbability( + dataMatrix: Vector, + weightMatrix: Vector, + intercept: Double) = { + throw new Exception("Not implemented for SVMModel") + } + override protected def predictPoint( dataMatrix: Vector, weightMatrix: Vector, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 6709bd79bc820..ed6b9061cefb0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -48,6 +48,15 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double */ protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double + /** + * Predict the result along with probability given a data point and the weights learned. + * + * @param dataMatrix Row vector containing the features for this data point + * @param weightMatrix Column vector containing the weights of the model + * @param intercept Intercept of the model. + */ + protected def predictPointWithProbability(dataMatrix: Vector, weightMatrix: Vector, intercept: Double): (Double, Double) + /** * Predict values for the given data set using the model trained. * @@ -76,6 +85,16 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double predictPoint(testData, weights, intercept) } + /** + * Predict values for a single data point using the model trained. + * + * @param testData array representing a single data point + * @return Double prediction from the trained model + */ + def predictWithProbability(testData: Vector): (Double, Double) = { + predictPointWithProbability(testData, weights, intercept) + } + /** * Print a summary of the model. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index 4f482384f0f38..75b590d5f16e2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -17,7 +17,10 @@ package org.apache.spark.mllib.regression +import java.lang + import org.apache.spark.SparkContext +import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable @@ -37,6 +40,13 @@ class LassoModel ( extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable with Saveable with PMMLExportable { + override protected def predictPointWithProbability( + dataMatrix: Vector, + weightMatrix: Vector, + intercept: Double) = { + throw new Exception("Not implemented for LassoModel") + } + override protected def predictPoint( dataMatrix: Vector, weightMatrix: Vector, @@ -49,6 +59,13 @@ class LassoModel ( } override protected def formatVersion: String = "1.0" + + /** + * Predict values for examples stored in a JavaRDD. + * @param testData JavaRDD representing data points to be predicted + * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction + */ + override def predict(testData: JavaRDD[Vector]): JavaRDD[lang.Double] = super.predict(testData) } object LassoModel extends Loader[LassoModel] { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 9453c4f66c216..fb546ede6a736 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -37,6 +37,13 @@ class LinearRegressionModel ( extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable with Saveable with PMMLExportable { + override protected def predictPointWithProbability( + dataMatrix: Vector, + weightMatrix: Vector, + intercept: Double) = { + throw new Exception("Not implemented for LinearRegressionModel") + } + override protected def predictPoint( dataMatrix: Vector, weightMatrix: Vector, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 7d28ffad45c92..a8438abd61c80 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -38,6 +38,13 @@ class RidgeRegressionModel ( extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable with Saveable with PMMLExportable { + override protected def predictPointWithProbability( + dataMatrix: Vector, + weightMatrix: Vector, + intercept: Double) = { + throw new Exception("Not implemented for RidgeRegressionModel") + } + override protected def predictPoint( dataMatrix: Vector, weightMatrix: Vector, diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 75cf5bd4ead4f..0eb85cc481786 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.Metadata +import org.apache.spark.sql.Row class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -110,7 +111,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { val ovr = new OneVsRest() .setClassifier(logReg) val output = ovr.fit(dataset).transform(dataset) - assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction")) + assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction", "rawPrediction")) } test("OneVsRest.copy and OneVsRestModel.copy") { From 8bfd54c31651174aab774259c9a293ecdc5b55c3 Mon Sep 17 00:00:00 2001 From: "badriub@gmail.com" Date: Wed, 29 Jul 2015 11:10:49 -0400 Subject: [PATCH 2/7] SPARK-9312: Adding confidence factor to OneVsRest Model --- .../spark/ml/classification/OneVsRest.scala | 21 +++++++++---------- .../ml/classification/OneVsRestSuite.scala | 4 ++-- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index ccfe513e57a7f..58f5c74fab6a0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -19,21 +19,20 @@ package org.apache.spark.ml.classification import java.util.UUID -import org.apache.spark.ml.param.shared.HasRawPredictionCol - -import scala.language.existentials - import org.apache.spark.annotation.Experimental import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.storage.StorageLevel +import scala.language.existentials + /** * Params for [[OneVsRest]]. */ @@ -72,9 +71,9 @@ private[ml] trait OneVsRestParams extends PredictorParams with HasRawPredictionC */ @Experimental final class OneVsRestModel private[ml] ( - override val uid: String, - labelMetadata: Metadata, - val models: Array[_ <: ClassificationModel[_, _]]) + override val uid: String, + labelMetadata: Metadata, + val models: Array[_ <: ClassificationModel[_, _]]) extends Model[OneVsRestModel] with OneVsRestParams { override def transformSchema(schema: StructType): StructType = { @@ -137,8 +136,8 @@ final class OneVsRestModel private[ml] ( // output label and label metadata as prediction aggregatedDataset .withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol), labelMetadata)) - .withColumn($(rawPredictionCol), probabilityUDF(col(accColName)).as($(rawPredictionCol), labelMetadata)) - .drop(accColName) + .withColumn($(rawPredictionCol), probabilityUDF(col(accColName)).as($(rawPredictionCol), + labelMetadata)).drop(accColName) } override def copy(extra: ParamMap): OneVsRestModel = { @@ -228,4 +227,4 @@ final class OneVsRest(override val uid: String) } copied } -} +} \ No newline at end of file diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 0eb85cc481786..c416c9080f6f9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -31,7 +31,6 @@ import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.Metadata -import org.apache.spark.sql.Row class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -111,7 +110,8 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { val ovr = new OneVsRest() .setClassifier(logReg) val output = ovr.fit(dataset).transform(dataset) - assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction", "rawPrediction")) + assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction", + "rawPrediction")) } test("OneVsRest.copy and OneVsRestModel.copy") { From c4a3c29b55b02247772babdeba6140e45893b3bc Mon Sep 17 00:00:00 2001 From: "badriub@gmail.com" Date: Wed, 29 Jul 2015 11:19:07 -0400 Subject: [PATCH 3/7] SPARK-9312: Undoing changes for the LogisticRegression related to predictive probability since it was not needed. --- .../classification/LogisticRegression.scala | 45 ------------------- .../spark/mllib/classification/SVM.scala | 7 --- .../GeneralizedLinearAlgorithm.scala | 19 -------- .../apache/spark/mllib/regression/Lasso.scala | 17 ------- .../mllib/regression/LinearRegression.scala | 7 --- .../mllib/regression/RidgeRegression.scala | 7 --- 6 files changed, 102 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index ede11f76d8199..268642ac6a2f6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -114,51 +114,6 @@ class LogisticRegressionModel ( this } - override protected def predictPointWithProbability( - dataMatrix: Vector, - weightMatrix: Vector, - intercept: Double) = { - require(dataMatrix.size == numFeatures) - - // If dataMatrix and weightMatrix have the same dimension, it's binary logistic regression. - if (numClasses == 2) { - val margin = dot(weightMatrix, dataMatrix) + intercept - val score = 1.0 / (1.0 + math.exp(-margin)) - (threshold match { - case Some(t) => if (score > t) 1.0 else 0.0 - case None => score - }, score) - } else { - /** - * Compute and find the one with maximum margins. If the maxMargin is negative, then the - * prediction result will be the first class. - * - * PS, if you want to compute the probabilities for each outcome instead of the outcome - * with maximum probability, remember to subtract the maxMargin from margins if maxMargin - * is positive to prevent overflow. - */ - var bestClass = 0 - var maxMargin = 0.0 - val withBias = dataMatrix.size + 1 == dataWithBiasSize - (0 until numClasses - 1).foreach { i => - var margin = 0.0 - dataMatrix.foreachActive { (index, value) => - if (value != 0.0) margin += value * weightsArray((i * dataWithBiasSize) + index) - } - // Intercept is required to be added into margin. - if (withBias) { - margin += weightsArray((i * dataWithBiasSize) + dataMatrix.size) - } - if (margin > maxMargin) { - maxMargin = margin - bestClass = i + 1 - } - } - val score = 1.0 / (1.0 + math.exp(-maxMargin)) - (bestClass.toDouble, score) - } - } - override protected def predictPoint( dataMatrix: Vector, weightMatrix: Vector, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 85cc3f0c01aec..5b54feeb10467 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -73,13 +73,6 @@ class SVMModel ( this } - override protected def predictPointWithProbability( - dataMatrix: Vector, - weightMatrix: Vector, - intercept: Double) = { - throw new Exception("Not implemented for SVMModel") - } - override protected def predictPoint( dataMatrix: Vector, weightMatrix: Vector, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index ed6b9061cefb0..6709bd79bc820 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -48,15 +48,6 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double */ protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double - /** - * Predict the result along with probability given a data point and the weights learned. - * - * @param dataMatrix Row vector containing the features for this data point - * @param weightMatrix Column vector containing the weights of the model - * @param intercept Intercept of the model. - */ - protected def predictPointWithProbability(dataMatrix: Vector, weightMatrix: Vector, intercept: Double): (Double, Double) - /** * Predict values for the given data set using the model trained. * @@ -85,16 +76,6 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double predictPoint(testData, weights, intercept) } - /** - * Predict values for a single data point using the model trained. - * - * @param testData array representing a single data point - * @return Double prediction from the trained model - */ - def predictWithProbability(testData: Vector): (Double, Double) = { - predictPointWithProbability(testData, weights, intercept) - } - /** * Print a summary of the model. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index 75b590d5f16e2..4f482384f0f38 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -17,10 +17,7 @@ package org.apache.spark.mllib.regression -import java.lang - import org.apache.spark.SparkContext -import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable @@ -40,13 +37,6 @@ class LassoModel ( extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable with Saveable with PMMLExportable { - override protected def predictPointWithProbability( - dataMatrix: Vector, - weightMatrix: Vector, - intercept: Double) = { - throw new Exception("Not implemented for LassoModel") - } - override protected def predictPoint( dataMatrix: Vector, weightMatrix: Vector, @@ -59,13 +49,6 @@ class LassoModel ( } override protected def formatVersion: String = "1.0" - - /** - * Predict values for examples stored in a JavaRDD. - * @param testData JavaRDD representing data points to be predicted - * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction - */ - override def predict(testData: JavaRDD[Vector]): JavaRDD[lang.Double] = super.predict(testData) } object LassoModel extends Loader[LassoModel] { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index fb546ede6a736..9453c4f66c216 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -37,13 +37,6 @@ class LinearRegressionModel ( extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable with Saveable with PMMLExportable { - override protected def predictPointWithProbability( - dataMatrix: Vector, - weightMatrix: Vector, - intercept: Double) = { - throw new Exception("Not implemented for LinearRegressionModel") - } - override protected def predictPoint( dataMatrix: Vector, weightMatrix: Vector, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index a8438abd61c80..7d28ffad45c92 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -38,13 +38,6 @@ class RidgeRegressionModel ( extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable with Saveable with PMMLExportable { - override protected def predictPointWithProbability( - dataMatrix: Vector, - weightMatrix: Vector, - intercept: Double) = { - throw new Exception("Not implemented for RidgeRegressionModel") - } - override protected def predictPoint( dataMatrix: Vector, weightMatrix: Vector, From d758972130b5eac97b53eb75ad81606a737bda82 Mon Sep 17 00:00:00 2001 From: "badriub@gmail.com" Date: Wed, 29 Jul 2015 11:26:38 -0400 Subject: [PATCH 4/7] Correcting indentation --- .../org/apache/spark/ml/classification/OneVsRest.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 58f5c74fab6a0..3ffa7cdd92741 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -71,9 +71,9 @@ private[ml] trait OneVsRestParams extends PredictorParams with HasRawPredictionC */ @Experimental final class OneVsRestModel private[ml] ( - override val uid: String, - labelMetadata: Metadata, - val models: Array[_ <: ClassificationModel[_, _]]) + override val uid: String, + labelMetadata: Metadata, + val models: Array[_ <: ClassificationModel[_, _]]) extends Model[OneVsRestModel] with OneVsRestParams { override def transformSchema(schema: StructType): StructType = { From 095774fe80856cf77a3da2fe6ce4efec49070dfb Mon Sep 17 00:00:00 2001 From: "badriub@gmail.com" Date: Wed, 29 Jul 2015 11:26:38 -0400 Subject: [PATCH 5/7] Implementing review comments --- .../apache/spark/ml/classification/OneVsRest.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 58f5c74fab6a0..f1a93cede543a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -36,7 +36,7 @@ import scala.language.existentials /** * Params for [[OneVsRest]]. */ -private[ml] trait OneVsRestParams extends PredictorParams with HasRawPredictionCol { +private[ml] trait OneVsRestParams extends ClassifierParams { // scalastyle:off structural.type type ClassifierType = Classifier[F, E, M] forSome { @@ -71,9 +71,9 @@ private[ml] trait OneVsRestParams extends PredictorParams with HasRawPredictionC */ @Experimental final class OneVsRestModel private[ml] ( - override val uid: String, - labelMetadata: Metadata, - val models: Array[_ <: ClassificationModel[_, _]]) + override val uid: String, + labelMetadata: Metadata, + val models: Array[_ <: ClassificationModel[_, _]]) extends Model[OneVsRestModel] with OneVsRestParams { override def transformSchema(schema: StructType): StructType = { @@ -133,7 +133,7 @@ final class OneVsRestModel private[ml] ( predictions.maxBy(_._2)._2.toDouble } - // output label and label metadata as prediction + // output label, confidence factor and label metadata as prediction aggregatedDataset .withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol), labelMetadata)) .withColumn($(rawPredictionCol), probabilityUDF(col(accColName)).as($(rawPredictionCol), @@ -227,4 +227,5 @@ final class OneVsRest(override val uid: String) } copied } + } \ No newline at end of file From a18dab61d94e90e1a19f309e2a6f3e4f40a834ad Mon Sep 17 00:00:00 2001 From: Badari Madhav Date: Wed, 12 Aug 2015 17:23:32 -0400 Subject: [PATCH 6/7] Import order corrected --- .../scala/org/apache/spark/ml/classification/OneVsRest.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 75b1b174b4e2b..0eabd6168b18d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -19,6 +19,8 @@ package org.apache.spark.ml.classification import java.util.UUID +import scala.language.existentials + import org.apache.spark.annotation.Experimental import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ @@ -31,8 +33,6 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.storage.StorageLevel -import scala.language.existentials - /** * Params for [[OneVsRest]]. */ From 3950a5b69805f11c7277b536f848c61405219c35 Mon Sep 17 00:00:00 2001 From: Badari Madhav Date: Wed, 12 Aug 2015 17:27:11 -0400 Subject: [PATCH 7/7] Missed fixing a doc earlier --- .../scala/org/apache/spark/ml/classification/OneVsRest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 0eabd6168b18d..0e0126c112ab6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -128,7 +128,7 @@ final class OneVsRestModel private[ml] ( predictions.maxBy(_._2)._1.toDouble } - // output the index of the classifier with highest confidence as prediction + // output the highest confidence as rawPredictionCol val probabilityUDF = udf { (predictions: Map[Int, Double]) => predictions.maxBy(_._2)._2.toDouble }