Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
_context?.CancelExecution();
}))
{
return Task.Run(() => Run(settings));
return Task.FromResult(Run(settings));
}
}
catch (Exception ex) when (ct.IsCancellationRequested)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
_context?.CancelExecution();
}))
{
return Task.Run(() => Run(settings));
return Task.FromResult(Run(settings));
}
}
catch (Exception ex) when (ct.IsCancellationRequested)
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.AutoML/API/RegressionExperiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ public override ExperimentResult<RegressionMetrics> Execute(IDataView trainData,
int numCrossValFolds = 10;
_experiment.SetDataset(trainData, numCrossValFolds);
_pipeline = CreateRegressionPipeline(trainData, columnInformation, preFeaturizer);

_experiment.SetPipeline(_pipeline);
TrialResultMonitor<RegressionMetrics> monitor = null;
_experiment.SetMonitor((provider) =>
{
Expand Down
8 changes: 0 additions & 8 deletions src/Microsoft.ML.AutoML/AutoMLExperiment/AutoMLExperiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,7 @@ private void InitializeServiceCollection()
_serviceCollection.TryAddTransient((provider) =>
{
var contextManager = provider.GetRequiredService<IMLContextManager>();
var trainingStopManager = provider.GetRequiredService<AggregateTrainingStopManager>();
var context = contextManager.CreateMLContext();
trainingStopManager.OnStopTraining += (s, e) =>
{
// only force-canceling running trials when there's completed trials.
// otherwise, wait for the current running trial to be completed.
if (_bestTrialResult != null)
context.CancelExecution();
};

return context;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
_mLContext?.CancelExecution();
}))
{
return Task.Run(() => Run(settings));
return Task.FromResult(Run(settings));
}
}
catch (Exception ex) when (ct.IsCancellationRequested)
Expand Down
41 changes: 32 additions & 9 deletions test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public void AutoFit_UCI_Adult_Test()
var trainData = textLoader.Load(dataPath);
var settings = new BinaryExperimentSettings
{
MaxExperimentTimeInSeconds = 1,
MaxModels = 1,
};

settings.Trainers.Remove(BinaryClassificationTrainer.LightGbm);
Expand All @@ -75,7 +75,7 @@ public void AutoFit_UCI_Adult_Train_Test_Split_Test()
var dataTrainTest = context.Data.TrainTestSplit(trainData);
var settings = new BinaryExperimentSettings
{
MaxExperimentTimeInSeconds = 1,
MaxModels = 1,
};

settings.Trainers.Remove(BinaryClassificationTrainer.LightGbm);
Expand All @@ -101,7 +101,7 @@ public void AutoFit_UCI_Adult_CrossValidation_10_Test()
var trainData = textLoader.Load(dataPath);
var settings = new BinaryExperimentSettings
{
MaxExperimentTimeInSeconds = 1,
MaxModels = 1,
};

settings.Trainers.Remove(BinaryClassificationTrainer.LightGbm);
Expand Down Expand Up @@ -197,13 +197,23 @@ public void AutoFit_Taxi_Fare_Test()
settings.Trainers.Remove(RegressionTrainer.StochasticDualCoordinateAscent);
settings.Trainers.Remove(RegressionTrainer.LbfgsPoissonRegression);

// verify for dataset > 15000L
var result = context.Auto()
.CreateRegressionExperiment(settings)
.Execute(dataset, label);

Assert.True(result.BestRun.ValidationMetrics.RSquared > 0.70);
Assert.NotNull(result.BestRun.Estimator);
Assert.NotNull(result.BestRun.TrainerName);

// verify for dataset < 15000L
result = context.Auto()
.CreateRegressionExperiment(settings)
.Execute(context.Data.TakeRows(dataset, 1000), label);

Assert.True(result.BestRun.ValidationMetrics.RSquared > 0.70);
Assert.NotNull(result.BestRun.Estimator);
Assert.NotNull(result.BestRun.TrainerName);
}

[Theory]
Expand All @@ -229,7 +239,7 @@ public void AutoFitMultiTest(bool useNumberOfCVFolds)
uint numberOfCVFolds = 5;
var settings = new MulticlassExperimentSettings
{
MaxExperimentTimeInSeconds = 1,
MaxModels = 1,
};

settings.Trainers.Remove(MulticlassClassificationTrainer.LightGbm);
Expand Down Expand Up @@ -257,7 +267,7 @@ public void AutoFitMultiTest(bool useNumberOfCVFolds)
trainData = context.Data.TakeRows(trainData, crossValRowCountThreshold - 1);
var settings = new MulticlassExperimentSettings
{
MaxExperimentTimeInSeconds = 1,
MaxModels = 1,
};

settings.Trainers.Remove(MulticlassClassificationTrainer.LightGbm);
Expand Down Expand Up @@ -286,8 +296,13 @@ public void AutoFitMultiClassification_Image_TrainTest()
TrainTestData trainTestData = context.Data.TrainTestSplit(trainData, testFraction: 0.2, seed: 1);
IDataView trainDataset = SplitUtil.DropAllColumnsExcept(context, trainTestData.TrainSet, originalColumnNames);
IDataView testDataset = SplitUtil.DropAllColumnsExcept(context, trainTestData.TestSet, originalColumnNames);
var settings = new MulticlassExperimentSettings
{
MaxModels = 1,
};

var result = context.Auto()
.CreateMulticlassClassificationExperiment(20)
.CreateMulticlassClassificationExperiment(settings)
.Execute(trainDataset, testDataset, columnInference.ColumnInformation);

result.BestRun.ValidationMetrics.MicroAccuracy.Should().BeGreaterThan(0.1);
Expand All @@ -305,8 +320,12 @@ public void AutoFitMultiClassification_Image_CV()
var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderOptions);
var trainData = context.Data.ShuffleRows(textLoader.Load(datasetPath), seed: 1);
var originalColumnNames = trainData.Schema.Select(c => c.Name);
var settings = new MulticlassExperimentSettings
{
MaxModels = 1,
};
var result = context.Auto()
.CreateMulticlassClassificationExperiment(100)
.CreateMulticlassClassificationExperiment(settings)
.Execute(trainData, 5, columnInference.ColumnInformation);

result.BestRun.Results.Select(x => x.ValidationMetrics.MicroAccuracy).Max().Should().BeGreaterThan(0.1);
Expand All @@ -330,8 +349,12 @@ public void AutoFitMultiClassification_Image()
var columnInference = context.Auto().InferColumns(datasetPath, "Label");
var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderOptions);
var trainData = textLoader.Load(datasetPath);
var settings = new MulticlassExperimentSettings
{
MaxModels = 1,
};
var result = context.Auto()
.CreateMulticlassClassificationExperiment(100)
.CreateMulticlassClassificationExperiment(settings)
.Execute(trainData, columnInference.ColumnInformation);

Assert.InRange(result.BestRun.ValidationMetrics.MicroAccuracy, 0.1, 0.9);
Expand All @@ -358,7 +381,7 @@ public void AutoFitRankingTest()
// STEP 2: Run AutoML experiment
var settings = new RankingExperimentSettings()
{
MaxExperimentTimeInSeconds = 5,
MaxModels = 5,
OptimizationMetricTruncationLevel = 3
};
var experiment = mlContext.Auto()
Expand Down
9 changes: 8 additions & 1 deletion test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,21 @@ public async Task AutoMLExperiment_return_current_best_trial_when_ct_is_canceled
public async Task AutoMLExperiment_finish_training_when_time_is_up_Async()
{
var context = new MLContext(1);
context.Log += (o, e) =>
{
if (e.Source.StartsWith("AutoMLExperiment"))
{
this.Output.WriteLine(e.RawMessage);
}
};

var experiment = context.Auto().CreateExperiment();
experiment.SetTrainingTimeInSeconds(5)
.SetTrialRunner((serviceProvider) =>
{
var channel = serviceProvider.GetService<IChannel>();
var settings = serviceProvider.GetService<AutoMLExperiment.AutoMLExperimentSettings>();
return new DummyTrialRunner(settings, 1, channel);
return new DummyTrialRunner(settings, 0, channel);
})
.SetTuner<RandomSearchTuner>();

Expand Down