From b5c9765fbaaeead6c6c121c49399b327174dfae9 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Mon, 27 Feb 2017 13:28:08 -0500 Subject: [PATCH 01/10] Added weight column for multiclass classification evaluator --- .../MulticlassClassificationEvaluator.scala | 22 +++-- .../mllib/evaluation/MulticlassMetrics.scala | 51 +++++++---- .../evaluation/MulticlassMetricsSuite.scala | 91 +++++++++++++++++++ 3 files changed, 139 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index 794b1e7d9d88..b31962936269 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} -import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.sql.{Dataset, Row} @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.DoubleType @Since("1.5.0") @Experimental class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { + extends Evaluator with HasPredictionCol with HasLabelCol + with HasWeightCol with DefaultParamsWritable { @Since("1.5.0") def this() = this(Identifiable.randomUID("mcEval")) @@ -67,6 +68,10 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid @Since("1.5.0") def setLabelCol(value: String): this.type = set(labelCol, value) + /** @group setParam */ + @Since("2.2.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(metricName -> "f1") @Since("2.0.0") @@ -75,11 +80,16 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) SchemaUtils.checkNumericType(schema, $(labelCol)) - val predictionAndLabels = - dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map { - case Row(prediction: Double, label: Double) => (prediction, label) + val predictionAndLabelsWithWeights = + dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType), + if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))) + .rdd.map { + case Row(prediction: Double, label: Double, weight: Double) => (prediction, label, weight) } - val metrics = new MulticlassMetrics(predictionAndLabels) + dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map { + case Row(prediction: Double, label: Double) => (prediction, label) + }.values.countByValue() + val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights) val metric = $(metricName) match { case "f1" => metrics.weightedFMeasure case "weightedPrecision" => metrics.weightedPrecision diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 980e0c92531a..12987ff3de1c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -27,10 +27,11 @@ import org.apache.spark.sql.DataFrame /** * Evaluator for multiclass classification. * - * @param predictionAndLabels an RDD of (prediction, label) pairs. + * @param predAndLabelsWithOptWeight an RDD of (prediction, label, weight) or + * (prediction, label) pairs. */ @Since("1.1.0") -class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Double)]) { +class MulticlassMetrics @Since("1.1.0") (predAndLabelsWithOptWeight: RDD[_]) { /** * An auxiliary constructor taking a DataFrame. @@ -39,21 +40,33 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl private[mllib] def this(predictionAndLabels: DataFrame) = this(predictionAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) - private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue() - private lazy val labelCount: Long = labelCountByClass.values.sum - private lazy val tpByClass: Map[Double, Int] = predictionAndLabels - .map { case (prediction, label) => - (label, if (label == prediction) 1 else 0) + private lazy val labelCountByClass: Map[Double, Double] = + predAndLabelsWithOptWeight.map { + case (prediction: Double, label: Double) => + (label, 1.0) + case (prediction: Double, label: Double, weight: Double) => + (label, weight) + }.mapValues(weight => weight).reduceByKey(_ + _).collect().toMap + private lazy val labelCount: Double = labelCountByClass.values.sum + private lazy val tpByClass: Map[Double, Double] = predAndLabelsWithOptWeight + .map { case (prediction: Double, label: Double) => + (label, if (label == prediction) 1.0 else 0.0) + case (prediction: Double, label: Double, weight: Double) => + (label, if (label == prediction) weight else 0.0) }.reduceByKey(_ + _) .collectAsMap() - private lazy val fpByClass: Map[Double, Int] = predictionAndLabels - .map { case (prediction, label) => - (prediction, if (prediction != label) 1 else 0) + private lazy val fpByClass: Map[Double, Double] = predAndLabelsWithOptWeight + .map { case (prediction: Double, label: Double) => + (prediction, if (prediction != label) 1.0 else 0.0) + case (prediction: Double, label: Double, weight: Double) => + (prediction, if (prediction != label) weight else 0.0) }.reduceByKey(_ + _) .collectAsMap() - private lazy val confusions = predictionAndLabels - .map { case (prediction, label) => - ((label, prediction), 1) + private lazy val confusions = predAndLabelsWithOptWeight + .map { case (prediction: Double, label: Double) => + ((label, prediction), 1.0) + case (prediction: Double, label: Double, weight: Double) => + ((label, prediction), weight) }.reduceByKey(_ + _) .collectAsMap() @@ -71,7 +84,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl while (i < n) { var j = 0 while (j < n) { - values(i + j * n) = confusions.getOrElse((labels(i), labels(j)), 0).toDouble + values(i + j * n) = confusions.getOrElse((labels(i), labels(j)), 0.0) j += 1 } i += 1 @@ -92,8 +105,8 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl */ @Since("1.1.0") def falsePositiveRate(label: Double): Double = { - val fp = fpByClass.getOrElse(label, 0) - fp.toDouble / (labelCount - labelCountByClass(label)) + val fp = fpByClass.getOrElse(label, 0.0) + fp / (labelCount - labelCountByClass(label)) } /** @@ -103,7 +116,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl @Since("1.1.0") def precision(label: Double): Double = { val tp = tpByClass(label) - val fp = fpByClass.getOrElse(label, 0) + val fp = fpByClass.getOrElse(label, 0.0) if (tp + fp == 0) 0 else tp.toDouble / (tp + fp) } @@ -112,7 +125,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl * @param label the label. */ @Since("1.1.0") - def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label) + def recall(label: Double): Double = tpByClass(label) / labelCountByClass(label) /** * Returns f-measure for a given label (category) @@ -140,7 +153,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl * out of the total number of instances.) */ @Since("2.0.0") - lazy val accuracy: Double = tpByClass.values.sum.toDouble / labelCount + lazy val accuracy: Double = tpByClass.values.sum / labelCount /** * Returns weighted true positive rate diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index 5394baab94bc..ca0373436ea4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -92,4 +92,95 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { ((4.0 / 9) * f2measure0 + (4.0 / 9) * f2measure1 + (1.0 / 9) * f2measure2)) < delta) assert(metrics.labels.sameElements(labels)) } + + test("Multiclass evaluation metrics with weights") { + /* + * Confusion matrix for 3-class classification with total 9 instances with 2 weights: + * |2 * w1|1 * w2 |1 * w1| true class0 (4 instances) + * |1 * w2|2 * w1 + 1 * w2|0 | true class1 (4 instances) + * |0 |0 |1 * w2| true class2 (1 instance) + */ + val w1 = 2.2 + val w2 = 1.5 + val tw = 2.0 * w1 + 1.0 * w2 + 1.0 * w1 + 1.0 * w2 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2 + val confusionMatrix = Matrices.dense(3, 3, + Array(2 * w1, 1 * w2, 0, 1 * w2, 2 * w1 + 1 * w2, 0, 1 * w1, 0, 1 * w2)) + val labels = Array(0.0, 1.0, 2.0) + val predictionAndLabelsWithWeights = sc.parallelize( + Seq((0.0, 0.0, w1), (0.0, 1.0, w2), (0.0, 0.0, w1), (1.0, 0.0, w2), + (1.0, 1.0, w1), (1.0, 1.0, w2), (1.0, 1.0, w1), (2.0, 2.0, w2), + (2.0, 0.0, w1)), 2) + val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights) + val delta = 0.0000001 + val tpRate0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1) + val tpRate1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2) + val tpRate2 = (1.0 * w2) / (1.0 * w2 + 0) + val fpRate0 = (1.0 * w2) / (tw - (2.0 * w1 + 1.0 * w2 + 1.0 * w1)) + val fpRate1 = (1.0 * w2) / (tw - (1.0 * w2 + 2.0 * w1 + 1.0 * w2)) + val fpRate2 = (1.0 * w1) / (tw - (1.0 * w2)) + val precision0 = (2.0 * w1) / (2 * w1 + 1 * w2) + val precision1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2) + val precision2 = (1.0 * w2) / (1 * w1 + 1 * w2) + val recall0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1) + val recall1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2) + val recall2 = (1.0 * w2) / (1.0 * w2 + 0) + val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0) + val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1) + val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2) + val f2measure0 = (1 + 2 * 2) * precision0 * recall0 / (2 * 2 * precision0 + recall0) + val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1) + val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2) + + assert(metrics.confusionMatrix.toArray.sameElements(confusionMatrix.toArray)) + assert(math.abs(metrics.truePositiveRate(0.0) - tpRate0) < delta) + assert(math.abs(metrics.truePositiveRate(1.0) - tpRate1) < delta) + assert(math.abs(metrics.truePositiveRate(2.0) - tpRate2) < delta) + assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta) + assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta) + assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta) + assert(math.abs(metrics.precision(0.0) - precision0) < delta) + assert(math.abs(metrics.precision(1.0) - precision1) < delta) + assert(math.abs(metrics.precision(2.0) - precision2) < delta) + assert(math.abs(metrics.recall(0.0) - recall0) < delta) + assert(math.abs(metrics.recall(1.0) - recall1) < delta) + assert(math.abs(metrics.recall(2.0) - recall2) < delta) + assert(math.abs(metrics.fMeasure(0.0) - f1measure0) < delta) + assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta) + assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta) + assert(math.abs(metrics.fMeasure(0.0, 2.0) - f2measure0) < delta) + assert(math.abs(metrics.fMeasure(1.0, 2.0) - f2measure1) < delta) + assert(math.abs(metrics.fMeasure(2.0, 2.0) - f2measure2) < delta) + + assert(math.abs(metrics.accuracy - + (2.0 * w1 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2) / tw) < delta) + assert(math.abs(metrics.accuracy - metrics.precision) < delta) + assert(math.abs(metrics.accuracy - metrics.recall) < delta) + assert(math.abs(metrics.accuracy - metrics.fMeasure) < delta) + assert(math.abs(metrics.accuracy - metrics.weightedRecall) < delta) + assert(math.abs(metrics.weightedTruePositiveRate - + (((2 * w1 + 1 * w2 + 1 * w1) / tw) * tpRate0 + + ((1 * w2 + 2 * w1 + 1 * w2) / tw) * tpRate1 + + (1 * w2 / tw) * tpRate2)) < delta) + assert(math.abs(metrics.weightedFalsePositiveRate - + (((2 * w1 + 1 * w2 + 1 * w1) / tw) * fpRate0 + + ((1 * w2 + 2 * w1 + 1 * w2) / tw) * fpRate1 + + (1 * w2 / tw) * fpRate2)) < delta) + assert(math.abs(metrics.weightedPrecision - + (((2 * w1 + 1 * w2 + 1 * w1) / tw) * precision0 + + ((1 * w2 + 2 * w1 + 1 * w2) / tw) * precision1 + + (1 * w2 / tw) * precision2)) < delta) + assert(math.abs(metrics.weightedRecall - + (((2 * w1 + 1 * w2 + 1 * w1) / tw) * recall0 + + ((1 * w2 + 2 * w1 + 1 * w2) / tw) * recall1 + + (1 * w2 / tw) * recall2)) < delta) + assert(math.abs(metrics.weightedFMeasure - + (((2 * w1 + 1 * w2 + 1 * w1) / tw) * f1measure0 + + ((1 * w2 + 2 * w1 + 1 * w2) / tw) * f1measure1 + + (1 * w2 / tw) * f1measure2)) < delta) + assert(math.abs(metrics.weightedFMeasure(2.0) - + (((2 * w1 + 1 * w2 + 1 * w1) / tw) * f2measure0 + + ((1 * w2 + 2 * w1 + 1 * w2) / tw) * f2measure1 + + (1 * w2 / tw) * f2measure2)) < delta) + assert(metrics.labels.sameElements(labels)) + } } From ef9440fa55b7738a62f872c553926fa1f2432a3c Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Tue, 17 Apr 2018 23:52:29 -0400 Subject: [PATCH 02/10] Updating based on comments, fixed since tag, removed useless code, fixed constructor --- .../MulticlassClassificationEvaluator.scala | 5 +-- .../mllib/evaluation/MulticlassMetrics.scala | 39 +++++++++---------- 2 files changed, 19 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index b31962936269..f678220bceaf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -69,7 +69,7 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid def setLabelCol(value: String): this.type = set(labelCol, value) /** @group setParam */ - @Since("2.2.0") + @Since("2.4.0") def setWeightCol(value: String): this.type = set(weightCol, value) setDefault(metricName -> "f1") @@ -86,9 +86,6 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid .rdd.map { case Row(prediction: Double, label: Double, weight: Double) => (prediction, label, weight) } - dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map { - case Row(prediction: Double, label: Double) => (prediction, label) - }.values.countByValue() val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights) val metric = $(metricName) match { case "f1" => metrics.weightedFMeasure diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 12987ff3de1c..24a78ca66ccb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -27,11 +27,13 @@ import org.apache.spark.sql.DataFrame /** * Evaluator for multiclass classification. * - * @param predAndLabelsWithOptWeight an RDD of (prediction, label, weight) or - * (prediction, label) pairs. + * @param predLabelsWeight an RDD of (prediction, label, weight). */ @Since("1.1.0") -class MulticlassMetrics @Since("1.1.0") (predAndLabelsWithOptWeight: RDD[_]) { +class MulticlassMetrics @Since("2.4.0") (predLabelsWeight: RDD[(Double, Double, Double)]) { + @Since("1.1.0") + def this(predAndLabels: RDD[(Double, Double)]) = + this(predAndLabels.map(r => (r._1, r._2, 1.0))) /** * An auxiliary constructor taking a DataFrame. @@ -41,32 +43,27 @@ class MulticlassMetrics @Since("1.1.0") (predAndLabelsWithOptWeight: RDD[_]) { this(predictionAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) private lazy val labelCountByClass: Map[Double, Double] = - predAndLabelsWithOptWeight.map { - case (prediction: Double, label: Double) => - (label, 1.0) + predLabelsWeight.map { case (prediction: Double, label: Double, weight: Double) => (label, weight) }.mapValues(weight => weight).reduceByKey(_ + _).collect().toMap private lazy val labelCount: Double = labelCountByClass.values.sum - private lazy val tpByClass: Map[Double, Double] = predAndLabelsWithOptWeight - .map { case (prediction: Double, label: Double) => - (label, if (label == prediction) 1.0 else 0.0) - case (prediction: Double, label: Double, weight: Double) => - (label, if (label == prediction) weight else 0.0) + private lazy val tpByClass: Map[Double, Double] = predLabelsWeight + .map { + case (prediction: Double, label: Double, weight: Double) => + (label, if (label == prediction) weight else 0.0) }.reduceByKey(_ + _) .collectAsMap() - private lazy val fpByClass: Map[Double, Double] = predAndLabelsWithOptWeight - .map { case (prediction: Double, label: Double) => - (prediction, if (prediction != label) 1.0 else 0.0) - case (prediction: Double, label: Double, weight: Double) => - (prediction, if (prediction != label) weight else 0.0) + private lazy val fpByClass: Map[Double, Double] = predLabelsWeight + .map { + case (prediction: Double, label: Double, weight: Double) => + (prediction, if (prediction != label) weight else 0.0) }.reduceByKey(_ + _) .collectAsMap() - private lazy val confusions = predAndLabelsWithOptWeight - .map { case (prediction: Double, label: Double) => - ((label, prediction), 1.0) - case (prediction: Double, label: Double, weight: Double) => - ((label, prediction), weight) + private lazy val confusions = predLabelsWeight + .map { + case (prediction: Double, label: Double, weight: Double) => + ((label, prediction), weight) }.reduceByKey(_ + _) .collectAsMap() From f181eb0cc8b848ede8b929ae8b4c7f1025142dd4 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Thu, 19 Apr 2018 00:43:42 -0400 Subject: [PATCH 03/10] updated based on comment, fixed build failure --- .../spark/mllib/evaluation/MulticlassMetrics.scala | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 24a78ca66ccb..cbe80981f3de 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -27,13 +27,17 @@ import org.apache.spark.sql.DataFrame /** * Evaluator for multiclass classification. * - * @param predLabelsWeight an RDD of (prediction, label, weight). + * @param predAndLabelsWithOptWeight an RDD of (prediction, label, weight) or + * (prediction, label) pairs. */ @Since("1.1.0") -class MulticlassMetrics @Since("2.4.0") (predLabelsWeight: RDD[(Double, Double, Double)]) { - @Since("1.1.0") - def this(predAndLabels: RDD[(Double, Double)]) = - this(predAndLabels.map(r => (r._1, r._2, 1.0))) +class MulticlassMetrics @Since("2.4.0") (predAndLabelsWithOptWeight: RDD[_]) { + val predLabelsWeight: RDD[(Double, Double, Double)] = predAndLabelsWithOptWeight.map { + case (prediction: Double, label: Double, weight: Double) => + (prediction, label, weight) + case (prediction: Double, label: Double) => + (prediction, label, 1.0) + } /** * An auxiliary constructor taking a DataFrame. From fcf333ee1f2c2e64963be6371c38e3be2d8e355e Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Thu, 26 Apr 2018 11:25:44 -0400 Subject: [PATCH 04/10] updated based on comments --- .../mllib/evaluation/MulticlassMetrics.scala | 2 +- .../evaluation/MulticlassMetricsSuite.scala | 172 +++++++++--------- 2 files changed, 86 insertions(+), 88 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index cbe80981f3de..f8a0d7e4918d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -50,7 +50,7 @@ class MulticlassMetrics @Since("2.4.0") (predAndLabelsWithOptWeight: RDD[_]) { predLabelsWeight.map { case (prediction: Double, label: Double, weight: Double) => (label, weight) - }.mapValues(weight => weight).reduceByKey(_ + _).collect().toMap + }.reduceByKey(_ + _).collect().toMap private lazy val labelCount: Double = labelCountByClass.values.sum private lazy val tpByClass: Map[Double, Double] = predLabelsWeight .map { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index ca0373436ea4..dd71c75da48f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -18,10 +18,16 @@ package org.apache.spark.mllib.evaluation import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.linalg.Matrices import org.apache.spark.mllib.util.MLlibTestSparkContext class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { + + import testImplicits._ + + val delta = 1e-7 + test("Multiclass evaluation metrics") { /* * Confusion matrix for 3-class classification with total 9 instances: @@ -35,7 +41,6 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2) val metrics = new MulticlassMetrics(predictionAndLabels) - val delta = 0.0000001 val tpRate0 = 2.0 / (2 + 2) val tpRate1 = 3.0 / (3 + 1) val tpRate2 = 1.0 / (1 + 0) @@ -55,42 +60,45 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1) val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2) - assert(metrics.confusionMatrix.toArray.sameElements(confusionMatrix.toArray)) - assert(math.abs(metrics.truePositiveRate(0.0) - tpRate0) < delta) - assert(math.abs(metrics.truePositiveRate(1.0) - tpRate1) < delta) - assert(math.abs(metrics.truePositiveRate(2.0) - tpRate2) < delta) - assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta) - assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta) - assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta) - assert(math.abs(metrics.precision(0.0) - precision0) < delta) - assert(math.abs(metrics.precision(1.0) - precision1) < delta) - assert(math.abs(metrics.precision(2.0) - precision2) < delta) - assert(math.abs(metrics.recall(0.0) - recall0) < delta) - assert(math.abs(metrics.recall(1.0) - recall1) < delta) - assert(math.abs(metrics.recall(2.0) - recall2) < delta) - assert(math.abs(metrics.fMeasure(0.0) - f1measure0) < delta) - assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta) - assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta) - assert(math.abs(metrics.fMeasure(0.0, 2.0) - f2measure0) < delta) - assert(math.abs(metrics.fMeasure(1.0, 2.0) - f2measure1) < delta) - assert(math.abs(metrics.fMeasure(2.0, 2.0) - f2measure2) < delta) + assert(metrics.confusionMatrix.asML ~== confusionMatrix.asML relTol delta) + assert(metrics.truePositiveRate(0.0) ~== tpRate0 absTol delta) + assert(metrics.truePositiveRate(1.0) ~== tpRate1 absTol delta) + assert(metrics.truePositiveRate(2.0) ~== tpRate2 absTol delta) + assert(metrics.falsePositiveRate(0.0) ~== fpRate0 absTol delta) + assert(metrics.falsePositiveRate(1.0) ~== fpRate1 absTol delta) + assert(metrics.falsePositiveRate(2.0) ~== fpRate2 absTol delta) + assert(metrics.precision(0.0) ~== precision0 absTol delta) + assert(metrics.precision(1.0) ~== precision1 absTol delta) + assert(metrics.precision(2.0) ~== precision2 absTol delta) + assert(metrics.recall(0.0) ~== recall0 absTol delta) + assert(metrics.recall(1.0) ~== recall1 absTol delta) + assert(metrics.recall(2.0) ~== recall2 absTol delta) + assert(metrics.fMeasure(0.0) ~== f1measure0 absTol delta) + assert(metrics.fMeasure(1.0) ~== f1measure1 absTol delta) + assert(metrics.fMeasure(2.0) ~== f1measure2 absTol delta) + assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 absTol delta) + assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 absTol delta) + assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 absTol delta) - assert(math.abs(metrics.accuracy - - (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta) - assert(math.abs(metrics.accuracy - metrics.weightedRecall) < delta) - assert(math.abs(metrics.weightedTruePositiveRate - - ((4.0 / 9) * tpRate0 + (4.0 / 9) * tpRate1 + (1.0 / 9) * tpRate2)) < delta) - assert(math.abs(metrics.weightedFalsePositiveRate - - ((4.0 / 9) * fpRate0 + (4.0 / 9) * fpRate1 + (1.0 / 9) * fpRate2)) < delta) - assert(math.abs(metrics.weightedPrecision - - ((4.0 / 9) * precision0 + (4.0 / 9) * precision1 + (1.0 / 9) * precision2)) < delta) - assert(math.abs(metrics.weightedRecall - - ((4.0 / 9) * recall0 + (4.0 / 9) * recall1 + (1.0 / 9) * recall2)) < delta) - assert(math.abs(metrics.weightedFMeasure - - ((4.0 / 9) * f1measure0 + (4.0 / 9) * f1measure1 + (1.0 / 9) * f1measure2)) < delta) - assert(math.abs(metrics.weightedFMeasure(2.0) - - ((4.0 / 9) * f2measure0 + (4.0 / 9) * f2measure1 + (1.0 / 9) * f2measure2)) < delta) - assert(metrics.labels.sameElements(labels)) + assert(metrics.accuracy ~== + (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1)) absTol delta) + assert(metrics.accuracy ~== metrics.weightedRecall absTol delta) + val weight0 = 4.0 / 9 + val weight1 = 4.0 / 9 + val weight2 = 1.0 / 9 + assert(metrics.weightedTruePositiveRate ~== + (weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) absTol delta) + assert(metrics.weightedFalsePositiveRate ~== + (weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) absTol delta) + assert(metrics.weightedPrecision ~== + (weight0 * precision0 + weight1 * precision1 + weight2 * precision2) absTol delta) + assert(metrics.weightedRecall ~== + (weight0 * recall0 + weight1 * recall1 + weight2 * recall2) absTol delta) + assert(metrics.weightedFMeasure ~== + (weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) absTol delta) + assert(metrics.weightedFMeasure(2.0) ~== + (weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) absTol delta) + assert(metrics.labels === labels) } test("Multiclass evaluation metrics with weights") { @@ -111,7 +119,6 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { (1.0, 1.0, w1), (1.0, 1.0, w2), (1.0, 1.0, w1), (2.0, 2.0, w2), (2.0, 0.0, w1)), 2) val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights) - val delta = 0.0000001 val tpRate0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1) val tpRate1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2) val tpRate2 = (1.0 * w2) / (1.0 * w2 + 0) @@ -131,56 +138,47 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1) val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2) - assert(metrics.confusionMatrix.toArray.sameElements(confusionMatrix.toArray)) - assert(math.abs(metrics.truePositiveRate(0.0) - tpRate0) < delta) - assert(math.abs(metrics.truePositiveRate(1.0) - tpRate1) < delta) - assert(math.abs(metrics.truePositiveRate(2.0) - tpRate2) < delta) - assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta) - assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta) - assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta) - assert(math.abs(metrics.precision(0.0) - precision0) < delta) - assert(math.abs(metrics.precision(1.0) - precision1) < delta) - assert(math.abs(metrics.precision(2.0) - precision2) < delta) - assert(math.abs(metrics.recall(0.0) - recall0) < delta) - assert(math.abs(metrics.recall(1.0) - recall1) < delta) - assert(math.abs(metrics.recall(2.0) - recall2) < delta) - assert(math.abs(metrics.fMeasure(0.0) - f1measure0) < delta) - assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta) - assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta) - assert(math.abs(metrics.fMeasure(0.0, 2.0) - f2measure0) < delta) - assert(math.abs(metrics.fMeasure(1.0, 2.0) - f2measure1) < delta) - assert(math.abs(metrics.fMeasure(2.0, 2.0) - f2measure2) < delta) + assert(metrics.confusionMatrix.asML ~== confusionMatrix.asML relTol delta) + assert(metrics.truePositiveRate(0.0) ~== tpRate0 absTol delta) + assert(metrics.truePositiveRate(1.0) ~== tpRate1 absTol delta) + assert(metrics.truePositiveRate(2.0) ~== tpRate2 absTol delta) + assert(metrics.falsePositiveRate(0.0) ~== fpRate0 absTol delta) + assert(metrics.falsePositiveRate(1.0) ~== fpRate1 absTol delta) + assert(metrics.falsePositiveRate(2.0) ~== fpRate2 absTol delta) + assert(metrics.precision(0.0) ~== precision0 absTol delta) + assert(metrics.precision(1.0) ~== precision1 absTol delta) + assert(metrics.precision(2.0) ~== precision2 absTol delta) + assert(metrics.recall(0.0) ~== recall0 absTol delta) + assert(metrics.recall(1.0) ~== recall1 absTol delta) + assert(metrics.recall(2.0) ~== recall2 absTol delta) + assert(metrics.fMeasure(0.0) ~== f1measure0 absTol delta) + assert(metrics.fMeasure(1.0) ~== f1measure1 absTol delta) + assert(metrics.fMeasure(2.0) ~== f1measure2 absTol delta) + assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 absTol delta) + assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 absTol delta) + assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 absTol delta) - assert(math.abs(metrics.accuracy - - (2.0 * w1 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2) / tw) < delta) - assert(math.abs(metrics.accuracy - metrics.precision) < delta) - assert(math.abs(metrics.accuracy - metrics.recall) < delta) - assert(math.abs(metrics.accuracy - metrics.fMeasure) < delta) - assert(math.abs(metrics.accuracy - metrics.weightedRecall) < delta) - assert(math.abs(metrics.weightedTruePositiveRate - - (((2 * w1 + 1 * w2 + 1 * w1) / tw) * tpRate0 + - ((1 * w2 + 2 * w1 + 1 * w2) / tw) * tpRate1 + - (1 * w2 / tw) * tpRate2)) < delta) - assert(math.abs(metrics.weightedFalsePositiveRate - - (((2 * w1 + 1 * w2 + 1 * w1) / tw) * fpRate0 + - ((1 * w2 + 2 * w1 + 1 * w2) / tw) * fpRate1 + - (1 * w2 / tw) * fpRate2)) < delta) - assert(math.abs(metrics.weightedPrecision - - (((2 * w1 + 1 * w2 + 1 * w1) / tw) * precision0 + - ((1 * w2 + 2 * w1 + 1 * w2) / tw) * precision1 + - (1 * w2 / tw) * precision2)) < delta) - assert(math.abs(metrics.weightedRecall - - (((2 * w1 + 1 * w2 + 1 * w1) / tw) * recall0 + - ((1 * w2 + 2 * w1 + 1 * w2) / tw) * recall1 + - (1 * w2 / tw) * recall2)) < delta) - assert(math.abs(metrics.weightedFMeasure - - (((2 * w1 + 1 * w2 + 1 * w1) / tw) * f1measure0 + - ((1 * w2 + 2 * w1 + 1 * w2) / tw) * f1measure1 + - (1 * w2 / tw) * f1measure2)) < delta) - assert(math.abs(metrics.weightedFMeasure(2.0) - - (((2 * w1 + 1 * w2 + 1 * w1) / tw) * f2measure0 + - ((1 * w2 + 2 * w1 + 1 * w2) / tw) * f2measure1 + - (1 * w2 / tw) * f2measure2)) < delta) - assert(metrics.labels.sameElements(labels)) + assert(metrics.accuracy ~== + (2.0 * w1 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2) / tw absTol delta) + assert(metrics.accuracy ~== metrics.precision absTol delta) + assert(metrics.accuracy ~== metrics.recall absTol delta) + assert(metrics.accuracy ~== metrics.fMeasure absTol delta) + assert(metrics.accuracy ~== metrics.weightedRecall absTol delta) + val weight0 = (2 * w1 + 1 * w2 + 1 * w1) / tw + val weight1 = (1 * w2 + 2 * w1 + 1 * w2) / tw + val weight2 = 1 * w2 / tw + assert(metrics.weightedTruePositiveRate ~== + (weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) absTol delta) + assert(metrics.weightedFalsePositiveRate ~== + (weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) absTol delta) + assert(metrics.weightedPrecision ~== + (weight0 * precision0 + weight1 * precision1 + weight2 * precision2) absTol delta) + assert(metrics.weightedRecall ~== + (weight0 * recall0 + weight1 * recall1 + weight2 * recall2) absTol delta) + assert(metrics.weightedFMeasure ~== + (weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) absTol delta) + assert(metrics.weightedFMeasure(2.0) ~== + (weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) absTol delta) + assert(metrics.labels === labels) } } From a95e33f1eea9d25675990508ae9bcd390a8bdde4 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Tue, 1 May 2018 00:14:59 -0400 Subject: [PATCH 05/10] updated based on latest comments --- .../evaluation/MulticlassMetricsSuite.scala | 118 +++++++++--------- 1 file changed, 58 insertions(+), 60 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index dd71c75da48f..4bc433e72c61 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -18,14 +18,12 @@ package org.apache.spark.mllib.evaluation import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.Matrices import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.linalg.Matrices import org.apache.spark.mllib.util.MLlibTestSparkContext class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { - import testImplicits._ - val delta = 1e-7 test("Multiclass evaluation metrics") { @@ -60,44 +58,44 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1) val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2) - assert(metrics.confusionMatrix.asML ~== confusionMatrix.asML relTol delta) - assert(metrics.truePositiveRate(0.0) ~== tpRate0 absTol delta) - assert(metrics.truePositiveRate(1.0) ~== tpRate1 absTol delta) - assert(metrics.truePositiveRate(2.0) ~== tpRate2 absTol delta) - assert(metrics.falsePositiveRate(0.0) ~== fpRate0 absTol delta) - assert(metrics.falsePositiveRate(1.0) ~== fpRate1 absTol delta) - assert(metrics.falsePositiveRate(2.0) ~== fpRate2 absTol delta) - assert(metrics.precision(0.0) ~== precision0 absTol delta) - assert(metrics.precision(1.0) ~== precision1 absTol delta) - assert(metrics.precision(2.0) ~== precision2 absTol delta) - assert(metrics.recall(0.0) ~== recall0 absTol delta) - assert(metrics.recall(1.0) ~== recall1 absTol delta) - assert(metrics.recall(2.0) ~== recall2 absTol delta) - assert(metrics.fMeasure(0.0) ~== f1measure0 absTol delta) - assert(metrics.fMeasure(1.0) ~== f1measure1 absTol delta) - assert(metrics.fMeasure(2.0) ~== f1measure2 absTol delta) - assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 absTol delta) - assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 absTol delta) - assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 absTol delta) + assert(metrics.confusionMatrix.asML ~== confusionMatrix relTol delta) + assert(metrics.truePositiveRate(0.0) ~== tpRate0 relTol delta) + assert(metrics.truePositiveRate(1.0) ~== tpRate1 relTol delta) + assert(metrics.truePositiveRate(2.0) ~== tpRate2 relTol delta) + assert(metrics.falsePositiveRate(0.0) ~== fpRate0 relTol delta) + assert(metrics.falsePositiveRate(1.0) ~== fpRate1 relTol delta) + assert(metrics.falsePositiveRate(2.0) ~== fpRate2 relTol delta) + assert(metrics.precision(0.0) ~== precision0 relTol delta) + assert(metrics.precision(1.0) ~== precision1 relTol delta) + assert(metrics.precision(2.0) ~== precision2 relTol delta) + assert(metrics.recall(0.0) ~== recall0 relTol delta) + assert(metrics.recall(1.0) ~== recall1 relTol delta) + assert(metrics.recall(2.0) ~== recall2 relTol delta) + assert(metrics.fMeasure(0.0) ~== f1measure0 relTol delta) + assert(metrics.fMeasure(1.0) ~== f1measure1 relTol delta) + assert(metrics.fMeasure(2.0) ~== f1measure2 relTol delta) + assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 relTol delta) + assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 relTol delta) + assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 relTol delta) assert(metrics.accuracy ~== - (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1)) absTol delta) - assert(metrics.accuracy ~== metrics.weightedRecall absTol delta) + (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1)) relTol delta) + assert(metrics.accuracy ~== metrics.weightedRecall relTol delta) val weight0 = 4.0 / 9 val weight1 = 4.0 / 9 val weight2 = 1.0 / 9 assert(metrics.weightedTruePositiveRate ~== - (weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) absTol delta) + (weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) relTol delta) assert(metrics.weightedFalsePositiveRate ~== - (weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) absTol delta) + (weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) relTol delta) assert(metrics.weightedPrecision ~== - (weight0 * precision0 + weight1 * precision1 + weight2 * precision2) absTol delta) + (weight0 * precision0 + weight1 * precision1 + weight2 * precision2) relTol delta) assert(metrics.weightedRecall ~== - (weight0 * recall0 + weight1 * recall1 + weight2 * recall2) absTol delta) + (weight0 * recall0 + weight1 * recall1 + weight2 * recall2) relTol delta) assert(metrics.weightedFMeasure ~== - (weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) absTol delta) + (weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) relTol delta) assert(metrics.weightedFMeasure(2.0) ~== - (weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) absTol delta) + (weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) relTol delta) assert(metrics.labels === labels) } @@ -138,47 +136,47 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1) val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2) - assert(metrics.confusionMatrix.asML ~== confusionMatrix.asML relTol delta) - assert(metrics.truePositiveRate(0.0) ~== tpRate0 absTol delta) - assert(metrics.truePositiveRate(1.0) ~== tpRate1 absTol delta) - assert(metrics.truePositiveRate(2.0) ~== tpRate2 absTol delta) - assert(metrics.falsePositiveRate(0.0) ~== fpRate0 absTol delta) - assert(metrics.falsePositiveRate(1.0) ~== fpRate1 absTol delta) - assert(metrics.falsePositiveRate(2.0) ~== fpRate2 absTol delta) - assert(metrics.precision(0.0) ~== precision0 absTol delta) - assert(metrics.precision(1.0) ~== precision1 absTol delta) - assert(metrics.precision(2.0) ~== precision2 absTol delta) - assert(metrics.recall(0.0) ~== recall0 absTol delta) - assert(metrics.recall(1.0) ~== recall1 absTol delta) - assert(metrics.recall(2.0) ~== recall2 absTol delta) - assert(metrics.fMeasure(0.0) ~== f1measure0 absTol delta) - assert(metrics.fMeasure(1.0) ~== f1measure1 absTol delta) - assert(metrics.fMeasure(2.0) ~== f1measure2 absTol delta) - assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 absTol delta) - assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 absTol delta) - assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 absTol delta) + assert(metrics.confusionMatrix.asML ~== confusionMatrix relTol delta) + assert(metrics.truePositiveRate(0.0) ~== tpRate0 relTol delta) + assert(metrics.truePositiveRate(1.0) ~== tpRate1 relTol delta) + assert(metrics.truePositiveRate(2.0) ~== tpRate2 relTol delta) + assert(metrics.falsePositiveRate(0.0) ~== fpRate0 relTol delta) + assert(metrics.falsePositiveRate(1.0) ~== fpRate1 relTol delta) + assert(metrics.falsePositiveRate(2.0) ~== fpRate2 relTol delta) + assert(metrics.precision(0.0) ~== precision0 relTol delta) + assert(metrics.precision(1.0) ~== precision1 relTol delta) + assert(metrics.precision(2.0) ~== precision2 relTol delta) + assert(metrics.recall(0.0) ~== recall0 relTol delta) + assert(metrics.recall(1.0) ~== recall1 relTol delta) + assert(metrics.recall(2.0) ~== recall2 relTol delta) + assert(metrics.fMeasure(0.0) ~== f1measure0 relTol delta) + assert(metrics.fMeasure(1.0) ~== f1measure1 relTol delta) + assert(metrics.fMeasure(2.0) ~== f1measure2 relTol delta) + assert(metrics.fMeasure(0.0, 2.0) ~== f2measure0 relTol delta) + assert(metrics.fMeasure(1.0, 2.0) ~== f2measure1 relTol delta) + assert(metrics.fMeasure(2.0, 2.0) ~== f2measure2 relTol delta) assert(metrics.accuracy ~== - (2.0 * w1 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2) / tw absTol delta) - assert(metrics.accuracy ~== metrics.precision absTol delta) - assert(metrics.accuracy ~== metrics.recall absTol delta) - assert(metrics.accuracy ~== metrics.fMeasure absTol delta) - assert(metrics.accuracy ~== metrics.weightedRecall absTol delta) + (2.0 * w1 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2) / tw relTol delta) + assert(metrics.accuracy ~== metrics.precision relTol delta) + assert(metrics.accuracy ~== metrics.recall relTol delta) + assert(metrics.accuracy ~== metrics.fMeasure relTol delta) + assert(metrics.accuracy ~== metrics.weightedRecall relTol delta) val weight0 = (2 * w1 + 1 * w2 + 1 * w1) / tw val weight1 = (1 * w2 + 2 * w1 + 1 * w2) / tw val weight2 = 1 * w2 / tw assert(metrics.weightedTruePositiveRate ~== - (weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) absTol delta) + (weight0 * tpRate0 + weight1 * tpRate1 + weight2 * tpRate2) relTol delta) assert(metrics.weightedFalsePositiveRate ~== - (weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) absTol delta) + (weight0 * fpRate0 + weight1 * fpRate1 + weight2 * fpRate2) relTol delta) assert(metrics.weightedPrecision ~== - (weight0 * precision0 + weight1 * precision1 + weight2 * precision2) absTol delta) + (weight0 * precision0 + weight1 * precision1 + weight2 * precision2) relTol delta) assert(metrics.weightedRecall ~== - (weight0 * recall0 + weight1 * recall1 + weight2 * recall2) absTol delta) + (weight0 * recall0 + weight1 * recall1 + weight2 * recall2) relTol delta) assert(metrics.weightedFMeasure ~== - (weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) absTol delta) + (weight0 * f1measure0 + weight1 * f1measure1 + weight2 * f1measure2) relTol delta) assert(metrics.weightedFMeasure(2.0) ~== - (weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) absTol delta) + (weight0 * f2measure0 + weight1 * f2measure1 + weight2 * f2measure2) relTol delta) assert(metrics.labels === labels) } } From 32734a020b037baa5c7b5f010d084e2d79711876 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Mon, 5 Nov 2018 22:52:38 -0500 Subject: [PATCH 06/10] updated based on comments --- .../spark/ml/evaluation/MulticlassClassificationEvaluator.scala | 2 +- .../org/apache/spark/mllib/evaluation/MulticlassMetrics.scala | 2 +- .../apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index f678220bceaf..f1602c1bc533 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -69,7 +69,7 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid def setLabelCol(value: String): this.type = set(labelCol, value) /** @group setParam */ - @Since("2.4.0") + @Since("3.0.0") def setWeightCol(value: String): this.type = set(weightCol, value) setDefault(metricName -> "f1") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index f8a0d7e4918d..9dd2b0611973 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.DataFrame * (prediction, label) pairs. */ @Since("1.1.0") -class MulticlassMetrics @Since("2.4.0") (predAndLabelsWithOptWeight: RDD[_]) { +class MulticlassMetrics @Since("3.0.0") (predAndLabelsWithOptWeight: RDD[_]) { val predLabelsWeight: RDD[(Double, Double, Double)] = predAndLabelsWithOptWeight.map { case (prediction: Double, label: Double, weight: Double) => (prediction, label, weight) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index 4bc433e72c61..96d14a6764d6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { - val delta = 1e-7 + private val delta = 1e-7 test("Multiclass evaluation metrics") { /* From aff0f51138ce067f3b36a6653c3491d7a678e4df Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Tue, 6 Nov 2018 15:08:03 -0500 Subject: [PATCH 07/10] updated based on new comments --- .../mllib/evaluation/MulticlassMetrics.scala | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 9dd2b0611973..0d8603db96fd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -48,20 +48,23 @@ class MulticlassMetrics @Since("3.0.0") (predAndLabelsWithOptWeight: RDD[_]) { private lazy val labelCountByClass: Map[Double, Double] = predLabelsWeight.map { - case (prediction: Double, label: Double, weight: Double) => + case (_: Double, label: Double, weight: Double) => (label, weight) - }.reduceByKey(_ + _).collect().toMap + }.reduceByKey(_ + _) + .collectAsMap() private lazy val labelCount: Double = labelCountByClass.values.sum private lazy val tpByClass: Map[Double, Double] = predLabelsWeight + .filter(predLabelWeight => predLabelWeight._1 == predLabelWeight._2) .map { - case (prediction: Double, label: Double, weight: Double) => - (label, if (label == prediction) weight else 0.0) + case (_: Double, label: Double, weight: Double) => + (label, weight) }.reduceByKey(_ + _) .collectAsMap() private lazy val fpByClass: Map[Double, Double] = predLabelsWeight + .filter(predLabelWeight => predLabelWeight._1 != predLabelWeight._2) .map { - case (prediction: Double, label: Double, weight: Double) => - (prediction, if (prediction != label) weight else 0.0) + case (prediction: Double, _: Double, weight: Double) => + (prediction, weight) }.reduceByKey(_ + _) .collectAsMap() private lazy val confusions = predLabelsWeight From 07382f09c68807bc7b2705e415be115516ddeeb7 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Tue, 6 Nov 2018 16:16:44 -0500 Subject: [PATCH 08/10] undid filter --- .../spark/mllib/evaluation/MulticlassMetrics.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 0d8603db96fd..db680fbba5e7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -54,17 +54,15 @@ class MulticlassMetrics @Since("3.0.0") (predAndLabelsWithOptWeight: RDD[_]) { .collectAsMap() private lazy val labelCount: Double = labelCountByClass.values.sum private lazy val tpByClass: Map[Double, Double] = predLabelsWeight - .filter(predLabelWeight => predLabelWeight._1 == predLabelWeight._2) .map { - case (_: Double, label: Double, weight: Double) => - (label, weight) + case (prediction: Double, label: Double, weight: Double) => + (label, if (label == prediction) weight else 0.0) }.reduceByKey(_ + _) .collectAsMap() private lazy val fpByClass: Map[Double, Double] = predLabelsWeight - .filter(predLabelWeight => predLabelWeight._1 != predLabelWeight._2) .map { - case (prediction: Double, _: Double, weight: Double) => - (prediction, weight) + case (prediction: Double, label: Double, weight: Double) => + (prediction, if (prediction != label) weight else 0.0) }.reduceByKey(_ + _) .collectAsMap() private lazy val confusions = predLabelsWeight From d54cc555d61a65e95c271776b85fe5d55795dc1a Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Wed, 7 Nov 2018 19:03:26 -0500 Subject: [PATCH 09/10] reverted version back, added constraint and validation --- .../org/apache/spark/mllib/evaluation/MulticlassMetrics.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index db680fbba5e7..ad83c24ede96 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -31,12 +31,14 @@ import org.apache.spark.sql.DataFrame * (prediction, label) pairs. */ @Since("1.1.0") -class MulticlassMetrics @Since("3.0.0") (predAndLabelsWithOptWeight: RDD[_]) { +class MulticlassMetrics @Since("1.1.0") (predAndLabelsWithOptWeight: RDD[_ <: Product]) { val predLabelsWeight: RDD[(Double, Double, Double)] = predAndLabelsWithOptWeight.map { case (prediction: Double, label: Double, weight: Double) => (prediction, label, weight) case (prediction: Double, label: Double) => (prediction, label, 1.0) + case other => + throw new IllegalArgumentException(s"Expected tuples, got $other") } /** From 50864497d013ba7f8a160d5142b0cfdd41f00f8d Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Thu, 8 Nov 2018 14:03:29 -0500 Subject: [PATCH 10/10] merge with latest --- .../apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index 96d14a6764d6..8779de590a25 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -158,9 +158,6 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { assert(metrics.accuracy ~== (2.0 * w1 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2) / tw relTol delta) - assert(metrics.accuracy ~== metrics.precision relTol delta) - assert(metrics.accuracy ~== metrics.recall relTol delta) - assert(metrics.accuracy ~== metrics.fMeasure relTol delta) assert(metrics.accuracy ~== metrics.weightedRecall relTol delta) val weight0 = (2 * w1 + 1 * w2 + 1 * w1) / tw val weight1 = (1 * w2 + 2 * w1 + 1 * w2) / tw