diff --git a/src/Microsoft.ML.AutoML/Experiment/Runners/CrossValSummaryRunner.cs b/src/Microsoft.ML.AutoML/Experiment/Runners/CrossValSummaryRunner.cs index 0079be3ade..9c382468a6 100644 --- a/src/Microsoft.ML.AutoML/Experiment/Runners/CrossValSummaryRunner.cs +++ b/src/Microsoft.ML.AutoML/Experiment/Runners/CrossValSummaryRunner.cs @@ -159,6 +159,19 @@ private static TMetrics GetAverageMetrics(IEnumerable metrics, TMetric private static double[] GetAverageOfNonNaNScoresInNestedEnumerable(IEnumerable> results) { + if (results.All(result => result == null)) + { + // If all nested enumerables are null, we say the average is a null enumerable as well. + // This is expected to happen on Multiclass metrics where the TopKAccuracyForAllK + // array can be null if the topKPredictionCount isn't a valid number. + // In that case all of the "results" enumerables will be null anyway, and so + // returning null is the expected solution. + return null; + } + + // In case there are only some null elements, we'll ignore them: + results = results.Where(result => result != null); + double[] arr = new double[results.ElementAt(0).Count()]; for (int i = 0; i < arr.Length; i++) { diff --git a/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs b/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs index 14771bb329..86670d5630 100644 --- a/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs +++ b/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs @@ -42,19 +42,48 @@ public void AutoFitBinaryTest() Assert.NotNull(result.BestRun.TrainerName); } - [Fact] - public void AutoFitMultiTest() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void AutoFitMultiTest(bool useNumberOfCVFolds) { var context = new MLContext(0); var columnInference = context.Auto().InferColumns(DatasetUtil.TrivialMulticlassDatasetPath, DatasetUtil.TrivialMulticlassDatasetLabel); var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderOptions); var trainData = textLoader.Load(DatasetUtil.TrivialMulticlassDatasetPath); - var result = context.Auto() - .CreateMulticlassClassificationExperiment(0) - .Execute(trainData, 5, DatasetUtil.TrivialMulticlassDatasetLabel); - Assert.True(result.BestRun.Results.First().ValidationMetrics.MicroAccuracy >= 0.7); - var scoredData = result.BestRun.Results.First().Model.Transform(trainData); - Assert.Equal(NumberDataViewType.Single, scoredData.Schema[DefaultColumnNames.PredictedLabel].Type); + + if (useNumberOfCVFolds) + { + // When setting numberOfCVFolds + // The results object is a CrossValidationExperimentResults<> object + uint numberOfCVFolds = 5; + var result = context.Auto() + .CreateMulticlassClassificationExperiment(0) + .Execute(trainData, numberOfCVFolds, DatasetUtil.TrivialMulticlassDatasetLabel); + + Assert.True(result.BestRun.Results.First().ValidationMetrics.MicroAccuracy >= 0.7); + var scoredData = result.BestRun.Results.First().Model.Transform(trainData); + Assert.Equal(NumberDataViewType.Single, scoredData.Schema[DefaultColumnNames.PredictedLabel].Type); + } + else + { + // When using this other API, if the trainset is under the + // crossValRowCounThreshold, AutoML will also perform CrossValidation + // but through a very different path that the one above, + // throw a CrossValSummaryRunner and will return + // a different type of object as "result" which would now be + // simply a ExperimentResult<> object + + int crossValRowCountThreshold = 15000; + trainData = context.Data.TakeRows(trainData, crossValRowCountThreshold - 1); + var result = context.Auto() + .CreateMulticlassClassificationExperiment(0) + .Execute(trainData, DatasetUtil.TrivialMulticlassDatasetLabel); + + Assert.True(result.BestRun.ValidationMetrics.MicroAccuracy >= 0.7); + var scoredData = result.BestRun.Model.Transform(trainData); + Assert.Equal(NumberDataViewType.Single, scoredData.Schema[DefaultColumnNames.PredictedLabel].Type); + } } [TensorFlowFact]