-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Fixed AutoML CrossValSummaryRunner for TopKAccuracyForAllK #5548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,19 +42,48 @@ public void AutoFitBinaryTest() | |
| Assert.NotNull(result.BestRun.TrainerName); | ||
| } | ||
|
|
||
| [Fact] | ||
| public void AutoFitMultiTest() | ||
| [Theory] | ||
| [InlineData(true)] | ||
| [InlineData(false)] | ||
| public void AutoFitMultiTest(bool useNumberOfCVFolds) | ||
| { | ||
| var context = new MLContext(0); | ||
| var columnInference = context.Auto().InferColumns(DatasetUtil.TrivialMulticlassDatasetPath, DatasetUtil.TrivialMulticlassDatasetLabel); | ||
| var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderOptions); | ||
| var trainData = textLoader.Load(DatasetUtil.TrivialMulticlassDatasetPath); | ||
| var result = context.Auto() | ||
| .CreateMulticlassClassificationExperiment(0) | ||
| .Execute(trainData, 5, DatasetUtil.TrivialMulticlassDatasetLabel); | ||
| Assert.True(result.BestRun.Results.First().ValidationMetrics.MicroAccuracy >= 0.7); | ||
| var scoredData = result.BestRun.Results.First().Model.Transform(trainData); | ||
| Assert.Equal(NumberDataViewType.Single, scoredData.Schema[DefaultColumnNames.PredictedLabel].Type); | ||
|
|
||
| if (useNumberOfCVFolds) | ||
| { | ||
| // When setting numberOfCVFolds | ||
| // The results object is a CrossValidationExperimentResults<> object | ||
| uint numberOfCVFolds = 5; | ||
| var result = context.Auto() | ||
| .CreateMulticlassClassificationExperiment(0) | ||
| .Execute(trainData, numberOfCVFolds, DatasetUtil.TrivialMulticlassDatasetLabel); | ||
|
|
||
| Assert.True(result.BestRun.Results.First().ValidationMetrics.MicroAccuracy >= 0.7); | ||
| var scoredData = result.BestRun.Results.First().Model.Transform(trainData); | ||
|
Comment on lines
+64
to
+65
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So here it's enough to check the accuracy of only the first
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Before my modifications, the test simply checked the accuracy of the first result, so I've pretty much left that test untouched. I've just added another test inside this one, to reuse code. |
||
| Assert.Equal(NumberDataViewType.Single, scoredData.Schema[DefaultColumnNames.PredictedLabel].Type); | ||
| } | ||
| else | ||
| { | ||
| // When using this other API, if the trainset is under the | ||
| // crossValRowCounThreshold, AutoML will also perform CrossValidation | ||
| // but through a very different path that the one above, | ||
| // throw a CrossValSummaryRunner and will return | ||
| // a different type of object as "result" which would now be | ||
| // simply a ExperimentResult<> object | ||
|
|
||
| int crossValRowCountThreshold = 15000; | ||
| trainData = context.Data.TakeRows(trainData, crossValRowCountThreshold - 1); | ||
| var result = context.Auto() | ||
| .CreateMulticlassClassificationExperiment(0) | ||
| .Execute(trainData, DatasetUtil.TrivialMulticlassDatasetLabel); | ||
|
|
||
| Assert.True(result.BestRun.ValidationMetrics.MicroAccuracy >= 0.7); | ||
| var scoredData = result.BestRun.Model.Transform(trainData); | ||
| Assert.Equal(NumberDataViewType.Single, scoredData.Schema[DefaultColumnNames.PredictedLabel].Type); | ||
| } | ||
| } | ||
|
|
||
| [TensorFlowFact] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd recommend taking the average of the non-null elements. In the TopKAccuracyForAllK case, since all are expected to be null, we would check for all values being null, and return null.
That would be a modification of
GetAverageOfNonNaNScores()below.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, by only modifying
GetAverageOfNonNaNScoresInNestedEnumerable