Skip to content

Commit e15ff5f

Browse files
Add MaxModelToExplore exit strategy to AutoMLExperiment. (#6402)
1 parent 73beb28 commit e15ff5f

File tree

7 files changed

+406
-76
lines changed

7 files changed

+406
-76
lines changed

src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ internal BinaryClassificationExperiment(MLContext context, BinaryExperimentSetti
150150
{
151151
_experiment.SetMaximumMemoryUsageInMegaByte(d);
152152
}
153+
_experiment.SetMaxModelToExplore(settings.MaxModels);
153154
}
154155

155156
public override ExperimentResult<BinaryClassificationMetrics> Execute(IDataView trainData, ColumnInformation columnInformation, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetail<BinaryClassificationMetrics>> progressHandler = null)

src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ internal MulticlassClassificationExperiment(MLContext context, MulticlassExperim
142142
{
143143
_experiment.SetMaximumMemoryUsageInMegaByte(d);
144144
}
145+
_experiment.SetMaxModelToExplore(settings.MaxModels);
145146
}
146147

147148
public override ExperimentResult<MulticlassClassificationMetrics> Execute(IDataView trainData, ColumnInformation columnInformation, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetail<MulticlassClassificationMetrics>> progressHandler = null)

src/Microsoft.ML.AutoML/API/RegressionExperiment.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ internal RegressionExperiment(MLContext context, RegressionExperimentSettings se
139139
}
140140

141141
_experiment.SetTrainingTimeInSeconds(Settings.MaxExperimentTimeInSeconds);
142+
_experiment.SetMaxModelToExplore(Settings.MaxModels);
142143
}
143144

144145
public override ExperimentResult<RegressionMetrics> Execute(IDataView trainData, ColumnInformation columnInformation, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetail<RegressionMetrics>> progressHandler = null)

src/Microsoft.ML.AutoML/AutoMLExperiment/AutoMLExperiment.cs

Lines changed: 109 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ public class AutoMLExperiment
2525
private double _bestLoss = double.MaxValue;
2626
private TrialResult _bestTrialResult = null;
2727
private readonly IServiceCollection _serviceCollection;
28-
private CancellationTokenSource _globalCancellationTokenSource;
2928

3029
public AutoMLExperiment(MLContext context, AutoMLExperimentSettings settings)
3130
{
@@ -51,14 +50,15 @@ private void InitializeServiceCollection()
5150
_serviceCollection.TryAddTransient((provider) =>
5251
{
5352
var contextManager = provider.GetRequiredService<IMLContextManager>();
53+
var trainingStopManager = provider.GetRequiredService<AggregateTrainingStopManager>();
5454
var context = contextManager.CreateMLContext();
55-
_globalCancellationTokenSource.Token.Register(() =>
55+
trainingStopManager.OnStopTraining += (s, e) =>
5656
{
5757
// only force-canceling running trials when there's completed trials.
5858
// otherwise, wait for the current running trial to be completed.
5959
if (_bestTrialResult != null)
6060
context.CancelExecution();
61-
});
61+
};
6262

6363
return context;
6464
});
@@ -74,6 +74,29 @@ private void InitializeServiceCollection()
7474
public AutoMLExperiment SetTrainingTimeInSeconds(uint trainingTimeInSeconds)
7575
{
7676
_settings.MaxExperimentTimeInSeconds = trainingTimeInSeconds;
77+
_serviceCollection.AddScoped<IStopTrainingManager>((provider) =>
78+
{
79+
var channel = provider.GetRequiredService<IChannel>();
80+
var timeoutManager = new TimeoutTrainingStopManager(TimeSpan.FromSeconds(trainingTimeInSeconds), channel);
81+
82+
return timeoutManager;
83+
});
84+
85+
return this;
86+
}
87+
88+
public AutoMLExperiment SetMaxModelToExplore(int maxModel)
89+
{
90+
_context.Assert(maxModel > 0, "maxModel has to be greater than 0");
91+
_settings.MaxModels = maxModel;
92+
_serviceCollection.AddScoped<IStopTrainingManager>((provider) =>
93+
{
94+
var channel = provider.GetRequiredService<IChannel>();
95+
var maxModelManager = new MaxModelStopManager(maxModel, channel);
96+
97+
return maxModelManager;
98+
});
99+
77100
return this;
78101
}
79102

@@ -204,19 +227,29 @@ public TrialResult Run()
204227
public async Task<TrialResult> RunAsync(CancellationToken ct = default)
205228
{
206229
ValidateSettings();
207-
_globalCancellationTokenSource = new CancellationTokenSource();
208-
_settings.CancellationToken = ct;
209-
// use TimeSpan to avoid overflow.
210-
_globalCancellationTokenSource.CancelAfter(TimeSpan.FromSeconds(_settings.MaxExperimentTimeInSeconds));
211-
_settings.CancellationToken.Register(() => _globalCancellationTokenSource.Cancel());
230+
_serviceCollection.AddScoped((serviceProvider) =>
231+
{
232+
var logger = serviceProvider.GetRequiredService<IChannel>();
233+
var stopServices = serviceProvider.GetServices<IStopTrainingManager>();
234+
var cancellationTrainingStopManager = new CancellationTokenStopTrainingManager(ct, logger);
235+
236+
// always get the most recent added stop service for each type.
237+
var mostRecentAddedStopServices = stopServices.GroupBy(s => s.GetType()).Select(g => g.Last()).ToList();
238+
mostRecentAddedStopServices.Add(cancellationTrainingStopManager);
239+
return new AggregateTrainingStopManager(logger, mostRecentAddedStopServices.ToArray());
240+
});
241+
212242
var serviceProvider = _serviceCollection.BuildServiceProvider();
213-
var monitor = serviceProvider.GetService<IMonitor>();
243+
244+
_settings.CancellationToken = ct;
214245
var logger = serviceProvider.GetRequiredService<IChannel>();
246+
var aggregateTrainingStopManager = serviceProvider.GetRequiredService<AggregateTrainingStopManager>();
247+
var monitor = serviceProvider.GetService<IMonitor>();
215248
var trialResultManager = serviceProvider.GetService<ITrialResultManager>();
216249
var trialNum = trialResultManager?.GetAllTrialResults().Max(t => t.TrialSettings?.TrialId) + 1 ?? 0;
217250
var tuner = serviceProvider.GetService<ITuner>();
218251
Contracts.Assert(tuner != null, "tuner can't be null");
219-
while (!_globalCancellationTokenSource.Token.IsCancellationRequested)
252+
while (!aggregateTrainingStopManager.IsStopTrainingRequested())
220253
{
221254
var setting = new TrialSettings()
222255
{
@@ -227,86 +260,95 @@ public async Task<TrialResult> RunAsync(CancellationToken ct = default)
227260
setting.Parameter = parameter;
228261

229262
monitor?.ReportRunningTrial(setting);
230-
try
263+
using (var trialCancellationTokenSource = new CancellationTokenSource())
231264
{
232-
using (var trialCancellationTokenSource = new CancellationTokenSource())
233-
using (var deregisterCallback = _globalCancellationTokenSource.Token.Register(() =>
265+
void handler(object o, EventArgs e)
234266
{
235267
// only force-canceling running trials when there's completed trials.
236268
// otherwise, wait for the current running trial to be completed.
237269
if (_bestTrialResult != null)
238270
trialCancellationTokenSource.Cancel();
239-
}))
240-
using (var performanceMonitor = serviceProvider.GetService<IPerformanceMonitor>())
241-
using (var runner = serviceProvider.GetRequiredService<ITrialRunner>())
271+
}
272+
try
242273
{
243-
performanceMonitor.MemoryUsageInMegaByte += (o, m) =>
274+
using (var performanceMonitor = serviceProvider.GetService<IPerformanceMonitor>())
275+
using (var runner = serviceProvider.GetRequiredService<ITrialRunner>())
244276
{
245-
if (_settings.MaximumMemoryUsageInMegaByte is double d && m > d && !trialCancellationTokenSource.IsCancellationRequested)
246-
{
247-
logger.Trace($"cancel current trial {setting.TrialId} because it uses {m} mb memory and the maximum memory usage is {d}");
248-
trialCancellationTokenSource.Cancel();
277+
aggregateTrainingStopManager.OnStopTraining += handler;
249278

250-
GC.AddMemoryPressure(Convert.ToInt64(m) * 1024 * 1024);
251-
GC.Collect();
279+
performanceMonitor.MemoryUsageInMegaByte += (o, m) =>
280+
{
281+
if (_settings.MaximumMemoryUsageInMegaByte is double d && m > d && !trialCancellationTokenSource.IsCancellationRequested)
282+
{
283+
logger.Trace($"cancel current trial {setting.TrialId} because it uses {m} mb memory and the maximum memory usage is {d}");
284+
trialCancellationTokenSource.Cancel();
285+
286+
GC.AddMemoryPressure(Convert.ToInt64(m) * 1024 * 1024);
287+
GC.Collect();
288+
}
289+
};
290+
291+
performanceMonitor.Start();
292+
logger.Trace($"trial setting - {JsonSerializer.Serialize(setting)}");
293+
var trialResult = await runner.RunAsync(setting, trialCancellationTokenSource.Token);
294+
295+
var peakCpu = performanceMonitor?.GetPeakCpuUsage();
296+
var peakMemoryInMB = performanceMonitor?.GetPeakMemoryUsageInMegaByte();
297+
trialResult.PeakCpu = peakCpu;
298+
trialResult.PeakMemoryInMegaByte = peakMemoryInMB;
299+
300+
monitor?.ReportCompletedTrial(trialResult);
301+
tuner.Update(trialResult);
302+
trialResultManager?.AddOrUpdateTrialResult(trialResult);
303+
aggregateTrainingStopManager.Update(trialResult);
304+
305+
var loss = trialResult.Loss;
306+
if (loss < _bestLoss)
307+
{
308+
_bestTrialResult = trialResult;
309+
_bestLoss = loss;
310+
monitor?.ReportBestTrial(trialResult);
252311
}
312+
}
313+
}
314+
catch (OperationCanceledException ex) when (aggregateTrainingStopManager.IsStopTrainingRequested() == false)
315+
{
316+
monitor?.ReportFailTrial(setting, ex);
317+
var result = new TrialResult
318+
{
319+
TrialSettings = setting,
320+
Loss = double.MaxValue,
253321
};
254322

255-
performanceMonitor.Start();
256-
logger.Trace($"trial setting - {JsonSerializer.Serialize(setting)}");
257-
var trialResult = await runner.RunAsync(setting, trialCancellationTokenSource.Token);
258-
259-
var peakCpu = performanceMonitor?.GetPeakCpuUsage();
260-
var peakMemoryInMB = performanceMonitor?.GetPeakMemoryUsageInMegaByte();
261-
trialResult.PeakCpu = peakCpu;
262-
trialResult.PeakMemoryInMegaByte = peakMemoryInMB;
263-
264-
monitor?.ReportCompletedTrial(trialResult);
265-
tuner.Update(trialResult);
266-
trialResultManager?.AddOrUpdateTrialResult(trialResult);
323+
tuner.Update(result);
324+
continue;
325+
}
326+
catch (OperationCanceledException) when (aggregateTrainingStopManager.IsStopTrainingRequested())
327+
{
328+
break;
329+
}
330+
catch (Exception ex)
331+
{
332+
monitor?.ReportFailTrial(setting, ex);
267333

268-
var loss = trialResult.Loss;
269-
if (loss < _bestLoss)
334+
if (!aggregateTrainingStopManager.IsStopTrainingRequested() && _bestTrialResult == null)
270335
{
271-
_bestTrialResult = trialResult;
272-
_bestLoss = loss;
273-
monitor?.ReportBestTrial(trialResult);
336+
// TODO
337+
// it's questionable on whether to abort the entire training process
338+
// for a single fail trial. We should make it an option and only exit
339+
// when error is fatal (like schema mismatch).
340+
throw;
274341
}
275342
}
276-
}
277-
catch (OperationCanceledException ex) when (_globalCancellationTokenSource.IsCancellationRequested == false)
278-
{
279-
monitor?.ReportFailTrial(setting, ex);
280-
var result = new TrialResult
343+
finally
281344
{
282-
TrialSettings = setting,
283-
Loss = double.MaxValue,
284-
};
345+
aggregateTrainingStopManager.OnStopTraining -= handler;
285346

286-
tuner.Update(result);
287-
continue;
288-
}
289-
catch (OperationCanceledException) when (_globalCancellationTokenSource.IsCancellationRequested)
290-
{
291-
break;
292-
}
293-
catch (Exception ex)
294-
{
295-
monitor?.ReportFailTrial(setting, ex);
296-
297-
if (!_globalCancellationTokenSource.IsCancellationRequested && _bestTrialResult == null)
298-
{
299-
// TODO
300-
// it's questionable on whether to abort the entire training process
301-
// for a single fail trial. We should make it an option and only exit
302-
// when error is fatal (like schema mismatch).
303-
throw;
304347
}
305348
}
306349
}
307350

308351
trialResultManager?.Save();
309-
310352
if (_bestTrialResult == null)
311353
{
312354
throw new TimeoutException("Training time finished without completing a trial run");

0 commit comments

Comments
 (0)