@@ -25,7 +25,6 @@ public class AutoMLExperiment
2525 private double _bestLoss = double . MaxValue ;
2626 private TrialResult _bestTrialResult = null ;
2727 private readonly IServiceCollection _serviceCollection ;
28- private CancellationTokenSource _globalCancellationTokenSource ;
2928
3029 public AutoMLExperiment ( MLContext context , AutoMLExperimentSettings settings )
3130 {
@@ -51,14 +50,15 @@ private void InitializeServiceCollection()
5150 _serviceCollection . TryAddTransient ( ( provider ) =>
5251 {
5352 var contextManager = provider . GetRequiredService < IMLContextManager > ( ) ;
53+ var trainingStopManager = provider . GetRequiredService < AggregateTrainingStopManager > ( ) ;
5454 var context = contextManager . CreateMLContext ( ) ;
55- _globalCancellationTokenSource . Token . Register ( ( ) =>
55+ trainingStopManager . OnStopTraining += ( s , e ) =>
5656 {
5757 // only force-canceling running trials when there's completed trials.
5858 // otherwise, wait for the current running trial to be completed.
5959 if ( _bestTrialResult != null )
6060 context . CancelExecution ( ) ;
61- } ) ;
61+ } ;
6262
6363 return context ;
6464 } ) ;
@@ -74,6 +74,29 @@ private void InitializeServiceCollection()
7474 public AutoMLExperiment SetTrainingTimeInSeconds ( uint trainingTimeInSeconds )
7575 {
7676 _settings . MaxExperimentTimeInSeconds = trainingTimeInSeconds ;
77+ _serviceCollection . AddScoped < IStopTrainingManager > ( ( provider ) =>
78+ {
79+ var channel = provider . GetRequiredService < IChannel > ( ) ;
80+ var timeoutManager = new TimeoutTrainingStopManager ( TimeSpan . FromSeconds ( trainingTimeInSeconds ) , channel ) ;
81+
82+ return timeoutManager ;
83+ } ) ;
84+
85+ return this ;
86+ }
87+
88+ public AutoMLExperiment SetMaxModelToExplore ( int maxModel )
89+ {
90+ _context . Assert ( maxModel > 0 , "maxModel has to be greater than 0" ) ;
91+ _settings . MaxModels = maxModel ;
92+ _serviceCollection . AddScoped < IStopTrainingManager > ( ( provider ) =>
93+ {
94+ var channel = provider . GetRequiredService < IChannel > ( ) ;
95+ var maxModelManager = new MaxModelStopManager ( maxModel , channel ) ;
96+
97+ return maxModelManager ;
98+ } ) ;
99+
77100 return this ;
78101 }
79102
@@ -204,19 +227,29 @@ public TrialResult Run()
204227 public async Task < TrialResult > RunAsync ( CancellationToken ct = default )
205228 {
206229 ValidateSettings ( ) ;
207- _globalCancellationTokenSource = new CancellationTokenSource ( ) ;
208- _settings . CancellationToken = ct ;
209- // use TimeSpan to avoid overflow.
210- _globalCancellationTokenSource . CancelAfter ( TimeSpan . FromSeconds ( _settings . MaxExperimentTimeInSeconds ) ) ;
211- _settings . CancellationToken . Register ( ( ) => _globalCancellationTokenSource . Cancel ( ) ) ;
230+ _serviceCollection . AddScoped ( ( serviceProvider ) =>
231+ {
232+ var logger = serviceProvider . GetRequiredService < IChannel > ( ) ;
233+ var stopServices = serviceProvider . GetServices < IStopTrainingManager > ( ) ;
234+ var cancellationTrainingStopManager = new CancellationTokenStopTrainingManager ( ct , logger ) ;
235+
236+ // always get the most recent added stop service for each type.
237+ var mostRecentAddedStopServices = stopServices . GroupBy ( s => s . GetType ( ) ) . Select ( g => g . Last ( ) ) . ToList ( ) ;
238+ mostRecentAddedStopServices . Add ( cancellationTrainingStopManager ) ;
239+ return new AggregateTrainingStopManager ( logger , mostRecentAddedStopServices . ToArray ( ) ) ;
240+ } ) ;
241+
212242 var serviceProvider = _serviceCollection . BuildServiceProvider ( ) ;
213- var monitor = serviceProvider . GetService < IMonitor > ( ) ;
243+
244+ _settings . CancellationToken = ct ;
214245 var logger = serviceProvider . GetRequiredService < IChannel > ( ) ;
246+ var aggregateTrainingStopManager = serviceProvider . GetRequiredService < AggregateTrainingStopManager > ( ) ;
247+ var monitor = serviceProvider . GetService < IMonitor > ( ) ;
215248 var trialResultManager = serviceProvider . GetService < ITrialResultManager > ( ) ;
216249 var trialNum = trialResultManager ? . GetAllTrialResults ( ) . Max ( t => t . TrialSettings ? . TrialId ) + 1 ?? 0 ;
217250 var tuner = serviceProvider . GetService < ITuner > ( ) ;
218251 Contracts . Assert ( tuner != null , "tuner can't be null" ) ;
219- while ( ! _globalCancellationTokenSource . Token . IsCancellationRequested )
252+ while ( ! aggregateTrainingStopManager . IsStopTrainingRequested ( ) )
220253 {
221254 var setting = new TrialSettings ( )
222255 {
@@ -227,86 +260,95 @@ public async Task<TrialResult> RunAsync(CancellationToken ct = default)
227260 setting . Parameter = parameter ;
228261
229262 monitor ? . ReportRunningTrial ( setting ) ;
230- try
263+ using ( var trialCancellationTokenSource = new CancellationTokenSource ( ) )
231264 {
232- using ( var trialCancellationTokenSource = new CancellationTokenSource ( ) )
233- using ( var deregisterCallback = _globalCancellationTokenSource . Token . Register ( ( ) =>
265+ void handler ( object o , EventArgs e )
234266 {
235267 // only force-canceling running trials when there's completed trials.
236268 // otherwise, wait for the current running trial to be completed.
237269 if ( _bestTrialResult != null )
238270 trialCancellationTokenSource . Cancel ( ) ;
239- } ) )
240- using ( var performanceMonitor = serviceProvider . GetService < IPerformanceMonitor > ( ) )
241- using ( var runner = serviceProvider . GetRequiredService < ITrialRunner > ( ) )
271+ }
272+ try
242273 {
243- performanceMonitor . MemoryUsageInMegaByte += ( o , m ) =>
274+ using ( var performanceMonitor = serviceProvider . GetService < IPerformanceMonitor > ( ) )
275+ using ( var runner = serviceProvider . GetRequiredService < ITrialRunner > ( ) )
244276 {
245- if ( _settings . MaximumMemoryUsageInMegaByte is double d && m > d && ! trialCancellationTokenSource . IsCancellationRequested )
246- {
247- logger . Trace ( $ "cancel current trial { setting . TrialId } because it uses { m } mb memory and the maximum memory usage is { d } ") ;
248- trialCancellationTokenSource . Cancel ( ) ;
277+ aggregateTrainingStopManager . OnStopTraining += handler ;
249278
250- GC . AddMemoryPressure ( Convert . ToInt64 ( m ) * 1024 * 1024 ) ;
251- GC . Collect ( ) ;
279+ performanceMonitor . MemoryUsageInMegaByte += ( o , m ) =>
280+ {
281+ if ( _settings . MaximumMemoryUsageInMegaByte is double d && m > d && ! trialCancellationTokenSource . IsCancellationRequested )
282+ {
283+ logger . Trace ( $ "cancel current trial { setting . TrialId } because it uses { m } mb memory and the maximum memory usage is { d } ") ;
284+ trialCancellationTokenSource . Cancel ( ) ;
285+
286+ GC . AddMemoryPressure ( Convert . ToInt64 ( m ) * 1024 * 1024 ) ;
287+ GC . Collect ( ) ;
288+ }
289+ } ;
290+
291+ performanceMonitor . Start ( ) ;
292+ logger . Trace ( $ "trial setting - { JsonSerializer . Serialize ( setting ) } ") ;
293+ var trialResult = await runner . RunAsync ( setting , trialCancellationTokenSource . Token ) ;
294+
295+ var peakCpu = performanceMonitor ? . GetPeakCpuUsage ( ) ;
296+ var peakMemoryInMB = performanceMonitor ? . GetPeakMemoryUsageInMegaByte ( ) ;
297+ trialResult . PeakCpu = peakCpu ;
298+ trialResult . PeakMemoryInMegaByte = peakMemoryInMB ;
299+
300+ monitor ? . ReportCompletedTrial ( trialResult ) ;
301+ tuner . Update ( trialResult ) ;
302+ trialResultManager ? . AddOrUpdateTrialResult ( trialResult ) ;
303+ aggregateTrainingStopManager . Update ( trialResult ) ;
304+
305+ var loss = trialResult . Loss ;
306+ if ( loss < _bestLoss )
307+ {
308+ _bestTrialResult = trialResult ;
309+ _bestLoss = loss ;
310+ monitor ? . ReportBestTrial ( trialResult ) ;
252311 }
312+ }
313+ }
314+ catch ( OperationCanceledException ex ) when ( aggregateTrainingStopManager . IsStopTrainingRequested ( ) == false )
315+ {
316+ monitor ? . ReportFailTrial ( setting , ex ) ;
317+ var result = new TrialResult
318+ {
319+ TrialSettings = setting ,
320+ Loss = double . MaxValue ,
253321 } ;
254322
255- performanceMonitor . Start ( ) ;
256- logger . Trace ( $ "trial setting - { JsonSerializer . Serialize ( setting ) } ") ;
257- var trialResult = await runner . RunAsync ( setting , trialCancellationTokenSource . Token ) ;
258-
259- var peakCpu = performanceMonitor ? . GetPeakCpuUsage ( ) ;
260- var peakMemoryInMB = performanceMonitor ? . GetPeakMemoryUsageInMegaByte ( ) ;
261- trialResult . PeakCpu = peakCpu ;
262- trialResult . PeakMemoryInMegaByte = peakMemoryInMB ;
263-
264- monitor ? . ReportCompletedTrial ( trialResult ) ;
265- tuner . Update ( trialResult ) ;
266- trialResultManager ? . AddOrUpdateTrialResult ( trialResult ) ;
323+ tuner . Update ( result ) ;
324+ continue ;
325+ }
326+ catch ( OperationCanceledException ) when ( aggregateTrainingStopManager . IsStopTrainingRequested ( ) )
327+ {
328+ break ;
329+ }
330+ catch ( Exception ex )
331+ {
332+ monitor ? . ReportFailTrial ( setting , ex ) ;
267333
268- var loss = trialResult . Loss ;
269- if ( loss < _bestLoss )
334+ if ( ! aggregateTrainingStopManager . IsStopTrainingRequested ( ) && _bestTrialResult == null )
270335 {
271- _bestTrialResult = trialResult ;
272- _bestLoss = loss ;
273- monitor ? . ReportBestTrial ( trialResult ) ;
336+ // TODO
337+ // it's questionable on whether to abort the entire training process
338+ // for a single fail trial. We should make it an option and only exit
339+ // when error is fatal (like schema mismatch).
340+ throw ;
274341 }
275342 }
276- }
277- catch ( OperationCanceledException ex ) when ( _globalCancellationTokenSource . IsCancellationRequested == false )
278- {
279- monitor ? . ReportFailTrial ( setting , ex ) ;
280- var result = new TrialResult
343+ finally
281344 {
282- TrialSettings = setting ,
283- Loss = double . MaxValue ,
284- } ;
345+ aggregateTrainingStopManager . OnStopTraining -= handler ;
285346
286- tuner . Update ( result ) ;
287- continue ;
288- }
289- catch ( OperationCanceledException ) when ( _globalCancellationTokenSource . IsCancellationRequested )
290- {
291- break ;
292- }
293- catch ( Exception ex )
294- {
295- monitor ? . ReportFailTrial ( setting , ex ) ;
296-
297- if ( ! _globalCancellationTokenSource . IsCancellationRequested && _bestTrialResult == null )
298- {
299- // TODO
300- // it's questionable on whether to abort the entire training process
301- // for a single fail trial. We should make it an option and only exit
302- // when error is fatal (like schema mismatch).
303- throw ;
304347 }
305348 }
306349 }
307350
308351 trialResultManager ? . Save ( ) ;
309-
310352 if ( _bestTrialResult == null )
311353 {
312354 throw new TimeoutException ( "Training time finished without completing a trial run" ) ;
0 commit comments