Skip to content

Commit a416fa0

Browse files
committed
updated based on latest comments
1 parent a80d890 commit a416fa0

File tree

1 file changed

+61
-63
lines changed

1 file changed

+61
-63
lines changed

mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala

Lines changed: 61 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,12 @@
1818
package org.apache.spark.mllib.evaluation
1919

2020
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.ml.linalg.Matrices
2122
import org.apache.spark.ml.util.TestingUtils._
22-
import org.apache.spark.mllib.linalg.Matrices
2323
import org.apache.spark.mllib.util.MLlibTestSparkContext
2424

2525
class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
2626

27-
import testImplicits._
28-
2927
val delta = 1e-7
3028

3129
test("Multiclass evaluation metrics") {
@@ -60,47 +58,47 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
6058
val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1)
6159
val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2)
6260

63-
assert(metrics.confusionMatrix.asML ~== confusionMatrix.asML relTol delta)
64-
assert(metrics.truePositiveRate(0.0) ~== tpRate0 absTol delta)
65-
assert(metrics.truePositiveRate(1.0) ~== tpRate1 absTol delta)
66-
assert(metrics.truePositiveRate(2.0) ~== tpRate2 absTol delta)
67-
assert(metrics.falsePositiveRate(0.0) ~== fpRate0 absTol delta)
68-
assert(metrics.falsePositiveRate(1.0) ~== fpRate1 absTol delta)
69-
assert(metrics.falsePositiveRate(2.0) ~== fpRate2 absTol delta)
70-
assert(metrics.precision(0.0) ~== precision0 absTol delta)
71-
assert(metrics.precision(1.0) ~== precision1 absTol delta)
72-
assert(metrics.precision(2.0) ~== precision2 absTol delta)
73-
assert(metrics.recall(0.0) ~== recall0 absTol delta)
74-
assert(metrics.recall(1.0) ~== recall1 absTol delta)
75-
assert(metrics.recall(2.0) ~== recall2 absTol delta)
76-
assert(metrics.fMeasure(0.0) ~== f1measure0 absTol delta)
77-
assert(metrics.fMeasure(1.0) ~== f1measure1 absTol delta)
78-
assert(metrics.fMeasure(2.0) ~== f1measure2 absTol delta)
79-
assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 absTol delta)
80-
assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 absTol delta)
81-
assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 absTol delta)
61+
assert(metrics.confusionMatrix.asML ~== confusionMatrix relTol delta)
62+
assert(metrics.truePositiveRate(0.0) ~== tpRate0 relTol delta)
63+
assert(metrics.truePositiveRate(1.0) ~== tpRate1 relTol delta)
64+
assert(metrics.truePositiveRate(2.0) ~== tpRate2 relTol delta)
65+
assert(metrics.falsePositiveRate(0.0) ~== fpRate0 relTol delta)
66+
assert(metrics.falsePositiveRate(1.0) ~== fpRate1 relTol delta)
67+
assert(metrics.falsePositiveRate(2.0) ~== fpRate2 relTol delta)
68+
assert(metrics.precision(0.0) ~== precision0 relTol delta)
69+
assert(metrics.precision(1.0) ~== precision1 relTol delta)
70+
assert(metrics.precision(2.0) ~== precision2 relTol delta)
71+
assert(metrics.recall(0.0) ~== recall0 relTol delta)
72+
assert(metrics.recall(1.0) ~== recall1 relTol delta)
73+
assert(metrics.recall(2.0) ~== recall2 relTol delta)
74+
assert(metrics.fMeasure(0.0) ~== f1measure0 relTol delta)
75+
assert(metrics.fMeasure(1.0) ~== f1measure1 relTol delta)
76+
assert(metrics.fMeasure(2.0) ~== f1measure2 relTol delta)
77+
assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 relTol delta)
78+
assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 relTol delta)
79+
assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 relTol delta)
8280

8381
assert(metrics.accuracy ~==
84-
(2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1)) absTol delta)
85-
assert(metrics.accuracy ~== metrics.precision absTol delta)
86-
assert(metrics.accuracy ~== metrics.recall absTol delta)
87-
assert(metrics.accuracy ~== metrics.fMeasure absTol delta)
88-
assert(metrics.accuracy ~== metrics.weightedRecall absTol delta)
82+
(2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1)) relTol delta)
83+
assert(metrics.accuracy ~== metrics.precision relTol delta)
84+
assert(metrics.accuracy ~== metrics.recall relTol delta)
85+
assert(metrics.accuracy ~== metrics.fMeasure relTol delta)
86+
assert(metrics.accuracy ~== metrics.weightedRecall relTol delta)
8987
val weight0 = 4.0 / 9
9088
val weight1 = 4.0 / 9
9189
val weight2 = 1.0 / 9
9290
assert(metrics.weightedTruePositiveRate ~==
93-
(weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) absTol delta)
91+
(weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) relTol delta)
9492
assert(metrics.weightedFalsePositiveRate ~==
95-
(weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) absTol delta)
93+
(weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) relTol delta)
9694
assert(metrics.weightedPrecision ~==
97-
(weight0 * precision0 + weight1 * precision1 + weight2 * precision2) absTol delta)
95+
(weight0 * precision0 + weight1 * precision1 + weight2 * precision2) relTol delta)
9896
assert(metrics.weightedRecall ~==
99-
(weight0 * recall0 + weight1 * recall1 + weight2 * recall2) absTol delta)
97+
(weight0 * recall0 + weight1 * recall1 + weight2 * recall2) relTol delta)
10098
assert(metrics.weightedFMeasure ~==
101-
(weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) absTol delta)
99+
(weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) relTol delta)
102100
assert(metrics.weightedFMeasure(2.0) ~==
103-
(weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) absTol delta)
101+
(weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) relTol delta)
104102
assert(metrics.labels === labels)
105103
}
106104

@@ -141,47 +139,47 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
141139
val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1)
142140
val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2)
143141

144-
assert(metrics.confusionMatrix.asML ~== confusionMatrix.asML relTol delta)
145-
assert(metrics.truePositiveRate(0.0) ~== tpRate0 absTol delta)
146-
assert(metrics.truePositiveRate(1.0) ~== tpRate1 absTol delta)
147-
assert(metrics.truePositiveRate(2.0) ~== tpRate2 absTol delta)
148-
assert(metrics.falsePositiveRate(0.0) ~== fpRate0 absTol delta)
149-
assert(metrics.falsePositiveRate(1.0) ~== fpRate1 absTol delta)
150-
assert(metrics.falsePositiveRate(2.0) ~== fpRate2 absTol delta)
151-
assert(metrics.precision(0.0) ~== precision0 absTol delta)
152-
assert(metrics.precision(1.0) ~== precision1 absTol delta)
153-
assert(metrics.precision(2.0) ~== precision2 absTol delta)
154-
assert(metrics.recall(0.0) ~== recall0 absTol delta)
155-
assert(metrics.recall(1.0) ~== recall1 absTol delta)
156-
assert(metrics.recall(2.0) ~== recall2 absTol delta)
157-
assert(metrics.fMeasure(0.0) ~== f1measure0 absTol delta)
158-
assert(metrics.fMeasure(1.0) ~== f1measure1 absTol delta)
159-
assert(metrics.fMeasure(2.0) ~== f1measure2 absTol delta)
160-
assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 absTol delta)
161-
assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 absTol delta)
162-
assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 absTol delta)
142+
assert(metrics.confusionMatrix.asML ~== confusionMatrix relTol delta)
143+
assert(metrics.truePositiveRate(0.0) ~== tpRate0 relTol delta)
144+
assert(metrics.truePositiveRate(1.0) ~== tpRate1 relTol delta)
145+
assert(metrics.truePositiveRate(2.0) ~== tpRate2 relTol delta)
146+
assert(metrics.falsePositiveRate(0.0) ~== fpRate0 relTol delta)
147+
assert(metrics.falsePositiveRate(1.0) ~== fpRate1 relTol delta)
148+
assert(metrics.falsePositiveRate(2.0) ~== fpRate2 relTol delta)
149+
assert(metrics.precision(0.0) ~== precision0 relTol delta)
150+
assert(metrics.precision(1.0) ~== precision1 relTol delta)
151+
assert(metrics.precision(2.0) ~== precision2 relTol delta)
152+
assert(metrics.recall(0.0) ~== recall0 relTol delta)
153+
assert(metrics.recall(1.0) ~== recall1 relTol delta)
154+
assert(metrics.recall(2.0) ~== recall2 relTol delta)
155+
assert(metrics.fMeasure(0.0) ~== f1measure0 relTol delta)
156+
assert(metrics.fMeasure(1.0) ~== f1measure1 relTol delta)
157+
assert(metrics.fMeasure(2.0) ~== f1measure2 relTol delta)
158+
assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 relTol delta)
159+
assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 relTol delta)
160+
assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 relTol delta)
163161

164162
assert(metrics.accuracy ~==
165-
(2.0 * w1 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2) / tw absTol delta)
166-
assert(metrics.accuracy ~== metrics.precision absTol delta)
167-
assert(metrics.accuracy ~== metrics.recall absTol delta)
168-
assert(metrics.accuracy ~== metrics.fMeasure absTol delta)
169-
assert(metrics.accuracy ~== metrics.weightedRecall absTol delta)
163+
(2.0 * w1 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2) / tw relTol delta)
164+
assert(metrics.accuracy ~== metrics.precision relTol delta)
165+
assert(metrics.accuracy ~== metrics.recall relTol delta)
166+
assert(metrics.accuracy ~== metrics.fMeasure relTol delta)
167+
assert(metrics.accuracy ~== metrics.weightedRecall relTol delta)
170168
val weight0 = (2 * w1 + 1 * w2 + 1 * w1) / tw
171169
val weight1 = (1 * w2 + 2 * w1 + 1 * w2) / tw
172170
val weight2 = 1 * w2 / tw
173171
assert(metrics.weightedTruePositiveRate ~==
174-
(weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) absTol delta)
172+
(weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) relTol delta)
175173
assert(metrics.weightedFalsePositiveRate ~==
176-
(weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) absTol delta)
174+
(weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) relTol delta)
177175
assert(metrics.weightedPrecision ~==
178-
(weight0 * precision0 + weight1 * precision1 + weight2 * precision2) absTol delta)
176+
(weight0 * precision0 + weight1 * precision1 + weight2 * precision2) relTol delta)
179177
assert(metrics.weightedRecall ~==
180-
(weight0 * recall0 + weight1 * recall1 + weight2 * recall2) absTol delta)
178+
(weight0 * recall0 + weight1 * recall1 + weight2 * recall2) relTol delta)
181179
assert(metrics.weightedFMeasure ~==
182-
(weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) absTol delta)
180+
(weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) relTol delta)
183181
assert(metrics.weightedFMeasure(2.0) ~==
184-
(weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) absTol delta)
182+
(weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) relTol delta)
185183
assert(metrics.labels === labels)
186184
}
187185
}

0 commit comments

Comments
 (0)