diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index bff72b20e1c3f..c6b04333885ae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -36,7 +36,8 @@ import org.apache.spark.sql.types.DoubleType @Since("1.2.0") @Experimental class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Evaluator with HasRawPredictionCol with HasLabelCol with DefaultParamsWritable { + extends Evaluator with HasRawPredictionCol with HasLabelCol + with HasWeightCol with DefaultParamsWritable { @Since("1.2.0") def this() = this(Identifiable.randomUID("binEval")) @@ -68,6 +69,10 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va @Since("1.2.0") def setLabelCol(value: String): this.type = set(labelCol, value) + /** @group setParam */ + @Since("3.0.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(metricName -> "areaUnderROC") @Since("2.0.0") @@ -75,14 +80,23 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va val schema = dataset.schema SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT)) SchemaUtils.checkNumericType(schema, $(labelCol)) + if (isDefined(weightCol)) { + SchemaUtils.checkNumericType(schema, $(weightCol)) + } // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2. - val scoreAndLabels = - dataset.select(col($(rawPredictionCol)), col($(labelCol)).cast(DoubleType)).rdd.map { - case Row(rawPrediction: Vector, label: Double) => (rawPrediction(1), label) - case Row(rawPrediction: Double, label: Double) => (rawPrediction, label) + val scoreAndLabelsWithWeights = + dataset.select( + col($(rawPredictionCol)), + col($(labelCol)).cast(DoubleType), + if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) + else col($(weightCol)).cast(DoubleType)).rdd.map { + case Row(rawPrediction: Vector, label: Double, weight: Double) => + (rawPrediction(1), label, weight) + case Row(rawPrediction: Double, label: Double, weight: Double) => + (rawPrediction, label, weight) } - val metrics = new BinaryClassificationMetrics(scoreAndLabels) + val metrics = new BinaryClassificationMetrics(scoreAndLabelsWithWeights) val metric = $(metricName) match { case "areaUnderROC" => metrics.areaUnderROC() case "areaUnderPR" => metrics.areaUnderPR() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index 2cfcf38eb4ca8..cc89edc286a3b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -21,12 +21,12 @@ import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.mllib.evaluation.binary._ import org.apache.spark.rdd.{RDD, UnionRDD} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} /** * Evaluator for binary classification. * - * @param scoreAndLabels an RDD of (score, label) pairs. + * @param scoreAndLabels an RDD of (score, label) or (score, label, weight) tuples. * @param numBins if greater than 0, then the curves (ROC curve, PR curve) computed internally * will be down-sampled to this many "bins". If 0, no down-sampling will occur. * This is useful because the curve contains a point for each distinct score @@ -41,9 +41,19 @@ import org.apache.spark.sql.DataFrame * partition boundaries. */ @Since("1.0.0") -class BinaryClassificationMetrics @Since("1.3.0") ( - @Since("1.3.0") val scoreAndLabels: RDD[(Double, Double)], - @Since("1.3.0") val numBins: Int) extends Logging { +class BinaryClassificationMetrics @Since("3.0.0") ( + @Since("1.3.0") val scoreAndLabels: RDD[_ <: Product], + @Since("1.3.0") val numBins: Int = 1000) + extends Logging { + val scoreLabelsWeight: RDD[(Double, (Double, Double))] = scoreAndLabels.map { + case (prediction: Double, label: Double, weight: Double) => + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") + (prediction, (label, weight)) + case (prediction: Double, label: Double) => + (prediction, (label, 1.0)) + case other => + throw new IllegalArgumentException(s"Expected tuples, got $other") + } require(numBins >= 0, "numBins must be nonnegative") @@ -58,7 +68,14 @@ class BinaryClassificationMetrics @Since("1.3.0") ( * @param scoreAndLabels a DataFrame with two double columns: score and label */ private[mllib] def this(scoreAndLabels: DataFrame) = - this(scoreAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) + this(scoreAndLabels.rdd.map { + case Row(prediction: Double, label: Double, weight: Double) => + (prediction, label, weight) + case Row(prediction: Double, label: Double) => + (prediction, label, 1.0) + case other => + throw new IllegalArgumentException(s"Expected Row of tuples, got $other") + }) /** * Unpersist intermediate RDDs used in the computation. @@ -146,11 +163,13 @@ class BinaryClassificationMetrics @Since("1.3.0") ( private lazy val ( cumulativeCounts: RDD[(Double, BinaryLabelCounter)], confusions: RDD[(Double, BinaryConfusionMatrix)]) = { - // Create a bin for each distinct score value, count positives and negatives within each bin, - // and then sort by score values in descending order. - val counts = scoreAndLabels.combineByKey( - createCombiner = (label: Double) => new BinaryLabelCounter(0L, 0L) += label, - mergeValue = (c: BinaryLabelCounter, label: Double) => c += label, + // Create a bin for each distinct score value, count weighted positives and + // negatives within each bin, and then sort by score values in descending order. + val counts = scoreLabelsWeight.combineByKey( + createCombiner = (labelAndWeight: (Double, Double)) => + new BinaryLabelCounter(0.0, 0.0) += (labelAndWeight._1, labelAndWeight._2), + mergeValue = (c: BinaryLabelCounter, labelAndWeight: (Double, Double)) => + c += (labelAndWeight._1, labelAndWeight._2), mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2 ).sortByKey(ascending = false) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index ad83c24ede964..a10f26ba4640e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -22,17 +22,17 @@ import scala.collection.Map import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{Matrices, Matrix} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} /** * Evaluator for multiclass classification. * - * @param predAndLabelsWithOptWeight an RDD of (prediction, label, weight) or - * (prediction, label) pairs. + * @param predictionAndLabels an RDD of (prediction, label, weight) or + * (prediction, label) tuples. */ @Since("1.1.0") -class MulticlassMetrics @Since("1.1.0") (predAndLabelsWithOptWeight: RDD[_ <: Product]) { - val predLabelsWeight: RDD[(Double, Double, Double)] = predAndLabelsWithOptWeight.map { +class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[_ <: Product]) { + val predLabelsWeight: RDD[(Double, Double, Double)] = predictionAndLabels.map { case (prediction: Double, label: Double, weight: Double) => (prediction, label, weight) case (prediction: Double, label: Double) => @@ -46,7 +46,14 @@ class MulticlassMetrics @Since("1.1.0") (predAndLabelsWithOptWeight: RDD[_ <: Pr * @param predictionAndLabels a DataFrame with two double columns: prediction and label */ private[mllib] def this(predictionAndLabels: DataFrame) = - this(predictionAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1)))) + this(predictionAndLabels.rdd.map { + case Row(prediction: Double, label: Double, weight: Double) => + (prediction, label, weight) + case Row(prediction: Double, label: Double) => + (prediction, label, 1.0) + case other => + throw new IllegalArgumentException(s"Expected Row of tuples, got $other") + }) private lazy val labelCountByClass: Map[Double, Double] = predLabelsWeight.map { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala index 5a4c6aef50b7b..d98ca2bdc9ded 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala @@ -27,11 +27,11 @@ private[evaluation] trait BinaryClassificationMetricComputer extends Serializabl /** Precision. Defined as 1.0 when there are no positive examples. */ private[evaluation] object Precision extends BinaryClassificationMetricComputer { override def apply(c: BinaryConfusionMatrix): Double = { - val totalPositives = c.numTruePositives + c.numFalsePositives - if (totalPositives == 0) { + val totalPositives = c.weightedTruePositives + c.weightedFalsePositives + if (totalPositives == 0.0) { 1.0 } else { - c.numTruePositives.toDouble / totalPositives + c.weightedTruePositives / totalPositives } } } @@ -39,10 +39,10 @@ private[evaluation] object Precision extends BinaryClassificationMetricComputer /** False positive rate. Defined as 0.0 when there are no negative examples. */ private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer { override def apply(c: BinaryConfusionMatrix): Double = { - if (c.numNegatives == 0) { + if (c.weightedNegatives == 0.0) { 0.0 } else { - c.numFalsePositives.toDouble / c.numNegatives + c.weightedFalsePositives / c.weightedNegatives } } } @@ -50,10 +50,10 @@ private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricC /** Recall. Defined as 0.0 when there are no positive examples. */ private[evaluation] object Recall extends BinaryClassificationMetricComputer { override def apply(c: BinaryConfusionMatrix): Double = { - if (c.numPositives == 0) { + if (c.weightedPositives == 0.0) { 0.0 } else { - c.numTruePositives.toDouble / c.numPositives + c.weightedTruePositives / c.weightedPositives } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala index 559c6ef7e7251..192c9b1863fe7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala @@ -21,23 +21,23 @@ package org.apache.spark.mllib.evaluation.binary * Trait for a binary confusion matrix. */ private[evaluation] trait BinaryConfusionMatrix { - /** number of true positives */ - def numTruePositives: Long + /** weighted number of true positives */ + def weightedTruePositives: Double - /** number of false positives */ - def numFalsePositives: Long + /** weighted number of false positives */ + def weightedFalsePositives: Double - /** number of false negatives */ - def numFalseNegatives: Long + /** weighted number of false negatives */ + def weightedFalseNegatives: Double - /** number of true negatives */ - def numTrueNegatives: Long + /** weighted number of true negatives */ + def weightedTrueNegatives: Double - /** number of positives */ - def numPositives: Long = numTruePositives + numFalseNegatives + /** weighted number of positives */ + def weightedPositives: Double = weightedTruePositives + weightedFalseNegatives - /** number of negatives */ - def numNegatives: Long = numFalsePositives + numTrueNegatives + /** weighted number of negatives */ + def weightedNegatives: Double = weightedFalsePositives + weightedTrueNegatives } /** @@ -51,20 +51,22 @@ private[evaluation] case class BinaryConfusionMatrixImpl( totalCount: BinaryLabelCounter) extends BinaryConfusionMatrix { /** number of true positives */ - override def numTruePositives: Long = count.numPositives + override def weightedTruePositives: Double = count.weightedNumPositives /** number of false positives */ - override def numFalsePositives: Long = count.numNegatives + override def weightedFalsePositives: Double = count.weightedNumNegatives /** number of false negatives */ - override def numFalseNegatives: Long = totalCount.numPositives - count.numPositives + override def weightedFalseNegatives: Double = + totalCount.weightedNumPositives - count.weightedNumPositives /** number of true negatives */ - override def numTrueNegatives: Long = totalCount.numNegatives - count.numNegatives + override def weightedTrueNegatives: Double = + totalCount.weightedNumNegatives - count.weightedNumNegatives /** number of positives */ - override def numPositives: Long = totalCount.numPositives + override def weightedPositives: Double = totalCount.weightedNumPositives /** number of negatives */ - override def numNegatives: Long = totalCount.numNegatives + override def weightedNegatives: Double = totalCount.weightedNumNegatives } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryLabelCounter.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryLabelCounter.scala index 1e610c20092a7..1ad91966b2141 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryLabelCounter.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryLabelCounter.scala @@ -20,31 +20,39 @@ package org.apache.spark.mllib.evaluation.binary /** * A counter for positives and negatives. * - * @param numPositives number of positive labels - * @param numNegatives number of negative labels + * @param weightedNumPositives weighted number of positive labels + * @param weightedNumNegatives weighted number of negative labels */ private[evaluation] class BinaryLabelCounter( - var numPositives: Long = 0L, - var numNegatives: Long = 0L) extends Serializable { + var weightedNumPositives: Double = 0.0, + var weightedNumNegatives: Double = 0.0) extends Serializable { /** Processes a label. */ def +=(label: Double): BinaryLabelCounter = { // Though we assume 1.0 for positive and 0.0 for negative, the following check will handle // -1.0 for negative as well. - if (label > 0.5) numPositives += 1L else numNegatives += 1L + if (label > 0.5) weightedNumPositives += 1.0 else weightedNumNegatives += 1.0 + this + } + + /** Processes a label with a weight. */ + def +=(label: Double, weight: Double): BinaryLabelCounter = { + // Though we assume 1.0 for positive and 0.0 for negative, the following check will handle + // -1.0 for negative as well. + if (label > 0.5) weightedNumPositives += weight else weightedNumNegatives += weight this } /** Merges another counter. */ def +=(other: BinaryLabelCounter): BinaryLabelCounter = { - numPositives += other.numPositives - numNegatives += other.numNegatives + weightedNumPositives += other.weightedNumPositives + weightedNumNegatives += other.weightedNumNegatives this } override def clone: BinaryLabelCounter = { - new BinaryLabelCounter(numPositives, numNegatives) + new BinaryLabelCounter(weightedNumPositives, weightedNumNegatives) } - override def toString: String = s"{numPos: $numPositives, numNeg: $numNegatives}" + override def toString: String = s"{numPos: $weightedNumPositives, numNeg: $weightedNumNegatives}" } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index 2b0909acf69c3..83b213ab51d43 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -45,23 +45,23 @@ class BinaryClassificationEvaluatorSuite .setMetricName("areaUnderPR") val vectorDF = Seq( - (0d, Vectors.dense(12, 2.5)), - (1d, Vectors.dense(1, 3)), - (0d, Vectors.dense(10, 2)) + (0.0, Vectors.dense(12, 2.5)), + (1.0, Vectors.dense(1, 3)), + (0.0, Vectors.dense(10, 2)) ).toDF("label", "rawPrediction") assert(evaluator.evaluate(vectorDF) === 1.0) val doubleDF = Seq( - (0d, 0d), - (1d, 1d), - (0d, 0d) + (0.0, 0.0), + (1.0, 1.0), + (0.0, 0.0) ).toDF("label", "rawPrediction") assert(evaluator.evaluate(doubleDF) === 1.0) val stringDF = Seq( - (0d, "0d"), - (1d, "1d"), - (0d, "0d") + (0.0, "0.0"), + (1.0, "1.0"), + (0.0, "0.0") ).toDF("label", "rawPrediction") val thrown = intercept[IllegalArgumentException] { evaluator.evaluate(stringDF) @@ -71,6 +71,33 @@ class BinaryClassificationEvaluatorSuite assert(thrown.getMessage.replace("\n", "") contains "but was actually of type string.") } + test("should accept weight column") { + val weightCol = "weight" + // get metric with weight column + val evaluator = new BinaryClassificationEvaluator() + .setMetricName("areaUnderROC").setWeightCol(weightCol) + val vectorDF = Seq( + (0.0, Vectors.dense(2.5, 12), 1.0), + (1.0, Vectors.dense(1, 3), 1.0), + (0.0, Vectors.dense(10, 2), 1.0) + ).toDF("label", "rawPrediction", weightCol) + val result = evaluator.evaluate(vectorDF) + // without weight column + val evaluator2 = new BinaryClassificationEvaluator() + .setMetricName("areaUnderROC") + val result2 = evaluator2.evaluate(vectorDF) + assert(result === result2) + // use different weights, validate metrics change + val vectorDF2 = Seq( + (0.0, Vectors.dense(2.5, 12), 2.5), + (1.0, Vectors.dense(1, 3), 0.1), + (0.0, Vectors.dense(10, 2), 2.0) + ).toDF("label", "rawPrediction", weightCol) + val result3 = evaluator.evaluate(vectorDF2) + // Since wrong result weighted more heavily, expect the score to be lower + assert(result3 < result) + } + test("should support all NumericType labels and not support other types") { val evaluator = new BinaryClassificationEvaluator().setRawPredictionCol("prediction") MLTestingUtils.checkNumericTypes(evaluator, spark) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index a08917ac1ebed..06a522f43482c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -82,6 +82,34 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls) } + test("binary evaluation metrics with weights") { + val w1 = 1.5 + val w2 = 0.7 + val w3 = 0.4 + val scoreAndLabelsWithWeights = sc.parallelize( + Seq((0.1, 0.0, w1), (0.1, 1.0, w2), (0.4, 0.0, w1), (0.6, 0.0, w3), + (0.6, 1.0, w2), (0.6, 1.0, w2), (0.8, 1.0, w1)), 2) + val metrics = new BinaryClassificationMetrics(scoreAndLabelsWithWeights, 0) + val thresholds = Seq(0.8, 0.6, 0.4, 0.1) + val numTruePositives = + Seq(1 * w1, 1 * w1 + 2 * w2, 1 * w1 + 2 * w2, 3 * w2 + 1 * w1) + val numFalsePositives = Seq(0.0, 1.0 * w3, 1.0 * w1 + 1.0 * w3, 1.0 * w3 + 2.0 * w1) + val numPositives = 3 * w2 + 1 * w1 + val numNegatives = 2 * w1 + w3 + val precisions = numTruePositives.zip(numFalsePositives).map { case (t, f) => + t.toDouble / (t + f) + } + val recalls = numTruePositives.map(_ / numPositives) + val fpr = numFalsePositives.map(_ / numNegatives) + val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0)) + val pr = recalls.zip(precisions) + val prCurve = Seq((0.0, 1.0)) ++ pr + val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)} + val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)} + + validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls) + } + test("binary evaluation metrics for RDD where all examples have positive label") { val scoreAndLabels = sc.parallelize(Seq((0.5, 1.0), (0.5, 1.0)), 2) val metrics = new BinaryClassificationMetrics(scoreAndLabels) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index f563a2d4d283f..0f70860ceaf0f 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -106,7 +106,7 @@ def isLargerBetter(self): @inherit_doc -class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol, +class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol, HasWeightCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -130,6 +130,16 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction >>> evaluator2 = BinaryClassificationEvaluator.load(bce_path) >>> str(evaluator2.getRawPredictionCol()) 'raw' + >>> scoreAndLabelsAndWeight = map(lambda x: (Vectors.dense([1.0 - x[0], x[0]]), x[1], x[2]), + ... [(0.1, 0.0, 1.0), (0.1, 1.0, 0.9), (0.4, 0.0, 0.7), (0.6, 0.0, 0.9), + ... (0.6, 1.0, 1.0), (0.6, 1.0, 0.3), (0.8, 1.0, 1.0)]) + >>> dataset = spark.createDataFrame(scoreAndLabelsAndWeight, ["raw", "label", "weight"]) + ... + >>> evaluator = BinaryClassificationEvaluator(rawPredictionCol="raw", weightCol="weight") + >>> evaluator.evaluate(dataset) + 0.70... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"}) + 0.82... .. versionadded:: 1.4.0 """ @@ -140,10 +150,10 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction @keyword_only def __init__(self, rawPredictionCol="rawPrediction", labelCol="label", - metricName="areaUnderROC"): + metricName="areaUnderROC", weightCol=None): """ __init__(self, rawPredictionCol="rawPrediction", labelCol="label", \ - metricName="areaUnderROC") + metricName="areaUnderROC", weightCol=None) """ super(BinaryClassificationEvaluator, self).__init__() self._java_obj = self._new_java_obj( @@ -169,10 +179,10 @@ def getMetricName(self): @keyword_only @since("1.4.0") def setParams(self, rawPredictionCol="rawPrediction", labelCol="label", - metricName="areaUnderROC"): + metricName="areaUnderROC", weightCol=None): """ setParams(self, rawPredictionCol="rawPrediction", labelCol="label", \ - metricName="areaUnderROC") + metricName="areaUnderROC", weightCol=None) Sets params for binary classification evaluator. """ kwargs = self._input_kwargs diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index b0283941171a7..5d8d20dcfcfcf 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -30,7 +30,7 @@ class BinaryClassificationMetrics(JavaModelWrapper): """ Evaluator for binary classification. - :param scoreAndLabels: an RDD of (score, label) pairs + :param scoreAndLabels: an RDD of score, label and optional weight. >>> scoreAndLabels = sc.parallelize([ ... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2) @@ -40,6 +40,14 @@ class BinaryClassificationMetrics(JavaModelWrapper): >>> metrics.areaUnderPR 0.83... >>> metrics.unpersist() + >>> scoreAndLabelsWithOptWeight = sc.parallelize([ + ... (0.1, 0.0, 1.0), (0.1, 1.0, 0.4), (0.4, 0.0, 0.2), (0.6, 0.0, 0.6), (0.6, 1.0, 0.9), + ... (0.6, 1.0, 0.5), (0.8, 1.0, 0.7)], 2) + >>> metrics = BinaryClassificationMetrics(scoreAndLabelsWithOptWeight) + >>> metrics.areaUnderROC + 0.79... + >>> metrics.areaUnderPR + 0.88... .. versionadded:: 1.4.0 """ @@ -47,9 +55,13 @@ class BinaryClassificationMetrics(JavaModelWrapper): def __init__(self, scoreAndLabels): sc = scoreAndLabels.ctx sql_ctx = SQLContext.getOrCreate(sc) - df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([ + numCol = len(scoreAndLabels.first()) + schema = StructType([ StructField("score", DoubleType(), nullable=False), - StructField("label", DoubleType(), nullable=False)])) + StructField("label", DoubleType(), nullable=False)]) + if numCol == 3: + schema.add("weight", DoubleType(), False) + df = sql_ctx.createDataFrame(scoreAndLabels, schema=schema) java_class = sc._jvm.org.apache.spark.mllib.evaluation.BinaryClassificationMetrics java_model = java_class(df._jdf) super(BinaryClassificationMetrics, self).__init__(java_model) @@ -162,7 +174,7 @@ class MulticlassMetrics(JavaModelWrapper): """ Evaluator for multiclass classification. - :param predAndLabelsWithOptWeight: an RDD of prediction, label and optional weight. + :param predictionAndLabels: an RDD of prediction, label and optional weight. >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]) @@ -223,16 +235,16 @@ class MulticlassMetrics(JavaModelWrapper): .. versionadded:: 1.4.0 """ - def __init__(self, predAndLabelsWithOptWeight): - sc = predAndLabelsWithOptWeight.ctx + def __init__(self, predictionAndLabels): + sc = predictionAndLabels.ctx sql_ctx = SQLContext.getOrCreate(sc) - numCol = len(predAndLabelsWithOptWeight.first()) + numCol = len(predictionAndLabels.first()) schema = StructType([ StructField("prediction", DoubleType(), nullable=False), StructField("label", DoubleType(), nullable=False)]) - if (numCol == 3): + if numCol == 3: schema.add("weight", DoubleType(), False) - df = sql_ctx.createDataFrame(predAndLabelsWithOptWeight, schema) + df = sql_ctx.createDataFrame(predictionAndLabels, schema) java_class = sc._jvm.org.apache.spark.mllib.evaluation.MulticlassMetrics java_model = java_class(df._jdf) super(MulticlassMetrics, self).__init__(java_model)