diff --git a/src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs b/src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs index 761aac2f4c..70db9035e5 100644 --- a/src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs +++ b/src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs @@ -447,7 +447,7 @@ public Task RunAsync(TrialSettings settings, CancellationToken ct) _context?.CancelExecution(); })) { - return Task.Run(() => Run(settings)); + return Task.FromResult(Run(settings)); } } catch (Exception ex) when (ct.IsCancellationRequested) diff --git a/src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs b/src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs index b7344e7b0f..df96c28873 100644 --- a/src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs +++ b/src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs @@ -445,7 +445,7 @@ public Task RunAsync(TrialSettings settings, CancellationToken ct) _context?.CancelExecution(); })) { - return Task.Run(() => Run(settings)); + return Task.FromResult(Run(settings)); } } catch (Exception ex) when (ct.IsCancellationRequested) diff --git a/src/Microsoft.ML.AutoML/API/RegressionExperiment.cs b/src/Microsoft.ML.AutoML/API/RegressionExperiment.cs index 3d93a63b0a..e8ac5e405e 100644 --- a/src/Microsoft.ML.AutoML/API/RegressionExperiment.cs +++ b/src/Microsoft.ML.AutoML/API/RegressionExperiment.cs @@ -159,7 +159,7 @@ public override ExperimentResult Execute(IDataView trainData, int numCrossValFolds = 10; _experiment.SetDataset(trainData, numCrossValFolds); _pipeline = CreateRegressionPipeline(trainData, columnInformation, preFeaturizer); - + _experiment.SetPipeline(_pipeline); TrialResultMonitor monitor = null; _experiment.SetMonitor((provider) => { diff --git a/src/Microsoft.ML.AutoML/AutoMLExperiment/AutoMLExperiment.cs b/src/Microsoft.ML.AutoML/AutoMLExperiment/AutoMLExperiment.cs index ab89bfcf13..5836fc2c8b 100644 --- a/src/Microsoft.ML.AutoML/AutoMLExperiment/AutoMLExperiment.cs +++ b/src/Microsoft.ML.AutoML/AutoMLExperiment/AutoMLExperiment.cs @@ -51,15 +51,7 @@ private void InitializeServiceCollection() _serviceCollection.TryAddTransient((provider) => { var contextManager = provider.GetRequiredService(); - var trainingStopManager = provider.GetRequiredService(); var context = contextManager.CreateMLContext(); - trainingStopManager.OnStopTraining += (s, e) => - { - // only force-canceling running trials when there's completed trials. - // otherwise, wait for the current running trial to be completed. - if (_bestTrialResult != null) - context.CancelExecution(); - }; return context; }); diff --git a/src/Microsoft.ML.AutoML/AutoMLExperiment/Runner/SweepablePipelineRunner.cs b/src/Microsoft.ML.AutoML/AutoMLExperiment/Runner/SweepablePipelineRunner.cs index d42363bde0..ff3298c291 100644 --- a/src/Microsoft.ML.AutoML/AutoMLExperiment/Runner/SweepablePipelineRunner.cs +++ b/src/Microsoft.ML.AutoML/AutoMLExperiment/Runner/SweepablePipelineRunner.cs @@ -96,7 +96,7 @@ public Task RunAsync(TrialSettings settings, CancellationToken ct) _mLContext?.CancelExecution(); })) { - return Task.Run(() => Run(settings)); + return Task.FromResult(Run(settings)); } } catch (Exception ex) when (ct.IsCancellationRequested) diff --git a/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs b/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs index 432895c441..a62b2dfc4b 100644 --- a/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs +++ b/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs @@ -48,7 +48,7 @@ public void AutoFit_UCI_Adult_Test() var trainData = textLoader.Load(dataPath); var settings = new BinaryExperimentSettings { - MaxExperimentTimeInSeconds = 1, + MaxModels = 1, }; settings.Trainers.Remove(BinaryClassificationTrainer.LightGbm); @@ -75,7 +75,7 @@ public void AutoFit_UCI_Adult_Train_Test_Split_Test() var dataTrainTest = context.Data.TrainTestSplit(trainData); var settings = new BinaryExperimentSettings { - MaxExperimentTimeInSeconds = 1, + MaxModels = 1, }; settings.Trainers.Remove(BinaryClassificationTrainer.LightGbm); @@ -101,7 +101,7 @@ public void AutoFit_UCI_Adult_CrossValidation_10_Test() var trainData = textLoader.Load(dataPath); var settings = new BinaryExperimentSettings { - MaxExperimentTimeInSeconds = 1, + MaxModels = 1, }; settings.Trainers.Remove(BinaryClassificationTrainer.LightGbm); @@ -197,6 +197,7 @@ public void AutoFit_Taxi_Fare_Test() settings.Trainers.Remove(RegressionTrainer.StochasticDualCoordinateAscent); settings.Trainers.Remove(RegressionTrainer.LbfgsPoissonRegression); + // verify for dataset > 15000L var result = context.Auto() .CreateRegressionExperiment(settings) .Execute(dataset, label); @@ -204,6 +205,15 @@ public void AutoFit_Taxi_Fare_Test() Assert.True(result.BestRun.ValidationMetrics.RSquared > 0.70); Assert.NotNull(result.BestRun.Estimator); Assert.NotNull(result.BestRun.TrainerName); + + // verify for dataset < 15000L + result = context.Auto() + .CreateRegressionExperiment(settings) + .Execute(context.Data.TakeRows(dataset, 1000), label); + + Assert.True(result.BestRun.ValidationMetrics.RSquared > 0.70); + Assert.NotNull(result.BestRun.Estimator); + Assert.NotNull(result.BestRun.TrainerName); } [Theory] @@ -229,7 +239,7 @@ public void AutoFitMultiTest(bool useNumberOfCVFolds) uint numberOfCVFolds = 5; var settings = new MulticlassExperimentSettings { - MaxExperimentTimeInSeconds = 1, + MaxModels = 1, }; settings.Trainers.Remove(MulticlassClassificationTrainer.LightGbm); @@ -257,7 +267,7 @@ public void AutoFitMultiTest(bool useNumberOfCVFolds) trainData = context.Data.TakeRows(trainData, crossValRowCountThreshold - 1); var settings = new MulticlassExperimentSettings { - MaxExperimentTimeInSeconds = 1, + MaxModels = 1, }; settings.Trainers.Remove(MulticlassClassificationTrainer.LightGbm); @@ -286,8 +296,13 @@ public void AutoFitMultiClassification_Image_TrainTest() TrainTestData trainTestData = context.Data.TrainTestSplit(trainData, testFraction: 0.2, seed: 1); IDataView trainDataset = SplitUtil.DropAllColumnsExcept(context, trainTestData.TrainSet, originalColumnNames); IDataView testDataset = SplitUtil.DropAllColumnsExcept(context, trainTestData.TestSet, originalColumnNames); + var settings = new MulticlassExperimentSettings + { + MaxModels = 1, + }; + var result = context.Auto() - .CreateMulticlassClassificationExperiment(20) + .CreateMulticlassClassificationExperiment(settings) .Execute(trainDataset, testDataset, columnInference.ColumnInformation); result.BestRun.ValidationMetrics.MicroAccuracy.Should().BeGreaterThan(0.1); @@ -305,8 +320,12 @@ public void AutoFitMultiClassification_Image_CV() var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderOptions); var trainData = context.Data.ShuffleRows(textLoader.Load(datasetPath), seed: 1); var originalColumnNames = trainData.Schema.Select(c => c.Name); + var settings = new MulticlassExperimentSettings + { + MaxModels = 1, + }; var result = context.Auto() - .CreateMulticlassClassificationExperiment(100) + .CreateMulticlassClassificationExperiment(settings) .Execute(trainData, 5, columnInference.ColumnInformation); result.BestRun.Results.Select(x => x.ValidationMetrics.MicroAccuracy).Max().Should().BeGreaterThan(0.1); @@ -330,8 +349,12 @@ public void AutoFitMultiClassification_Image() var columnInference = context.Auto().InferColumns(datasetPath, "Label"); var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderOptions); var trainData = textLoader.Load(datasetPath); + var settings = new MulticlassExperimentSettings + { + MaxModels = 1, + }; var result = context.Auto() - .CreateMulticlassClassificationExperiment(100) + .CreateMulticlassClassificationExperiment(settings) .Execute(trainData, columnInference.ColumnInformation); Assert.InRange(result.BestRun.ValidationMetrics.MicroAccuracy, 0.1, 0.9); @@ -358,7 +381,7 @@ public void AutoFitRankingTest() // STEP 2: Run AutoML experiment var settings = new RankingExperimentSettings() { - MaxExperimentTimeInSeconds = 5, + MaxModels = 5, OptimizationMetricTruncationLevel = 3 }; var experiment = mlContext.Auto() diff --git a/test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs b/test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs index 8e073394c6..415528ba9c 100644 --- a/test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs +++ b/test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs @@ -151,6 +151,13 @@ public async Task AutoMLExperiment_return_current_best_trial_when_ct_is_canceled public async Task AutoMLExperiment_finish_training_when_time_is_up_Async() { var context = new MLContext(1); + context.Log += (o, e) => + { + if (e.Source.StartsWith("AutoMLExperiment")) + { + this.Output.WriteLine(e.RawMessage); + } + }; var experiment = context.Auto().CreateExperiment(); experiment.SetTrainingTimeInSeconds(5) @@ -158,7 +165,7 @@ public async Task AutoMLExperiment_finish_training_when_time_is_up_Async() { var channel = serviceProvider.GetService(); var settings = serviceProvider.GetService(); - return new DummyTrialRunner(settings, 1, channel); + return new DummyTrialRunner(settings, 0, channel); }) .SetTuner();