@@ -32,7 +32,7 @@ import org.apache.spark.SparkContext
3232import org .apache .spark .annotation .Since
3333import org .apache .spark .ml ._
3434import org .apache .spark .ml .attribute ._
35- import org .apache .spark .ml .linalg .Vector
35+ import org .apache .spark .ml .linalg .{ Vector , Vectors }
3636import org .apache .spark .ml .param .{Param , ParamMap , ParamPair , Params }
3737import org .apache .spark .ml .param .shared .{HasParallelism , HasWeightCol }
3838import org .apache .spark .ml .util ._
@@ -55,7 +55,7 @@ private[ml] trait ClassifierTypeTrait {
5555/**
5656 * Params for [[OneVsRest ]].
5757 */
58- private [ml] trait OneVsRestParams extends PredictorParams
58+ private [ml] trait OneVsRestParams extends ClassifierParams
5959 with ClassifierTypeTrait with HasWeightCol {
6060
6161 /**
@@ -138,6 +138,12 @@ final class OneVsRestModel private[ml] (
138138 @ Since (" 1.4.0" ) val models : Array [_ <: ClassificationModel [_, _]])
139139 extends Model [OneVsRestModel ] with OneVsRestParams with MLWritable {
140140
141+ @ Since (" 2.4.0" )
142+ val numClasses : Int = models.length
143+
144+ @ Since (" 2.4.0" )
145+ val numFeatures : Int = models.head.numFeatures
146+
141147 /** @group setParam */
142148 @ Since (" 2.1.0" )
143149 def setFeaturesCol (value : String ): this .type = set(featuresCol, value)
@@ -146,6 +152,10 @@ final class OneVsRestModel private[ml] (
146152 @ Since (" 2.1.0" )
147153 def setPredictionCol (value : String ): this .type = set(predictionCol, value)
148154
155+ /** @group setParam */
156+ @ Since (" 2.4.0" )
157+ def setRawPredictionCol (value : String ): this .type = set(rawPredictionCol, value)
158+
149159 @ Since (" 1.4.0" )
150160 override def transformSchema (schema : StructType ): StructType = {
151161 validateAndTransformSchema(schema, fitting = false , getClassifier.featuresDataType)
@@ -195,14 +205,18 @@ final class OneVsRestModel private[ml] (
195205 newDataset.unpersist()
196206 }
197207
198- // output the index of the classifier with highest confidence as prediction
199- val labelUDF = udf { (predictions : Map [Int , Double ]) =>
200- predictions.maxBy(_._2)._1.toDouble
208+ // output the RawPrediction as vector
209+ val rawPredictionUDF = udf { (predictions : Map [Int , Double ]) =>
210+ Vectors .sparse(numClasses, predictions.toList )
201211 }
202212
203- // output label and label metadata as prediction
213+ // output the index of the classifier with highest confidence as prediction
214+ val labelUDF = udf { (predictions : Vector ) => predictions.argmax.toDouble }
215+
216+ // output confidence as rwa prediction, label and label metadata as prediction
204217 aggregatedDataset
205- .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata)
218+ .withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName)))
219+ .withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata)
206220 .drop(accColName)
207221 }
208222
0 commit comments