Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,17 @@ import org.apache.spark.sql.types.DoubleType
@Since("1.2.0")
@Experimental
class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Evaluator with HasRawPredictionCol with HasLabelCol with DefaultParamsWritable {
extends Evaluator with HasRawPredictionCol with HasLabelCol
with HasWeightCol with DefaultParamsWritable {

@Since("1.2.0")
def this() = this(Identifiable.randomUID("binEval"))

/**
* Default number of bins to use for binary classification evaluation.
*/
val defaultNumberOfBins = 1000

/**
* param for metric name in evaluation (supports `"areaUnderROC"` (default), `"areaUnderPR"`)
* @group param
Expand Down Expand Up @@ -68,6 +74,10 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
@Since("1.2.0")
def setLabelCol(value: String): this.type = set(labelCol, value)

/** @group setParam */
@Since("2.2.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

setDefault(metricName -> "areaUnderROC")

@Since("2.0.0")
Expand All @@ -77,12 +87,16 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
SchemaUtils.checkNumericType(schema, $(labelCol))

// TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2.
val scoreAndLabels =
dataset.select(col($(rawPredictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map {
case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label)
case Row(rawPrediction: Double, label: Double) => (rawPrediction, label)
val scoreAndLabelsWithWeights =
dataset.select(col($(rawPredictionCol)), col($(labelCol)).cast(DoubleType),
if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)))
.rdd.map {
case Row(rawPrediction: Vector, label: Double, weight: Double) =>
(rawPrediction(1), (label, weight))
case Row(rawPrediction: Double, label: Double, weight: Double) =>
(rawPrediction, (label, weight))
}
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
val metrics = new BinaryClassificationMetrics(defaultNumberOfBins, scoreAndLabelsWithWeights)
val metric = $(metricName) match {
case "areaUnderROC" => metrics.areaUnderROC()
case "areaUnderPR" => metrics.areaUnderPR()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.sql.{Dataset, Row}
Expand All @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.DoubleType
@Since("1.5.0")
@Experimental
class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") override val uid: String)
extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {
extends Evaluator with HasPredictionCol with HasLabelCol
with HasWeightCol with DefaultParamsWritable {

@Since("1.5.0")
def this() = this(Identifiable.randomUID("mcEval"))
Expand Down Expand Up @@ -67,6 +68,10 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
@Since("1.5.0")
def setLabelCol(value: String): this.type = set(labelCol, value)

/** @group setParam */
@Since("2.2.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

setDefault(metricName -> "f1")

@Since("2.0.0")
Expand All @@ -75,11 +80,16 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
SchemaUtils.checkNumericType(schema, $(labelCol))

val predictionAndLabels =
dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map {
case Row(prediction: Double, label: Double) => (prediction, label)
val predictionAndLabelsWithWeights =
dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType),
if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)))
.rdd.map {
case Row(prediction: Double, label: Double, weight: Double) => (prediction, label, weight)
}
val metrics = new MulticlassMetrics(predictionAndLabels)
dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map {
case Row(prediction: Double, label: Double) => (prediction, label)
}.values.countByValue()
val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights)
val metric = $(metricName) match {
case "f1" => metrics.weightedFMeasure
case "weightedPrecision" => metrics.weightedPrecision
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.sql.{Dataset, Row}
Expand All @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.{DoubleType, FloatType}
@Since("1.4.0")
@Experimental
final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {
extends Evaluator with HasPredictionCol with HasLabelCol
with HasWeightCol with DefaultParamsWritable {

@Since("1.4.0")
def this() = this(Identifiable.randomUID("regEval"))
Expand Down Expand Up @@ -69,6 +70,10 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
@Since("1.4.0")
def setLabelCol(value: String): this.type = set(labelCol, value)

/** @group setParam */
@Since("2.2.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

setDefault(metricName -> "rmse")

@Since("2.0.0")
Expand All @@ -77,11 +82,13 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, FloatType))
SchemaUtils.checkNumericType(schema, $(labelCol))

val predictionAndLabels = dataset
.select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType))
val predictionAndLabelsWithWeights = dataset
.select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType),
if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)))
.rdd
.map { case Row(prediction: Double, label: Double) => (prediction, label) }
val metrics = new RegressionMetrics(predictionAndLabels)
.map { case Row(prediction: Double, label: Double, weight: Double) =>
(prediction, label, weight) }
val metrics = new RegressionMetrics(false, predictionAndLabelsWithWeights)
val metric = $(metricName) match {
case "rmse" => metrics.rootMeanSquaredError
case "mse" => metrics.meanSquaredError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.DataFrame
/**
* Evaluator for binary classification.
*
* @param scoreAndLabels an RDD of (score, label) pairs.
* @param scoreAndLabelsWithWeights an RDD of (score, (label, weight)) pairs.
* @param numBins if greater than 0, then the curves (ROC curve, PR curve) computed internally
* will be down-sampled to this many "bins". If 0, no down-sampling will occur.
* This is useful because the curve contains a point for each distinct score
Expand All @@ -41,12 +41,26 @@ import org.apache.spark.sql.DataFrame
* partition boundaries.
*/
@Since("1.0.0")
class BinaryClassificationMetrics @Since("1.3.0") (
@Since("1.3.0") val scoreAndLabels: RDD[(Double, Double)],
@Since("1.3.0") val numBins: Int) extends Logging {
class BinaryClassificationMetrics @Since("2.2.0") (
val numBins: Int,
@Since("2.2.0") val scoreAndLabelsWithWeights: RDD[(Double, (Double, Double))])
extends Logging {

require(numBins >= 0, "numBins must be nonnegative")

/**
* Retrieves the score and labels (for binary compatibility).
* @return The score and labels.
*/
@Since("1.0.0")
def scoreAndLabels: RDD[(Double, Double)] = {
scoreAndLabelsWithWeights.map(values => (values._1, values._2._1))
}

@Since("1.0.0")
def this(@Since("1.3.0") scoreAndLabels: RDD[(Double, Double)], @Since("1.3.0") numBins: Int) =
this(numBins, scoreAndLabels.map(scoreAndLabel => (scoreAndLabel._1, (scoreAndLabel._2, 1.0))))

/**
* Defaults `numBins` to 0.
*/
Expand Down Expand Up @@ -146,11 +160,13 @@ class BinaryClassificationMetrics @Since("1.3.0") (
private lazy val (
cumulativeCounts: RDD[(Double, BinaryLabelCounter)],
confusions: RDD[(Double, BinaryConfusionMatrix)]) = {
// Create a bin for each distinct score value, count positives and negatives within each bin,
// and then sort by score values in descending order.
val counts = scoreAndLabels.combineByKey(
createCombiner = (label: Double) => new BinaryLabelCounter(0L, 0L) += label,
mergeValue = (c: BinaryLabelCounter, label: Double) => c += label,
// Create a bin for each distinct score value, count weighted positives and
// negatives within each bin, and then sort by score values in descending order.
val counts = scoreAndLabelsWithWeights.combineByKey(
createCombiner = (labelAndWeight: (Double, Double)) =>
new BinaryLabelCounter(0L, 0L) += (labelAndWeight._1, labelAndWeight._2),
mergeValue = (c: BinaryLabelCounter, labelAndWeight: (Double, Double)) =>
c += (labelAndWeight._1, labelAndWeight._2),
mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2
).sortByKey(ascending = false)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ import org.apache.spark.sql.DataFrame
/**
* Evaluator for multiclass classification.
*
* @param predictionAndLabels an RDD of (prediction, label) pairs.
* @param predAndLabelsWithOptWeight an RDD of (prediction, label, weight) or
* (prediction, label) pairs.
*/
@Since("1.1.0")
class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Double)]) {
class MulticlassMetrics @Since("1.1.0") (predAndLabelsWithOptWeight: RDD[_]) {

/**
* An auxiliary constructor taking a DataFrame.
Expand All @@ -39,21 +40,33 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
private[mllib] def this(predictionAndLabels: DataFrame) =
this(predictionAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1))))

private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue()
private lazy val labelCount: Long = labelCountByClass.values.sum
private lazy val tpByClass: Map[Double, Int] = predictionAndLabels
.map { case (prediction, label) =>
(label, if (label == prediction) 1 else 0)
private lazy val labelCountByClass: Map[Double, Double] =
predAndLabelsWithOptWeight.map {
case (prediction: Double, label: Double) =>
(label, 1.0)
case (prediction: Double, label: Double, weight: Double) =>
(label, weight)
}.mapValues(weight => weight).reduceByKey(_ + _).collect().toMap
private lazy val labelCount: Double = labelCountByClass.values.sum
private lazy val tpByClass: Map[Double, Double] = predAndLabelsWithOptWeight
.map { case (prediction: Double, label: Double) =>
(label, if (label == prediction) 1.0 else 0.0)
case (prediction: Double, label: Double, weight: Double) =>
(label, if (label == prediction) weight else 0.0)
}.reduceByKey(_ + _)
.collectAsMap()
private lazy val fpByClass: Map[Double, Int] = predictionAndLabels
.map { case (prediction, label) =>
(prediction, if (prediction != label) 1 else 0)
private lazy val fpByClass: Map[Double, Double] = predAndLabelsWithOptWeight
.map { case (prediction: Double, label: Double) =>
(prediction, if (prediction != label) 1.0 else 0.0)
case (prediction: Double, label: Double, weight: Double) =>
(prediction, if (prediction != label) weight else 0.0)
}.reduceByKey(_ + _)
.collectAsMap()
private lazy val confusions = predictionAndLabels
.map { case (prediction, label) =>
((label, prediction), 1)
private lazy val confusions = predAndLabelsWithOptWeight
.map { case (prediction: Double, label: Double) =>
((label, prediction), 1.0)
case (prediction: Double, label: Double, weight: Double) =>
((label, prediction), weight)
}.reduceByKey(_ + _)
.collectAsMap()

Expand All @@ -71,7 +84,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
while (i < n) {
var j = 0
while (j < n) {
values(i + j * n) = confusions.getOrElse((labels(i), labels(j)), 0).toDouble
values(i + j * n) = confusions.getOrElse((labels(i), labels(j)), 0.0)
j += 1
}
i += 1
Expand All @@ -92,8 +105,8 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
*/
@Since("1.1.0")
def falsePositiveRate(label: Double): Double = {
val fp = fpByClass.getOrElse(label, 0)
fp.toDouble / (labelCount - labelCountByClass(label))
val fp = fpByClass.getOrElse(label, 0.0)
fp / (labelCount - labelCountByClass(label))
}

/**
Expand All @@ -103,7 +116,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
@Since("1.1.0")
def precision(label: Double): Double = {
val tp = tpByClass(label)
val fp = fpByClass.getOrElse(label, 0)
val fp = fpByClass.getOrElse(label, 0.0)
if (tp + fp == 0) 0 else tp.toDouble / (tp + fp)
}

Expand All @@ -112,7 +125,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
* @param label the label.
*/
@Since("1.1.0")
def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label)
def recall(label: Double): Double = tpByClass(label) / labelCountByClass(label)

/**
* Returns f-measure for a given label (category)
Expand Down Expand Up @@ -165,7 +178,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
* out of the total number of instances.)
*/
@Since("2.0.0")
lazy val accuracy: Double = tpByClass.values.sum.toDouble / labelCount
lazy val accuracy: Double = tpByClass.values.sum / labelCount

/**
* Returns weighted true positive rate
Expand Down
Loading