Skip to content

Commit 122c319

Browse files
authored
Add API to get Precision-Recall Curve data (#3039)
* Add GetPrecisionRecallCurve API * fix typo * Change to EvaluateWithPRCurve(...) * Added to mlContext * fix comments * better name for PrecisionRecallDataPointGetters(..) * plus test * Add TPR property * fix order of using
1 parent 6da1493 commit 122c319

File tree

6 files changed

+336
-2
lines changed

6 files changed

+336
-2
lines changed

docs/samples/Microsoft.ML.Samples/Static/SDCABinaryClassification.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
using System;
2+
using System.Collections.Generic;
3+
using Microsoft.ML.Data;
24
using Microsoft.ML.StaticPipe;
35

46
namespace Microsoft.ML.Samples.Static
@@ -89,7 +91,7 @@ public static void SdcaBinaryClassification()
8991
// Evaluate how the model is doing on the test data
9092
var dataWithPredictions = model.Transform(testData);
9193

92-
var metrics = mlContext.BinaryClassification.Evaluate(dataWithPredictions, row => row.Label, row => row.Score);
94+
var metrics = mlContext.BinaryClassification.EvaluateWithPRCurve(dataWithPredictions, row => row.Label, row => row.Score, out List<BinaryPrecisionRecallDataPoint> prCurve);
9395

9496
Console.WriteLine($"Accuracy: {metrics.Accuracy}"); // 0.83
9597
Console.WriteLine($"AUC: {metrics.AreaUnderRocCurve}"); // 0.88
@@ -98,7 +100,15 @@ public static void SdcaBinaryClassification()
98100
Console.WriteLine($"Negative Precision: {metrics.NegativePrecision}"); // 0.87
99101
Console.WriteLine($"Negative Recall: {metrics.NegativeRecall}"); // 0.91
100102
Console.WriteLine($"Positive Precision: {metrics.PositivePrecision}"); // 0.65
101-
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall}"); // 0.55
103+
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall}"); // 0.55
104+
105+
foreach(var prData in prCurve)
106+
{
107+
Console.Write($"Threshold: {prData.Threshold} ");
108+
Console.Write($"Precision: {prData.Precision} ");
109+
Console.Write($"Recall: {prData.Recall} ");
110+
Console.WriteLine($"FPR: {prData.FalsePositiveRate}");
111+
}
102112
}
103113
}
104114
}

src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,93 @@ public CalibratedBinaryClassificationMetrics Evaluate(IDataView data, string lab
828828
return result;
829829
}
830830

831+
/// <summary>
832+
/// Evaluates scored binary classification data and generates precision recall curve data.
833+
/// </summary>
834+
/// <param name="data">The scored data.</param>
835+
/// <param name="label">The name of the label column in <paramref name="data"/>.</param>
836+
/// <param name="score">The name of the score column in <paramref name="data"/>.</param>
837+
/// <param name="probability">The name of the probability column in <paramref name="data"/>, the calibrated version of <paramref name="score"/>.</param>
838+
/// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param>
839+
/// <param name="prCurve">The generated precision recall curve data. Up to 100000 of samples are used for p/r curve generation.</param>
840+
/// <returns>The evaluation results for these calibrated outputs.</returns>
841+
public CalibratedBinaryClassificationMetrics EvaluateWithPRCurve(
842+
IDataView data,
843+
string label,
844+
string score,
845+
string probability,
846+
string predictedLabel,
847+
out List<BinaryPrecisionRecallDataPoint> prCurve)
848+
{
849+
Host.CheckValue(data, nameof(data));
850+
Host.CheckNonEmpty(label, nameof(label));
851+
Host.CheckNonEmpty(score, nameof(score));
852+
Host.CheckNonEmpty(probability, nameof(probability));
853+
Host.CheckNonEmpty(predictedLabel, nameof(predictedLabel));
854+
855+
var roles = new RoleMappedData(data, opt: false,
856+
RoleMappedSchema.ColumnRole.Label.Bind(label),
857+
RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.Score, score),
858+
RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.Probability, probability),
859+
RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.PredictedLabel, predictedLabel));
860+
861+
var resultDict = ((IEvaluator)this).Evaluate(roles);
862+
Host.Assert(resultDict.ContainsKey(MetricKinds.PrCurve));
863+
var prCurveView = resultDict[MetricKinds.PrCurve];
864+
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
865+
var overall = resultDict[MetricKinds.OverallMetrics];
866+
867+
var prCurveResult = new List<BinaryPrecisionRecallDataPoint>();
868+
using (var cursor = prCurveView.GetRowCursorForAllColumns())
869+
{
870+
GetPrecisionRecallDataPointGetters(prCurveView, cursor,
871+
out ValueGetter<float> thresholdGetter,
872+
out ValueGetter<double> precisionGetter,
873+
out ValueGetter<double> recallGetter,
874+
out ValueGetter<double> fprGetter);
875+
876+
while (cursor.MoveNext())
877+
{
878+
prCurveResult.Add(new BinaryPrecisionRecallDataPoint(thresholdGetter, precisionGetter, recallGetter, fprGetter));
879+
}
880+
}
881+
prCurve = prCurveResult;
882+
883+
CalibratedBinaryClassificationMetrics result;
884+
using (var cursor = overall.GetRowCursorForAllColumns())
885+
{
886+
var moved = cursor.MoveNext();
887+
Host.Assert(moved);
888+
result = new CalibratedBinaryClassificationMetrics(Host, cursor);
889+
moved = cursor.MoveNext();
890+
Host.Assert(!moved);
891+
}
892+
893+
return result;
894+
}
895+
896+
private void GetPrecisionRecallDataPointGetters(IDataView prCurveView,
897+
DataViewRowCursor cursor,
898+
out ValueGetter<float> thresholdGetter,
899+
out ValueGetter<double> precisionGetter,
900+
out ValueGetter<double> recallGetter,
901+
out ValueGetter<double> fprGetter)
902+
{
903+
var thresholdColumn = prCurveView.Schema.GetColumnOrNull(BinaryClassifierEvaluator.Threshold);
904+
var precisionColumn = prCurveView.Schema.GetColumnOrNull(BinaryClassifierEvaluator.Precision);
905+
var recallColumn = prCurveView.Schema.GetColumnOrNull(BinaryClassifierEvaluator.Recall);
906+
var fprColumn = prCurveView.Schema.GetColumnOrNull(BinaryClassifierEvaluator.FalsePositiveRate);
907+
Host.Assert(thresholdColumn != null);
908+
Host.Assert(precisionColumn != null);
909+
Host.Assert(recallColumn != null);
910+
Host.Assert(fprColumn != null);
911+
912+
thresholdGetter = cursor.GetGetter<float>((DataViewSchema.Column)thresholdColumn);
913+
precisionGetter = cursor.GetGetter<double>((DataViewSchema.Column)precisionColumn);
914+
recallGetter = cursor.GetGetter<double>((DataViewSchema.Column)recallColumn);
915+
fprGetter = cursor.GetGetter<double>((DataViewSchema.Column)fprColumn);
916+
}
917+
831918
/// <summary>
832919
/// Evaluates scored binary classification data, without probability-based metrics.
833920
/// </summary>
@@ -864,6 +951,69 @@ public BinaryClassificationMetrics Evaluate(IDataView data, string label, string
864951
}
865952
return result;
866953
}
954+
955+
/// <summary>
956+
/// Evaluates scored binary classification data, without probability-based metrics
957+
/// and generates precision recall curve data.
958+
/// </summary>
959+
/// <param name="data">The scored data.</param>
960+
/// <param name="label">The name of the label column in <paramref name="data"/>.</param>
961+
/// <param name="score">The name of the score column in <paramref name="data"/>.</param>
962+
/// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param>
963+
/// <param name="prCurve">The generated precision recall curve data. Up to 100000 of samples are used for p/r curve generation.</param>
964+
/// <returns>The evaluation results for these uncalibrated outputs.</returns>
965+
/// <seealso cref="Evaluate(IDataView, string, string, string)"/>
966+
public BinaryClassificationMetrics EvaluateWithPRCurve(
967+
IDataView data,
968+
string label,
969+
string score,
970+
string predictedLabel,
971+
out List<BinaryPrecisionRecallDataPoint> prCurve)
972+
{
973+
Host.CheckValue(data, nameof(data));
974+
Host.CheckNonEmpty(label, nameof(label));
975+
Host.CheckNonEmpty(score, nameof(score));
976+
Host.CheckNonEmpty(predictedLabel, nameof(predictedLabel));
977+
978+
var roles = new RoleMappedData(data, opt: false,
979+
RoleMappedSchema.ColumnRole.Label.Bind(label),
980+
RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.Score, score),
981+
RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.PredictedLabel, predictedLabel));
982+
983+
var resultDict = ((IEvaluator)this).Evaluate(roles);
984+
Host.Assert(resultDict.ContainsKey(MetricKinds.PrCurve));
985+
var prCurveView = resultDict[MetricKinds.PrCurve];
986+
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
987+
var overall = resultDict[MetricKinds.OverallMetrics];
988+
989+
var prCurveResult = new List<BinaryPrecisionRecallDataPoint>();
990+
using (var cursor = prCurveView.GetRowCursorForAllColumns())
991+
{
992+
GetPrecisionRecallDataPointGetters(prCurveView, cursor,
993+
out ValueGetter<float> thresholdGetter,
994+
out ValueGetter<double> precisionGetter,
995+
out ValueGetter<double> recallGetter,
996+
out ValueGetter<double> fprGetter);
997+
998+
while (cursor.MoveNext())
999+
{
1000+
prCurveResult.Add(new BinaryPrecisionRecallDataPoint(thresholdGetter, precisionGetter, recallGetter, fprGetter));
1001+
}
1002+
}
1003+
prCurve = prCurveResult;
1004+
1005+
BinaryClassificationMetrics result;
1006+
using (var cursor = overall.GetRowCursorForAllColumns())
1007+
{
1008+
var moved = cursor.MoveNext();
1009+
Host.Assert(moved);
1010+
result = new BinaryClassificationMetrics(Host, cursor);
1011+
moved = cursor.MoveNext();
1012+
Host.Assert(!moved);
1013+
}
1014+
1015+
return result;
1016+
}
8671017
}
8681018

8691019
internal sealed class BinaryPerInstanceEvaluator : PerInstanceEvaluatorBase

src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1776,6 +1776,11 @@ internal static class MetricKinds
17761776
/// </summary>
17771777
public const string OverallMetrics = "OverallMetrics";
17781778

1779+
/// <summary>
1780+
/// This is a data view with precision recall data in its columns. It has four columns: Threshold, Precision, Recall and Fpr.
1781+
/// </summary>
1782+
public const string PrCurve = "PrCurve";
1783+
17791784
/// <summary>
17801785
/// This data view contains a single text column, with warnings about bad input values encountered by the evaluator during
17811786
/// the aggregation of metrics. Each warning is in a separate row.
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
namespace Microsoft.ML.Data
6+
{
7+
/// <summary>
8+
/// This class represents one data point on Precision-Recall curve for binary classification.
9+
/// </summary>
10+
public sealed class BinaryPrecisionRecallDataPoint
11+
{
12+
/// <summary>
13+
/// Gets the threshold for this data point.
14+
/// </summary>
15+
public double Threshold { get; }
16+
/// <summary>
17+
/// Gets the precision for the current threshold.
18+
/// </summary>
19+
public double Precision { get; }
20+
/// <summary>
21+
/// Gets the recall for the current threshold.
22+
/// </summary>
23+
public double Recall { get; }
24+
25+
/// <summary>
26+
/// Gets the true positive rate for the current threshold.
27+
/// </summary>
28+
public double TruePositiveRate => Recall;
29+
30+
/// <summary>
31+
/// Gets the false positive rate for the given threshold.
32+
/// </summary>
33+
public double FalsePositiveRate { get; }
34+
35+
internal BinaryPrecisionRecallDataPoint(ValueGetter<float> thresholdGetter, ValueGetter<double> precisionGetter, ValueGetter<double> recallGetter, ValueGetter<double> fprGetter)
36+
{
37+
float threshold = default;
38+
double precision = default;
39+
double recall = default;
40+
double fpr = default;
41+
42+
thresholdGetter(ref threshold);
43+
precisionGetter(ref precision);
44+
recallGetter(ref recall);
45+
fprGetter(ref fpr);
46+
47+
Threshold = threshold;
48+
Precision = precision;
49+
Recall = recall;
50+
FalsePositiveRate = fpr;
51+
}
52+
}
53+
54+
}

src/Microsoft.ML.StaticPipe/EvaluatorStaticExtensions.cs

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Collections.Generic;
67
using Microsoft.ML.Data;
78
using Microsoft.ML.Runtime;
89
using Microsoft.ML.Trainers;
@@ -50,6 +51,44 @@ public static CalibratedBinaryClassificationMetrics Evaluate<T>(
5051
return eval.Evaluate(data.AsDynamic, labelName, scoreName, probName, predName);
5152
}
5253

54+
/// <summary>
55+
/// Evaluates scored binary classification data and generates precision recall curve data.
56+
/// </summary>
57+
/// <typeparam name="T">The shape type for the input data.</typeparam>
58+
/// <param name="catalog">The binary classification catalog.</param>
59+
/// <param name="data">The data to evaluate.</param>
60+
/// <param name="label">The index delegate for the label column.</param>
61+
/// <param name="pred">The index delegate for columns from calibrated prediction of a binary classifier.
62+
/// Under typical scenarios, this will just be the same tuple of results returned from the trainer.</param>
63+
/// <param name="prCurve">The generated precision recall curve data. Up to 100000 of samples are used for p/r curve generation.</param>
64+
/// <returns>The evaluation results for these calibrated outputs.</returns>
65+
public static CalibratedBinaryClassificationMetrics EvaluateWithPRCurve<T>(
66+
this BinaryClassificationCatalog catalog,
67+
DataView<T> data,
68+
Func<T, Scalar<bool>> label,
69+
Func<T, (Scalar<float> score, Scalar<float> probability, Scalar<bool> predictedLabel)> pred,
70+
out List<BinaryPrecisionRecallDataPoint> prCurve)
71+
{
72+
Contracts.CheckValue(data, nameof(data));
73+
var env = StaticPipeUtils.GetEnvironment(data);
74+
Contracts.AssertValue(env);
75+
env.CheckValue(label, nameof(label));
76+
env.CheckValue(pred, nameof(pred));
77+
78+
var indexer = StaticPipeUtils.GetIndexer(data);
79+
string labelName = indexer.Get(label(indexer.Indices));
80+
(var scoreCol, var probCol, var predCol) = pred(indexer.Indices);
81+
env.CheckParam(scoreCol != null, nameof(pred), "Indexing delegate resulted in null score column.");
82+
env.CheckParam(probCol != null, nameof(pred), "Indexing delegate resulted in null probability column.");
83+
env.CheckParam(predCol != null, nameof(pred), "Indexing delegate resulted in null predicted label column.");
84+
string scoreName = indexer.Get(scoreCol);
85+
string probName = indexer.Get(probCol);
86+
string predName = indexer.Get(predCol);
87+
88+
var eval = new BinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments() { NumRocExamples = 100000 });
89+
return eval.EvaluateWithPRCurve(data.AsDynamic, labelName, scoreName, probName, predName, out prCurve);
90+
}
91+
5392
/// <summary>
5493
/// Evaluates scored binary classification data, if the predictions are not calibrated.
5594
/// </summary>
@@ -84,6 +123,43 @@ public static BinaryClassificationMetrics Evaluate<T>(
84123
return eval.Evaluate(data.AsDynamic, labelName, scoreName, predName);
85124
}
86125

126+
/// <summary>
127+
/// Evaluates scored binary classification data, if the predictions are not calibrated
128+
/// and generates precision recall curve data.
129+
/// </summary>
130+
/// <typeparam name="T">The shape type for the input data.</typeparam>
131+
/// <param name="catalog">The binary classification catalog.</param>
132+
/// <param name="data">The data to evaluate.</param>
133+
/// <param name="label">The index delegate for the label column.</param>
134+
/// <param name="pred">The index delegate for columns from uncalibrated prediction of a binary classifier.
135+
/// Under typical scenarios, this will just be the same tuple of results returned from the trainer.</param>
136+
/// <param name="prCurve">The generated precision recall curve data. Up to 100000 of samples are used for p/r curve generation.</param>
137+
/// <returns>The evaluation results for these uncalibrated outputs.</returns>
138+
public static BinaryClassificationMetrics EvaluateWithPRCurve<T>(
139+
this BinaryClassificationCatalog catalog,
140+
DataView<T> data,
141+
Func<T, Scalar<bool>> label,
142+
Func<T, (Scalar<float> score, Scalar<bool> predictedLabel)> pred,
143+
out List<BinaryPrecisionRecallDataPoint> prCurve)
144+
{
145+
Contracts.CheckValue(data, nameof(data));
146+
var env = StaticPipeUtils.GetEnvironment(data);
147+
Contracts.AssertValue(env);
148+
env.CheckValue(label, nameof(label));
149+
env.CheckValue(pred, nameof(pred));
150+
151+
var indexer = StaticPipeUtils.GetIndexer(data);
152+
string labelName = indexer.Get(label(indexer.Indices));
153+
(var scoreCol, var predCol) = pred(indexer.Indices);
154+
Contracts.CheckParam(scoreCol != null, nameof(pred), "Indexing delegate resulted in null score column.");
155+
Contracts.CheckParam(predCol != null, nameof(pred), "Indexing delegate resulted in null predicted label column.");
156+
string scoreName = indexer.Get(scoreCol);
157+
string predName = indexer.Get(predCol);
158+
159+
var eval = new BinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments() { NumRocExamples = 100000 });
160+
return eval.EvaluateWithPRCurve(data.AsDynamic, labelName, scoreName, predName, out prCurve);
161+
}
162+
87163
/// <summary>
88164
/// Evaluates scored clustering prediction data.
89165
/// </summary>

0 commit comments

Comments
 (0)