Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ abstract class ProbabilisticClassificationModel[
* Predict the probability of each class given the features.
* These predictions are also called class conditional probabilities.
*
* See BinaryClassificationMetrics.calibration to assess calibration.
*
* This internal method is used to implement [[transform()]] and output [[probabilityCol]].
*
* @return Estimated class conditional probabilities
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ import org.apache.spark.sql.DataFrame
* Evaluator for binary classification.
*
* @param scoreAndLabels an RDD of (score, label) pairs.
* @param numBins if greater than 0, then the curves (ROC curve, PR curve) computed internally
* will be down-sampled to this many "bins". If 0, no down-sampling will occur.
* This is useful because the curve contains a point for each distinct score
* in the input, and this could be as large as the input itself -- millions of
* points or more, when thousands may be entirely sufficient to summarize
* the curve. After down-sampling, the curves will instead be made of approximately
* `numBins` points instead. Points are made from bins of equal numbers of
* consecutive points. The size of each bin is
* @param numBins if greater than 0, then the curves (ROC curve, PR curve, calibration curve)
* computed internally will be down-sampled to this many "bins". If 0, no
* down-sampling will occur. This is useful because the curve contains a point for
* each distinct score in the input, and this could be as large as the input itself
* -- millions of points or more, when thousands may be entirely sufficient to
* summarize the curve. After down-sampling, the curves will instead be made of
* approximately `numBins` points instead. Points are made from bins of equal
* numbers of consecutive points. The size of each bin is
* `floor(scoreAndLabels.count() / numBins)`, which means the resulting number
* of bins may not exactly equal numBins. The last bin in each partition may
* be smaller as a result, meaning there may be an extra sample at
Expand Down Expand Up @@ -226,4 +226,76 @@ class BinaryClassificationMetrics @Since("1.3.0") (
(x(c), y(c))
}
}

/**
* Returns the calibration or reliability curve,
* which is an RDD of (average score in bin, fraction of positive examples in bin).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This description of the RDD contents doesn't match with the ((Double, Double), (Double, Long)) type signature

* @see http://en.wikipedia.org/wiki/Calibration_%28statistics%29#In_classification
*
* References:
*
* Mahdi Pakdaman Naeini, Gregory F. Cooper, Milos Hauskrecht.
* Binary Classifier Calibration: Non-parametric approach.
* http://arxiv.org/abs/1401.3390
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you link the doi if it's available?

*
* Alexandru Niculescu-Mizil, Rich Caruana.
* Predicting Good Probabilities With Supervised Learning.
* Appearing in Proceedings of the 22nd International Conference on Machine Learning,
* Bonn, Germany, 2005.
* http://www.cs.cornell.edu/~alexn/papers/calibration.icml05.crc.rev3.pdf
*
* Properties and benefits of calibrated classifiers.
* Ira Cohen, Moises Goldszmidt.
* http://www.hpl.hp.com/techreports/2004/HPL-2004-22R1.pdf
*/
def calibration(): RDD[((Double, Double), (Double, Long))] = {
assessedCalibration
}

private lazy val assessedCalibration: RDD[((Double, Double), (Double, Long))] = {
val distinctScoresAndLabelCounts = scoreAndLabels.combineByKey(
createCombiner = (label: Double) => new BinaryLabelCounter(0L, 0L) += label,
mergeValue = (c: BinaryLabelCounter, label: Double) => c += label,
mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2
).sortByKey(ascending = true)

val binnedDistinctScoresAndLabelCounts =
if (numBins == 0) {
distinctScoresAndLabelCounts.map { pair => ((pair._1, pair._1), pair._2) }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It's a bit hard to keep track what each of the coordinates of the pair tuple means throughout your code, can you do a pattern match i.e. map { case (score, labelCounter) => ...

} else {
val distinctScoresCount = distinctScoresAndLabelCounts.count()

var groupCount =
if (distinctScoresCount % numBins == 0) {
distinctScoresCount / numBins
} else {
// prevent the last bin from being very small compared to the others
distinctScoresCount / numBins + 1
}

if (groupCount < 2) {
logInfo(s"Too few distinct scores ($distinctScoresCount) for $numBins bins to be useful")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Log that you are proceeding with $numBins = $distinctScoresCount and document this behavior (as well as the numBins == 0 and overflow case) in calibration()'s scala doc

distinctScoresAndLabelCounts.map { pair => ((pair._1, pair._1), pair._2) }
} else {
if (groupCount >= Int.MaxValue) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overflow will wrap around to be negative so this check needs to be changed (the check on L254 should probably check 0 < groupCount < 2 to avoid catching integer overflow case)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this PR but this should probably also be fixed on L147

val n = distinctScoresCount
logWarning(
s"Too many distinct scores ($n) for $numBins bins; capping at ${Int.MaxValue}")
groupCount = Int.MaxValue
}
distinctScoresAndLabelCounts.mapPartitions(_.grouped(groupCount.toInt).map { pairs =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto about pattern matching pairs

val firstScore = pairs.head._1
val lastScore = pairs.last._1
val agg = new BinaryLabelCounter()
pairs.foreach(pair => agg += pair._2)
((firstScore, lastScore), agg)
})
}
}

binnedDistinctScoresAndLabelCounts.map { pair =>
val n = pair._2.numPositives + pair._2.numNegatives
(pair._1, (pair._2.numPositives / n.toDouble, n))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark
private def pairsWithinEpsilon(x: ((Double, Double), (Double, Double))): Boolean =
(x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5)

private def pairPairsWithinEpsilon(x: (((Double, Double), (Double, Long)), ((Double, Double), (Double, Long)))): Boolean =
(x._1._1._1 ~= x._2._1._1 absTol 1E-5) && (x._1._1._2 ~= x._2._1._2 absTol 1E-5) &&
(x._1._2._1 ~= x._2._2._1 absTol 1E-5) && x._1._2._2 == x._2._2._2

private def assertSequencesMatch(left: Seq[Double], right: Seq[Double]): Unit = {
assert(left.zip(right).forall(areWithinEpsilon))
}
Expand All @@ -37,14 +41,20 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark
assert(left.zip(right).forall(pairsWithinEpsilon))
}

private def assertTupleTupleSequencesMatch(left: Seq[((Double, Double), (Double, Long))],
right: Seq[((Double, Double), (Double, Long))]): Unit = {
assert(left.zip(right).forall(pairPairsWithinEpsilon))
}

private def validateMetrics(metrics: BinaryClassificationMetrics,
expectedThresholds: Seq[Double],
expectedROCCurve: Seq[(Double, Double)],
expectedPRCurve: Seq[(Double, Double)],
expectedFMeasures1: Seq[Double],
expectedFmeasures2: Seq[Double],
expectedPrecisions: Seq[Double],
expectedRecalls: Seq[Double]) = {
expectedRecalls: Seq[Double],
expectedCalibration: Seq[((Double, Double), (Double, Long))]) = {

assertSequencesMatch(metrics.thresholds().collect(), expectedThresholds)
assertTupleSequencesMatch(metrics.roc().collect(), expectedROCCurve)
Expand All @@ -59,6 +69,7 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark
expectedThresholds.zip(expectedPrecisions))
assertTupleSequencesMatch(metrics.recallByThreshold().collect(),
expectedThresholds.zip(expectedRecalls))
assertTupleTupleSequencesMatch(metrics.calibration().collect(), expectedCalibration)
}

test("binary evaluation metrics") {
Expand All @@ -80,8 +91,11 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark
val prCurve = Seq((0.0, 1.0)) ++ pr
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
val calibration = Seq(((0.1, 0.1), (0.5, 2L)), ((0.4, 0.4), (0.0, 1L)), ((0.6, 0.6), (2/3.0, 3L)),
((0.8, 0.8), (1.0, 1L)))

validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls,
calibration)
}

test("binary evaluation metrics for RDD where all examples have positive label") {
Expand All @@ -97,8 +111,10 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark
val prCurve = Seq((0.0, 1.0)) ++ pr
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
val calibration = Seq(((0.5, 0.5), (1.0, 2L)))

validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls,
calibration)
}

test("binary evaluation metrics for RDD where all examples have negative label") {
Expand All @@ -121,7 +137,10 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark
case (r, p) => 5.0 * (p * r) / (4.0 * p + r)
}

validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls)
val calibration = Seq(((0.5, 0.5), (0.0, 2L)))

validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls,
calibration)
}

test("binary evaluation metrics with downsampling") {
Expand Down Expand Up @@ -157,6 +176,9 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark
(1.0, 1.0), (1.0, 1.0)
) ==
downsampledROC)

val calibration = Array(((0.1, 0.3), (1/3.0, 3L)), ((0.4, 0.6), (1/3.0, 3L)), ((0.7, 0.9), (2/3.0, 3L)))
assertTupleTupleSequencesMatch(calibration, downsampled.calibration().collect())
}

}