diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index bff72b20e1c3f..fc7cdd0b3042f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -36,11 +36,17 @@ import org.apache.spark.sql.types.DoubleType @Since("1.2.0") @Experimental class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Evaluator with HasRawPredictionCol with HasLabelCol with DefaultParamsWritable { + extends Evaluator with HasRawPredictionCol with HasLabelCol + with HasWeightCol with DefaultParamsWritable { @Since("1.2.0") def this() = this(Identifiable.randomUID("binEval")) + /** + * Default number of bins to use for binary classification evaluation. + */ + val defaultNumberOfBins = 1000 + /** * param for metric name in evaluation (supports `"areaUnderROC"` (default), `"areaUnderPR"`) * @group param @@ -68,6 +74,10 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va @Since("1.2.0") def setLabelCol(value: String): this.type = set(labelCol, value) + /** @group setParam */ + @Since("2.2.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(metricName -> "areaUnderROC") @Since("2.0.0") @@ -77,12 +87,16 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va SchemaUtils.checkNumericType(schema, $(labelCol)) // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2. - val scoreAndLabels = - dataset.select(col($(rawPredictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map { - case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label) - case Row(rawPrediction: Double, label: Double) => (rawPrediction, label) + val scoreAndLabelsWithWeights = + dataset.select(col($(rawPredictionCol)), col($(labelCol)).cast(DoubleType), + if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))) + .rdd.map { + case Row(rawPrediction: Vector, label: Double, weight: Double) => + (rawPrediction(1), (label, weight)) + case Row(rawPrediction: Double, label: Double, weight: Double) => + (rawPrediction, (label, weight)) } - val metrics = new BinaryClassificationMetrics(scoreAndLabels) + val metrics = new BinaryClassificationMetrics(defaultNumberOfBins, scoreAndLabelsWithWeights) val metric = $(metricName) match { case "areaUnderROC" => metrics.areaUnderROC() case "areaUnderPR" => metrics.areaUnderPR() diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index 794b1e7d9d881..b319629362696 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} -import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.sql.{Dataset, Row} @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.DoubleType @Since("1.5.0") @Experimental class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { + extends Evaluator with HasPredictionCol with HasLabelCol + with HasWeightCol with DefaultParamsWritable { @Since("1.5.0") def this() = this(Identifiable.randomUID("mcEval")) @@ -67,6 +68,10 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid @Since("1.5.0") def setLabelCol(value: String): this.type = set(labelCol, value) + /** @group setParam */ + @Since("2.2.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(metricName -> "f1") @Since("2.0.0") @@ -75,11 +80,16 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) SchemaUtils.checkNumericType(schema, $(labelCol)) - val predictionAndLabels = - dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map { - case Row(prediction: Double, label: Double) => (prediction, label) + val predictionAndLabelsWithWeights = + dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType), + if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))) + .rdd.map { + case Row(prediction: Double, label: Double, weight: Double) => (prediction, label, weight) } - val metrics = new MulticlassMetrics(predictionAndLabels) + dataset.select(col($(predictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map { + case Row(prediction: Double, label: Double) => (prediction, label) + }.values.countByValue() + val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights) val metric = $(metricName) match { case "f1" => metrics.weightedFMeasure case "weightedPrecision" => metrics.weightedPrecision diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index 031cd0d635bf4..7589eedb8ce61 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} -import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.sql.{Dataset, Row} @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.{DoubleType, FloatType} @Since("1.4.0") @Experimental final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable { + extends Evaluator with HasPredictionCol with HasLabelCol + with HasWeightCol with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("regEval")) @@ -69,6 +70,10 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui @Since("1.4.0") def setLabelCol(value: String): this.type = set(labelCol, value) + /** @group setParam */ + @Since("2.2.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(metricName -> "rmse") @Since("2.0.0") @@ -77,11 +82,13 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, FloatType)) SchemaUtils.checkNumericType(schema, $(labelCol)) - val predictionAndLabels = dataset - .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType)) + val predictionAndLabelsWithWeights = dataset + .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType), + if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))) .rdd - .map { case Row(prediction: Double, label: Double) => (prediction, label) } - val metrics = new RegressionMetrics(predictionAndLabels) + .map { case Row(prediction: Double, label: Double, weight: Double) => + (prediction, label, weight) } + val metrics = new RegressionMetrics(false, predictionAndLabelsWithWeights) val metric = $(metricName) match { case "rmse" => metrics.rootMeanSquaredError case "mse" => metrics.meanSquaredError diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index 9b7cd0427f5ed..cebef17f53025 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.DataFrame /** * Evaluator for binary classification. * - * @param scoreAndLabels an RDD of (score, label) pairs. + * @param scoreAndLabelsWithWeights an RDD of (score, (label, weight)) pairs. * @param numBins if greater than 0, then the curves (ROC curve, PR curve) computed internally * will be down-sampled to this many "bins". If 0, no down-sampling will occur. * This is useful because the curve contains a point for each distinct score @@ -41,12 +41,26 @@ import org.apache.spark.sql.DataFrame * partition boundaries. */ @Since("1.0.0") -class BinaryClassificationMetrics @Since("1.3.0") ( - @Since("1.3.0") val scoreAndLabels: RDD[(Double, Double)], - @Since("1.3.0") val numBins: Int) extends Logging { +class BinaryClassificationMetrics @Since("2.2.0") ( + val numBins: Int, + @Since("2.2.0") val scoreAndLabelsWithWeights: RDD[(Double, (Double, Double))]) + extends Logging { require(numBins >= 0, "numBins must be nonnegative") + /** + * Retrieves the score and labels (for binary compatibility). + * @return The score and labels. + */ + @Since("1.0.0") + def scoreAndLabels: RDD[(Double, Double)] = { + scoreAndLabelsWithWeights.map(values => (values._1, values._2._1)) + } + + @Since("1.0.0") + def this(@Since("1.3.0") scoreAndLabels: RDD[(Double, Double)], @Since("1.3.0") numBins: Int) = + this(numBins, scoreAndLabels.map(scoreAndLabel => (scoreAndLabel._1, (scoreAndLabel._2, 1.0)))) + /** * Defaults `numBins` to 0. */ @@ -146,11 +160,13 @@ class BinaryClassificationMetrics @Since("1.3.0") ( private lazy val ( cumulativeCounts: RDD[(Double, BinaryLabelCounter)], confusions: RDD[(Double, BinaryConfusionMatrix)]) = { - // Create a bin for each distinct score value, count positives and negatives within each bin, - // and then sort by score values in descending order. - val counts = scoreAndLabels.combineByKey( - createCombiner = (label: Double) => new BinaryLabelCounter(0L, 0L) += label, - mergeValue = (c: BinaryLabelCounter, label: Double) => c += label, + // Create a bin for each distinct score value, count weighted positives and + // negatives within each bin, and then sort by score values in descending order. + val counts = scoreAndLabelsWithWeights.combineByKey( + createCombiner = (labelAndWeight: (Double, Double)) => + new BinaryLabelCounter(0L, 0L) += (labelAndWeight._1, labelAndWeight._2), + mergeValue = (c: BinaryLabelCounter, labelAndWeight: (Double, Double)) => + c += (labelAndWeight._1, labelAndWeight._2), mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2 ).sortByKey(ascending = false) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 9a6a8dbdccbf3..39d4b7b58cb77 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -27,10 +27,11 @@ import org.apache.spark.sql.DataFrame /** * Evaluator for multiclass classification. * - * @param predictionAndLabels an RDD of (prediction, label) pairs. + * @param predAndLabelsWithOptWeight an RDD of (prediction, label, weight) or + * (prediction, label) pairs. */ @Since("1.1.0") -class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Double)]) { +class MulticlassMetrics @Since("1.1.0") (predAndLabelsWithOptWeight: RDD[_]) { /** * An auxiliary constructor taking a DataFrame. @@ -39,21 +40,33 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl private[mllib] def this(predictionAndLabels: DataFrame) = this(predictionAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) - private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue() - private lazy val labelCount: Long = labelCountByClass.values.sum - private lazy val tpByClass: Map[Double, Int] = predictionAndLabels - .map { case (prediction, label) => - (label, if (label == prediction) 1 else 0) + private lazy val labelCountByClass: Map[Double, Double] = + predAndLabelsWithOptWeight.map { + case (prediction: Double, label: Double) => + (label, 1.0) + case (prediction: Double, label: Double, weight: Double) => + (label, weight) + }.mapValues(weight => weight).reduceByKey(_ + _).collect().toMap + private lazy val labelCount: Double = labelCountByClass.values.sum + private lazy val tpByClass: Map[Double, Double] = predAndLabelsWithOptWeight + .map { case (prediction: Double, label: Double) => + (label, if (label == prediction) 1.0 else 0.0) + case (prediction: Double, label: Double, weight: Double) => + (label, if (label == prediction) weight else 0.0) }.reduceByKey(_ + _) .collectAsMap() - private lazy val fpByClass: Map[Double, Int] = predictionAndLabels - .map { case (prediction, label) => - (prediction, if (prediction != label) 1 else 0) + private lazy val fpByClass: Map[Double, Double] = predAndLabelsWithOptWeight + .map { case (prediction: Double, label: Double) => + (prediction, if (prediction != label) 1.0 else 0.0) + case (prediction: Double, label: Double, weight: Double) => + (prediction, if (prediction != label) weight else 0.0) }.reduceByKey(_ + _) .collectAsMap() - private lazy val confusions = predictionAndLabels - .map { case (prediction, label) => - ((label, prediction), 1) + private lazy val confusions = predAndLabelsWithOptWeight + .map { case (prediction: Double, label: Double) => + ((label, prediction), 1.0) + case (prediction: Double, label: Double, weight: Double) => + ((label, prediction), weight) }.reduceByKey(_ + _) .collectAsMap() @@ -71,7 +84,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl while (i < n) { var j = 0 while (j < n) { - values(i + j * n) = confusions.getOrElse((labels(i), labels(j)), 0).toDouble + values(i + j * n) = confusions.getOrElse((labels(i), labels(j)), 0.0) j += 1 } i += 1 @@ -92,8 +105,8 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl */ @Since("1.1.0") def falsePositiveRate(label: Double): Double = { - val fp = fpByClass.getOrElse(label, 0) - fp.toDouble / (labelCount - labelCountByClass(label)) + val fp = fpByClass.getOrElse(label, 0.0) + fp / (labelCount - labelCountByClass(label)) } /** @@ -103,7 +116,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl @Since("1.1.0") def precision(label: Double): Double = { val tp = tpByClass(label) - val fp = fpByClass.getOrElse(label, 0) + val fp = fpByClass.getOrElse(label, 0.0) if (tp + fp == 0) 0 else tp.toDouble / (tp + fp) } @@ -112,7 +125,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl * @param label the label. */ @Since("1.1.0") - def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label) + def recall(label: Double): Double = tpByClass(label) / labelCountByClass(label) /** * Returns f-measure for a given label (category) @@ -165,7 +178,7 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl * out of the total number of instances.) */ @Since("2.0.0") - lazy val accuracy: Double = tpByClass.values.sum.toDouble / labelCount + lazy val accuracy: Double = tpByClass.values.sum / labelCount /** * Returns weighted true positive rate diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index ad99b00a31fd5..6dc812955291d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -27,18 +27,30 @@ import org.apache.spark.sql.DataFrame /** * Evaluator for regression. * - * @param predictionAndObservations an RDD of (prediction, observation) pairs + * @param predAndObsWithOptWeight an RDD of either (prediction, observation, weight) + * or (prediction, observation) pairs * @param throughOrigin True if the regression is through the origin. For example, in linear * regression, it will be true without fitting intercept. */ @Since("1.2.0") class RegressionMetrics @Since("2.0.0") ( - predictionAndObservations: RDD[(Double, Double)], throughOrigin: Boolean) + throughOrigin: Boolean, predAndObsWithOptWeight: RDD[_]) extends Logging { @Since("1.2.0") def this(predictionAndObservations: RDD[(Double, Double)]) = - this(predictionAndObservations, false) + this(false, predictionAndObservations) + + /** + * Evaluator for regression. + * + * @param predictionAndObservations an RDD of (prediction, observation) pairs + * @param throughOrigin True if the regression is through the origin. For example, in linear + * regression, it will be true without fitting intercept. + */ + @Since("2.0.0") + def this(predictionAndObservations: RDD[(Double, Double)], throughOrigin: Boolean) = + this(throughOrigin, predictionAndObservations) /** * An auxiliary constructor taking a DataFrame. @@ -46,16 +58,19 @@ class RegressionMetrics @Since("2.0.0") ( * prediction and observation */ private[mllib] def this(predictionAndObservations: DataFrame) = - this(predictionAndObservations.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) + this(false, predictionAndObservations.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) /** * Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors. */ private lazy val summary: MultivariateStatisticalSummary = { - val summary: MultivariateStatisticalSummary = predictionAndObservations.map { - case (prediction, observation) => Vectors.dense(observation, observation - prediction) + val summary: MultivariateStatisticalSummary = predAndObsWithOptWeight.map { + case (prediction: Double, observation: Double, weight: Double) => + (Vectors.dense(observation, observation - prediction), weight) + case (prediction: Double, observation: Double) => + (Vectors.dense(observation, observation - prediction), 1.0) }.aggregate(new MultivariateOnlineSummarizer())( - (summary, v) => summary.add(v), + (summary, sample) => summary.add(sample._1, sample._2), (sum1, sum2) => sum1.merge(sum2) ) summary @@ -63,11 +78,13 @@ class RegressionMetrics @Since("2.0.0") ( private lazy val SSy = math.pow(summary.normL2(0), 2) private lazy val SSerr = math.pow(summary.normL2(1), 2) - private lazy val SStot = summary.variance(0) * (summary.count - 1) + private lazy val SStot = summary.variance(0) * (summary.weightSum - 1) private lazy val SSreg = { val yMean = summary.mean(0) - predictionAndObservations.map { - case (prediction, _) => math.pow(prediction - yMean, 2) + predAndObsWithOptWeight.map { + case (prediction: Double, _: Double, weight: Double) => + math.pow(prediction - yMean, 2) * weight + case (prediction: Double, _: Double) => math.pow(prediction - yMean, 2) }.sum() } @@ -79,7 +96,7 @@ class RegressionMetrics @Since("2.0.0") ( */ @Since("1.2.0") def explainedVariance: Double = { - SSreg / summary.count + SSreg / summary.weightSum } /** @@ -88,7 +105,7 @@ class RegressionMetrics @Since("2.0.0") ( */ @Since("1.2.0") def meanAbsoluteError: Double = { - summary.normL1(1) / summary.count + summary.normL1(1) / summary.weightSum } /** @@ -97,7 +114,7 @@ class RegressionMetrics @Since("2.0.0") ( */ @Since("1.2.0") def meanSquaredError: Double = { - SSerr / summary.count + SSerr / summary.weightSum } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala index 559c6ef7e7251..1c0130700421e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala @@ -22,22 +22,22 @@ package org.apache.spark.mllib.evaluation.binary */ private[evaluation] trait BinaryConfusionMatrix { /** number of true positives */ - def numTruePositives: Long + def numTruePositives: Double /** number of false positives */ - def numFalsePositives: Long + def numFalsePositives: Double /** number of false negatives */ - def numFalseNegatives: Long + def numFalseNegatives: Double /** number of true negatives */ - def numTrueNegatives: Long + def numTrueNegatives: Double /** number of positives */ - def numPositives: Long = numTruePositives + numFalseNegatives + def numPositives: Double = numTruePositives + numFalseNegatives /** number of negatives */ - def numNegatives: Long = numFalsePositives + numTrueNegatives + def numNegatives: Double = numFalsePositives + numTrueNegatives } /** @@ -51,20 +51,20 @@ private[evaluation] case class BinaryConfusionMatrixImpl( totalCount: BinaryLabelCounter) extends BinaryConfusionMatrix { /** number of true positives */ - override def numTruePositives: Long = count.numPositives + override def numTruePositives: Double = count.numPositives /** number of false positives */ - override def numFalsePositives: Long = count.numNegatives + override def numFalsePositives: Double = count.numNegatives /** number of false negatives */ - override def numFalseNegatives: Long = totalCount.numPositives - count.numPositives + override def numFalseNegatives: Double = totalCount.numPositives - count.numPositives /** number of true negatives */ - override def numTrueNegatives: Long = totalCount.numNegatives - count.numNegatives + override def numTrueNegatives: Double = totalCount.numNegatives - count.numNegatives /** number of positives */ - override def numPositives: Long = totalCount.numPositives + override def numPositives: Double = totalCount.numPositives /** number of negatives */ - override def numNegatives: Long = totalCount.numNegatives + override def numNegatives: Double = totalCount.numNegatives } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryLabelCounter.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryLabelCounter.scala index 1e610c20092a7..eb38af748ad9e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryLabelCounter.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryLabelCounter.scala @@ -24,14 +24,22 @@ package org.apache.spark.mllib.evaluation.binary * @param numNegatives number of negative labels */ private[evaluation] class BinaryLabelCounter( - var numPositives: Long = 0L, - var numNegatives: Long = 0L) extends Serializable { + var numPositives: Double = 0.0, + var numNegatives: Double = 0.0) extends Serializable { /** Processes a label. */ def +=(label: Double): BinaryLabelCounter = { // Though we assume 1.0 for positive and 0.0 for negative, the following check will handle // -1.0 for negative as well. - if (label > 0.5) numPositives += 1L else numNegatives += 1L + if (label > 0.5) numPositives += 1.0 else numNegatives += 1.0 + this + } + + /** Processes a label with a weight. */ + def +=(label: Double, weight: Double): BinaryLabelCounter = { + // Though we assume 1.0 for positive and 0.0 for negative, the following check will handle + // -1.0 for negative as well. + if (label > 0.5) numPositives += weight else numNegatives += weight this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 7dc0c459ec032..f14dbf9810d2d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -52,7 +52,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S private var totalCnt: Long = 0 private var totalWeightSum: Double = 0.0 private var weightSquareSum: Double = 0.0 - private var weightSum: Array[Double] = _ + private var currWeightSum: Array[Double] = _ private var nnz: Array[Long] = _ private var currMax: Array[Double] = _ private var currMin: Array[Double] = _ @@ -78,7 +78,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currM2n = Array.ofDim[Double](n) currM2 = Array.ofDim[Double](n) currL1 = Array.ofDim[Double](n) - weightSum = Array.ofDim[Double](n) + currWeightSum = Array.ofDim[Double](n) nnz = Array.ofDim[Long](n) currMax = Array.fill[Double](n)(Double.MinValue) currMin = Array.fill[Double](n)(Double.MaxValue) @@ -91,7 +91,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val localCurrM2n = currM2n val localCurrM2 = currM2 val localCurrL1 = currL1 - val localWeightSum = weightSum + val localWeightSum = currWeightSum val localNumNonzeros = nnz val localCurrMax = currMax val localCurrMin = currMin @@ -139,8 +139,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S weightSquareSum += other.weightSquareSum var i = 0 while (i < n) { - val thisNnz = weightSum(i) - val otherNnz = other.weightSum(i) + val thisNnz = currWeightSum(i) + val otherNnz = other.currWeightSum(i) val totalNnz = thisNnz + otherNnz val totalCnnz = nnz(i) + other.nnz(i) if (totalNnz != 0.0) { @@ -157,7 +157,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currMax(i) = math.max(currMax(i), other.currMax(i)) currMin(i) = math.min(currMin(i), other.currMin(i)) } - weightSum(i) = totalNnz + currWeightSum(i) = totalNnz nnz(i) = totalCnnz i += 1 } @@ -170,7 +170,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S this.totalCnt = other.totalCnt this.totalWeightSum = other.totalWeightSum this.weightSquareSum = other.weightSquareSum - this.weightSum = other.weightSum.clone() + this.currWeightSum = other.currWeightSum.clone() this.nnz = other.nnz.clone() this.currMax = other.currMax.clone() this.currMin = other.currMin.clone() @@ -189,7 +189,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val realMean = Array.ofDim[Double](n) var i = 0 while (i < n) { - realMean(i) = currMean(i) * (weightSum(i) / totalWeightSum) + realMean(i) = currMean(i) * (currWeightSum(i) / totalWeightSum) i += 1 } Vectors.dense(realMean) @@ -213,8 +213,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S var i = 0 val len = currM2n.length while (i < len) { - realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) * - (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator + realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * currWeightSum(i) * + (totalWeightSum - currWeightSum(i)) / totalWeightSum) / denominator i += 1 } } @@ -228,6 +228,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S @Since("1.1.0") override def count: Long = totalCnt + /** + * Sum of weights. + */ + override def weightSum: Double = totalWeightSum + /** * Number of nonzero elements in each dimension. * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala index 39a16fb743d64..978842fdb27e1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala @@ -44,6 +44,12 @@ trait MultivariateStatisticalSummary { @Since("1.0.0") def count: Long + /** + * Sum of weights. + */ + @Since("2.2.0") + def weightSum: Double + /** * Number of nonzero elements (including explicitly presented zero values) in each column. */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index ede284712b1c0..1e7b96f94cfc8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -71,6 +71,33 @@ class BinaryClassificationEvaluatorSuite assert(thrown.getMessage.replace("\n", "") contains "but was actually of type StringType.") } + test("should accept weight column") { + val weightCol = "weight" + // get metric with weight column + val evaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderROC").setWeightCol(weightCol) + val vectorDF = Seq( + (0d, Vectors.dense(2.5, 12), 1.0), + (1d, Vectors.dense(1, 3), 1.0), + (0d, Vectors.dense(10, 2), 1.0) + ).toDF("label", "rawPrediction", weightCol) + val result = evaluator.evaluate(vectorDF) + // without weight column + val evaluator2 = new BinaryClassificationEvaluator() + .setMetricName("areaUnderROC") + val result2 = evaluator2.evaluate(vectorDF) + assert(result === result2) + // use different weights, validate metrics change + val vectorDF2 = Seq( + (0d, Vectors.dense(2.5, 12), 2.5), + (1d, Vectors.dense(1, 3), 0.1), + (0d, Vectors.dense(10, 2), 2.0) + ).toDF("label", "rawPrediction", weightCol) + val result3 = evaluator.evaluate(vectorDF2) + // Since wrong result weighted more heavily, expect the score to be lower + assert(result3 < result) + } + test("should support all NumericType labels and not support other types") { val evaluator = new BinaryClassificationEvaluator().setRawPredictionCol("prediction") MLTestingUtils.checkNumericTypes(evaluator, spark) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index 99d52fabc5309..ecc60b0d4bf8e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -84,6 +84,34 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls) } + test("binary evaluation metrics with weights") { + val w1 = 1.5 + val w2 = 0.7 + val w3 = 0.4 + val scoreAndLabelsWithWeights = sc.parallelize( + Seq((0.1, (0.0, w1)), (0.1, (1.0, w2)), (0.4, (0.0, w1)), (0.6, (0.0, w3)), + (0.6, (1.0, w2)), (0.6, (1.0, w2)), (0.8, (1.0, w1))), 2) + val metrics = new BinaryClassificationMetrics(0, scoreAndLabelsWithWeights) + val thresholds = Seq(0.8, 0.6, 0.4, 0.1) + val numTruePositives = + Seq(1.0 * w1, 1.0 * w1 + 2.0 * w2, 1.0 * w1 + 2.0 * w2, 3.0 * w2 + 1.0 * w1) + val numFalsePositives = Seq(0.0, 1.0 * w3, 1.0 * w1 + 1.0 * w3, 1.0 * w3 + 2.0 * w1) + val numPositives = 3 * w2 + 1 * w1 + val numNegatives = 2 * w1 + w3 + val precisions = numTruePositives.zip(numFalsePositives).map { case (t, f) => + t.toDouble / (t + f) + } + val recalls = numTruePositives.map(t => t / numPositives) + val fpr = numFalsePositives.map(f => f / numNegatives) + val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0)) + val pr = recalls.zip(precisions) + val prCurve = Seq((0.0, 1.0)) ++ pr + val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)} + val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)} + + validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls) + } + test("binary evaluation metrics for RDD where all examples have positive label") { val scoreAndLabels = sc.parallelize(Seq((0.5, 1.0), (0.5, 1.0)), 2) val metrics = new BinaryClassificationMetrics(scoreAndLabels) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index 142d1e9812ef1..47b1e1af59c71 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -95,4 +95,95 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { ((4.0 / 9) * f2measure0 + (4.0 / 9) * f2measure1 + (1.0 / 9) * f2measure2)) < delta) assert(metrics.labels.sameElements(labels)) } + + test("Multiclass evaluation metrics with weights") { + /* + * Confusion matrix for 3-class classification with total 9 instances with 2 weights: + * |2 * w1|1 * w2 |1 * w1| true class0 (4 instances) + * |1 * w2|2 * w1 + 1 * w2|0 | true class1 (4 instances) + * |0 |0 |1 * w2| true class2 (1 instance) + */ + val w1 = 2.2 + val w2 = 1.5 + val tw = 2.0 * w1 + 1.0 * w2 + 1.0 * w1 + 1.0 * w2 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2 + val confusionMatrix = Matrices.dense(3, 3, + Array(2 * w1, 1 * w2, 0, 1 * w2, 2 * w1 + 1 * w2, 0, 1 * w1, 0, 1 * w2)) + val labels = Array(0.0, 1.0, 2.0) + val predictionAndLabelsWithWeights = sc.parallelize( + Seq((0.0, 0.0, w1), (0.0, 1.0, w2), (0.0, 0.0, w1), (1.0, 0.0, w2), + (1.0, 1.0, w1), (1.0, 1.0, w2), (1.0, 1.0, w1), (2.0, 2.0, w2), + (2.0, 0.0, w1)), 2) + val metrics = new MulticlassMetrics(predictionAndLabelsWithWeights) + val delta = 0.0000001 + val tpRate0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1) + val tpRate1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2) + val tpRate2 = (1.0 * w2) / (1.0 * w2 + 0) + val fpRate0 = (1.0 * w2) / (tw - (2.0 * w1 + 1.0 * w2 + 1.0 * w1)) + val fpRate1 = (1.0 * w2) / (tw - (1.0 * w2 + 2.0 * w1 + 1.0 * w2)) + val fpRate2 = (1.0 * w1) / (tw - (1.0 * w2)) + val precision0 = (2.0 * w1) / (2 * w1 + 1 * w2) + val precision1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2) + val precision2 = (1.0 * w2) / (1 * w1 + 1 * w2) + val recall0 = (2.0 * w1) / (2.0 * w1 + 1.0 * w2 + 1.0 * w1) + val recall1 = (2.0 * w1 + 1.0 * w2) / (2.0 * w1 + 1.0 * w2 + 1.0 * w2) + val recall2 = (1.0 * w2) / (1.0 * w2 + 0) + val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0) + val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1) + val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2) + val f2measure0 = (1 + 2 * 2) * precision0 * recall0 / (2 * 2 * precision0 + recall0) + val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1) + val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2) + + assert(metrics.confusionMatrix.toArray.sameElements(confusionMatrix.toArray)) + assert(math.abs(metrics.truePositiveRate(0.0) - tpRate0) < delta) + assert(math.abs(metrics.truePositiveRate(1.0) - tpRate1) < delta) + assert(math.abs(metrics.truePositiveRate(2.0) - tpRate2) < delta) + assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta) + assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta) + assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta) + assert(math.abs(metrics.precision(0.0) - precision0) < delta) + assert(math.abs(metrics.precision(1.0) - precision1) < delta) + assert(math.abs(metrics.precision(2.0) - precision2) < delta) + assert(math.abs(metrics.recall(0.0) - recall0) < delta) + assert(math.abs(metrics.recall(1.0) - recall1) < delta) + assert(math.abs(metrics.recall(2.0) - recall2) < delta) + assert(math.abs(metrics.fMeasure(0.0) - f1measure0) < delta) + assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta) + assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta) + assert(math.abs(metrics.fMeasure(0.0, 2.0) - f2measure0) < delta) + assert(math.abs(metrics.fMeasure(1.0, 2.0) - f2measure1) < delta) + assert(math.abs(metrics.fMeasure(2.0, 2.0) - f2measure2) < delta) + + assert(math.abs(metrics.accuracy - + (2.0 * w1 + 2.0 * w1 + 1.0 * w2 + 1.0 * w2) / tw) < delta) + assert(math.abs(metrics.accuracy - metrics.precision) < delta) + assert(math.abs(metrics.accuracy - metrics.recall) < delta) + assert(math.abs(metrics.accuracy - metrics.fMeasure) < delta) + assert(math.abs(metrics.accuracy - metrics.weightedRecall) < delta) + assert(math.abs(metrics.weightedTruePositiveRate - + (((2 * w1 + 1 * w2 + 1 * w1) / tw) * tpRate0 + + ((1 * w2 + 2 * w1 + 1 * w2) / tw) * tpRate1 + + (1 * w2 / tw) * tpRate2)) < delta) + assert(math.abs(metrics.weightedFalsePositiveRate - + (((2 * w1 + 1 * w2 + 1 * w1) / tw) * fpRate0 + + ((1 * w2 + 2 * w1 + 1 * w2) / tw) * fpRate1 + + (1 * w2 / tw) * fpRate2)) < delta) + assert(math.abs(metrics.weightedPrecision - + (((2 * w1 + 1 * w2 + 1 * w1) / tw) * precision0 + + ((1 * w2 + 2 * w1 + 1 * w2) / tw) * precision1 + + (1 * w2 / tw) * precision2)) < delta) + assert(math.abs(metrics.weightedRecall - + (((2 * w1 + 1 * w2 + 1 * w1) / tw) * recall0 + + ((1 * w2 + 2 * w1 + 1 * w2) / tw) * recall1 + + (1 * w2 / tw) * recall2)) < delta) + assert(math.abs(metrics.weightedFMeasure - + (((2 * w1 + 1 * w2 + 1 * w1) / tw) * f1measure0 + + ((1 * w2 + 2 * w1 + 1 * w2) / tw) * f1measure1 + + (1 * w2 / tw) * f1measure2)) < delta) + assert(math.abs(metrics.weightedFMeasure(2.0) - + (((2 * w1 + 1 * w2 + 1 * w1) / tw) * f2measure0 + + ((1 * w2 + 2 * w1 + 1 * w2) / tw) * f2measure1 + + (1 * w2 / tw) * f2measure2)) < delta) + assert(metrics.labels.sameElements(labels)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala index f1d517383643d..6de76f333b862 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala @@ -133,4 +133,54 @@ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { "root mean squared error mismatch") assert(metrics.r2 ~== 1.0 absTol eps, "r2 score mismatch") } + + test("regression metrics with same (1.0) weight samples") { + val predictionAndObservationWithWeight = sc.parallelize( + Seq((2.25, 3.0, 1.0), (-0.25, -0.5, 1.0), (1.75, 2.0, 1.0), (7.75, 7.0, 1.0)), 2) + val metrics = new RegressionMetrics(false, predictionAndObservationWithWeight) + assert(metrics.explainedVariance ~== 8.79687 absTol eps, + "explained variance regression score mismatch") + assert(metrics.meanAbsoluteError ~== 0.5 absTol eps, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 0.3125 absTol eps, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 0.55901 absTol eps, + "root mean squared error mismatch") + assert(metrics.r2 ~== 0.95717 absTol eps, "r2 score mismatch") + } + + /** + * The following values are hand calculated using the formula: + * [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]] + * preds = c(2.25, -0.25, 1.75, 7.75) + * obs = c(3.0, -0.5, 2.0, 7.0) + * weights = c(0.1, 0.2, 0.15, 0.05) + * count = 4 + * + * Weighted metrics can be calculated with MultivariateStatisticalSummary. + * (observations, observations - predictions) + * mean (1.7, 0.05) + * variance (7.3, 0.3) + * numNonZeros (0.5, 0.5) + * max (7.0, 0.75) + * min (-0.5, -0.75) + * normL2 (2.0, 0.32596) + * normL1 (1.05, 0.2) + * + * explainedVariance: sum(pow((preds - 1.7),2)*weight) / weightedCount = 5.2425 + * meanAbsoluteError: normL1(1) / weightedCount = 0.4 + * meanSquaredError: pow(normL2(1),2) / weightedCount = 0.2125 + * rootMeanSquaredError: sqrt(meanSquaredError) = 0.46098 + * r2: 1 - pow(normL2(1),2) / (variance(0) * (weightedCount - 1)) = 1.02910 + */ + test("regression metrics with weighted samples") { + val predictionAndObservationWithWeight = sc.parallelize( + Seq((2.25, 3.0, 0.1), (-0.25, -0.5, 0.2), (1.75, 2.0, 0.15), (7.75, 7.0, 0.05)), 2) + val metrics = new RegressionMetrics(false, predictionAndObservationWithWeight) + assert(metrics.explainedVariance ~== 5.2425 absTol eps, + "explained variance regression score mismatch") + assert(metrics.meanAbsoluteError ~== 0.4 absTol eps, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 0.2125 absTol eps, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 0.46098 absTol eps, + "root mean squared error mismatch") + assert(metrics.r2 ~== 1.02910 absTol eps, "r2 score mismatch") + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 511686fb4f37f..26335b6829886 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -61,7 +61,10 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$11"), // [SPARK-17161] Removing Python-friendly constructors not needed - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.this") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.this"), + + // [SPARK-18693] Added weightSum to trait MultivariateStatisticalSummary + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.stat.MultivariateStatisticalSummary.weightSum") ) // Exclude rules for 2.1.x