diff --git a/src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs b/src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs index 70db9035e5..ea7e12c769 100644 --- a/src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs +++ b/src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs @@ -341,12 +341,14 @@ internal class BinaryClassificationRunner : ITrialRunner { private MLContext _context; private readonly IDatasetManager _datasetManager; + private readonly IMLContextManager _contextManager; private readonly IMetricManager _metricManager; private readonly SweepablePipeline _pipeline; private readonly Random _rnd; - public BinaryClassificationRunner(MLContext context, IDatasetManager datasetManager, IMetricManager metricManager, SweepablePipeline pipeline, AutoMLExperiment.AutoMLExperimentSettings settings) + public BinaryClassificationRunner(IMLContextManager contextManager, IDatasetManager datasetManager, IMetricManager metricManager, SweepablePipeline pipeline, AutoMLExperiment.AutoMLExperimentSettings settings) { - _context = context; + _context = contextManager.CreateMLContext(); + _contextManager = contextManager; _datasetManager = datasetManager; _metricManager = metricManager; _pipeline = pipeline; @@ -365,6 +367,10 @@ public TrialResult Run(TrialSettings settings) { var parameter = settings.Parameter[AutoMLExperiment.PipelineSearchspaceName]; var pipeline = _pipeline.BuildFromOption(_context, parameter); + // _context will be cancelled after training. So returned pipeline need to be created on a + // new MLContext. + var refitContext = _contextManager.CreateMLContext(); + var refitPipeline = _pipeline.BuildFromOption(refitContext, parameter); if (_datasetManager is ICrossValidateDatasetManager datasetManager) { var stopWatch = new Stopwatch(); @@ -396,7 +402,7 @@ public TrialResult Run(TrialSettings settings) DurationInMilliseconds = stopWatch.ElapsedMilliseconds, Metrics = res.Metrics, CrossValidationMetrics = metrics, - Pipeline = pipeline, + Pipeline = refitPipeline, }; } @@ -430,7 +436,7 @@ public TrialResult Run(TrialSettings settings) TrialSettings = settings, DurationInMilliseconds = stopWatch.ElapsedMilliseconds, Metrics = metrics, - Pipeline = pipeline, + Pipeline = refitPipeline, }; } } diff --git a/src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs b/src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs index df96c28873..dfddaaa78e 100644 --- a/src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs +++ b/src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs @@ -343,12 +343,14 @@ internal class MulticlassClassificationRunner : ITrialRunner private MLContext _context; private readonly IDatasetManager _datasetManager; private readonly IMetricManager _metricManager; + private readonly IMLContextManager _contextManager; private readonly SweepablePipeline _pipeline; private readonly Random _rnd; - public MulticlassClassificationRunner(MLContext context, IDatasetManager datasetManager, IMetricManager metricManager, SweepablePipeline pipeline, AutoMLExperiment.AutoMLExperimentSettings settings) + public MulticlassClassificationRunner(IMLContextManager contextManager, IDatasetManager datasetManager, IMetricManager metricManager, SweepablePipeline pipeline, AutoMLExperiment.AutoMLExperimentSettings settings) { - _context = context; + _context = contextManager.CreateMLContext(); + _contextManager = contextManager; _datasetManager = datasetManager; _metricManager = metricManager; _pipeline = pipeline; @@ -361,6 +363,8 @@ public TrialResult Run(TrialSettings settings) { var parameter = settings.Parameter[AutoMLExperiment.PipelineSearchspaceName]; var pipeline = _pipeline.BuildFromOption(_context, parameter); + var refitContext = _contextManager.CreateMLContext(); + var refitPipeline = _pipeline.BuildFromOption(refitContext, parameter); if (_datasetManager is ICrossValidateDatasetManager datasetManager) { var stopWatch = new Stopwatch(); @@ -394,7 +398,7 @@ public TrialResult Run(TrialSettings settings) DurationInMilliseconds = stopWatch.ElapsedMilliseconds, Metrics = res.Metrics, CrossValidationMetrics = metrics, - Pipeline = pipeline, + Pipeline = refitPipeline, }; } @@ -428,7 +432,7 @@ public TrialResult Run(TrialSettings settings) TrialSettings = settings, DurationInMilliseconds = stopWatch.ElapsedMilliseconds, Metrics = metrics, - Pipeline = pipeline, + Pipeline = refitPipeline, }; } } diff --git a/src/Microsoft.ML.AutoML/API/RegressionExperiment.cs b/src/Microsoft.ML.AutoML/API/RegressionExperiment.cs index e8ac5e405e..b044ad664c 100644 --- a/src/Microsoft.ML.AutoML/API/RegressionExperiment.cs +++ b/src/Microsoft.ML.AutoML/API/RegressionExperiment.cs @@ -363,12 +363,14 @@ internal class RegressionTrialRunner : ITrialRunner private MLContext _context; private readonly IDatasetManager _datasetManager; private readonly IMetricManager _metricManager; + private readonly IMLContextManager _contextManager; private readonly SweepablePipeline _pipeline; private readonly Random _rnd; - public RegressionTrialRunner(MLContext context, IDatasetManager datasetManager, IMetricManager metricManager, SweepablePipeline pipeline, AutoMLExperiment.AutoMLExperimentSettings settings) + public RegressionTrialRunner(IMLContextManager contextManager, IDatasetManager datasetManager, IMetricManager metricManager, SweepablePipeline pipeline, AutoMLExperiment.AutoMLExperimentSettings settings) { - _context = context; + _context = contextManager.CreateMLContext(); + _contextManager = contextManager; _datasetManager = datasetManager; _metricManager = metricManager; _pipeline = pipeline; @@ -388,6 +390,8 @@ public Task RunAsync(TrialSettings settings, CancellationToken ct) { var parameter = settings.Parameter[AutoMLExperiment.PipelineSearchspaceName]; var pipeline = _pipeline.BuildFromOption(_context, parameter); + var refitContext = _contextManager.CreateMLContext(); + var refitPipeline = _pipeline.BuildFromOption(refitContext, parameter); if (_datasetManager is ICrossValidateDatasetManager datasetManager) { var stopWatch = new Stopwatch(); @@ -420,7 +424,7 @@ public Task RunAsync(TrialSettings settings, CancellationToken ct) DurationInMilliseconds = stopWatch.ElapsedMilliseconds, Metrics = res.Metrics, CrossValidationMetrics = metrics, - Pipeline = pipeline, + Pipeline = refitPipeline, } as TrialResult); } @@ -453,7 +457,7 @@ public Task RunAsync(TrialSettings settings, CancellationToken ct) TrialSettings = settings, DurationInMilliseconds = stopWatch.ElapsedMilliseconds, Metrics = res, - Pipeline = pipeline, + Pipeline = refitPipeline, } as TrialResult); } } diff --git a/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs b/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs index a62b2dfc4b..8899312dbb 100644 --- a/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs +++ b/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs @@ -113,6 +113,10 @@ public void AutoFit_UCI_Adult_CrossValidation_10_Test() Assert.True(result.BestRun.Results.Select(x => x.ValidationMetrics.Accuracy).Min() > 0.70); Assert.NotNull(result.BestRun.Estimator); Assert.NotNull(result.BestRun.TrainerName); + + // test refit + var model = result.BestRun.Estimator.Fit(trainData); + Assert.NotNull(model); } [Fact] @@ -214,6 +218,10 @@ public void AutoFit_Taxi_Fare_Test() Assert.True(result.BestRun.ValidationMetrics.RSquared > 0.70); Assert.NotNull(result.BestRun.Estimator); Assert.NotNull(result.BestRun.TrainerName); + + // verify refit + var model = result.BestRun.Estimator.Fit(context.Data.TakeRows(dataset, 1000)); + Assert.NotNull(model); } [Theory] @@ -253,6 +261,10 @@ public void AutoFitMultiTest(bool useNumberOfCVFolds) result.BestRun.Results.First().ValidationMetrics.MicroAccuracy.Should().BeGreaterThan(0.7); var scoredData = result.BestRun.Results.First().Model.Transform(trainData); Assert.Equal(NumberDataViewType.Single, scoredData.Schema[DefaultColumnNames.PredictedLabel].Type); + + // test refit + var model = result.BestRun.Estimator.Fit(trainData); + Assert.NotNull(model); } else { @@ -281,6 +293,9 @@ public void AutoFitMultiTest(bool useNumberOfCVFolds) Assert.True(result.BestRun.ValidationMetrics.MicroAccuracy >= 0.7); var scoredData = result.BestRun.Model.Transform(trainData); Assert.Equal(NumberDataViewType.Single, scoredData.Schema[DefaultColumnNames.PredictedLabel].Type); + + var model = result.BestRun.Estimator.Fit(trainData); + Assert.NotNull(model); } }