@@ -23,18 +23,21 @@ import org.apache.spark.SparkContext._
2323
2424/**
2525 * Evaluator for multiclass classification.
26+ * NB: type Double both for prediction and label is retained
27+ * for compatibility with model.predict that returns Double
28+ * and MLUtils.loadLibSVMFile that loads class labels as Double
2629 *
27- * @param scoreAndLabels an RDD of (score , label) pairs.
30+ * @param predictionsAndLabels an RDD of (prediction , label) pairs.
2831 */
29- class MulticlassMetrics (scoreAndLabels : RDD [(Double , Double )]) extends Logging {
32+ class MulticlassMetrics (predictionsAndLabels : RDD [(Double , Double )]) extends Logging {
3033
3134 /* class = category; label = instance of class; prediction = instance of class */
3235
33- private lazy val labelCountByClass = scoreAndLabels .values.countByValue()
36+ private lazy val labelCountByClass = predictionsAndLabels .values.countByValue()
3437 private lazy val labelCount = labelCountByClass.foldLeft(0L ){case (sum, (_, count)) => sum + count}
35- private lazy val tpByClass = scoreAndLabels .map{ case (prediction, label) =>
38+ private lazy val tpByClass = predictionsAndLabels .map{ case (prediction, label) =>
3639 (label, if (label == prediction) 1 else 0 ) }.reduceByKey{_ + _}.collectAsMap
37- private lazy val fpByClass = scoreAndLabels .map{ case (prediction, label) =>
40+ private lazy val fpByClass = predictionsAndLabels .map{ case (prediction, label) =>
3841 (prediction, if (prediction != label) 1 else 0 ) }.reduceByKey{_ + _}.collectAsMap
3942
4043 /**
0 commit comments