Skip to content

Commit e38647c

Browse files
yaeldMSantoniovs1029
authored andcommitted
Fix PFI issue in binary classification (#4587)
This change adds support for running PFI on binary classification models that do not contain a calibrator. Fixes #4517 .
1 parent 5bba7ed commit e38647c

File tree

3 files changed

+41
-11
lines changed

3 files changed

+41
-11
lines changed

src/Microsoft.ML.Transforms/PermutationFeatureImportance.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ public static ImmutableArray<TResult>
171171
int processedCnt = 0;
172172
int nextFeatureIndex = 0;
173173
var shuffleRand = RandomUtils.Create(host.Rand.Next());
174-
using (var pch = host.StartProgressChannel("SDCA preprocessing with lookup"))
174+
using (var pch = host.StartProgressChannel("Calculating Permutation Feature Importance"))
175175
{
176176
pch.SetHeader(new ProgressHeader("processed slots"), e => e.SetProgress(0, processedCnt));
177177
foreach (var workingIndx in workingFeatureIndices)

src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -145,16 +145,16 @@ public static ImmutableArray<BinaryClassificationMetricsStatistics>
145145
int permutationCount = 1) where TModel : class
146146
{
147147
return PermutationFeatureImportance<TModel, BinaryClassificationMetrics, BinaryClassificationMetricsStatistics>.GetImportanceMetricsMatrix(
148-
catalog.GetEnvironment(),
149-
predictionTransformer,
150-
data,
151-
() => new BinaryClassificationMetricsStatistics(),
152-
idv => catalog.Evaluate(idv, labelColumnName),
153-
BinaryClassifierDelta,
154-
predictionTransformer.FeatureColumnName,
155-
permutationCount,
156-
useFeatureWeightFilter,
157-
numberOfExamplesToUse);
148+
catalog.GetEnvironment(),
149+
predictionTransformer,
150+
data,
151+
() => new BinaryClassificationMetricsStatistics(),
152+
idv => catalog.EvaluateNonCalibrated(idv, labelColumnName),
153+
BinaryClassifierDelta,
154+
predictionTransformer.FeatureColumnName,
155+
permutationCount,
156+
useFeatureWeightFilter,
157+
numberOfExamplesToUse);
158158
}
159159

160160
private static BinaryClassificationMetrics BinaryClassifierDelta(

test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,36 @@ public void TestPfiBinaryClassificationOnSparseFeatures(bool saveModel)
305305

306306
Done();
307307
}
308+
309+
[Fact]
310+
public void TestBinaryClassificationWithoutCalibrator()
311+
{
312+
var dataPath = GetDataPath("breast-cancer.txt");
313+
var ff = ML.BinaryClassification.Trainers.FastForest();
314+
var data = ML.Data.LoadFromTextFile(dataPath,
315+
new[] { new TextLoader.Column("Label", DataKind.Boolean, 0),
316+
new TextLoader.Column("Features", DataKind.Single, 1, 9) });
317+
var model = ff.Fit(data);
318+
var pfi = ML.BinaryClassification.PermutationFeatureImportance(model, data);
319+
320+
// For the following metrics higher is better, so minimum delta means more important feature, and vice versa
321+
Assert.Equal(7, MaxDeltaIndex(pfi, m => m.AreaUnderRocCurve.Mean));
322+
Assert.Equal(1, MinDeltaIndex(pfi, m => m.AreaUnderRocCurve.Mean));
323+
Assert.Equal(3, MaxDeltaIndex(pfi, m => m.Accuracy.Mean));
324+
Assert.Equal(1, MinDeltaIndex(pfi, m => m.Accuracy.Mean));
325+
Assert.Equal(3, MaxDeltaIndex(pfi, m => m.PositivePrecision.Mean));
326+
Assert.Equal(1, MinDeltaIndex(pfi, m => m.PositivePrecision.Mean));
327+
Assert.Equal(3, MaxDeltaIndex(pfi, m => m.PositiveRecall.Mean));
328+
Assert.Equal(1, MinDeltaIndex(pfi, m => m.PositiveRecall.Mean));
329+
Assert.Equal(3, MaxDeltaIndex(pfi, m => m.NegativePrecision.Mean));
330+
Assert.Equal(1, MinDeltaIndex(pfi, m => m.NegativePrecision.Mean));
331+
Assert.Equal(2, MaxDeltaIndex(pfi, m => m.NegativeRecall.Mean));
332+
Assert.Equal(1, MinDeltaIndex(pfi, m => m.NegativeRecall.Mean));
333+
Assert.Equal(3, MaxDeltaIndex(pfi, m => m.F1Score.Mean));
334+
Assert.Equal(1, MinDeltaIndex(pfi, m => m.F1Score.Mean));
335+
Assert.Equal(7, MaxDeltaIndex(pfi, m => m.AreaUnderPrecisionRecallCurve.Mean));
336+
Assert.Equal(1, MinDeltaIndex(pfi, m => m.AreaUnderPrecisionRecallCurve.Mean));
337+
}
308338
#endregion
309339

310340
#region Multiclass Classification Tests

0 commit comments

Comments
 (0)