Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -341,12 +341,14 @@ internal class BinaryClassificationRunner : ITrialRunner
{
private MLContext _context;
private readonly IDatasetManager _datasetManager;
private readonly IMLContextManager _contextManager;
private readonly IMetricManager _metricManager;
private readonly SweepablePipeline _pipeline;
private readonly Random _rnd;
public BinaryClassificationRunner(MLContext context, IDatasetManager datasetManager, IMetricManager metricManager, SweepablePipeline pipeline, AutoMLExperiment.AutoMLExperimentSettings settings)
public BinaryClassificationRunner(IMLContextManager contextManager, IDatasetManager datasetManager, IMetricManager metricManager, SweepablePipeline pipeline, AutoMLExperiment.AutoMLExperimentSettings settings)
{
_context = context;
_context = contextManager.CreateMLContext();
_contextManager = contextManager;
_datasetManager = datasetManager;
_metricManager = metricManager;
_pipeline = pipeline;
Expand All @@ -365,6 +367,10 @@ public TrialResult Run(TrialSettings settings)
{
var parameter = settings.Parameter[AutoMLExperiment.PipelineSearchspaceName];
var pipeline = _pipeline.BuildFromOption(_context, parameter);
// _context will be cancelled after training. So returned pipeline need to be created on a
// new MLContext.
var refitContext = _contextManager.CreateMLContext();
var refitPipeline = _pipeline.BuildFromOption(refitContext, parameter);
if (_datasetManager is ICrossValidateDatasetManager datasetManager)
{
var stopWatch = new Stopwatch();
Expand Down Expand Up @@ -396,7 +402,7 @@ public TrialResult Run(TrialSettings settings)
DurationInMilliseconds = stopWatch.ElapsedMilliseconds,
Metrics = res.Metrics,
CrossValidationMetrics = metrics,
Pipeline = pipeline,
Pipeline = refitPipeline,
};
}

Expand Down Expand Up @@ -430,7 +436,7 @@ public TrialResult Run(TrialSettings settings)
TrialSettings = settings,
DurationInMilliseconds = stopWatch.ElapsedMilliseconds,
Metrics = metrics,
Pipeline = pipeline,
Pipeline = refitPipeline,
};
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,12 +343,14 @@ internal class MulticlassClassificationRunner : ITrialRunner
private MLContext _context;
private readonly IDatasetManager _datasetManager;
private readonly IMetricManager _metricManager;
private readonly IMLContextManager _contextManager;
private readonly SweepablePipeline _pipeline;
private readonly Random _rnd;

public MulticlassClassificationRunner(MLContext context, IDatasetManager datasetManager, IMetricManager metricManager, SweepablePipeline pipeline, AutoMLExperiment.AutoMLExperimentSettings settings)
public MulticlassClassificationRunner(IMLContextManager contextManager, IDatasetManager datasetManager, IMetricManager metricManager, SweepablePipeline pipeline, AutoMLExperiment.AutoMLExperimentSettings settings)
{
_context = context;
_context = contextManager.CreateMLContext();
_contextManager = contextManager;
_datasetManager = datasetManager;
_metricManager = metricManager;
_pipeline = pipeline;
Expand All @@ -361,6 +363,8 @@ public TrialResult Run(TrialSettings settings)
{
var parameter = settings.Parameter[AutoMLExperiment.PipelineSearchspaceName];
var pipeline = _pipeline.BuildFromOption(_context, parameter);
var refitContext = _contextManager.CreateMLContext();
var refitPipeline = _pipeline.BuildFromOption(refitContext, parameter);
if (_datasetManager is ICrossValidateDatasetManager datasetManager)
{
var stopWatch = new Stopwatch();
Expand Down Expand Up @@ -394,7 +398,7 @@ public TrialResult Run(TrialSettings settings)
DurationInMilliseconds = stopWatch.ElapsedMilliseconds,
Metrics = res.Metrics,
CrossValidationMetrics = metrics,
Pipeline = pipeline,
Pipeline = refitPipeline,
};
}

Expand Down Expand Up @@ -428,7 +432,7 @@ public TrialResult Run(TrialSettings settings)
TrialSettings = settings,
DurationInMilliseconds = stopWatch.ElapsedMilliseconds,
Metrics = metrics,
Pipeline = pipeline,
Pipeline = refitPipeline,
};
}
}
Expand Down
12 changes: 8 additions & 4 deletions src/Microsoft.ML.AutoML/API/RegressionExperiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -363,12 +363,14 @@ internal class RegressionTrialRunner : ITrialRunner
private MLContext _context;
private readonly IDatasetManager _datasetManager;
private readonly IMetricManager _metricManager;
private readonly IMLContextManager _contextManager;
private readonly SweepablePipeline _pipeline;
private readonly Random _rnd;

public RegressionTrialRunner(MLContext context, IDatasetManager datasetManager, IMetricManager metricManager, SweepablePipeline pipeline, AutoMLExperiment.AutoMLExperimentSettings settings)
public RegressionTrialRunner(IMLContextManager contextManager, IDatasetManager datasetManager, IMetricManager metricManager, SweepablePipeline pipeline, AutoMLExperiment.AutoMLExperimentSettings settings)
{
_context = context;
_context = contextManager.CreateMLContext();
_contextManager = contextManager;
_datasetManager = datasetManager;
_metricManager = metricManager;
_pipeline = pipeline;
Expand All @@ -388,6 +390,8 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
{
var parameter = settings.Parameter[AutoMLExperiment.PipelineSearchspaceName];
var pipeline = _pipeline.BuildFromOption(_context, parameter);
var refitContext = _contextManager.CreateMLContext();
var refitPipeline = _pipeline.BuildFromOption(refitContext, parameter);
if (_datasetManager is ICrossValidateDatasetManager datasetManager)
{
var stopWatch = new Stopwatch();
Expand Down Expand Up @@ -420,7 +424,7 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
DurationInMilliseconds = stopWatch.ElapsedMilliseconds,
Metrics = res.Metrics,
CrossValidationMetrics = metrics,
Pipeline = pipeline,
Pipeline = refitPipeline,
} as TrialResult);
}

Expand Down Expand Up @@ -453,7 +457,7 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
TrialSettings = settings,
DurationInMilliseconds = stopWatch.ElapsedMilliseconds,
Metrics = res,
Pipeline = pipeline,
Pipeline = refitPipeline,
} as TrialResult);
}
}
Expand Down
15 changes: 15 additions & 0 deletions test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ public void AutoFit_UCI_Adult_CrossValidation_10_Test()
Assert.True(result.BestRun.Results.Select(x => x.ValidationMetrics.Accuracy).Min() > 0.70);
Assert.NotNull(result.BestRun.Estimator);
Assert.NotNull(result.BestRun.TrainerName);

// test refit
var model = result.BestRun.Estimator.Fit(trainData);
Assert.NotNull(model);
}

[Fact]
Expand Down Expand Up @@ -214,6 +218,10 @@ public void AutoFit_Taxi_Fare_Test()
Assert.True(result.BestRun.ValidationMetrics.RSquared > 0.70);
Assert.NotNull(result.BestRun.Estimator);
Assert.NotNull(result.BestRun.TrainerName);

// verify refit
var model = result.BestRun.Estimator.Fit(context.Data.TakeRows(dataset, 1000));
Assert.NotNull(model);
}

[Theory]
Expand Down Expand Up @@ -253,6 +261,10 @@ public void AutoFitMultiTest(bool useNumberOfCVFolds)
result.BestRun.Results.First().ValidationMetrics.MicroAccuracy.Should().BeGreaterThan(0.7);
var scoredData = result.BestRun.Results.First().Model.Transform(trainData);
Assert.Equal(NumberDataViewType.Single, scoredData.Schema[DefaultColumnNames.PredictedLabel].Type);

// test refit
var model = result.BestRun.Estimator.Fit(trainData);
Assert.NotNull(model);
}
else
{
Expand Down Expand Up @@ -281,6 +293,9 @@ public void AutoFitMultiTest(bool useNumberOfCVFolds)
Assert.True(result.BestRun.ValidationMetrics.MicroAccuracy >= 0.7);
var scoredData = result.BestRun.Model.Transform(trainData);
Assert.Equal(NumberDataViewType.Single, scoredData.Schema[DefaultColumnNames.PredictedLabel].Type);

var model = result.BestRun.Estimator.Fit(trainData);
Assert.NotNull(model);
}
}

Expand Down