Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol}
import org.apache.spark.ml.util._
Expand All @@ -55,7 +55,7 @@ private[ml] trait ClassifierTypeTrait {
/**
* Params for [[OneVsRest]].
*/
private[ml] trait OneVsRestParams extends PredictorParams
private[ml] trait OneVsRestParams extends ClassifierParams
with ClassifierTypeTrait with HasWeightCol {

/**
Expand Down Expand Up @@ -138,6 +138,14 @@ final class OneVsRestModel private[ml] (
@Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]])
extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a require() statement here which checks that models.nonEmpty is true (to throw an exception upon construction, rather than when numFeatures calls models.head below). Just to be safe...

require(models.nonEmpty, "OneVsRestModel requires at least one model for one class")

@Since("2.4.0")
val numClasses: Int = models.length

@Since("2.4.0")
val numFeatures: Int = models.head.numFeatures

/** @group setParam */
@Since("2.1.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
Expand All @@ -146,6 +154,10 @@ final class OneVsRestModel private[ml] (
@Since("2.1.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)

/** @group setParam */
@Since("2.4.0")
def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll need to add this to the Estimator too.


@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
Expand Down Expand Up @@ -181,6 +193,7 @@ final class OneVsRestModel private[ml] (
val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) =>
predictions + ((index, prediction(1)))
}

model.setFeaturesCol($(featuresCol))
val transformedDataset = model.transform(df).select(columns: _*)
val updatedDataset = transformedDataset
Expand All @@ -195,15 +208,34 @@ final class OneVsRestModel private[ml] (
newDataset.unpersist()
}

// output the index of the classifier with highest confidence as prediction
val labelUDF = udf { (predictions: Map[Int, Double]) =>
predictions.maxBy(_._2)._1.toDouble
}
if (getRawPredictionCol != "") {
val numClass = models.length

// output label and label metadata as prediction
aggregatedDataset
.withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata)
.drop(accColName)
// output the RawPrediction as vector
val rawPredictionUDF = udf { (predictions: Map[Int, Double]) =>
val predArray = Array.fill[Double](numClass)(0.0)
predictions.foreach { case (idx, value) => predArray(idx) = value }
Vectors.dense(predArray)
}

// output the index of the classifier with highest confidence as prediction
val labelUDF = udf { (rawPredictions: Vector) => rawPredictions.argmax.toDouble }

// output confidence as raw prediction, label and label metadata as prediction
aggregatedDataset
.withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName)))
.withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata)
.drop(accColName)
} else {
// output the index of the classifier with highest confidence as prediction
val labelUDF = udf { (predictions: Map[Int, Double]) =>
predictions.maxBy(_._2)._1.toDouble
}
// output label and label metadata as prediction
aggregatedDataset
.withColumn(getPredictionCol, labelUDF(col(accColName)), labelMetadata)
.drop(accColName)
}
}

@Since("1.4.1")
Expand Down Expand Up @@ -297,6 +329,10 @@ final class OneVsRest @Since("1.4.0") (
@Since("1.5.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)

/** @group setParam */
@Since("2.4.0")
def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value)

/**
* The implementation of parallel one vs. rest runs the classification for
* each class in a separate threads.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,12 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
.setClassifier(new LogisticRegression)
assert(ova.getLabelCol === "label")
assert(ova.getPredictionCol === "prediction")
assert(ova.getRawPredictionCol === "rawPrediction")
val ovaModel = ova.fit(dataset)

MLTestingUtils.checkCopyAndUids(ova, ovaModel)

assert(ovaModel.models.length === numClasses)
assert(ovaModel.numClasses === numClasses)
val transformedDataset = ovaModel.transform(dataset)

// check for label metadata in prediction col
Expand Down Expand Up @@ -179,6 +180,7 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
val dataset2 = dataset.select(col("label").as("y"), col("features").as("fea"))
ovaModel.setFeaturesCol("fea")
ovaModel.setPredictionCol("pred")
ovaModel.setRawPredictionCol("")
val transformedDataset = ovaModel.transform(dataset2)
val outputFields = transformedDataset.schema.fieldNames.toSet
assert(outputFields === Set("y", "fea", "pred"))
Expand All @@ -190,7 +192,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
val ovr = new OneVsRest()
.setClassifier(logReg)
val output = ovr.fit(dataset).transform(dataset)
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
assert(output.schema.fieldNames.toSet
=== Set("label", "features", "prediction", "rawPrediction"))
}

test("SPARK-21306: OneVsRest should support setWeightCol") {
Expand Down