|
17 | 17 |
|
18 | 18 | package org.apache.spark.mllib.evaluation |
19 | 19 |
|
| 20 | +import org.apache.spark.annotation.Experimental |
20 | 21 | import org.apache.spark.rdd.RDD |
21 | 22 | import org.apache.spark.Logging |
22 | 23 | import org.apache.spark.SparkContext._ |
23 | 24 |
|
24 | 25 | /** |
25 | 26 | * Evaluator for multiclass classification. |
26 | | - * NB: type Double both for prediction and label is retained |
27 | | - * for compatibility with model.predict that returns Double |
28 | | - * and MLUtils.loadLibSVMFile that loads class labels as Double |
29 | 27 | * |
30 | 28 | * @param predictionsAndLabels an RDD of (prediction, label) pairs. |
31 | 29 | */ |
| 30 | +@Experimental |
32 | 31 | class MulticlassMetrics(predictionsAndLabels: RDD[(Double, Double)]) extends Logging { |
33 | 32 |
|
34 | | - /* class = category; label = instance of class; prediction = instance of class */ |
35 | | - |
36 | 33 | private lazy val labelCountByClass = predictionsAndLabels.values.countByValue() |
37 | | - private lazy val labelCount = labelCountByClass.foldLeft(0L){case(sum, (_, count)) => sum + count} |
38 | | - private lazy val tpByClass = predictionsAndLabels.map{ case (prediction, label) => |
39 | | - (label, if(label == prediction) 1 else 0) }.reduceByKey{_ + _}.collectAsMap |
40 | | - private lazy val fpByClass = predictionsAndLabels.map{ case (prediction, label) => |
41 | | - (prediction, if(prediction != label) 1 else 0) }.reduceByKey{_ + _}.collectAsMap |
| 34 | + private lazy val labelCount = labelCountByClass.values.sum |
| 35 | + private lazy val tpByClass = predictionsAndLabels |
| 36 | + .map{ case (prediction, label) => |
| 37 | + (label, if (label == prediction) 1 else 0) |
| 38 | + }.reduceByKey(_ + _) |
| 39 | + .collectAsMap() |
| 40 | + private lazy val fpByClass = predictionsAndLabels |
| 41 | + .map{ case (prediction, label) => |
| 42 | + (prediction, if (prediction != label) 1 else 0) |
| 43 | + }.reduceByKey(_ + _) |
| 44 | + .collectAsMap() |
42 | 45 |
|
43 | 46 | /** |
44 | | - * Returns Precision for a given label (category) |
| 47 | + * Returns precision for a given label (category) |
45 | 48 | * @param label the label. |
46 | | - * @return Precision. |
47 | 49 | */ |
48 | | - def precision(label: Double): Double = if(tpByClass(label) + fpByClass.getOrElse(label, 0) == 0) 0 |
49 | | - else tpByClass(label).toDouble / (tpByClass(label) + fpByClass.getOrElse(label, 0)).toDouble |
| 50 | + def precision(label: Double): Double = { |
| 51 | + val tp = tpByClass(label) |
| 52 | + val fp = fpByClass.getOrElse(label, 0) |
| 53 | + if (tp + fp == 0) 0 else tp.toDouble / (tp + fp) |
| 54 | + } |
50 | 55 |
|
51 | 56 | /** |
52 | | - * Returns Recall for a given label (category) |
| 57 | + * Returns recall for a given label (category) |
53 | 58 | * @param label the label. |
54 | | - * @return Recall. |
55 | 59 | */ |
56 | | - def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label).toDouble |
| 60 | + def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label) |
57 | 61 |
|
58 | 62 | /** |
59 | | - * Returns F1-measure for a given label (category) |
| 63 | + * Returns f-measure for a given label (category) |
60 | 64 | * @param label the label. |
61 | | - * @return F1-measure. |
62 | 65 | */ |
63 | | - def f1Measure(label: Double): Double ={ |
| 66 | + def fMeasure(label: Double, beta:Double = 1.0): Double = { |
64 | 67 | val p = precision(label) |
65 | 68 | val r = recall(label) |
66 | | - if((p + r) == 0) 0 else 2 * p * r / (p + r) |
| 69 | + val betaSqrd = beta * beta |
| 70 | + if (p + r == 0) 0 else (1 + betaSqrd) * p * r / (betaSqrd * p + r) |
67 | 71 | } |
68 | 72 |
|
69 | 73 | /** |
70 | | - * Returns micro-averaged Recall |
| 74 | + * Returns micro-averaged recall |
71 | 75 | * (equals to microPrecision and microF1measure for multiclass classifier) |
72 | | - * @return microRecall. |
73 | 76 | */ |
74 | | - lazy val microRecall: Double = |
75 | | - tpByClass.foldLeft(0L){case (sum,(_, tp)) => sum + tp}.toDouble / labelCount |
| 77 | + lazy val recall: Double = |
| 78 | + tpByClass.values.sum.toDouble / labelCount |
76 | 79 |
|
77 | 80 | /** |
78 | | - * Returns micro-averaged Precision |
| 81 | + * Returns micro-averaged precision |
79 | 82 | * (equals to microPrecision and microF1measure for multiclass classifier) |
80 | | - * @return microPrecision. |
81 | 83 | */ |
82 | | - lazy val microPrecision: Double = microRecall |
| 84 | + lazy val precision: Double = recall |
83 | 85 |
|
84 | 86 | /** |
85 | | - * Returns micro-averaged F1-measure |
| 87 | + * Returns micro-averaged f-measure |
86 | 88 | * (equals to microPrecision and microRecall for multiclass classifier) |
87 | | - * @return microF1measure. |
88 | 89 | */ |
89 | | - lazy val microF1Measure: Double = microRecall |
| 90 | + lazy val fMeasure: Double = recall |
90 | 91 |
|
91 | 92 | /** |
92 | | - * Returns weighted averaged Recall |
93 | | - * @return weightedRecall. |
| 93 | + * Returns weighted averaged recall |
| 94 | + * (equals to micro-averaged precision, recall and f-measure) |
94 | 95 | */ |
95 | | - lazy val weightedRecall: Double = labelCountByClass.foldLeft(0.0){case(wRecall, (category, count)) => |
96 | | - wRecall + recall(category) * count.toDouble / labelCount} |
| 96 | + lazy val weightedRecall: Double = labelCountByClass.map { case (category, count) => |
| 97 | + recall(category) * count.toDouble / labelCount |
| 98 | + }.sum |
97 | 99 |
|
98 | 100 | /** |
99 | | - * Returns weighted averaged Precision |
100 | | - * @return weightedPrecision. |
| 101 | + * Returns weighted averaged precision |
101 | 102 | */ |
102 | | - lazy val weightedPrecision: Double = |
103 | | - labelCountByClass.foldLeft(0.0){case(wPrecision, (category, count)) => |
104 | | - wPrecision + precision(category) * count.toDouble / labelCount} |
| 103 | + lazy val weightedPrecision: Double = labelCountByClass.map { case (category, count) => |
| 104 | + precision(category) * count.toDouble / labelCount |
| 105 | + }.sum |
105 | 106 |
|
106 | 107 | /** |
107 | | - * Returns weighted averaged F1-measure |
108 | | - * @return weightedF1Measure. |
| 108 | + * Returns weighted averaged f1-measure |
109 | 109 | */ |
110 | | - lazy val weightedF1Measure: Double = |
111 | | - labelCountByClass.foldLeft(0.0){case(wF1measure, (category, count)) => |
112 | | - wF1measure + f1Measure(category) * count.toDouble / labelCount} |
| 110 | + lazy val weightedF1Measure: Double = labelCountByClass.map { case (category, count) => |
| 111 | + fMeasure(category) * count.toDouble / labelCount |
| 112 | + }.sum |
113 | 113 |
|
114 | 114 | /** |
115 | | - * Returns map with Precisions for individual classes |
116 | | - * @return precisionPerClass. |
| 115 | + * Returns the sequence of labels in ascending order |
117 | 116 | */ |
118 | | - lazy val precisionPerClass = |
119 | | - labelCountByClass.map{case (category, _) => (category, precision(category))}.toMap |
| 117 | + lazy val labels = tpByClass.unzip._1.toSeq.sorted |
120 | 118 |
|
121 | | - /** |
122 | | - * Returns map with Recalls for individual classes |
123 | | - * @return recallPerClass. |
124 | | - */ |
125 | | - lazy val recallPerClass = |
126 | | - labelCountByClass.map{case (category, _) => (category, recall(category))}.toMap |
127 | | - |
128 | | - /** |
129 | | - * Returns map with F1-measures for individual classes |
130 | | - * @return f1MeasurePerClass. |
131 | | - */ |
132 | | - lazy val f1MeasurePerClass = |
133 | | - labelCountByClass.map{case (category, _) => (category, f1Measure(category))}.toMap |
134 | 119 | } |
0 commit comments