Skip to content

Commit 600d48d

Browse files
authored
Use Timer and ctx.CancelExecution() to fix AutoML max-time experiment bug (#5445)
* Use ctx.CalncelExecution() to fix AutoML max-time experiment bug * Added unit test for checking canceled experiment * Nit fix * Different run time on Linux * Review * Testing four ouput * Used reflection to test for contexts being canceled * Reviews * Reviews * Added main MLContext listener-timer * Added PRNG on _context, held onto timers for avoiding GC * Addressed reviews * Unit test edits * Increase run time of experiment to guarantee probabilities * Edited unit test to check produced schema of next run model's predictions * Remove scheme check as different CI builds result in varying schemas * Decrease max experiment time unit test time * Added Timers * Increase second timer time, edit unit test * Added try catch for OperationCanceledException in Execute() * Add AggregateException try catch to slow unit tests for parallel testing * Reviews * Final reviews * Added LightGBMFact to binary classification test * Removed extra Operation Stopped exception try catch * Add back OperationCanceledException to Experiment.cs
1 parent a0e959c commit 600d48d

File tree

5 files changed

+215
-74
lines changed

5 files changed

+215
-74
lines changed

src/Microsoft.ML.AutoML/Experiment/Experiment.cs

Lines changed: 112 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
using System.Diagnostics;
88
using System.IO;
99
using System.Linq;
10+
using System.Threading;
11+
using Microsoft.ML.Data;
1012
using Microsoft.ML.Runtime;
1113

1214
namespace Microsoft.ML.AutoML
@@ -25,6 +27,11 @@ internal class Experiment<TRunDetail, TMetrics> where TRunDetail : RunDetail
2527
private readonly IRunner<TRunDetail> _runner;
2628
private readonly IList<SuggestedPipelineRunDetail> _history;
2729
private readonly IChannel _logger;
30+
private Timer _maxExperimentTimeTimer;
31+
private Timer _mainContextCanceledTimer;
32+
private bool _experimentTimerExpired;
33+
private MLContext _currentModelMLContext;
34+
private Random _newContextSeedGenerator;
2835

2936
public Experiment(MLContext context,
3037
TaskKind task,
@@ -49,60 +56,125 @@ public Experiment(MLContext context,
4956
_datasetColumnInfo = datasetColumnInfo;
5057
_runner = runner;
5158
_logger = logger;
59+
_experimentTimerExpired = false;
60+
}
61+
62+
private void MaxExperimentTimeExpiredEvent(object state)
63+
{
64+
// If at least one model was run, end experiment immediately.
65+
// Else, wait for first model to run before experiment is concluded.
66+
_experimentTimerExpired = true;
67+
if (_history.Any(r => r.RunSucceeded))
68+
{
69+
_logger.Warning("Allocated time for Experiment of {0} seconds has elapsed with {1} models run. Ending experiment...",
70+
_experimentSettings.MaxExperimentTimeInSeconds, _history.Count());
71+
_currentModelMLContext.CancelExecution();
72+
}
73+
}
74+
75+
private void MainContextCanceledEvent(object state)
76+
{
77+
// If the main MLContext is canceled, cancel the ongoing model training and MLContext.
78+
if ((_context.Model.GetEnvironment() as ICancelable).IsCanceled)
79+
{
80+
_logger.Warning("Main MLContext has been canceled. Ending experiment...");
81+
// Stop timer to prevent restarting and prevent continuous calls to
82+
// MainContextCanceledEvent
83+
_mainContextCanceledTimer.Change(Timeout.Infinite, Timeout.Infinite);
84+
_currentModelMLContext.CancelExecution();
85+
}
5286
}
5387

5488
public IList<TRunDetail> Execute()
5589
{
56-
var stopwatch = Stopwatch.StartNew();
5790
var iterationResults = new List<TRunDetail>();
91+
// Create a timer for the max duration of experiment. When given time has
92+
// elapsed, MaxExperimentTimeExpiredEvent is called to interrupt training
93+
// of current model. Timer is not used if no experiment time is given, or
94+
// is not a positive number.
95+
if (_experimentSettings.MaxExperimentTimeInSeconds > 0)
96+
{
97+
_maxExperimentTimeTimer = new Timer(
98+
new TimerCallback(MaxExperimentTimeExpiredEvent), null,
99+
_experimentSettings.MaxExperimentTimeInSeconds * 1000, Timeout.Infinite
100+
);
101+
}
102+
// If given max duration of experiment is 0, only 1 model will be trained.
103+
// _experimentSettings.MaxExperimentTimeInSeconds is of type uint, it is
104+
// either 0 or >0.
105+
else
106+
_experimentTimerExpired = true;
107+
108+
// Add second timer to check for the cancelation signal from the main MLContext
109+
// to the active child MLContext. This timer will propagate the cancelation
110+
// signal from the main to the child MLContexs if the main MLContext is
111+
// canceled.
112+
_mainContextCanceledTimer = new Timer(new TimerCallback(MainContextCanceledEvent), null, 1000, 1000);
113+
114+
// Pseudo random number generator to result in deterministic runs with the provided main MLContext's seed and to
115+
// maintain variability between training iterations.
116+
int? mainContextSeed = ((ISeededEnvironment)_context.Model.GetEnvironment()).Seed;
117+
_newContextSeedGenerator = (mainContextSeed.HasValue) ? RandomUtils.Create(mainContextSeed.Value) : null;
58118

59119
do
60120
{
61-
var iterationStopwatch = Stopwatch.StartNew();
62-
63-
// get next pipeline
64-
var getPipelineStopwatch = Stopwatch.StartNew();
65-
var pipeline = PipelineSuggester.GetNextInferredPipeline(_context, _history, _datasetColumnInfo, _task,
66-
_optimizingMetricInfo.IsMaximizing, _experimentSettings.CacheBeforeTrainer, _logger, _trainerAllowList);
67-
68-
var pipelineInferenceTimeInSeconds = getPipelineStopwatch.Elapsed.TotalSeconds;
69-
70-
// break if no candidates returned, means no valid pipeline available
71-
if (pipeline == null)
72-
{
73-
break;
74-
}
75-
76-
// evaluate pipeline
77-
_logger.Trace($"Evaluating pipeline {pipeline.ToString()}");
78-
(SuggestedPipelineRunDetail suggestedPipelineRunDetail, TRunDetail runDetail)
79-
= _runner.Run(pipeline, _modelDirectory, _history.Count + 1);
80-
81-
_history.Add(suggestedPipelineRunDetail);
82-
WriteIterationLog(pipeline, suggestedPipelineRunDetail, iterationStopwatch);
83-
84-
runDetail.RuntimeInSeconds = iterationStopwatch.Elapsed.TotalSeconds;
85-
runDetail.PipelineInferenceTimeInSeconds = getPipelineStopwatch.Elapsed.TotalSeconds;
86-
87-
ReportProgress(runDetail);
88-
iterationResults.Add(runDetail);
89-
90-
// if model is perfect, break
91-
if (_metricsAgent.IsModelPerfect(suggestedPipelineRunDetail.Score))
121+
try
92122
{
93-
break;
123+
var iterationStopwatch = Stopwatch.StartNew();
124+
125+
// get next pipeline
126+
var getPipelineStopwatch = Stopwatch.StartNew();
127+
128+
// A new MLContext is needed per model run. When max experiment time is reached, each used
129+
// context is canceled to stop further model training. The cancellation of the main MLContext
130+
// a user has instantiated is not desirable, thus additional MLContexts are used.
131+
_currentModelMLContext = _newContextSeedGenerator == null ? new MLContext() : new MLContext(_newContextSeedGenerator.Next());
132+
var pipeline = PipelineSuggester.GetNextInferredPipeline(_currentModelMLContext, _history, _datasetColumnInfo, _task,
133+
_optimizingMetricInfo.IsMaximizing, _experimentSettings.CacheBeforeTrainer, _logger, _trainerAllowList);
134+
// break if no candidates returned, means no valid pipeline available
135+
if (pipeline == null)
136+
{
137+
break;
138+
}
139+
140+
// evaluate pipeline
141+
_logger.Trace($"Evaluating pipeline {pipeline.ToString()}");
142+
(SuggestedPipelineRunDetail suggestedPipelineRunDetail, TRunDetail runDetail)
143+
= _runner.Run(pipeline, _modelDirectory, _history.Count + 1);
144+
145+
_history.Add(suggestedPipelineRunDetail);
146+
WriteIterationLog(pipeline, suggestedPipelineRunDetail, iterationStopwatch);
147+
148+
runDetail.RuntimeInSeconds = iterationStopwatch.Elapsed.TotalSeconds;
149+
runDetail.PipelineInferenceTimeInSeconds = getPipelineStopwatch.Elapsed.TotalSeconds;
150+
151+
ReportProgress(runDetail);
152+
iterationResults.Add(runDetail);
153+
154+
// if model is perfect, break
155+
if (_metricsAgent.IsModelPerfect(suggestedPipelineRunDetail.Score))
156+
{
157+
break;
158+
}
159+
160+
// If after third run, all runs have failed so far, throw exception
161+
if (_history.Count() == 3 && _history.All(r => !r.RunSucceeded))
162+
{
163+
throw new InvalidOperationException($"Training failed with the exception: {_history.Last().Exception}");
164+
}
94165
}
95-
96-
// If after third run, all runs have failed so far, throw exception
97-
if (_history.Count() == 3 && _history.All(r => !r.RunSucceeded))
166+
catch (OperationCanceledException e)
98167
{
99-
throw new InvalidOperationException($"Training failed with the exception: {_history.Last().Exception}");
168+
// This exception is thrown when the IHost/MLContext of the trainer is canceled due to
169+
// reaching maximum experiment time. Simply catch this exception and return finished
170+
// iteration results.
171+
_logger.Warning("OperationCanceledException has been caught after maximum experiment time" +
172+
"was reached, and the running MLContext was stopped. Details: {0}", e.Message);
173+
return iterationResults;
100174
}
101-
102175
} while (_history.Count < _experimentSettings.MaxModels &&
103176
!_experimentSettings.CancellationToken.IsCancellationRequested &&
104-
stopwatch.Elapsed.TotalSeconds < _experimentSettings.MaxExperimentTimeInSeconds);
105-
177+
!_experimentTimerExpired);
106178
return iterationResults;
107179
}
108180

src/Microsoft.ML.AutoML/Experiment/Runners/CrossValSummaryRunner.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public CrossValSummaryRunner(MLContext context,
5858
for (var i = 0; i < _trainDatasets.Length; i++)
5959
{
6060
var modelFileInfo = RunnerUtil.GetModelFileInfo(modelDirectory, iterationNum, i + 1);
61-
var trainResult = RunnerUtil.TrainAndScorePipeline(_context, pipeline, _trainDatasets[i], _validDatasets[i],
61+
var trainResult = RunnerUtil.TrainAndScorePipeline(pipeline.GetContext(), pipeline, _trainDatasets[i], _validDatasets[i],
6262
_groupIdColumn, _labelColumn, _metricsAgent, _preprocessorTransforms?.ElementAt(i), modelFileInfo, _modelInputSchema,
6363
_logger);
6464
trainResults.Add(trainResult);

src/Microsoft.ML.AutoML/Experiment/SuggestedPipeline.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ public override int GetHashCode()
5252
return ToString().GetHashCode();
5353
}
5454

55+
public MLContext GetContext()
56+
{
57+
return _context;
58+
}
59+
5560
public Pipeline ToPipeline()
5661
{
5762
var pipelineElements = new List<PipelineNode>();

src/Microsoft.ML.Core/Data/IHostEnvironment.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ public interface IHostEnvironment : IChannelProvider, IProgressChannelProvider
7272
internal interface ICancelable
7373
{
7474
/// <summary>
75-
/// Signal to stop exection in all the hosts.
75+
/// Signal to stop execution in all the hosts.
7676
/// </summary>
7777
void CancelExecution();
7878

0 commit comments

Comments
 (0)