Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.sql.{Dataset, Row}
Expand All @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.DoubleType
@Since("1.5.0")
@Experimental
class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") override val uid: String)
extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {
extends Evaluator with HasPredictionCol with HasLabelCol
with HasWeightCol with DefaultParamsWritable {

@Since("1.5.0")
def this() = this(Identifiable.randomUID("mcEval"))
Expand Down Expand Up @@ -67,6 +68,10 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
@Since("1.5.0")
def setLabelCol(value: String): this.type = set(labelCol, value)

/** @group setParam */
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

setDefault(metricName -> "f1")

@Since("2.0.0")
Expand All @@ -75,11 +80,13 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
SchemaUtils.checkNumericType(schema, $(labelCol))

val predictionAndLabels =
dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map {
case Row(prediction: Double, label: Double) => (prediction, label)
val predictionAndLabelsWithWeights =
dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType),
if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)))
.rdd.map {
case Row(prediction: Double, label: Double, weight: Double) => (prediction, label, weight)
}
val metrics = new MulticlassMetrics(predictionAndLabels)
val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights)
val metric = $(metricName) match {
case "f1" => metrics.weightedFMeasure
case "weightedPrecision" => metrics.weightedPrecision
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,19 @@ import org.apache.spark.sql.DataFrame
/**
* Evaluator for multiclass classification.
*
* @param predictionAndLabels an RDD of (prediction, label) pairs.
* @param predAndLabelsWithOptWeight an RDD of (prediction, label, weight) or
* (prediction, label) pairs.
*/
@Since("1.1.0")
class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Double)]) {
class MulticlassMetrics @Since("1.1.0") (predAndLabelsWithOptWeight: RDD[_ <: Product]) {
val predLabelsWeight: RDD[(Double, Double, Double)] = predAndLabelsWithOptWeight.map {
case (prediction: Double, label: Double, weight: Double) =>
(prediction, label, weight)
case (prediction: Double, label: Double) =>
(prediction, label, 1.0)
case other =>
throw new IllegalArgumentException(s"Expected tuples, got $other")
}

/**
* An auxiliary constructor taking a DataFrame.
Expand All @@ -39,21 +48,29 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
private[mllib] def this(predictionAndLabels: DataFrame) =
this(predictionAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1))))

private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue()
private lazy val labelCount: Long = labelCountByClass.values.sum
private lazy val tpByClass: Map[Double, Int] = predictionAndLabels
.map { case (prediction, label) =>
(label, if (label == prediction) 1 else 0)
private lazy val labelCountByClass: Map[Double, Double] =
predLabelsWeight.map {
case (_: Double, label: Double, weight: Double) =>
(label, weight)
}.reduceByKey(_ + _)
.collectAsMap()
private lazy val labelCount: Double = labelCountByClass.values.sum
private lazy val tpByClass: Map[Double, Double] = predLabelsWeight
.map {
case (prediction: Double, label: Double, weight: Double) =>
(label, if (label == prediction) weight else 0.0)
}.reduceByKey(_ + _)
.collectAsMap()
private lazy val fpByClass: Map[Double, Int] = predictionAndLabels
.map { case (prediction, label) =>
(prediction, if (prediction != label) 1 else 0)
private lazy val fpByClass: Map[Double, Double] = predLabelsWeight
.map {
case (prediction: Double, label: Double, weight: Double) =>
(prediction, if (prediction != label) weight else 0.0)
}.reduceByKey(_ + _)
.collectAsMap()
private lazy val confusions = predictionAndLabels
.map { case (prediction, label) =>
((label, prediction), 1)
private lazy val confusions = predLabelsWeight
.map {
case (prediction: Double, label: Double, weight: Double) =>
((label, prediction), weight)
}.reduceByKey(_ + _)
.collectAsMap()

Expand All @@ -71,7 +88,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
while (i < n) {
var j = 0
while (j < n) {
values(i + j * n) = confusions.getOrElse((labels(i), labels(j)), 0).toDouble
values(i + j * n) = confusions.getOrElse((labels(i), labels(j)), 0.0)
j += 1
}
i += 1
Expand All @@ -92,8 +109,8 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
*/
@Since("1.1.0")
def falsePositiveRate(label: Double): Double = {
val fp = fpByClass.getOrElse(label, 0)
fp.toDouble / (labelCount - labelCountByClass(label))
val fp = fpByClass.getOrElse(label, 0.0)
fp / (labelCount - labelCountByClass(label))
}

/**
Expand All @@ -103,7 +120,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
@Since("1.1.0")
def precision(label: Double): Double = {
val tp = tpByClass(label)
val fp = fpByClass.getOrElse(label, 0)
val fp = fpByClass.getOrElse(label, 0.0)
if (tp + fp == 0) 0 else tp.toDouble / (tp + fp)
}

Expand All @@ -112,7 +129,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
* @param label the label.
*/
@Since("1.1.0")
def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label)
def recall(label: Double): Double = tpByClass(label) / labelCountByClass(label)

/**
* Returns f-measure for a given label (category)
Expand Down Expand Up @@ -140,7 +157,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
* out of the total number of instances.)
*/
@Since("2.0.0")
lazy val accuracy: Double = tpByClass.values.sum.toDouble / labelCount
lazy val accuracy: Double = tpByClass.values.sum / labelCount

/**
* Returns weighted true positive rate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@
package org.apache.spark.mllib.evaluation

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Matrices
import org.apache.spark.ml.linalg.Matrices
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext

class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {

private val delta = 1e-7

test("Multiclass evaluation metrics") {
/*
* Confusion matrix for 3-class classification with total 9 instances:
Expand All @@ -35,7 +39,6 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0),
(1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2)
val metrics = new MulticlassMetrics(predictionAndLabels)
val delta = 0.0000001
val tpRate0 = 2.0 / (2 + 2)
val tpRate1 = 3.0 / (3 + 1)
val tpRate2 = 1.0 / (1 + 0)
Expand All @@ -55,41 +58,122 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1)
val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2)

assert(metrics.confusionMatrix.toArray.sameElements(confusionMatrix.toArray))
assert(math.abs(metrics.truePositiveRate(0.0) - tpRate0) < delta)
assert(math.abs(metrics.truePositiveRate(1.0) - tpRate1) < delta)
assert(math.abs(metrics.truePositiveRate(2.0) - tpRate2) < delta)
assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta)
assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta)
assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta)
assert(math.abs(metrics.precision(0.0) - precision0) < delta)
assert(math.abs(metrics.precision(1.0) - precision1) < delta)
assert(math.abs(metrics.precision(2.0) - precision2) < delta)
assert(math.abs(metrics.recall(0.0) - recall0) < delta)
assert(math.abs(metrics.recall(1.0) - recall1) < delta)
assert(math.abs(metrics.recall(2.0) - recall2) < delta)
assert(math.abs(metrics.fMeasure(0.0) - f1measure0) < delta)
assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta)
assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta)
assert(math.abs(metrics.fMeasure(0.0, 2.0) - f2measure0) < delta)
assert(math.abs(metrics.fMeasure(1.0, 2.0) - f2measure1) < delta)
assert(math.abs(metrics.fMeasure(2.0, 2.0) - f2measure2) < delta)
assert(metrics.confusionMatrix.asML ~== confusionMatrix relTol delta)
assert(metrics.truePositiveRate(0.0) ~== tpRate0 relTol delta)
assert(metrics.truePositiveRate(1.0) ~== tpRate1 relTol delta)
assert(metrics.truePositiveRate(2.0) ~== tpRate2 relTol delta)
assert(metrics.falsePositiveRate(0.0) ~== fpRate0 relTol delta)
assert(metrics.falsePositiveRate(1.0) ~== fpRate1 relTol delta)
assert(metrics.falsePositiveRate(2.0) ~== fpRate2 relTol delta)
assert(metrics.precision(0.0) ~== precision0 relTol delta)
assert(metrics.precision(1.0) ~== precision1 relTol delta)
assert(metrics.precision(2.0) ~== precision2 relTol delta)
assert(metrics.recall(0.0) ~== recall0 relTol delta)
assert(metrics.recall(1.0) ~== recall1 relTol delta)
assert(metrics.recall(2.0) ~== recall2 relTol delta)
assert(metrics.fMeasure(0.0) ~== f1measure0 relTol delta)
assert(metrics.fMeasure(1.0) ~== f1measure1 relTol delta)
assert(metrics.fMeasure(2.0) ~== f1measure2 relTol delta)
assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 relTol delta)
assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 relTol delta)
assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 relTol delta)

assert(metrics.accuracy ~==
(2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1)) relTol delta)
assert(metrics.accuracy ~== metrics.weightedRecall relTol delta)
val weight0 = 4.0 / 9
val weight1 = 4.0 / 9
val weight2 = 1.0 / 9
assert(metrics.weightedTruePositiveRate ~==
(weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) relTol delta)
assert(metrics.weightedFalsePositiveRate ~==
(weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) relTol delta)
assert(metrics.weightedPrecision ~==
(weight0 * precision0 + weight1 * precision1 + weight2 * precision2) relTol delta)
assert(metrics.weightedRecall ~==
(weight0 * recall0 + weight1 * recall1 + weight2 * recall2) relTol delta)
assert(metrics.weightedFMeasure ~==
(weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) relTol delta)
assert(metrics.weightedFMeasure(2.0) ~==
(weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) relTol delta)
assert(metrics.labels === labels)
}

test("Multiclass evaluation metrics with weights") {
/*
* Confusion matrix for 3-class classification with total 9 instances with 2 weights:
* |2 * w1|1 * w2 |1 * w1| true class0 (4 instances)
* |1 * w2|2 * w1 + 1 * w2|0 | true class1 (4 instances)
* |0 |0 |1 * w2| true class2 (1 instance)
*/
val w1 = 2.2
val w2 = 1.5
val tw = 2.0 * w1 + 1.0 * w2 + 1.0 * w1 + 1.0 * w2 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2
val confusionMatrix = Matrices.dense(3, 3,
Array(2 * w1, 1 * w2, 0, 1 * w2, 2 * w1 + 1 * w2, 0, 1 * w1, 0, 1 * w2))
val labels = Array(0.0, 1.0, 2.0)
val predictionAndLabelsWithWeights = sc.parallelize(
Seq((0.0, 0.0, w1), (0.0, 1.0, w2), (0.0, 0.0, w1), (1.0, 0.0, w2),
(1.0, 1.0, w1), (1.0, 1.0, w2), (1.0, 1.0, w1), (2.0, 2.0, w2),
(2.0, 0.0, w1)), 2)
val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights)
val tpRate0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1)
val tpRate1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2)
val tpRate2 = (1.0 * w2) / (1.0 * w2 + 0)
val fpRate0 = (1.0 * w2) / (tw - (2.0 * w1 + 1.0 * w2 + 1.0 * w1))
val fpRate1 = (1.0 * w2) / (tw - (1.0 * w2 + 2.0 * w1 + 1.0 * w2))
val fpRate2 = (1.0 * w1) / (tw - (1.0 * w2))
val precision0 = (2.0 * w1) / (2 * w1 + 1 * w2)
val precision1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2)
val precision2 = (1.0 * w2) / (1 * w1 + 1 * w2)
val recall0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1)
val recall1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2)
val recall2 = (1.0 * w2) / (1.0 * w2 + 0)
val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
val f2measure0 = (1 + 2 * 2) * precision0 * recall0 / (2 * 2 * precision0 + recall0)
val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1)
val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2)

assert(metrics.confusionMatrix.asML ~== confusionMatrix relTol delta)
assert(metrics.truePositiveRate(0.0) ~== tpRate0 relTol delta)
assert(metrics.truePositiveRate(1.0) ~== tpRate1 relTol delta)
assert(metrics.truePositiveRate(2.0) ~== tpRate2 relTol delta)
assert(metrics.falsePositiveRate(0.0) ~== fpRate0 relTol delta)
assert(metrics.falsePositiveRate(1.0) ~== fpRate1 relTol delta)
assert(metrics.falsePositiveRate(2.0) ~== fpRate2 relTol delta)
assert(metrics.precision(0.0) ~== precision0 relTol delta)
assert(metrics.precision(1.0) ~== precision1 relTol delta)
assert(metrics.precision(2.0) ~== precision2 relTol delta)
assert(metrics.recall(0.0) ~== recall0 relTol delta)
assert(metrics.recall(1.0) ~== recall1 relTol delta)
assert(metrics.recall(2.0) ~== recall2 relTol delta)
assert(metrics.fMeasure(0.0) ~== f1measure0 relTol delta)
assert(metrics.fMeasure(1.0) ~== f1measure1 relTol delta)
assert(metrics.fMeasure(2.0) ~== f1measure2 relTol delta)
assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 relTol delta)
assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 relTol delta)
assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 relTol delta)

assert(math.abs(metrics.accuracy -
(2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta)
assert(math.abs(metrics.accuracy - metrics.weightedRecall) < delta)
assert(math.abs(metrics.weightedTruePositiveRate -
((4.0 / 9) * tpRate0 + (4.0 / 9) * tpRate1 + (1.0 / 9) * tpRate2)) < delta)
assert(math.abs(metrics.weightedFalsePositiveRate -
((4.0 / 9) * fpRate0 + (4.0 / 9) * fpRate1 + (1.0 / 9) * fpRate2)) < delta)
assert(math.abs(metrics.weightedPrecision -
((4.0 / 9) * precision0 + (4.0 / 9) * precision1 + (1.0 / 9) * precision2)) < delta)
assert(math.abs(metrics.weightedRecall -
((4.0 / 9) * recall0 + (4.0 / 9) * recall1 + (1.0 / 9) * recall2)) < delta)
assert(math.abs(metrics.weightedFMeasure -
((4.0 / 9) * f1measure0 + (4.0 / 9) * f1measure1 + (1.0 / 9) * f1measure2)) < delta)
assert(math.abs(metrics.weightedFMeasure(2.0) -
((4.0 / 9) * f2measure0 + (4.0 / 9) * f2measure1 + (1.0 / 9) * f2measure2)) < delta)
assert(metrics.labels.sameElements(labels))
assert(metrics.accuracy ~==
(2.0 * w1 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2) / tw relTol delta)
assert(metrics.accuracy ~== metrics.weightedRecall relTol delta)
val weight0 = (2 * w1 + 1 * w2 + 1 * w1) / tw
val weight1 = (1 * w2 + 2 * w1 + 1 * w2) / tw
val weight2 = 1 * w2 / tw
assert(metrics.weightedTruePositiveRate ~==
(weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) relTol delta)
assert(metrics.weightedFalsePositiveRate ~==
(weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) relTol delta)
assert(metrics.weightedPrecision ~==
(weight0 * precision0 + weight1 * precision1 + weight2 * precision2) relTol delta)
assert(metrics.weightedRecall ~==
(weight0 * recall0 + weight1 * recall1 + weight2 * recall2) relTol delta)
assert(metrics.weightedFMeasure ~==
(weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) relTol delta)
assert(metrics.weightedFMeasure(2.0) ~==
(weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) relTol delta)
assert(metrics.labels === labels)
}
}