Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System;
using System.Collections.Generic;
using Microsoft.ML.Data;
using Microsoft.ML.StaticPipe;

namespace Microsoft.ML.Samples.Static
Expand Down Expand Up @@ -89,7 +91,7 @@ public static void SdcaBinaryClassification()
// Evaluate how the model is doing on the test data
var dataWithPredictions = model.Transform(testData);

var metrics = mlContext.BinaryClassification.Evaluate(dataWithPredictions, row => row.Label, row => row.Score);
var metrics = mlContext.BinaryClassification.EvaluateWithPRCurve(dataWithPredictions, row => row.Label, row => row.Score, out List<BinaryPrecisionRecallDataPoint> prCurve);

Console.WriteLine($"Accuracy: {metrics.Accuracy}"); // 0.83
Console.WriteLine($"AUC: {metrics.AreaUnderRocCurve}"); // 0.88
Expand All @@ -98,7 +100,15 @@ public static void SdcaBinaryClassification()
Console.WriteLine($"Negative Precision: {metrics.NegativePrecision}"); // 0.87
Console.WriteLine($"Negative Recall: {metrics.NegativeRecall}"); // 0.91
Console.WriteLine($"Positive Precision: {metrics.PositivePrecision}"); // 0.65
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall}"); // 0.55
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall}"); // 0.55

foreach(var prData in prCurve)
{
Console.Write($"Threshold: {prData.Threshold} ");
Console.Write($"Precision: {prData.Precision} ");
Console.Write($"Recall: {prData.Recall} ");
Console.WriteLine($"FPR: {prData.FalsePositiveRate}");
}
}
}
}
150 changes: 150 additions & 0 deletions src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,93 @@ public CalibratedBinaryClassificationMetrics Evaluate(IDataView data, string lab
return result;
}

/// <summary>
/// Evaluates scored binary classification data and generates precision recall curve data.
/// </summary>
/// <param name="data">The scored data.</param>
/// <param name="label">The name of the label column in <paramref name="data"/>.</param>
/// <param name="score">The name of the score column in <paramref name="data"/>.</param>
/// <param name="probability">The name of the probability column in <paramref name="data"/>, the calibrated version of <paramref name="score"/>.</param>
/// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param>
/// <param name="prCurve">The generated precision recall curve data. Up to 100000 of samples are used for p/r curve generation.</param>
/// <returns>The evaluation results for these calibrated outputs.</returns>
public CalibratedBinaryClassificationMetrics EvaluateWithPRCurve(
IDataView data,
string label,
string score,
string probability,
string predictedLabel,
out List<BinaryPrecisionRecallDataPoint> prCurve)
{
Host.CheckValue(data, nameof(data));
Host.CheckNonEmpty(label, nameof(label));
Host.CheckNonEmpty(score, nameof(score));
Host.CheckNonEmpty(probability, nameof(probability));
Host.CheckNonEmpty(predictedLabel, nameof(predictedLabel));

var roles = new RoleMappedData(data, opt: false,
RoleMappedSchema.ColumnRole.Label.Bind(label),
RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.Score, score),
RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.Probability, probability),
RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.PredictedLabel, predictedLabel));

var resultDict = ((IEvaluator)this).Evaluate(roles);
Host.Assert(resultDict.ContainsKey(MetricKinds.PrCurve));
var prCurveView = resultDict[MetricKinds.PrCurve];
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
var overall = resultDict[MetricKinds.OverallMetrics];

var prCurveResult = new List<BinaryPrecisionRecallDataPoint>();
using (var cursor = prCurveView.GetRowCursorForAllColumns())
{
GetPrecisionRecallDataPointGetters(prCurveView, cursor,
out ValueGetter<float> thresholdGetter,
out ValueGetter<double> precisionGetter,
out ValueGetter<double> recallGetter,
out ValueGetter<double> fprGetter);

while (cursor.MoveNext())
Copy link
Contributor

@TomFinley TomFinley Mar 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to put in the PR what we had talked about, it is important that we get the getters before the tight loop, and we use the getters inside the tight loop to retrieve the values, since constructing the delegates is extremely expensive but using them is very cheap. (See here for more architectural info on why this is so.) #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx


In reply to: 268291075 [](ancestors = 268291075)

{
prCurveResult.Add(new BinaryPrecisionRecallDataPoint(thresholdGetter, precisionGetter, recallGetter, fprGetter));
}
}
prCurve = prCurveResult;

CalibratedBinaryClassificationMetrics result;
using (var cursor = overall.GetRowCursorForAllColumns())
{
var moved = cursor.MoveNext();
Host.Assert(moved);
result = new CalibratedBinaryClassificationMetrics(Host, cursor);
moved = cursor.MoveNext();
Host.Assert(!moved);
}

return result;
}

private void GetPrecisionRecallDataPointGetters(IDataView prCurveView,
DataViewRowCursor cursor,
out ValueGetter<float> thresholdGetter,
out ValueGetter<double> precisionGetter,
out ValueGetter<double> recallGetter,
out ValueGetter<double> fprGetter)
{
var thresholdColumn = prCurveView.Schema.GetColumnOrNull(BinaryClassifierEvaluator.Threshold);
var precisionColumn = prCurveView.Schema.GetColumnOrNull(BinaryClassifierEvaluator.Precision);
var recallColumn = prCurveView.Schema.GetColumnOrNull(BinaryClassifierEvaluator.Recall);
var fprColumn = prCurveView.Schema.GetColumnOrNull(BinaryClassifierEvaluator.FalsePositiveRate);
Host.Assert(thresholdColumn != null);
Copy link
Contributor

@TomFinley TomFinley Mar 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Host.Assert [](start = 12, length = 11)

In future, consider AssertValue to make this a bit more elegant. #ByDesign

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tried it, didnt work here bcs DataViewSchema.Column is not a reference type, its a struct


In reply to: 268380273 [](ancestors = 268380273)

Host.Assert(precisionColumn != null);
Host.Assert(recallColumn != null);
Host.Assert(fprColumn != null);

thresholdGetter = cursor.GetGetter<float>((DataViewSchema.Column)thresholdColumn);
precisionGetter = cursor.GetGetter<double>((DataViewSchema.Column)precisionColumn);
recallGetter = cursor.GetGetter<double>((DataViewSchema.Column)recallColumn);
fprGetter = cursor.GetGetter<double>((DataViewSchema.Column)fprColumn);
}

/// <summary>
/// Evaluates scored binary classification data, without probability-based metrics.
/// </summary>
Expand Down Expand Up @@ -864,6 +951,69 @@ public BinaryClassificationMetrics Evaluate(IDataView data, string label, string
}
return result;
}

/// <summary>
/// Evaluates scored binary classification data, without probability-based metrics
/// and generates precision recall curve data.
/// </summary>
/// <param name="data">The scored data.</param>
/// <param name="label">The name of the label column in <paramref name="data"/>.</param>
/// <param name="score">The name of the score column in <paramref name="data"/>.</param>
/// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param>
/// <param name="prCurve">The generated precision recall curve data. Up to 100000 of samples are used for p/r curve generation.</param>
/// <returns>The evaluation results for these uncalibrated outputs.</returns>
/// <seealso cref="Evaluate(IDataView, string, string, string)"/>
public BinaryClassificationMetrics EvaluateWithPRCurve(
IDataView data,
string label,
string score,
string predictedLabel,
out List<BinaryPrecisionRecallDataPoint> prCurve)
{
Host.CheckValue(data, nameof(data));
Host.CheckNonEmpty(label, nameof(label));
Host.CheckNonEmpty(score, nameof(score));
Host.CheckNonEmpty(predictedLabel, nameof(predictedLabel));

var roles = new RoleMappedData(data, opt: false,
RoleMappedSchema.ColumnRole.Label.Bind(label),
RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.Score, score),
RoleMappedSchema.CreatePair(AnnotationUtils.Const.ScoreValueKind.PredictedLabel, predictedLabel));

var resultDict = ((IEvaluator)this).Evaluate(roles);
Host.Assert(resultDict.ContainsKey(MetricKinds.PrCurve));
var prCurveView = resultDict[MetricKinds.PrCurve];
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
var overall = resultDict[MetricKinds.OverallMetrics];

var prCurveResult = new List<BinaryPrecisionRecallDataPoint>();
using (var cursor = prCurveView.GetRowCursorForAllColumns())
{
GetPrecisionRecallDataPointGetters(prCurveView, cursor,
out ValueGetter<float> thresholdGetter,
out ValueGetter<double> precisionGetter,
out ValueGetter<double> recallGetter,
out ValueGetter<double> fprGetter);

while (cursor.MoveNext())
{
prCurveResult.Add(new BinaryPrecisionRecallDataPoint(thresholdGetter, precisionGetter, recallGetter, fprGetter));
}
}
prCurve = prCurveResult;

BinaryClassificationMetrics result;
using (var cursor = overall.GetRowCursorForAllColumns())
{
var moved = cursor.MoveNext();
Host.Assert(moved);
result = new BinaryClassificationMetrics(Host, cursor);
moved = cursor.MoveNext();
Host.Assert(!moved);
}

return result;
}
}

internal sealed class BinaryPerInstanceEvaluator : PerInstanceEvaluatorBase
Expand Down
5 changes: 5 additions & 0 deletions src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1776,6 +1776,11 @@ internal static class MetricKinds
/// </summary>
public const string OverallMetrics = "OverallMetrics";

/// <summary>
/// This is a data view with precision recall data in its columns. It has four columns: Threshold, Precision, Recall and Fpr.
/// </summary>
public const string PrCurve = "PrCurve";

/// <summary>
/// This data view contains a single text column, with warnings about bad input values encountered by the evaluator during
/// the aggregation of metrics. Each warning is in a separate row.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

namespace Microsoft.ML.Data
{
/// <summary>
/// This class represents one data point on Precision-Recall curve for binary classification.
/// </summary>
public sealed class BinaryPrecisionRecallDataPoint
{
/// <summary>
/// Gets the threshold for this data point.
/// </summary>
public double Threshold { get; }
/// <summary>
/// Gets the precision for the current threshold.
/// </summary>
public double Precision { get; }
/// <summary>
/// Gets the recall for the current threshold.
/// </summary>
public double Recall { get; }

/// <summary>
/// Gets the true positive rate for the current threshold.
/// </summary>
public double TruePositiveRate => Recall;
Copy link
Contributor

@TomFinley TomFinley Mar 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

public double TruePositiveRate => Recall; [](start = 8, length = 41)

So, what is the point of this? I see a class, one property exists merely to point to another, I think, what is up? #ByDesign

Copy link
Member Author

@ganik ganik Mar 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to make life easier for users to plot ROC curve. ROC curve is plotted as TPR vs FPR. We do have FPR, but spent some time with Rogan looking for TPR in order to get ROC data. It appears Recall is synonymous with TPR. So for users to avoid same confusion and to reduce number of potential inqueries in future decided to have field TPR in here


In reply to: 268370748 [](ancestors = 268370748)


/// <summary>
/// Gets the false positive rate for the given threshold.
/// </summary>
public double FalsePositiveRate { get; }

internal BinaryPrecisionRecallDataPoint(ValueGetter<float> thresholdGetter, ValueGetter<double> precisionGetter, ValueGetter<double> recallGetter, ValueGetter<double> fprGetter)
Copy link
Contributor

@TomFinley TomFinley Mar 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ValueGetter precisionGetter [](start = 84, length = 35)

I would recommend that all logic persuant to the getters remain in the same block of code as the cursor (anything else is just a pointless increase of complexity), but this at least avoids the main problem so, I guess we can fix it later. #ByDesign

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, we can do it in this or in a separate PR.


In reply to: 268370725 [](ancestors = 268370725)

{
float threshold = default;
double precision = default;
double recall = default;
double fpr = default;

thresholdGetter(ref threshold);
precisionGetter(ref precision);
recallGetter(ref recall);
fprGetter(ref fpr);

Threshold = threshold;
Precision = precision;
Recall = recall;
FalsePositiveRate = fpr;
}
}

}
76 changes: 76 additions & 0 deletions src/Microsoft.ML.StaticPipe/EvaluatorStaticExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;
Expand Down Expand Up @@ -50,6 +51,44 @@ public static CalibratedBinaryClassificationMetrics Evaluate<T>(
return eval.Evaluate(data.AsDynamic, labelName, scoreName, probName, predName);
}

/// <summary>
/// Evaluates scored binary classification data and generates precision recall curve data.
/// </summary>
/// <typeparam name="T">The shape type for the input data.</typeparam>
/// <param name="catalog">The binary classification catalog.</param>
/// <param name="data">The data to evaluate.</param>
/// <param name="label">The index delegate for the label column.</param>
/// <param name="pred">The index delegate for columns from calibrated prediction of a binary classifier.
/// Under typical scenarios, this will just be the same tuple of results returned from the trainer.</param>
/// <param name="prCurve">The generated precision recall curve data. Up to 100000 of samples are used for p/r curve generation.</param>
/// <returns>The evaluation results for these calibrated outputs.</returns>
public static CalibratedBinaryClassificationMetrics EvaluateWithPRCurve<T>(
this BinaryClassificationCatalog catalog,
DataView<T> data,
Func<T, Scalar<bool>> label,
Func<T, (Scalar<float> score, Scalar<float> probability, Scalar<bool> predictedLabel)> pred,
Copy link
Contributor

@glebuk glebuk Mar 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

predictedLabel [](start = 82, length = 14)

is this always a bool? don't we sometimes predict ints or keys?
Consider renaming the parameter to a full word such as prediction #Closed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a twin method for public static CalibratedBinaryClassificationMetrics Evaluate(..) on line 28. It does everything the same way plus spits out PR Curve. The scope of this PR is to add spitting of the PRCurve, To change parameter types or names will mean refactoring the original Evaluate(..) method which is out oof scope for this PR


In reply to: 268282834 [](ancestors = 268282834)

out List<BinaryPrecisionRecallDataPoint> prCurve)
{
Contracts.CheckValue(data, nameof(data));
var env = StaticPipeUtils.GetEnvironment(data);
Contracts.AssertValue(env);
env.CheckValue(label, nameof(label));
env.CheckValue(pred, nameof(pred));

var indexer = StaticPipeUtils.GetIndexer(data);
string labelName = indexer.Get(label(indexer.Indices));
(var scoreCol, var probCol, var predCol) = pred(indexer.Indices);
env.CheckParam(scoreCol != null, nameof(pred), "Indexing delegate resulted in null score column.");
env.CheckParam(probCol != null, nameof(pred), "Indexing delegate resulted in null probability column.");
env.CheckParam(predCol != null, nameof(pred), "Indexing delegate resulted in null predicted label column.");
string scoreName = indexer.Get(scoreCol);
string probName = indexer.Get(probCol);
string predName = indexer.Get(predCol);

var eval = new BinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments() { NumRocExamples = 100000 });
return eval.EvaluateWithPRCurve(data.AsDynamic, labelName, scoreName, probName, predName, out prCurve);
Copy link
Contributor

@glebuk glebuk Mar 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super simillar to another new method. Refactor? #Closed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would mean refactoring both Evaluate(..) on lines 28 & 102. As these new methods are based off them. I dont want to make bigger changes at this point and just want to follow the established pattern here to minimize risks


In reply to: 268285448 [](ancestors = 268285448)

}

/// <summary>
/// Evaluates scored binary classification data, if the predictions are not calibrated.
/// </summary>
Expand Down Expand Up @@ -84,6 +123,43 @@ public static BinaryClassificationMetrics Evaluate<T>(
return eval.Evaluate(data.AsDynamic, labelName, scoreName, predName);
}

/// <summary>
/// Evaluates scored binary classification data, if the predictions are not calibrated
/// and generates precision recall curve data.
/// </summary>
/// <typeparam name="T">The shape type for the input data.</typeparam>
/// <param name="catalog">The binary classification catalog.</param>
/// <param name="data">The data to evaluate.</param>
/// <param name="label">The index delegate for the label column.</param>
/// <param name="pred">The index delegate for columns from uncalibrated prediction of a binary classifier.
/// Under typical scenarios, this will just be the same tuple of results returned from the trainer.</param>
/// <param name="prCurve">The generated precision recall curve data. Up to 100000 of samples are used for p/r curve generation.</param>
/// <returns>The evaluation results for these uncalibrated outputs.</returns>
public static BinaryClassificationMetrics EvaluateWithPRCurve<T>(
this BinaryClassificationCatalog catalog,
DataView<T> data,
Func<T, Scalar<bool>> label,
Func<T, (Scalar<float> score, Scalar<bool> predictedLabel)> pred,
out List<BinaryPrecisionRecallDataPoint> prCurve)
{
Contracts.CheckValue(data, nameof(data));
var env = StaticPipeUtils.GetEnvironment(data);
Contracts.AssertValue(env);
env.CheckValue(label, nameof(label));
env.CheckValue(pred, nameof(pred));

var indexer = StaticPipeUtils.GetIndexer(data);
string labelName = indexer.Get(label(indexer.Indices));
(var scoreCol, var predCol) = pred(indexer.Indices);
Contracts.CheckParam(scoreCol != null, nameof(pred), "Indexing delegate resulted in null score column.");
Contracts.CheckParam(predCol != null, nameof(pred), "Indexing delegate resulted in null predicted label column.");
string scoreName = indexer.Get(scoreCol);
string predName = indexer.Get(predCol);

var eval = new BinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments() { NumRocExamples = 100000 });
return eval.EvaluateWithPRCurve(data.AsDynamic, labelName, scoreName, predName, out prCurve);
}

/// <summary>
/// Evaluates scored clustering prediction data.
/// </summary>
Expand Down
Loading