-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-6332] [MLlib] compute calibration curve for binary classifier #5025
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
23d12d1
1df8619
0769ee6
bf682c0
967e961
4281f55
6cf1e2c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,14 +29,14 @@ import org.apache.spark.sql.DataFrame | |
| * Evaluator for binary classification. | ||
| * | ||
| * @param scoreAndLabels an RDD of (score, label) pairs. | ||
| * @param numBins if greater than 0, then the curves (ROC curve, PR curve) computed internally | ||
| * will be down-sampled to this many "bins". If 0, no down-sampling will occur. | ||
| * This is useful because the curve contains a point for each distinct score | ||
| * in the input, and this could be as large as the input itself -- millions of | ||
| * points or more, when thousands may be entirely sufficient to summarize | ||
| * the curve. After down-sampling, the curves will instead be made of approximately | ||
| * `numBins` points instead. Points are made from bins of equal numbers of | ||
| * consecutive points. The size of each bin is | ||
| * @param numBins if greater than 0, then the curves (ROC curve, PR curve, calibration curve) | ||
| * computed internally will be down-sampled to this many "bins". If 0, no | ||
| * down-sampling will occur. This is useful because the curve contains a point for | ||
| * each distinct score in the input, and this could be as large as the input itself | ||
| * -- millions of points or more, when thousands may be entirely sufficient to | ||
| * summarize the curve. After down-sampling, the curves will instead be made of | ||
| * approximately `numBins` points instead. Points are made from bins of equal | ||
| * numbers of consecutive points. The size of each bin is | ||
| * `floor(scoreAndLabels.count() / numBins)`, which means the resulting number | ||
| * of bins may not exactly equal numBins. The last bin in each partition may | ||
| * be smaller as a result, meaning there may be an extra sample at | ||
|
|
@@ -226,4 +226,76 @@ class BinaryClassificationMetrics @Since("1.3.0") ( | |
| (x(c), y(c)) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Returns the calibration or reliability curve, | ||
| * which is an RDD of (average score in bin, fraction of positive examples in bin). | ||
| * @see http://en.wikipedia.org/wiki/Calibration_%28statistics%29#In_classification | ||
| * | ||
| * References: | ||
| * | ||
| * Mahdi Pakdaman Naeini, Gregory F. Cooper, Milos Hauskrecht. | ||
| * Binary Classifier Calibration: Non-parametric approach. | ||
| * http://arxiv.org/abs/1401.3390 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you link the doi if it's available? |
||
| * | ||
| * Alexandru Niculescu-Mizil, Rich Caruana. | ||
| * Predicting Good Probabilities With Supervised Learning. | ||
| * Appearing in Proceedings of the 22nd International Conference on Machine Learning, | ||
| * Bonn, Germany, 2005. | ||
| * http://www.cs.cornell.edu/~alexn/papers/calibration.icml05.crc.rev3.pdf | ||
| * | ||
| * Properties and benefits of calibrated classifiers. | ||
| * Ira Cohen, Moises Goldszmidt. | ||
| * http://www.hpl.hp.com/techreports/2004/HPL-2004-22R1.pdf | ||
| */ | ||
| def calibration(): RDD[((Double, Double), (Double, Long))] = { | ||
| assessedCalibration | ||
| } | ||
|
|
||
| private lazy val assessedCalibration: RDD[((Double, Double), (Double, Long))] = { | ||
| val distinctScoresAndLabelCounts = scoreAndLabels.combineByKey( | ||
| createCombiner = (label: Double) => new BinaryLabelCounter(0L, 0L) += label, | ||
| mergeValue = (c: BinaryLabelCounter, label: Double) => c += label, | ||
| mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2 | ||
| ).sortByKey(ascending = true) | ||
|
|
||
| val binnedDistinctScoresAndLabelCounts = | ||
| if (numBins == 0) { | ||
| distinctScoresAndLabelCounts.map { pair => ((pair._1, pair._1), pair._2) } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: It's a bit hard to keep track what each of the coordinates of the |
||
| } else { | ||
| val distinctScoresCount = distinctScoresAndLabelCounts.count() | ||
|
|
||
| var groupCount = | ||
| if (distinctScoresCount % numBins == 0) { | ||
| distinctScoresCount / numBins | ||
| } else { | ||
| // prevent the last bin from being very small compared to the others | ||
| distinctScoresCount / numBins + 1 | ||
| } | ||
|
|
||
| if (groupCount < 2) { | ||
| logInfo(s"Too few distinct scores ($distinctScoresCount) for $numBins bins to be useful") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Log that you are proceeding with |
||
| distinctScoresAndLabelCounts.map { pair => ((pair._1, pair._1), pair._2) } | ||
| } else { | ||
| if (groupCount >= Int.MaxValue) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The overflow will wrap around to be negative so this check needs to be changed (the check on L254 should probably check
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unrelated to this PR but this should probably also be fixed on L147 |
||
| val n = distinctScoresCount | ||
| logWarning( | ||
| s"Too many distinct scores ($n) for $numBins bins; capping at ${Int.MaxValue}") | ||
| groupCount = Int.MaxValue | ||
| } | ||
| distinctScoresAndLabelCounts.mapPartitions(_.grouped(groupCount.toInt).map { pairs => | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto about pattern matching |
||
| val firstScore = pairs.head._1 | ||
| val lastScore = pairs.last._1 | ||
| val agg = new BinaryLabelCounter() | ||
| pairs.foreach(pair => agg += pair._2) | ||
| ((firstScore, lastScore), agg) | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| binnedDistinctScoresAndLabelCounts.map { pair => | ||
| val n = pair._2.numPositives + pair._2.numNegatives | ||
| (pair._1, (pair._2.numPositives / n.toDouble, n)) | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This description of the RDD contents doesn't match with the
((Double, Double), (Double, Long))type signature