diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index ea757c5e40c76..0e0126c112ab6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -24,18 +24,19 @@ import scala.language.existentials import org.apache.spark.annotation.Experimental import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.storage.StorageLevel /** * Params for [[OneVsRest]]. */ -private[ml] trait OneVsRestParams extends PredictorParams { +private[ml] trait OneVsRestParams extends ClassifierParams { // scalastyle:off structural.type type ClassifierType = Classifier[F, E, M] forSome { @@ -127,10 +128,16 @@ final class OneVsRestModel private[ml] ( predictions.maxBy(_._2)._1.toDouble } - // output label and label metadata as prediction + // output the highest confidence as rawPredictionCol + val probabilityUDF = udf { (predictions: Map[Int, Double]) => + predictions.maxBy(_._2)._2.toDouble + } + + // output label, confidence factor and label metadata as prediction aggregatedDataset .withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol), labelMetadata)) - .drop(accColName) + .withColumn($(rawPredictionCol), probabilityUDF(col(accColName)).as($(rawPredictionCol), + labelMetadata)).drop(accColName) } override def copy(extra: ParamMap): OneVsRestModel = { @@ -220,4 +227,5 @@ final class OneVsRest(override val uid: String) } copied } + } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 75cf5bd4ead4f..c416c9080f6f9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -110,7 +110,8 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { val ovr = new OneVsRest() .setClassifier(logReg) val output = ovr.fit(dataset).transform(dataset) - assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction")) + assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction", + "rawPrediction")) } test("OneVsRest.copy and OneVsRestModel.copy") {