diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index fdd1851ae5508..071d05e1b1de1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -177,6 +177,8 @@ abstract class ProbabilisticClassificationModel[ * Predict the probability of each class given the features. * These predictions are also called class conditional probabilities. * + * See BinaryClassificationMetrics.calibration to assess calibration. + * * This internal method is used to implement [[transform()]] and output [[probabilityCol]]. * * @return Estimated class conditional probabilities diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index 508fe532b1306..604c6c6258c28 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -29,14 +29,14 @@ import org.apache.spark.sql.DataFrame * Evaluator for binary classification. * * @param scoreAndLabels an RDD of (score, label) pairs. - * @param numBins if greater than 0, then the curves (ROC curve, PR curve) computed internally - * will be down-sampled to this many "bins". If 0, no down-sampling will occur. - * This is useful because the curve contains a point for each distinct score - * in the input, and this could be as large as the input itself -- millions of - * points or more, when thousands may be entirely sufficient to summarize - * the curve. After down-sampling, the curves will instead be made of approximately - * `numBins` points instead. Points are made from bins of equal numbers of - * consecutive points. The size of each bin is + * @param numBins if greater than 0, then the curves (ROC curve, PR curve, calibration curve) + * computed internally will be down-sampled to this many "bins". If 0, no + * down-sampling will occur. This is useful because the curve contains a point for + * each distinct score in the input, and this could be as large as the input itself + * -- millions of points or more, when thousands may be entirely sufficient to + * summarize the curve. After down-sampling, the curves will instead be made of + * approximately `numBins` points instead. Points are made from bins of equal + * numbers of consecutive points. The size of each bin is * `floor(scoreAndLabels.count() / numBins)`, which means the resulting number * of bins may not exactly equal numBins. The last bin in each partition may * be smaller as a result, meaning there may be an extra sample at @@ -226,4 +226,76 @@ class BinaryClassificationMetrics @Since("1.3.0") ( (x(c), y(c)) } } + + /** + * Returns the calibration or reliability curve, + * which is an RDD of (average score in bin, fraction of positive examples in bin). + * @see http://en.wikipedia.org/wiki/Calibration_%28statistics%29#In_classification + * + * References: + * + * Mahdi Pakdaman Naeini, Gregory F. Cooper, Milos Hauskrecht. + * Binary Classifier Calibration: Non-parametric approach. + * http://arxiv.org/abs/1401.3390 + * + * Alexandru Niculescu-Mizil, Rich Caruana. + * Predicting Good Probabilities With Supervised Learning. + * Appearing in Proceedings of the 22nd International Conference on Machine Learning, + * Bonn, Germany, 2005. + * http://www.cs.cornell.edu/~alexn/papers/calibration.icml05.crc.rev3.pdf + * + * Properties and benefits of calibrated classifiers. + * Ira Cohen, Moises Goldszmidt. + * http://www.hpl.hp.com/techreports/2004/HPL-2004-22R1.pdf + */ + def calibration(): RDD[((Double, Double), (Double, Long))] = { + assessedCalibration + } + + private lazy val assessedCalibration: RDD[((Double, Double), (Double, Long))] = { + val distinctScoresAndLabelCounts = scoreAndLabels.combineByKey( + createCombiner = (label: Double) => new BinaryLabelCounter(0L, 0L) += label, + mergeValue = (c: BinaryLabelCounter, label: Double) => c += label, + mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2 + ).sortByKey(ascending = true) + + val binnedDistinctScoresAndLabelCounts = + if (numBins == 0) { + distinctScoresAndLabelCounts.map { pair => ((pair._1, pair._1), pair._2) } + } else { + val distinctScoresCount = distinctScoresAndLabelCounts.count() + + var groupCount = + if (distinctScoresCount % numBins == 0) { + distinctScoresCount / numBins + } else { + // prevent the last bin from being very small compared to the others + distinctScoresCount / numBins + 1 + } + + if (groupCount < 2) { + logInfo(s"Too few distinct scores ($distinctScoresCount) for $numBins bins to be useful") + distinctScoresAndLabelCounts.map { pair => ((pair._1, pair._1), pair._2) } + } else { + if (groupCount >= Int.MaxValue) { + val n = distinctScoresCount + logWarning( + s"Too many distinct scores ($n) for $numBins bins; capping at ${Int.MaxValue}") + groupCount = Int.MaxValue + } + distinctScoresAndLabelCounts.mapPartitions(_.grouped(groupCount.toInt).map { pairs => + val firstScore = pairs.head._1 + val lastScore = pairs.last._1 + val agg = new BinaryLabelCounter() + pairs.foreach(pair => agg += pair._2) + ((firstScore, lastScore), agg) + }) + } + } + + binnedDistinctScoresAndLabelCounts.map { pair => + val n = pair._2.numPositives + pair._2.numNegatives + (pair._1, (pair._2.numPositives / n.toDouble, n)) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index 99d52fabc5309..f4a3c1283a27c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -28,6 +28,10 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark private def pairsWithinEpsilon(x: ((Double, Double), (Double, Double))): Boolean = (x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5) + private def pairPairsWithinEpsilon(x: (((Double, Double), (Double, Long)), ((Double, Double), (Double, Long)))): Boolean = + (x._1._1._1 ~= x._2._1._1 absTol 1E-5) && (x._1._1._2 ~= x._2._1._2 absTol 1E-5) && + (x._1._2._1 ~= x._2._2._1 absTol 1E-5) && x._1._2._2 == x._2._2._2 + private def assertSequencesMatch(left: Seq[Double], right: Seq[Double]): Unit = { assert(left.zip(right).forall(areWithinEpsilon)) } @@ -37,6 +41,11 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark assert(left.zip(right).forall(pairsWithinEpsilon)) } + private def assertTupleTupleSequencesMatch(left: Seq[((Double, Double), (Double, Long))], + right: Seq[((Double, Double), (Double, Long))]): Unit = { + assert(left.zip(right).forall(pairPairsWithinEpsilon)) + } + private def validateMetrics(metrics: BinaryClassificationMetrics, expectedThresholds: Seq[Double], expectedROCCurve: Seq[(Double, Double)], @@ -44,7 +53,8 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark expectedFMeasures1: Seq[Double], expectedFmeasures2: Seq[Double], expectedPrecisions: Seq[Double], - expectedRecalls: Seq[Double]) = { + expectedRecalls: Seq[Double], + expectedCalibration: Seq[((Double, Double), (Double, Long))]) = { assertSequencesMatch(metrics.thresholds().collect(), expectedThresholds) assertTupleSequencesMatch(metrics.roc().collect(), expectedROCCurve) @@ -59,6 +69,7 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark expectedThresholds.zip(expectedPrecisions)) assertTupleSequencesMatch(metrics.recallByThreshold().collect(), expectedThresholds.zip(expectedRecalls)) + assertTupleTupleSequencesMatch(metrics.calibration().collect(), expectedCalibration) } test("binary evaluation metrics") { @@ -80,8 +91,11 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark val prCurve = Seq((0.0, 1.0)) ++ pr val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)} val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)} + val calibration = Seq(((0.1, 0.1), (0.5, 2L)), ((0.4, 0.4), (0.0, 1L)), ((0.6, 0.6), (2/3.0, 3L)), + ((0.8, 0.8), (1.0, 1L))) - validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls) + validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls, + calibration) } test("binary evaluation metrics for RDD where all examples have positive label") { @@ -97,8 +111,10 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark val prCurve = Seq((0.0, 1.0)) ++ pr val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)} val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)} + val calibration = Seq(((0.5, 0.5), (1.0, 2L))) - validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls) + validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls, + calibration) } test("binary evaluation metrics for RDD where all examples have negative label") { @@ -121,7 +137,10 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark case (r, p) => 5.0 * (p * r) / (4.0 * p + r) } - validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls) + val calibration = Seq(((0.5, 0.5), (0.0, 2L))) + + validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls, + calibration) } test("binary evaluation metrics with downsampling") { @@ -157,6 +176,9 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark (1.0, 1.0), (1.0, 1.0) ) == downsampledROC) + + val calibration = Array(((0.1, 0.3), (1/3.0, 3L)), ((0.4, 0.6), (1/3.0, 3L)), ((0.7, 0.9), (2/3.0, 3L))) + assertTupleTupleSequencesMatch(calibration, downsampled.calibration().collect()) } }