Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ import org.apache.spark.sql.types.DoubleType
@Since("1.2.0")
@Experimental
class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Evaluator with HasRawPredictionCol with HasLabelCol with DefaultParamsWritable {
extends Evaluator with HasRawPredictionCol with HasLabelCol
with HasWeightCol with DefaultParamsWritable {

@Since("1.2.0")
def this() = this(Identifiable.randomUID("binEval"))
Expand Down Expand Up @@ -68,21 +69,34 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
@Since("1.2.0")
def setLabelCol(value: String): this.type = set(labelCol, value)

/** @group setParam */
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

setDefault(metricName -> "areaUnderROC")

@Since("2.0.0")
override def evaluate(dataset: Dataset[_]): Double = {
val schema = dataset.schema
SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT))
SchemaUtils.checkNumericType(schema, $(labelCol))
if (isDefined(weightCol)) {
SchemaUtils.checkNumericType(schema, $(weightCol))
}

// TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2.
val scoreAndLabels =
dataset.select(col($(rawPredictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map {
case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label)
case Row(rawPrediction: Double, label: Double) => (rawPrediction, label)
val scoreAndLabelsWithWeights =
dataset.select(
col($(rawPredictionCol)),
col($(labelCol)).cast(DoubleType),
if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0)
else col($(weightCol)).cast(DoubleType)).rdd.map {
case Row(rawPrediction: Vector, label: Double, weight: Double) =>
(rawPrediction(1), label, weight)
case Row(rawPrediction: Double, label: Double, weight: Double) =>
(rawPrediction, label, weight)
}
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
val metrics = new BinaryClassificationMetrics(scoreAndLabelsWithWeights)
val metric = $(metricName) match {
case "areaUnderROC" => metrics.areaUnderROC()
case "areaUnderPR" => metrics.areaUnderPR()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.evaluation.binary._
import org.apache.spark.rdd.{RDD, UnionRDD}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.{DataFrame, Row}

/**
* Evaluator for binary classification.
*
* @param scoreAndLabels an RDD of (score, label) pairs.
* @param scoreAndLabels an RDD of (score, label) or (score, label, weight) tuples.
* @param numBins if greater than 0, then the curves (ROC curve, PR curve) computed internally
* will be down-sampled to this many "bins". If 0, no down-sampling will occur.
* This is useful because the curve contains a point for each distinct score
Expand All @@ -41,9 +41,19 @@ import org.apache.spark.sql.DataFrame
* partition boundaries.
*/
@Since("1.0.0")
class BinaryClassificationMetrics @Since("1.3.0") (
@Since("1.3.0") val scoreAndLabels: RDD[(Double, Double)],
@Since("1.3.0") val numBins: Int) extends Logging {
class BinaryClassificationMetrics @Since("3.0.0") (
@Since("1.3.0") val scoreAndLabels: RDD[_ <: Product],
@Since("1.3.0") val numBins: Int = 1000)
extends Logging {
val scoreLabelsWeight: RDD[(Double, (Double, Double))] = scoreAndLabels.map {
case (prediction: Double, label: Double, weight: Double) =>
require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
(prediction, (label, weight))
case (prediction: Double, label: Double) =>
(prediction, (label, 1.0))
case other =>
throw new IllegalArgumentException(s"Expected tuples, got $other")
}

require(numBins >= 0, "numBins must be nonnegative")

Expand All @@ -58,7 +68,14 @@ class BinaryClassificationMetrics @Since("1.3.0") (
* @param scoreAndLabels a DataFrame with two double columns: score and label
*/
private[mllib] def this(scoreAndLabels: DataFrame) =
this(scoreAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1))))
this(scoreAndLabels.rdd.map {
case Row(prediction: Double, label: Double, weight: Double) =>
(prediction, label, weight)
case Row(prediction: Double, label: Double) =>
(prediction, label, 1.0)
case other =>
throw new IllegalArgumentException(s"Expected Row of tuples, got $other")
})

/**
* Unpersist intermediate RDDs used in the computation.
Expand Down Expand Up @@ -146,11 +163,13 @@ class BinaryClassificationMetrics @Since("1.3.0") (
private lazy val (
cumulativeCounts: RDD[(Double, BinaryLabelCounter)],
confusions: RDD[(Double, BinaryConfusionMatrix)]) = {
// Create a bin for each distinct score value, count positives and negatives within each bin,
// and then sort by score values in descending order.
val counts = scoreAndLabels.combineByKey(
createCombiner = (label: Double) => new BinaryLabelCounter(0L, 0L) += label,
mergeValue = (c: BinaryLabelCounter, label: Double) => c += label,
// Create a bin for each distinct score value, count weighted positives and
// negatives within each bin, and then sort by score values in descending order.
val counts = scoreLabelsWeight.combineByKey(
createCombiner = (labelAndWeight: (Double, Double)) =>
new BinaryLabelCounter(0.0, 0.0) += (labelAndWeight._1, labelAndWeight._2),
mergeValue = (c: BinaryLabelCounter, labelAndWeight: (Double, Double)) =>
c += (labelAndWeight._1, labelAndWeight._2),
mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2
).sortByKey(ascending = false)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ import scala.collection.Map
import org.apache.spark.annotation.Since
import org.apache.spark.mllib.linalg.{Matrices, Matrix}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.{DataFrame, Row}

/**
* Evaluator for multiclass classification.
*
* @param predAndLabelsWithOptWeight an RDD of (prediction, label, weight) or
* (prediction, label) pairs.
* @param predictionAndLabels an RDD of (prediction, label, weight) or
* (prediction, label) tuples.
*/
@Since("1.1.0")
class MulticlassMetrics @Since("1.1.0") (predAndLabelsWithOptWeight: RDD[_ <: Product]) {
val predLabelsWeight: RDD[(Double, Double, Double)] = predAndLabelsWithOptWeight.map {
class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[_ <: Product]) {
val predLabelsWeight: RDD[(Double, Double, Double)] = predictionAndLabels.map {
case (prediction: Double, label: Double, weight: Double) =>
(prediction, label, weight)
case (prediction: Double, label: Double) =>
Expand All @@ -46,7 +46,14 @@ class MulticlassMetrics @Since("1.1.0") (predAndLabelsWithOptWeight: RDD[_ <: Pr
* @param predictionAndLabels a DataFrame with two double columns: prediction and label
*/
private[mllib] def this(predictionAndLabels: DataFrame) =
this(predictionAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1))))
this(predictionAndLabels.rdd.map {
case Row(prediction: Double, label: Double, weight: Double) =>
(prediction, label, weight)
case Row(prediction: Double, label: Double) =>
(prediction, label, 1.0)
case other =>
throw new IllegalArgumentException(s"Expected Row of tuples, got $other")
})

private lazy val labelCountByClass: Map[Double, Double] =
predLabelsWeight.map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,33 +27,33 @@ private[evaluation] trait BinaryClassificationMetricComputer extends Serializabl
/** Precision. Defined as 1.0 when there are no positive examples. */
private[evaluation] object Precision extends BinaryClassificationMetricComputer {
override def apply(c: BinaryConfusionMatrix): Double = {
val totalPositives = c.numTruePositives + c.numFalsePositives
if (totalPositives == 0) {
val totalPositives = c.weightedTruePositives + c.weightedFalsePositives
if (totalPositives == 0.0) {
1.0
} else {
c.numTruePositives.toDouble / totalPositives
c.weightedTruePositives / totalPositives
}
}
}

/** False positive rate. Defined as 0.0 when there are no negative examples. */
private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer {
override def apply(c: BinaryConfusionMatrix): Double = {
if (c.numNegatives == 0) {
if (c.weightedNegatives == 0.0) {
0.0
} else {
c.numFalsePositives.toDouble / c.numNegatives
c.weightedFalsePositives / c.weightedNegatives
}
}
}

/** Recall. Defined as 0.0 when there are no positive examples. */
private[evaluation] object Recall extends BinaryClassificationMetricComputer {
override def apply(c: BinaryConfusionMatrix): Double = {
if (c.numPositives == 0) {
if (c.weightedPositives == 0.0) {
0.0
} else {
c.numTruePositives.toDouble / c.numPositives
c.weightedTruePositives / c.weightedPositives
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,23 @@ package org.apache.spark.mllib.evaluation.binary
* Trait for a binary confusion matrix.
*/
private[evaluation] trait BinaryConfusionMatrix {
/** number of true positives */
def numTruePositives: Long
/** weighted number of true positives */
def weightedTruePositives: Double

/** number of false positives */
def numFalsePositives: Long
/** weighted number of false positives */
def weightedFalsePositives: Double

/** number of false negatives */
def numFalseNegatives: Long
/** weighted number of false negatives */
def weightedFalseNegatives: Double

/** number of true negatives */
def numTrueNegatives: Long
/** weighted number of true negatives */
def weightedTrueNegatives: Double

/** number of positives */
def numPositives: Long = numTruePositives + numFalseNegatives
/** weighted number of positives */
def weightedPositives: Double = weightedTruePositives + weightedFalseNegatives

/** number of negatives */
def numNegatives: Long = numFalsePositives + numTrueNegatives
/** weighted number of negatives */
def weightedNegatives: Double = weightedFalsePositives + weightedTrueNegatives
}

/**
Expand All @@ -51,20 +51,22 @@ private[evaluation] case class BinaryConfusionMatrixImpl(
totalCount: BinaryLabelCounter) extends BinaryConfusionMatrix {

/** number of true positives */
override def numTruePositives: Long = count.numPositives
override def weightedTruePositives: Double = count.weightedNumPositives

/** number of false positives */
override def numFalsePositives: Long = count.numNegatives
override def weightedFalsePositives: Double = count.weightedNumNegatives

/** number of false negatives */
override def numFalseNegatives: Long = totalCount.numPositives - count.numPositives
override def weightedFalseNegatives: Double =
totalCount.weightedNumPositives - count.weightedNumPositives

/** number of true negatives */
override def numTrueNegatives: Long = totalCount.numNegatives - count.numNegatives
override def weightedTrueNegatives: Double =
totalCount.weightedNumNegatives - count.weightedNumNegatives

/** number of positives */
override def numPositives: Long = totalCount.numPositives
override def weightedPositives: Double = totalCount.weightedNumPositives

/** number of negatives */
override def numNegatives: Long = totalCount.numNegatives
override def weightedNegatives: Double = totalCount.weightedNumNegatives
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,39 @@ package org.apache.spark.mllib.evaluation.binary
/**
* A counter for positives and negatives.
*
* @param numPositives number of positive labels
* @param numNegatives number of negative labels
* @param weightedNumPositives weighted number of positive labels
* @param weightedNumNegatives weighted number of negative labels
*/
private[evaluation] class BinaryLabelCounter(
var numPositives: Long = 0L,
var numNegatives: Long = 0L) extends Serializable {
var weightedNumPositives: Double = 0.0,
var weightedNumNegatives: Double = 0.0) extends Serializable {

/** Processes a label. */
def +=(label: Double): BinaryLabelCounter = {
// Though we assume 1.0 for positive and 0.0 for negative, the following check will handle
// -1.0 for negative as well.
if (label > 0.5) numPositives += 1L else numNegatives += 1L
if (label > 0.5) weightedNumPositives += 1.0 else weightedNumNegatives += 1.0
this
}

/** Processes a label with a weight. */
def +=(label: Double, weight: Double): BinaryLabelCounter = {
// Though we assume 1.0 for positive and 0.0 for negative, the following check will handle
// -1.0 for negative as well.
if (label > 0.5) weightedNumPositives += weight else weightedNumNegatives += weight
this
}

/** Merges another counter. */
def +=(other: BinaryLabelCounter): BinaryLabelCounter = {
numPositives += other.numPositives
numNegatives += other.numNegatives
weightedNumPositives += other.weightedNumPositives
weightedNumNegatives += other.weightedNumNegatives
this
}

override def clone: BinaryLabelCounter = {
new BinaryLabelCounter(numPositives, numNegatives)
new BinaryLabelCounter(weightedNumPositives, weightedNumNegatives)
}

override def toString: String = s"{numPos: $numPositives, numNeg: $numNegatives}"
override def toString: String = s"{numPos: $weightedNumPositives, numNeg: $weightedNumNegatives}"
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,23 @@ class BinaryClassificationEvaluatorSuite
.setMetricName("areaUnderPR")

val vectorDF = Seq(
(0d, Vectors.dense(12, 2.5)),
(1d, Vectors.dense(1, 3)),
(0d, Vectors.dense(10, 2))
(0.0, Vectors.dense(12, 2.5)),
(1.0, Vectors.dense(1, 3)),
(0.0, Vectors.dense(10, 2))
).toDF("label", "rawPrediction")
assert(evaluator.evaluate(vectorDF) === 1.0)

val doubleDF = Seq(
(0d, 0d),
(1d, 1d),
(0d, 0d)
(0.0, 0.0),
(1.0, 1.0),
(0.0, 0.0)
).toDF("label", "rawPrediction")
assert(evaluator.evaluate(doubleDF) === 1.0)

val stringDF = Seq(
(0d, "0d"),
(1d, "1d"),
(0d, "0d")
(0.0, "0.0"),
(1.0, "1.0"),
(0.0, "0.0")
).toDF("label", "rawPrediction")
val thrown = intercept[IllegalArgumentException] {
evaluator.evaluate(stringDF)
Expand All @@ -71,6 +71,33 @@ class BinaryClassificationEvaluatorSuite
assert(thrown.getMessage.replace("\n", "") contains "but was actually of type string.")
}

test("should accept weight column") {
val weightCol = "weight"
// get metric with weight column
val evaluator = new BinaryClassificationEvaluator()
.setMetricName("areaUnderROC").setWeightCol(weightCol)
val vectorDF = Seq(
(0.0, Vectors.dense(2.5, 12), 1.0),
(1.0, Vectors.dense(1, 3), 1.0),
(0.0, Vectors.dense(10, 2), 1.0)
).toDF("label", "rawPrediction", weightCol)
val result = evaluator.evaluate(vectorDF)
// without weight column
val evaluator2 = new BinaryClassificationEvaluator()
.setMetricName("areaUnderROC")
val result2 = evaluator2.evaluate(vectorDF)
assert(result === result2)
// use different weights, validate metrics change
val vectorDF2 = Seq(
(0.0, Vectors.dense(2.5, 12), 2.5),
(1.0, Vectors.dense(1, 3), 0.1),
(0.0, Vectors.dense(10, 2), 2.0)
).toDF("label", "rawPrediction", weightCol)
val result3 = evaluator.evaluate(vectorDF2)
// Since wrong result weighted more heavily, expect the score to be lower
assert(result3 < result)
}

test("should support all NumericType labels and not support other types") {
val evaluator = new BinaryClassificationEvaluator().setRawPredictionCol("prediction")
MLTestingUtils.checkNumericTypes(evaluator, spark)
Expand Down
Loading