Skip to content

Commit 5318cc2

Browse files
Fixed AutoML CrossValSummaryRunner for TopKAccuracyForAllK (#5548)
* Fixed bug
1 parent 26066f7 commit 5318cc2

File tree

2 files changed

+50
-8
lines changed

2 files changed

+50
-8
lines changed

src/Microsoft.ML.AutoML/Experiment/Runners/CrossValSummaryRunner.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,19 @@ private static TMetrics GetAverageMetrics(IEnumerable<TMetrics> metrics, TMetric
159159

160160
private static double[] GetAverageOfNonNaNScoresInNestedEnumerable(IEnumerable<IEnumerable<double>> results)
161161
{
162+
if (results.All(result => result == null))
163+
{
164+
// If all nested enumerables are null, we say the average is a null enumerable as well.
165+
// This is expected to happen on Multiclass metrics where the TopKAccuracyForAllK
166+
// array can be null if the topKPredictionCount isn't a valid number.
167+
// In that case all of the "results" enumerables will be null anyway, and so
168+
// returning null is the expected solution.
169+
return null;
170+
}
171+
172+
// In case there are only some null elements, we'll ignore them:
173+
results = results.Where(result => result != null);
174+
162175
double[] arr = new double[results.ElementAt(0).Count()];
163176
for (int i = 0; i < arr.Length; i++)
164177
{

test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,48 @@ public void AutoFitBinaryTest()
4242
Assert.NotNull(result.BestRun.TrainerName);
4343
}
4444

45-
[Fact]
46-
public void AutoFitMultiTest()
45+
[Theory]
46+
[InlineData(true)]
47+
[InlineData(false)]
48+
public void AutoFitMultiTest(bool useNumberOfCVFolds)
4749
{
4850
var context = new MLContext(0);
4951
var columnInference = context.Auto().InferColumns(DatasetUtil.TrivialMulticlassDatasetPath, DatasetUtil.TrivialMulticlassDatasetLabel);
5052
var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderOptions);
5153
var trainData = textLoader.Load(DatasetUtil.TrivialMulticlassDatasetPath);
52-
var result = context.Auto()
53-
.CreateMulticlassClassificationExperiment(0)
54-
.Execute(trainData, 5, DatasetUtil.TrivialMulticlassDatasetLabel);
55-
Assert.True(result.BestRun.Results.First().ValidationMetrics.MicroAccuracy >= 0.7);
56-
var scoredData = result.BestRun.Results.First().Model.Transform(trainData);
57-
Assert.Equal(NumberDataViewType.Single, scoredData.Schema[DefaultColumnNames.PredictedLabel].Type);
54+
55+
if (useNumberOfCVFolds)
56+
{
57+
// When setting numberOfCVFolds
58+
// The results object is a CrossValidationExperimentResults<> object
59+
uint numberOfCVFolds = 5;
60+
var result = context.Auto()
61+
.CreateMulticlassClassificationExperiment(0)
62+
.Execute(trainData, numberOfCVFolds, DatasetUtil.TrivialMulticlassDatasetLabel);
63+
64+
Assert.True(result.BestRun.Results.First().ValidationMetrics.MicroAccuracy >= 0.7);
65+
var scoredData = result.BestRun.Results.First().Model.Transform(trainData);
66+
Assert.Equal(NumberDataViewType.Single, scoredData.Schema[DefaultColumnNames.PredictedLabel].Type);
67+
}
68+
else
69+
{
70+
// When using this other API, if the trainset is under the
71+
// crossValRowCounThreshold, AutoML will also perform CrossValidation
72+
// but through a very different path that the one above,
73+
// throw a CrossValSummaryRunner and will return
74+
// a different type of object as "result" which would now be
75+
// simply a ExperimentResult<> object
76+
77+
int crossValRowCountThreshold = 15000;
78+
trainData = context.Data.TakeRows(trainData, crossValRowCountThreshold - 1);
79+
var result = context.Auto()
80+
.CreateMulticlassClassificationExperiment(0)
81+
.Execute(trainData, DatasetUtil.TrivialMulticlassDatasetLabel);
82+
83+
Assert.True(result.BestRun.ValidationMetrics.MicroAccuracy >= 0.7);
84+
var scoredData = result.BestRun.Model.Transform(trainData);
85+
Assert.Equal(NumberDataViewType.Single, scoredData.Schema[DefaultColumnNames.PredictedLabel].Type);
86+
}
5887
}
5988

6089
[TensorFlowFact]

0 commit comments

Comments
 (0)