From 0cfc20a3637c06071e6fe48ca5db4834b34c889e Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Wed, 11 Apr 2018 12:08:22 -0700 Subject: [PATCH 1/5] add rawPrediction as an output column; add numCLasses and numFeatures to OneVsRestModel --- .../spark/ml/classification/OneVsRest.scala | 28 ++++++++++++++----- .../ml/classification/OneVsRestSuite.scala | 9 ++++-- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index f04fde2cbbca1..fc0cc6612d102 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -32,7 +32,7 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol} import org.apache.spark.ml.util._ @@ -55,7 +55,7 @@ private[ml] trait ClassifierTypeTrait { /** * Params for [[OneVsRest]]. */ -private[ml] trait OneVsRestParams extends PredictorParams +private[ml] trait OneVsRestParams extends ClassifierParams with ClassifierTypeTrait with HasWeightCol { /** @@ -138,6 +138,12 @@ final class OneVsRestModel private[ml] ( @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]]) extends Model[OneVsRestModel] with OneVsRestParams with MLWritable { + @Since("2.4.0") + val numClasses: Int = models.length + + @Since("2.4.0") + val numFeatures: Int = models.head.numFeatures + /** @group setParam */ @Since("2.1.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) @@ -146,6 +152,10 @@ final class OneVsRestModel private[ml] ( @Since("2.1.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** @group setParam */ + @Since("2.4.0") + def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value) + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) @@ -195,14 +205,18 @@ final class OneVsRestModel private[ml] ( newDataset.unpersist() } - // output the index of the classifier with highest confidence as prediction - val labelUDF = udf { (predictions: Map[Int, Double]) => - predictions.maxBy(_._2)._1.toDouble + // output the RawPrediction as vector + val rawPredictionUDF = udf { (predictions: Map[Int, Double]) => + Vectors.sparse(numClasses, predictions.toList ) } - // output label and label metadata as prediction + // output the index of the classifier with highest confidence as prediction + val labelUDF = udf { (predictions: Vector) => predictions.argmax.toDouble } + + // output confidence as rwa prediction, label and label metadata as prediction aggregatedDataset - .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata) + .withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName))) + .withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata) .drop(accColName) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 11e88367108b4..088492e3bb74b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -72,11 +72,12 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { .setClassifier(new LogisticRegression) assert(ova.getLabelCol === "label") assert(ova.getPredictionCol === "prediction") + assert(ova.getRawPredictionCol === "rawPrediction") val ovaModel = ova.fit(dataset) MLTestingUtils.checkCopyAndUids(ova, ovaModel) - assert(ovaModel.models.length === numClasses) + assert(ovaModel.numClasses === numClasses) val transformedDataset = ovaModel.transform(dataset) // check for label metadata in prediction col @@ -179,9 +180,10 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { val dataset2 = dataset.select(col("label").as("y"), col("features").as("fea")) ovaModel.setFeaturesCol("fea") ovaModel.setPredictionCol("pred") + ovaModel.setRawPredictionCol("rawpred") val transformedDataset = ovaModel.transform(dataset2) val outputFields = transformedDataset.schema.fieldNames.toSet - assert(outputFields === Set("y", "fea", "pred")) + assert(outputFields === Set("y", "fea", "pred", "rawpred")) } test("SPARK-8049: OneVsRest shouldn't output temp columns") { @@ -190,7 +192,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { val ovr = new OneVsRest() .setClassifier(logReg) val output = ovr.fit(dataset).transform(dataset) - assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction")) + assert(output.schema.fieldNames.toSet + === Set("label", "features", "prediction", "rawPrediction")) } test("SPARK-21306: OneVsRest should support setWeightCol") { From 2a47e2be30d52e3fbea7e1eeeaa5048a6ac97116 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Thu, 12 Apr 2018 17:13:12 -0700 Subject: [PATCH 2/5] make rawPrediction optionall --- .../spark/ml/classification/OneVsRest.scala | 34 +++++++++++++------ .../ml/classification/OneVsRestSuite.scala | 4 +-- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index fc0cc6612d102..042ec356e65fe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -191,6 +191,7 @@ final class OneVsRestModel private[ml] ( val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) => predictions + ((index, prediction(1))) } + model.setFeaturesCol($(featuresCol)) val transformedDataset = model.transform(df).select(columns: _*) val updatedDataset = transformedDataset @@ -206,18 +207,31 @@ final class OneVsRestModel private[ml] ( } // output the RawPrediction as vector - val rawPredictionUDF = udf { (predictions: Map[Int, Double]) => - Vectors.sparse(numClasses, predictions.toList ) - } + if (getRawPredictionCol != "") { + val rawPredictionUDF = udf { (predictions: Map[Int, Double]) => + val myArray = Array.fill[Double](numClasses)(0.0) + predictions.foreach { case (idx, value) => myArray(idx) = value } + Vectors.dense(myArray) + } - // output the index of the classifier with highest confidence as prediction - val labelUDF = udf { (predictions: Vector) => predictions.argmax.toDouble } + // output the index of the classifier with highest confidence as prediction + val labelUDF = udf { (predictions: Vector) => predictions.argmax.toDouble } - // output confidence as rwa prediction, label and label metadata as prediction - aggregatedDataset - .withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName))) - .withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata) - .drop(accColName) + aggregatedDataset + .withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName))) + .withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata) + .drop(accColName) + } + else { + // output the index of the classifier with highest confidence as prediction + val labelUDF = udf { (predictions: Map[Int, Double]) => + predictions.maxBy(_._2)._1.toDouble + } + // output confidence as rwa prediction, label and label metadata as prediction + aggregatedDataset + .withColumn(getPredictionCol, labelUDF(col(accColName)), labelMetadata) + .drop(accColName) + } } @Since("1.4.1") diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 088492e3bb74b..2c3417c7e4028 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -180,10 +180,10 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { val dataset2 = dataset.select(col("label").as("y"), col("features").as("fea")) ovaModel.setFeaturesCol("fea") ovaModel.setPredictionCol("pred") - ovaModel.setRawPredictionCol("rawpred") + ovaModel.setRawPredictionCol("") val transformedDataset = ovaModel.transform(dataset2) val outputFields = transformedDataset.schema.fieldNames.toSet - assert(outputFields === Set("y", "fea", "pred", "rawpred")) + assert(outputFields === Set("y", "fea", "pred")) } test("SPARK-8049: OneVsRest shouldn't output temp columns") { From 0c32fcaaf87f1922170e4ce7e60381ccd23ab6e8 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Thu, 12 Apr 2018 17:19:19 -0700 Subject: [PATCH 3/5] change the tmp array name --- .../org/apache/spark/ml/classification/OneVsRest.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 042ec356e65fe..171522901d047 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -209,9 +209,9 @@ final class OneVsRestModel private[ml] ( // output the RawPrediction as vector if (getRawPredictionCol != "") { val rawPredictionUDF = udf { (predictions: Map[Int, Double]) => - val myArray = Array.fill[Double](numClasses)(0.0) - predictions.foreach { case (idx, value) => myArray(idx) = value } - Vectors.dense(myArray) + val predArray = Array.fill[Double](numClasses)(0.0) + predictions.foreach { case (idx, value) => predArray(idx) = value } + Vectors.dense(predArray) } // output the index of the classifier with highest confidence as prediction From ebf4a6c155be6a13fb41f492eb2777465d163478 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Fri, 13 Apr 2018 09:54:40 -0700 Subject: [PATCH 4/5] fix typos and comments --- .../scala/org/apache/spark/ml/classification/OneVsRest.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 171522901d047..ddc6a749835dd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -215,8 +215,9 @@ final class OneVsRestModel private[ml] ( } // output the index of the classifier with highest confidence as prediction - val labelUDF = udf { (predictions: Vector) => predictions.argmax.toDouble } + val labelUDF = udf { (rawpredictions: Vector) => rawpredictions.argmax.toDouble } + // output confidence as raw prediction, label and label metadata as prediction aggregatedDataset .withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName))) .withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata) @@ -227,7 +228,7 @@ final class OneVsRestModel private[ml] ( val labelUDF = udf { (predictions: Map[Int, Double]) => predictions.maxBy(_._2)._1.toDouble } - // output confidence as rwa prediction, label and label metadata as prediction + // output label and label metadata as prediction aggregatedDataset .withColumn(getPredictionCol, labelUDF(col(accColName)), labelMetadata) .drop(accColName) From b3c7fec0fda9056b832d1d35e829e9946218e504 Mon Sep 17 00:00:00 2001 From: Lu WANG Date: Fri, 13 Apr 2018 13:30:05 -0700 Subject: [PATCH 5/5] add require in OneVsRestModel add setRawPredictionCol in OneVsRest create a local var numClass to resolve the issue --- .../spark/ml/classification/OneVsRest.scala | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index ddc6a749835dd..5348d882cfd67 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -138,6 +138,8 @@ final class OneVsRestModel private[ml] ( @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]]) extends Model[OneVsRestModel] with OneVsRestParams with MLWritable { + require(models.nonEmpty, "OneVsRestModel requires at least one model for one class") + @Since("2.4.0") val numClasses: Int = models.length @@ -206,24 +208,25 @@ final class OneVsRestModel private[ml] ( newDataset.unpersist() } - // output the RawPrediction as vector if (getRawPredictionCol != "") { + val numClass = models.length + + // output the RawPrediction as vector val rawPredictionUDF = udf { (predictions: Map[Int, Double]) => - val predArray = Array.fill[Double](numClasses)(0.0) + val predArray = Array.fill[Double](numClass)(0.0) predictions.foreach { case (idx, value) => predArray(idx) = value } Vectors.dense(predArray) } // output the index of the classifier with highest confidence as prediction - val labelUDF = udf { (rawpredictions: Vector) => rawpredictions.argmax.toDouble } + val labelUDF = udf { (rawPredictions: Vector) => rawPredictions.argmax.toDouble } // output confidence as raw prediction, label and label metadata as prediction aggregatedDataset .withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName))) .withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata) .drop(accColName) - } - else { + } else { // output the index of the classifier with highest confidence as prediction val labelUDF = udf { (predictions: Map[Int, Double]) => predictions.maxBy(_._2)._1.toDouble @@ -326,6 +329,10 @@ final class OneVsRest @Since("1.4.0") ( @Since("1.5.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** @group setParam */ + @Since("2.4.0") + def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value) + /** * The implementation of parallel one vs. rest runs the classification for * each class in a separate threads.