-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Add API to get Precision-Recall Curve data #3039
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
34f9b69
247d93a
19bf06e
4c86ddb
2daff7a
c64a745
a690881
c2d2314
b27e594
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -828,6 +828,93 @@ public CalibratedBinaryClassificationMetrics Evaluate(IDataView data, string lab | |
| return result; | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Evaluates scored binary classification data and generates precision recall curve data. | ||
| /// </summary> | ||
| /// <param name="data">The scored data.</param> | ||
| /// <param name="label">The name of the label column in <paramref name="data"/>.</param> | ||
| /// <param name="score">The name of the score column in <paramref name="data"/>.</param> | ||
| /// <param name="probability">The name of the probability column in <paramref name="data"/>, the calibrated version of <paramref name="score"/>.</param> | ||
| /// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param> | ||
| /// <param name="prCurve">The generated precision recall curve data. Up to 100000 of samples are used for p/r curve generation.</param> | ||
| /// <returns>The evaluation results for these calibrated outputs.</returns> | ||
| public CalibratedBinaryClassificationMetrics EvaluateWithPRCurve( | ||
| IDataView data, | ||
| string label, | ||
| string score, | ||
| string probability, | ||
| string predictedLabel, | ||
| out List<BinaryPrecisionRecallDataPoint> prCurve) | ||
| { | ||
| Host.CheckValue(data, nameof(data)); | ||
| Host.CheckNonEmpty(label, nameof(label)); | ||
| Host.CheckNonEmpty(score, nameof(score)); | ||
| Host.CheckNonEmpty(probability, nameof(probability)); | ||
| Host.CheckNonEmpty(predictedLabel, nameof(predictedLabel)); | ||
|
|
||
| var roles = new RoleMappedData(data, opt: false, | ||
| RoleMappedSchema.ColumnRole.Label.Bind(label), | ||
| RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.Score, score), | ||
| RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.Probability, probability), | ||
| RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.PredictedLabel, predictedLabel)); | ||
|
|
||
| var resultDict = ((IEvaluator)this).Evaluate(roles); | ||
| Host.Assert(resultDict.ContainsKey(MetricKinds.PrCurve)); | ||
| var prCurveView = resultDict[MetricKinds.PrCurve]; | ||
| Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics)); | ||
| var overall = resultDict[MetricKinds.OverallMetrics]; | ||
|
|
||
| var prCurveResult = new List<BinaryPrecisionRecallDataPoint>(); | ||
| using (var cursor = prCurveView.GetRowCursorForAllColumns()) | ||
| { | ||
| GetPrecisionRecallDataPointGetters(prCurveView, cursor, | ||
| out ValueGetter<float> thresholdGetter, | ||
| out ValueGetter<double> precisionGetter, | ||
| out ValueGetter<double> recallGetter, | ||
| out ValueGetter<double> fprGetter); | ||
|
|
||
| while (cursor.MoveNext()) | ||
| { | ||
| prCurveResult.Add(new BinaryPrecisionRecallDataPoint(thresholdGetter, precisionGetter, recallGetter, fprGetter)); | ||
| } | ||
| } | ||
| prCurve = prCurveResult; | ||
|
|
||
| CalibratedBinaryClassificationMetrics result; | ||
| using (var cursor = overall.GetRowCursorForAllColumns()) | ||
| { | ||
| var moved = cursor.MoveNext(); | ||
| Host.Assert(moved); | ||
| result = new CalibratedBinaryClassificationMetrics(Host, cursor); | ||
| moved = cursor.MoveNext(); | ||
| Host.Assert(!moved); | ||
| } | ||
|
|
||
| return result; | ||
| } | ||
|
|
||
| private void GetPrecisionRecallDataPointGetters(IDataView prCurveView, | ||
| DataViewRowCursor cursor, | ||
| out ValueGetter<float> thresholdGetter, | ||
| out ValueGetter<double> precisionGetter, | ||
| out ValueGetter<double> recallGetter, | ||
| out ValueGetter<double> fprGetter) | ||
| { | ||
| var thresholdColumn = prCurveView.Schema.GetColumnOrNull(BinaryClassifierEvaluator.Threshold); | ||
| var precisionColumn = prCurveView.Schema.GetColumnOrNull(BinaryClassifierEvaluator.Precision); | ||
| var recallColumn = prCurveView.Schema.GetColumnOrNull(BinaryClassifierEvaluator.Recall); | ||
| var fprColumn = prCurveView.Schema.GetColumnOrNull(BinaryClassifierEvaluator.FalsePositiveRate); | ||
| Host.Assert(thresholdColumn != null); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In future, consider
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tried it, didnt work here bcs DataViewSchema.Column is not a reference type, its a struct In reply to: 268380273 [](ancestors = 268380273) |
||
| Host.Assert(precisionColumn != null); | ||
| Host.Assert(recallColumn != null); | ||
| Host.Assert(fprColumn != null); | ||
|
|
||
| thresholdGetter = cursor.GetGetter<float>((DataViewSchema.Column)thresholdColumn); | ||
| precisionGetter = cursor.GetGetter<double>((DataViewSchema.Column)precisionColumn); | ||
| recallGetter = cursor.GetGetter<double>((DataViewSchema.Column)recallColumn); | ||
| fprGetter = cursor.GetGetter<double>((DataViewSchema.Column)fprColumn); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Evaluates scored binary classification data, without probability-based metrics. | ||
| /// </summary> | ||
|
|
@@ -864,6 +951,69 @@ public BinaryClassificationMetrics Evaluate(IDataView data, string label, string | |
| } | ||
| return result; | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Evaluates scored binary classification data, without probability-based metrics | ||
| /// and generates precision recall curve data. | ||
| /// </summary> | ||
| /// <param name="data">The scored data.</param> | ||
| /// <param name="label">The name of the label column in <paramref name="data"/>.</param> | ||
| /// <param name="score">The name of the score column in <paramref name="data"/>.</param> | ||
| /// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param> | ||
| /// <param name="prCurve">The generated precision recall curve data. Up to 100000 of samples are used for p/r curve generation.</param> | ||
| /// <returns>The evaluation results for these uncalibrated outputs.</returns> | ||
| /// <seealso cref="Evaluate(IDataView, string, string, string)"/> | ||
| public BinaryClassificationMetrics EvaluateWithPRCurve( | ||
| IDataView data, | ||
| string label, | ||
| string score, | ||
| string predictedLabel, | ||
| out List<BinaryPrecisionRecallDataPoint> prCurve) | ||
| { | ||
| Host.CheckValue(data, nameof(data)); | ||
| Host.CheckNonEmpty(label, nameof(label)); | ||
| Host.CheckNonEmpty(score, nameof(score)); | ||
| Host.CheckNonEmpty(predictedLabel, nameof(predictedLabel)); | ||
|
|
||
| var roles = new RoleMappedData(data, opt: false, | ||
| RoleMappedSchema.ColumnRole.Label.Bind(label), | ||
| RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.Score, score), | ||
| RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.PredictedLabel, predictedLabel)); | ||
|
|
||
| var resultDict = ((IEvaluator)this).Evaluate(roles); | ||
| Host.Assert(resultDict.ContainsKey(MetricKinds.PrCurve)); | ||
| var prCurveView = resultDict[MetricKinds.PrCurve]; | ||
| Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics)); | ||
| var overall = resultDict[MetricKinds.OverallMetrics]; | ||
|
|
||
| var prCurveResult = new List<BinaryPrecisionRecallDataPoint>(); | ||
| using (var cursor = prCurveView.GetRowCursorForAllColumns()) | ||
| { | ||
| GetPrecisionRecallDataPointGetters(prCurveView, cursor, | ||
| out ValueGetter<float> thresholdGetter, | ||
| out ValueGetter<double> precisionGetter, | ||
| out ValueGetter<double> recallGetter, | ||
| out ValueGetter<double> fprGetter); | ||
|
|
||
| while (cursor.MoveNext()) | ||
| { | ||
| prCurveResult.Add(new BinaryPrecisionRecallDataPoint(thresholdGetter, precisionGetter, recallGetter, fprGetter)); | ||
| } | ||
| } | ||
| prCurve = prCurveResult; | ||
|
|
||
| BinaryClassificationMetrics result; | ||
| using (var cursor = overall.GetRowCursorForAllColumns()) | ||
| { | ||
| var moved = cursor.MoveNext(); | ||
| Host.Assert(moved); | ||
| result = new BinaryClassificationMetrics(Host, cursor); | ||
| moved = cursor.MoveNext(); | ||
| Host.Assert(!moved); | ||
| } | ||
|
|
||
| return result; | ||
| } | ||
| } | ||
|
|
||
| internal sealed class BinaryPerInstanceEvaluator : PerInstanceEvaluatorBase | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| // Licensed to the .NET Foundation under one or more agreements. | ||
| // The .NET Foundation licenses this file to you under the MIT license. | ||
| // See the LICENSE file in the project root for more information. | ||
|
|
||
| namespace Microsoft.ML.Data | ||
| { | ||
| /// <summary> | ||
| /// This class represents one data point on Precision-Recall curve for binary classification. | ||
| /// </summary> | ||
| public sealed class BinaryPrecisionRecallDataPoint | ||
| { | ||
| /// <summary> | ||
| /// Gets the threshold for this data point. | ||
| /// </summary> | ||
| public double Threshold { get; } | ||
| /// <summary> | ||
| /// Gets the precision for the current threshold. | ||
| /// </summary> | ||
| public double Precision { get; } | ||
| /// <summary> | ||
| /// Gets the recall for the current threshold. | ||
| /// </summary> | ||
| public double Recall { get; } | ||
|
|
||
| /// <summary> | ||
| /// Gets the true positive rate for the current threshold. | ||
| /// </summary> | ||
| public double TruePositiveRate => Recall; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
So, what is the point of this? I see a class, one property exists merely to point to another, I think, what is up? #ByDesign
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is to make life easier for users to plot ROC curve. ROC curve is plotted as TPR vs FPR. We do have FPR, but spent some time with Rogan looking for TPR in order to get ROC data. It appears Recall is synonymous with TPR. So for users to avoid same confusion and to reduce number of potential inqueries in future decided to have field TPR in here In reply to: 268370748 [](ancestors = 268370748) |
||
|
|
||
| /// <summary> | ||
| /// Gets the false positive rate for the given threshold. | ||
| /// </summary> | ||
| public double FalsePositiveRate { get; } | ||
|
|
||
| internal BinaryPrecisionRecallDataPoint(ValueGetter<float> thresholdGetter, ValueGetter<double> precisionGetter, ValueGetter<double> recallGetter, ValueGetter<double> fprGetter) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I would recommend that all logic persuant to the getters remain in the same block of code as the cursor (anything else is just a pointless increase of complexity), but this at least avoids the main problem so, I guess we can fix it later. #ByDesign
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| { | ||
| float threshold = default; | ||
| double precision = default; | ||
| double recall = default; | ||
| double fpr = default; | ||
|
|
||
| thresholdGetter(ref threshold); | ||
| precisionGetter(ref precision); | ||
| recallGetter(ref recall); | ||
| fprGetter(ref fpr); | ||
|
|
||
| Threshold = threshold; | ||
| Precision = precision; | ||
| Recall = recall; | ||
| FalsePositiveRate = fpr; | ||
| } | ||
| } | ||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |
| // See the LICENSE file in the project root for more information. | ||
|
|
||
| using System; | ||
| using System.Collections.Generic; | ||
| using Microsoft.ML.Data; | ||
| using Microsoft.ML.Runtime; | ||
| using Microsoft.ML.Trainers; | ||
|
|
@@ -50,6 +51,44 @@ public static CalibratedBinaryClassificationMetrics Evaluate<T>( | |
| return eval.Evaluate(data.AsDynamic, labelName, scoreName, probName, predName); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Evaluates scored binary classification data and generates precision recall curve data. | ||
| /// </summary> | ||
| /// <typeparam name="T">The shape type for the input data.</typeparam> | ||
| /// <param name="catalog">The binary classification catalog.</param> | ||
| /// <param name="data">The data to evaluate.</param> | ||
| /// <param name="label">The index delegate for the label column.</param> | ||
| /// <param name="pred">The index delegate for columns from calibrated prediction of a binary classifier. | ||
| /// Under typical scenarios, this will just be the same tuple of results returned from the trainer.</param> | ||
| /// <param name="prCurve">The generated precision recall curve data. Up to 100000 of samples are used for p/r curve generation.</param> | ||
| /// <returns>The evaluation results for these calibrated outputs.</returns> | ||
| public static CalibratedBinaryClassificationMetrics EvaluateWithPRCurve<T>( | ||
| this BinaryClassificationCatalog catalog, | ||
| DataView<T> data, | ||
| Func<T, Scalar<bool>> label, | ||
| Func<T, (Scalar<float> score, Scalar<float> probability, Scalar<bool> predictedLabel)> pred, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
is this always a bool? don't we sometimes predict ints or keys?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a twin method for public static CalibratedBinaryClassificationMetrics Evaluate(..) on line 28. It does everything the same way plus spits out PR Curve. The scope of this PR is to add spitting of the PRCurve, To change parameter types or names will mean refactoring the original Evaluate(..) method which is out oof scope for this PR In reply to: 268282834 [](ancestors = 268282834) |
||
| out List<BinaryPrecisionRecallDataPoint> prCurve) | ||
| { | ||
| Contracts.CheckValue(data, nameof(data)); | ||
| var env = StaticPipeUtils.GetEnvironment(data); | ||
| Contracts.AssertValue(env); | ||
| env.CheckValue(label, nameof(label)); | ||
| env.CheckValue(pred, nameof(pred)); | ||
|
|
||
| var indexer = StaticPipeUtils.GetIndexer(data); | ||
| string labelName = indexer.Get(label(indexer.Indices)); | ||
| (var scoreCol, var probCol, var predCol) = pred(indexer.Indices); | ||
| env.CheckParam(scoreCol != null, nameof(pred), "Indexing delegate resulted in null score column."); | ||
| env.CheckParam(probCol != null, nameof(pred), "Indexing delegate resulted in null probability column."); | ||
| env.CheckParam(predCol != null, nameof(pred), "Indexing delegate resulted in null predicted label column."); | ||
| string scoreName = indexer.Get(scoreCol); | ||
| string probName = indexer.Get(probCol); | ||
| string predName = indexer.Get(predCol); | ||
|
|
||
| var eval = new BinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments() { NumRocExamples = 100000 }); | ||
| return eval.EvaluateWithPRCurve(data.AsDynamic, labelName, scoreName, probName, predName, out prCurve); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. super simillar to another new method. Refactor? #Closed
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this would mean refactoring both Evaluate(..) on lines 28 & 102. As these new methods are based off them. I dont want to make bigger changes at this point and just want to follow the established pattern here to minimize risks In reply to: 268285448 [](ancestors = 268285448) |
||
| } | ||
|
|
||
| /// <summary> | ||
| /// Evaluates scored binary classification data, if the predictions are not calibrated. | ||
| /// </summary> | ||
|
|
@@ -84,6 +123,43 @@ public static BinaryClassificationMetrics Evaluate<T>( | |
| return eval.Evaluate(data.AsDynamic, labelName, scoreName, predName); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Evaluates scored binary classification data, if the predictions are not calibrated | ||
| /// and generates precision recall curve data. | ||
| /// </summary> | ||
| /// <typeparam name="T">The shape type for the input data.</typeparam> | ||
| /// <param name="catalog">The binary classification catalog.</param> | ||
| /// <param name="data">The data to evaluate.</param> | ||
| /// <param name="label">The index delegate for the label column.</param> | ||
| /// <param name="pred">The index delegate for columns from uncalibrated prediction of a binary classifier. | ||
| /// Under typical scenarios, this will just be the same tuple of results returned from the trainer.</param> | ||
| /// <param name="prCurve">The generated precision recall curve data. Up to 100000 of samples are used for p/r curve generation.</param> | ||
| /// <returns>The evaluation results for these uncalibrated outputs.</returns> | ||
| public static BinaryClassificationMetrics EvaluateWithPRCurve<T>( | ||
| this BinaryClassificationCatalog catalog, | ||
| DataView<T> data, | ||
| Func<T, Scalar<bool>> label, | ||
| Func<T, (Scalar<float> score, Scalar<bool> predictedLabel)> pred, | ||
| out List<BinaryPrecisionRecallDataPoint> prCurve) | ||
| { | ||
| Contracts.CheckValue(data, nameof(data)); | ||
| var env = StaticPipeUtils.GetEnvironment(data); | ||
| Contracts.AssertValue(env); | ||
| env.CheckValue(label, nameof(label)); | ||
| env.CheckValue(pred, nameof(pred)); | ||
|
|
||
| var indexer = StaticPipeUtils.GetIndexer(data); | ||
| string labelName = indexer.Get(label(indexer.Indices)); | ||
| (var scoreCol, var predCol) = pred(indexer.Indices); | ||
| Contracts.CheckParam(scoreCol != null, nameof(pred), "Indexing delegate resulted in null score column."); | ||
| Contracts.CheckParam(predCol != null, nameof(pred), "Indexing delegate resulted in null predicted label column."); | ||
| string scoreName = indexer.Get(scoreCol); | ||
| string predName = indexer.Get(predCol); | ||
|
|
||
| var eval = new BinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments() { NumRocExamples = 100000 }); | ||
| return eval.EvaluateWithPRCurve(data.AsDynamic, labelName, scoreName, predName, out prCurve); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Evaluates scored clustering prediction data. | ||
| /// </summary> | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to put in the PR what we had talked about, it is important that we get the getters before the tight loop, and we use the getters inside the tight loop to retrieve the values, since constructing the delegates is extremely expensive but using them is very cheap. (See here for more architectural info on why this is so.) #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thx
In reply to: 268291075 [](ancestors = 268291075)