|
18 | 18 | package org.apache.spark.mllib.evaluation |
19 | 19 |
|
20 | 20 | import org.apache.spark.SparkFunSuite |
| 21 | +import org.apache.spark.ml.linalg.Matrices |
21 | 22 | import org.apache.spark.ml.util.TestingUtils._ |
22 | | -import org.apache.spark.mllib.linalg.Matrices |
23 | 23 | import org.apache.spark.mllib.util.MLlibTestSparkContext |
24 | 24 |
|
25 | 25 | class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { |
26 | 26 |
|
27 | | - import testImplicits._ |
28 | | - |
29 | 27 | val delta = 1e-7 |
30 | 28 |
|
31 | 29 | test("Multiclass evaluation metrics") { |
@@ -60,47 +58,47 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { |
60 | 58 | val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1) |
61 | 59 | val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2) |
62 | 60 |
|
63 | | - assert(metrics.confusionMatrix.asML ~== confusionMatrix.asML relTol delta) |
64 | | - assert(metrics.truePositiveRate(0.0) ~== tpRate0 absTol delta) |
65 | | - assert(metrics.truePositiveRate(1.0) ~== tpRate1 absTol delta) |
66 | | - assert(metrics.truePositiveRate(2.0) ~== tpRate2 absTol delta) |
67 | | - assert(metrics.falsePositiveRate(0.0) ~== fpRate0 absTol delta) |
68 | | - assert(metrics.falsePositiveRate(1.0) ~== fpRate1 absTol delta) |
69 | | - assert(metrics.falsePositiveRate(2.0) ~== fpRate2 absTol delta) |
70 | | - assert(metrics.precision(0.0) ~== precision0 absTol delta) |
71 | | - assert(metrics.precision(1.0) ~== precision1 absTol delta) |
72 | | - assert(metrics.precision(2.0) ~== precision2 absTol delta) |
73 | | - assert(metrics.recall(0.0) ~== recall0 absTol delta) |
74 | | - assert(metrics.recall(1.0) ~== recall1 absTol delta) |
75 | | - assert(metrics.recall(2.0) ~== recall2 absTol delta) |
76 | | - assert(metrics.fMeasure(0.0) ~== f1measure0 absTol delta) |
77 | | - assert(metrics.fMeasure(1.0) ~== f1measure1 absTol delta) |
78 | | - assert(metrics.fMeasure(2.0) ~== f1measure2 absTol delta) |
79 | | - assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 absTol delta) |
80 | | - assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 absTol delta) |
81 | | - assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 absTol delta) |
| 61 | + assert(metrics.confusionMatrix.asML ~== confusionMatrix relTol delta) |
| 62 | + assert(metrics.truePositiveRate(0.0) ~== tpRate0 relTol delta) |
| 63 | + assert(metrics.truePositiveRate(1.0) ~== tpRate1 relTol delta) |
| 64 | + assert(metrics.truePositiveRate(2.0) ~== tpRate2 relTol delta) |
| 65 | + assert(metrics.falsePositiveRate(0.0) ~== fpRate0 relTol delta) |
| 66 | + assert(metrics.falsePositiveRate(1.0) ~== fpRate1 relTol delta) |
| 67 | + assert(metrics.falsePositiveRate(2.0) ~== fpRate2 relTol delta) |
| 68 | + assert(metrics.precision(0.0) ~== precision0 relTol delta) |
| 69 | + assert(metrics.precision(1.0) ~== precision1 relTol delta) |
| 70 | + assert(metrics.precision(2.0) ~== precision2 relTol delta) |
| 71 | + assert(metrics.recall(0.0) ~== recall0 relTol delta) |
| 72 | + assert(metrics.recall(1.0) ~== recall1 relTol delta) |
| 73 | + assert(metrics.recall(2.0) ~== recall2 relTol delta) |
| 74 | + assert(metrics.fMeasure(0.0) ~== f1measure0 relTol delta) |
| 75 | + assert(metrics.fMeasure(1.0) ~== f1measure1 relTol delta) |
| 76 | + assert(metrics.fMeasure(2.0) ~== f1measure2 relTol delta) |
| 77 | + assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 relTol delta) |
| 78 | + assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 relTol delta) |
| 79 | + assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 relTol delta) |
82 | 80 |
|
83 | 81 | assert(metrics.accuracy ~== |
84 | | - (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1)) absTol delta) |
85 | | - assert(metrics.accuracy ~== metrics.precision absTol delta) |
86 | | - assert(metrics.accuracy ~== metrics.recall absTol delta) |
87 | | - assert(metrics.accuracy ~== metrics.fMeasure absTol delta) |
88 | | - assert(metrics.accuracy ~== metrics.weightedRecall absTol delta) |
| 82 | + (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1)) relTol delta) |
| 83 | + assert(metrics.accuracy ~== metrics.precision relTol delta) |
| 84 | + assert(metrics.accuracy ~== metrics.recall relTol delta) |
| 85 | + assert(metrics.accuracy ~== metrics.fMeasure relTol delta) |
| 86 | + assert(metrics.accuracy ~== metrics.weightedRecall relTol delta) |
89 | 87 | val weight0 = 4.0 / 9 |
90 | 88 | val weight1 = 4.0 / 9 |
91 | 89 | val weight2 = 1.0 / 9 |
92 | 90 | assert(metrics.weightedTruePositiveRate ~== |
93 | | - (weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) absTol delta) |
| 91 | + (weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) relTol delta) |
94 | 92 | assert(metrics.weightedFalsePositiveRate ~== |
95 | | - (weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) absTol delta) |
| 93 | + (weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) relTol delta) |
96 | 94 | assert(metrics.weightedPrecision ~== |
97 | | - (weight0 * precision0 + weight1 * precision1 + weight2 * precision2) absTol delta) |
| 95 | + (weight0 * precision0 + weight1 * precision1 + weight2 * precision2) relTol delta) |
98 | 96 | assert(metrics.weightedRecall ~== |
99 | | - (weight0 * recall0 + weight1 * recall1 + weight2 * recall2) absTol delta) |
| 97 | + (weight0 * recall0 + weight1 * recall1 + weight2 * recall2) relTol delta) |
100 | 98 | assert(metrics.weightedFMeasure ~== |
101 | | - (weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) absTol delta) |
| 99 | + (weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) relTol delta) |
102 | 100 | assert(metrics.weightedFMeasure(2.0) ~== |
103 | | - (weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) absTol delta) |
| 101 | + (weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) relTol delta) |
104 | 102 | assert(metrics.labels === labels) |
105 | 103 | } |
106 | 104 |
|
@@ -141,47 +139,47 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { |
141 | 139 | val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1) |
142 | 140 | val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2) |
143 | 141 |
|
144 | | - assert(metrics.confusionMatrix.asML ~== confusionMatrix.asML relTol delta) |
145 | | - assert(metrics.truePositiveRate(0.0) ~== tpRate0 absTol delta) |
146 | | - assert(metrics.truePositiveRate(1.0) ~== tpRate1 absTol delta) |
147 | | - assert(metrics.truePositiveRate(2.0) ~== tpRate2 absTol delta) |
148 | | - assert(metrics.falsePositiveRate(0.0) ~== fpRate0 absTol delta) |
149 | | - assert(metrics.falsePositiveRate(1.0) ~== fpRate1 absTol delta) |
150 | | - assert(metrics.falsePositiveRate(2.0) ~== fpRate2 absTol delta) |
151 | | - assert(metrics.precision(0.0) ~== precision0 absTol delta) |
152 | | - assert(metrics.precision(1.0) ~== precision1 absTol delta) |
153 | | - assert(metrics.precision(2.0) ~== precision2 absTol delta) |
154 | | - assert(metrics.recall(0.0) ~== recall0 absTol delta) |
155 | | - assert(metrics.recall(1.0) ~== recall1 absTol delta) |
156 | | - assert(metrics.recall(2.0) ~== recall2 absTol delta) |
157 | | - assert(metrics.fMeasure(0.0) ~== f1measure0 absTol delta) |
158 | | - assert(metrics.fMeasure(1.0) ~== f1measure1 absTol delta) |
159 | | - assert(metrics.fMeasure(2.0) ~== f1measure2 absTol delta) |
160 | | - assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 absTol delta) |
161 | | - assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 absTol delta) |
162 | | - assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 absTol delta) |
| 142 | + assert(metrics.confusionMatrix.asML ~== confusionMatrix relTol delta) |
| 143 | + assert(metrics.truePositiveRate(0.0) ~== tpRate0 relTol delta) |
| 144 | + assert(metrics.truePositiveRate(1.0) ~== tpRate1 relTol delta) |
| 145 | + assert(metrics.truePositiveRate(2.0) ~== tpRate2 relTol delta) |
| 146 | + assert(metrics.falsePositiveRate(0.0) ~== fpRate0 relTol delta) |
| 147 | + assert(metrics.falsePositiveRate(1.0) ~== fpRate1 relTol delta) |
| 148 | + assert(metrics.falsePositiveRate(2.0) ~== fpRate2 relTol delta) |
| 149 | + assert(metrics.precision(0.0) ~== precision0 relTol delta) |
| 150 | + assert(metrics.precision(1.0) ~== precision1 relTol delta) |
| 151 | + assert(metrics.precision(2.0) ~== precision2 relTol delta) |
| 152 | + assert(metrics.recall(0.0) ~== recall0 relTol delta) |
| 153 | + assert(metrics.recall(1.0) ~== recall1 relTol delta) |
| 154 | + assert(metrics.recall(2.0) ~== recall2 relTol delta) |
| 155 | + assert(metrics.fMeasure(0.0) ~== f1measure0 relTol delta) |
| 156 | + assert(metrics.fMeasure(1.0) ~== f1measure1 relTol delta) |
| 157 | + assert(metrics.fMeasure(2.0) ~== f1measure2 relTol delta) |
| 158 | + assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 relTol delta) |
| 159 | + assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 relTol delta) |
| 160 | + assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 relTol delta) |
163 | 161 |
|
164 | 162 | assert(metrics.accuracy ~== |
165 | | - (2.0 * w1 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2) / tw absTol delta) |
166 | | - assert(metrics.accuracy ~== metrics.precision absTol delta) |
167 | | - assert(metrics.accuracy ~== metrics.recall absTol delta) |
168 | | - assert(metrics.accuracy ~== metrics.fMeasure absTol delta) |
169 | | - assert(metrics.accuracy ~== metrics.weightedRecall absTol delta) |
| 163 | + (2.0 * w1 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2) / tw relTol delta) |
| 164 | + assert(metrics.accuracy ~== metrics.precision relTol delta) |
| 165 | + assert(metrics.accuracy ~== metrics.recall relTol delta) |
| 166 | + assert(metrics.accuracy ~== metrics.fMeasure relTol delta) |
| 167 | + assert(metrics.accuracy ~== metrics.weightedRecall relTol delta) |
170 | 168 | val weight0 = (2 * w1 + 1 * w2 + 1 * w1) / tw |
171 | 169 | val weight1 = (1 * w2 + 2 * w1 + 1 * w2) / tw |
172 | 170 | val weight2 = 1 * w2 / tw |
173 | 171 | assert(metrics.weightedTruePositiveRate ~== |
174 | | - (weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) absTol delta) |
| 172 | + (weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) relTol delta) |
175 | 173 | assert(metrics.weightedFalsePositiveRate ~== |
176 | | - (weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) absTol delta) |
| 174 | + (weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) relTol delta) |
177 | 175 | assert(metrics.weightedPrecision ~== |
178 | | - (weight0 * precision0 + weight1 * precision1 + weight2 * precision2) absTol delta) |
| 176 | + (weight0 * precision0 + weight1 * precision1 + weight2 * precision2) relTol delta) |
179 | 177 | assert(metrics.weightedRecall ~== |
180 | | - (weight0 * recall0 + weight1 * recall1 + weight2 * recall2) absTol delta) |
| 178 | + (weight0 * recall0 + weight1 * recall1 + weight2 * recall2) relTol delta) |
181 | 179 | assert(metrics.weightedFMeasure ~== |
182 | | - (weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) absTol delta) |
| 180 | + (weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) relTol delta) |
183 | 181 | assert(metrics.weightedFMeasure(2.0) ~== |
184 | | - (weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) absTol delta) |
| 182 | + (weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) relTol delta) |
185 | 183 | assert(metrics.labels === labels) |
186 | 184 | } |
187 | 185 | } |
0 commit comments