Skip to content

Commit c3a77ad

Browse files
committed
Addressing reviewers comments mengxr
1 parent e2c91c3 commit c3a77ad

File tree

2 files changed

+59
-75
lines changed

2 files changed

+59
-75
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala

Lines changed: 47 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -17,118 +17,103 @@
1717

1818
package org.apache.spark.mllib.evaluation
1919

20+
import org.apache.spark.annotation.Experimental
2021
import org.apache.spark.rdd.RDD
2122
import org.apache.spark.Logging
2223
import org.apache.spark.SparkContext._
2324

2425
/**
2526
* Evaluator for multiclass classification.
26-
* NB: type Double both for prediction and label is retained
27-
* for compatibility with model.predict that returns Double
28-
* and MLUtils.loadLibSVMFile that loads class labels as Double
2927
*
3028
* @param predictionsAndLabels an RDD of (prediction, label) pairs.
3129
*/
30+
@Experimental
3231
class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Logging {
3332

34-
/* class = category; label = instance of class; prediction = instance of class */
35-
3633
private lazy val labelCountByClass = predictionsAndLabels.values.countByValue()
37-
private lazy val labelCount = labelCountByClass.foldLeft(0L){case(sum, (_, count)) => sum + count}
38-
private lazy val tpByClass = predictionsAndLabels.map{ case (prediction, label) =>
39-
(label, if(label == prediction) 1 else 0) }.reduceByKey{_ + _}.collectAsMap
40-
private lazy val fpByClass = predictionsAndLabels.map{ case (prediction, label) =>
41-
(prediction, if(prediction != label) 1 else 0) }.reduceByKey{_ + _}.collectAsMap
34+
private lazy val labelCount = labelCountByClass.values.sum
35+
private lazy val tpByClass = predictionsAndLabels
36+
.map{ case (prediction, label) =>
37+
(label, if (label == prediction) 1 else 0)
38+
}.reduceByKey(_ + _)
39+
.collectAsMap()
40+
private lazy val fpByClass = predictionsAndLabels
41+
.map{ case (prediction, label) =>
42+
(prediction, if (prediction != label) 1 else 0)
43+
}.reduceByKey(_ + _)
44+
.collectAsMap()
4245

4346
/**
44-
* Returns Precision for a given label (category)
47+
* Returns precision for a given label (category)
4548
* @param label the label.
46-
* @return Precision.
4749
*/
48-
def precision(label: Double): Double = if(tpByClass(label) + fpByClass.getOrElse(label, 0) == 0) 0
49-
else tpByClass(label).toDouble / (tpByClass(label) + fpByClass.getOrElse(label, 0)).toDouble
50+
def precision(label: Double): Double = {
51+
val tp = tpByClass(label)
52+
val fp = fpByClass.getOrElse(label, 0)
53+
if (tp + fp == 0) 0 else tp.toDouble / (tp + fp)
54+
}
5055

5156
/**
52-
* Returns Recall for a given label (category)
57+
* Returns recall for a given label (category)
5358
* @param label the label.
54-
* @return Recall.
5559
*/
56-
def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label).toDouble
60+
def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label)
5761

5862
/**
59-
* Returns F1-measure for a given label (category)
63+
* Returns f-measure for a given label (category)
6064
* @param label the label.
61-
* @return F1-measure.
6265
*/
63-
def f1Measure(label: Double): Double ={
66+
def fMeasure(label: Double, beta:Double = 1.0): Double = {
6467
val p = precision(label)
6568
val r = recall(label)
66-
if((p + r) == 0) 0 else 2 * p * r / (p + r)
69+
val betaSqrd = beta * beta
70+
if (p + r == 0) 0 else (1 + betaSqrd) * p * r / (betaSqrd * p + r)
6771
}
6872

6973
/**
70-
* Returns micro-averaged Recall
74+
* Returns micro-averaged recall
7175
* (equals to microPrecision and microF1measure for multiclass classifier)
72-
* @return microRecall.
7376
*/
74-
lazy val microRecall: Double =
75-
tpByClass.foldLeft(0L){case (sum,(_, tp)) => sum + tp}.toDouble / labelCount
77+
lazy val recall: Double =
78+
tpByClass.values.sum.toDouble / labelCount
7679

7780
/**
78-
* Returns micro-averaged Precision
81+
* Returns micro-averaged precision
7982
* (equals to microPrecision and microF1measure for multiclass classifier)
80-
* @return microPrecision.
8183
*/
82-
lazy val microPrecision: Double = microRecall
84+
lazy val precision: Double = recall
8385

8486
/**
85-
* Returns micro-averaged F1-measure
87+
* Returns micro-averaged f-measure
8688
* (equals to microPrecision and microRecall for multiclass classifier)
87-
* @return microF1measure.
8889
*/
89-
lazy val microF1Measure: Double = microRecall
90+
lazy val fMeasure: Double = recall
9091

9192
/**
92-
* Returns weighted averaged Recall
93-
* @return weightedRecall.
93+
* Returns weighted averaged recall
94+
* (equals to micro-averaged precision, recall and f-measure)
9495
*/
95-
lazy val weightedRecall: Double = labelCountByClass.foldLeft(0.0){case(wRecall, (category, count)) =>
96-
wRecall + recall(category) * count.toDouble / labelCount}
96+
lazy val weightedRecall: Double = labelCountByClass.map { case (category, count) =>
97+
recall(category) * count.toDouble / labelCount
98+
}.sum
9799

98100
/**
99-
* Returns weighted averaged Precision
100-
* @return weightedPrecision.
101+
* Returns weighted averaged precision
101102
*/
102-
lazy val weightedPrecision: Double =
103-
labelCountByClass.foldLeft(0.0){case(wPrecision, (category, count)) =>
104-
wPrecision + precision(category) * count.toDouble / labelCount}
103+
lazy val weightedPrecision: Double = labelCountByClass.map { case (category, count) =>
104+
precision(category) * count.toDouble / labelCount
105+
}.sum
105106

106107
/**
107-
* Returns weighted averaged F1-measure
108-
* @return weightedF1Measure.
108+
* Returns weighted averaged f1-measure
109109
*/
110-
lazy val weightedF1Measure: Double =
111-
labelCountByClass.foldLeft(0.0){case(wF1measure, (category, count)) =>
112-
wF1measure + f1Measure(category) * count.toDouble / labelCount}
110+
lazy val weightedF1Measure: Double = labelCountByClass.map { case (category, count) =>
111+
fMeasure(category) * count.toDouble / labelCount
112+
}.sum
113113

114114
/**
115-
* Returns map with Precisions for individual classes
116-
* @return precisionPerClass.
115+
* Returns the sequence of labels in ascending order
117116
*/
118-
lazy val precisionPerClass =
119-
labelCountByClass.map{case (category, _) => (category, precision(category))}.toMap
117+
lazy val labels = tpByClass.unzip._1.toSeq.sorted
120118

121-
/**
122-
* Returns map with Recalls for individual classes
123-
* @return recallPerClass.
124-
*/
125-
lazy val recallPerClass =
126-
labelCountByClass.map{case (category, _) => (category, recall(category))}.toMap
127-
128-
/**
129-
* Returns map with F1-measures for individual classes
130-
* @return f1MeasurePerClass.
131-
*/
132-
lazy val f1MeasurePerClass =
133-
labelCountByClass.map{case (category, _) => (category, f1Measure(category))}.toMap
134119
}

mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717

1818
package org.apache.spark.mllib.evaluation
1919

20-
import org.apache.spark.mllib.util.LocalSparkContext
2120
import org.scalatest.FunSuite
2221

22+
import org.apache.spark.mllib.util.LocalSparkContext
23+
2324
class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
2425
test("Multiclass evaluation metrics") {
2526
/*
@@ -29,12 +30,12 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
2930
* |0|0|1| true class2 (1 instance)
3031
*
3132
*/
33+
val labels = Seq(0.0, 1.0, 2.0)
3234
val scoreAndLabels = sc.parallelize(
3335
Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0),
3436
(1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2)
3537
val metrics = new MulticlassMetrics(scoreAndLabels)
36-
37-
val delta = 0.00001
38+
val delta = 0.0000001
3839
val precision0 = 2.0 / (2.0 + 1.0)
3940
val precision1 = 3.0 / (3.0 + 1.0)
4041
val precision2 = 1.0 / (1.0 + 1.0)
@@ -44,28 +45,26 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
4445
val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
4546
val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
4647
val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
47-
4848
assert(math.abs(metrics.precision(0.0) - precision0) < delta)
4949
assert(math.abs(metrics.precision(1.0) - precision1) < delta)
5050
assert(math.abs(metrics.precision(2.0) - precision2) < delta)
5151
assert(math.abs(metrics.recall(0.0) - recall0) < delta)
5252
assert(math.abs(metrics.recall(1.0) - recall1) < delta)
5353
assert(math.abs(metrics.recall(2.0) - recall2) < delta)
54-
assert(math.abs(metrics.f1Measure(0.0) - f1measure0) < delta)
55-
assert(math.abs(metrics.f1Measure(1.0) - f1measure1) < delta)
56-
assert(math.abs(metrics.f1Measure(2.0) - f1measure2) < delta)
57-
58-
assert(math.abs(metrics.microRecall -
54+
assert(math.abs(metrics.fMeasure(0.0) - f1measure0) < delta)
55+
assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta)
56+
assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta)
57+
assert(math.abs(metrics.recall -
5958
(2.0 + 3.0 + 1.0) / ((2.0 + 3.0 + 1.0) + (1.0 + 1.0 + 1.0))) < delta)
60-
assert(math.abs(metrics.microRecall - metrics.microPrecision) < delta)
61-
assert(math.abs(metrics.microRecall - metrics.microF1Measure) < delta)
62-
assert(math.abs(metrics.microRecall - metrics.weightedRecall) < delta)
59+
assert(math.abs(metrics.recall - metrics.precision) < delta)
60+
assert(math.abs(metrics.recall - metrics.fMeasure) < delta)
61+
assert(math.abs(metrics.recall - metrics.weightedRecall) < delta)
6362
assert(math.abs(metrics.weightedPrecision -
6463
((4.0 / 9.0) * precision0 + (4.0 / 9.0) * precision1 + (1.0 / 9.0) * precision2)) < delta)
6564
assert(math.abs(metrics.weightedRecall -
6665
((4.0 / 9.0) * recall0 + (4.0 / 9.0) * recall1 + (1.0 / 9.0) * recall2)) < delta)
6766
assert(math.abs(metrics.weightedF1Measure -
6867
((4.0 / 9.0) * f1measure0 + (4.0 / 9.0) * f1measure1 + (1.0 / 9.0) * f1measure2)) < delta)
69-
68+
assert(metrics.labels == labels)
7069
}
7170
}

0 commit comments

Comments
 (0)