@@ -36,15 +36,21 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
3636 (1.0 , 1.0 ), (1.0 , 1.0 ), (2.0 , 2.0 ), (2.0 , 0.0 )), 2 )
3737 val metrics = new MulticlassMetrics (scoreAndLabels)
3838 val delta = 0.0000001
39- val precision0 = 2.0 / (2.0 + 1.0 )
40- val precision1 = 3.0 / (3.0 + 1.0 )
41- val precision2 = 1.0 / (1.0 + 1.0 )
42- val recall0 = 2.0 / (2.0 + 2.0 )
43- val recall1 = 3.0 / (3.0 + 1.0 )
44- val recall2 = 1.0 / (1.0 + 0.0 )
39+ val fpRate0 = 1.0 / (9 - 4 )
40+ val fpRate1 = 1.0 / (9 - 4 )
41+ val fpRate2 = 1.0 / (9 - 1 )
42+ val precision0 = 2.0 / (2 + 1 )
43+ val precision1 = 3.0 / (3 + 1 )
44+ val precision2 = 1.0 / (1 + 1 )
45+ val recall0 = 2.0 / (2 + 2 )
46+ val recall1 = 3.0 / (3 + 1 )
47+ val recall2 = 1.0 / (1 + 0 )
4548 val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
4649 val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
4750 val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
51+ assert(math.abs(metrics.falsePositiveRate(0.0 ) - fpRate0) < delta)
52+ assert(math.abs(metrics.falsePositiveRate(1.0 ) - fpRate1) < delta)
53+ assert(math.abs(metrics.falsePositiveRate(2.0 ) - fpRate2) < delta)
4854 assert(math.abs(metrics.precision(0.0 ) - precision0) < delta)
4955 assert(math.abs(metrics.precision(1.0 ) - precision1) < delta)
5056 assert(math.abs(metrics.precision(2.0 ) - precision2) < delta)
@@ -55,16 +61,16 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
5561 assert(math.abs(metrics.fMeasure(1.0 ) - f1measure1) < delta)
5662 assert(math.abs(metrics.fMeasure(2.0 ) - f1measure2) < delta)
5763 assert(math.abs(metrics.recall -
58- (2.0 + 3.0 + 1.0 ) / ((2.0 + 3.0 + 1.0 ) + (1.0 + 1.0 + 1.0 ))) < delta)
64+ (2.0 + 3.0 + 1.0 ) / ((2 + 3 + 1 ) + (1 + 1 + 1 ))) < delta)
5965 assert(math.abs(metrics.recall - metrics.precision) < delta)
6066 assert(math.abs(metrics.recall - metrics.fMeasure) < delta)
6167 assert(math.abs(metrics.recall - metrics.weightedRecall) < delta)
6268 assert(math.abs(metrics.weightedPrecision -
63- ((4.0 / 9.0 ) * precision0 + (4.0 / 9.0 ) * precision1 + (1.0 / 9.0 ) * precision2)) < delta)
69+ ((4.0 / 9 ) * precision0 + (4.0 / 9 ) * precision1 + (1.0 / 9 ) * precision2)) < delta)
6470 assert(math.abs(metrics.weightedRecall -
65- ((4.0 / 9.0 ) * recall0 + (4.0 / 9.0 ) * recall1 + (1.0 / 9.0 ) * recall2)) < delta)
66- assert(math.abs(metrics.weightedF1Measure -
67- ((4.0 / 9.0 ) * f1measure0 + (4.0 / 9.0 ) * f1measure1 + (1.0 / 9.0 ) * f1measure2)) < delta)
71+ ((4.0 / 9 ) * recall0 + (4.0 / 9 ) * recall1 + (1.0 / 9 ) * recall2)) < delta)
72+ assert(math.abs(metrics.weightedFMeasure -
73+ ((4.0 / 9 ) * f1measure0 + (4.0 / 9 ) * f1measure1 + (1.0 / 9 ) * f1measure2)) < delta)
6874 assert(metrics.labels.sameElements(labels))
6975 }
7076}
0 commit comments