Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,19 @@ import scala.language.existentials
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.shared.HasRawPredictionCol
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.storage.StorageLevel

/**
* Params for [[OneVsRest]].
*/
private[ml] trait OneVsRestParams extends PredictorParams {
private[ml] trait OneVsRestParams extends ClassifierParams {

// scalastyle:off structural.type
type ClassifierType = Classifier[F, E, M] forSome {
Expand Down Expand Up @@ -127,10 +128,16 @@ final class OneVsRestModel private[ml] (
predictions.maxBy(_._2)._1.toDouble
}

// output label and label metadata as prediction
// output the highest confidence as rawPredictionCol
val probabilityUDF = udf { (predictions: Map[Int, Double]) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is outputting a Double, but it should return a Vector with one value for each class label. Actually, the output value should be the accCol converted to a Vector.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delayed response. I've made some small changes as per your recommendations. WRT this point, I'm not sure I follow. As far as I understand, I need to return the highest confidence factor(which will be Double) corresponding to the model used with OVR.

predictions.maxBy(_._2)._2.toDouble
}

// output label, confidence factor and label metadata as prediction
aggregatedDataset
.withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol), labelMetadata))
.drop(accColName)
.withColumn($(rawPredictionCol), probabilityUDF(col(accColName)).as($(rawPredictionCol),
labelMetadata)).drop(accColName)
}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding the max confidence factor as output column.


override def copy(extra: ParamMap): OneVsRestModel = {
Expand Down Expand Up @@ -220,4 +227,5 @@ final class OneVsRest(override val uid: String)
}
copied
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
val ovr = new OneVsRest()
.setClassifier(logReg)
val output = ovr.fit(dataset).transform(dataset)
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction",
"rawPrediction"))
}

test("OneVsRest.copy and OneVsRestModel.copy") {
Expand Down