77using System . Diagnostics ;
88using System . IO ;
99using System . Linq ;
10+ using System . Threading ;
11+ using Microsoft . ML . Data ;
1012using Microsoft . ML . Runtime ;
1113
1214namespace Microsoft . ML . AutoML
@@ -25,6 +27,11 @@ internal class Experiment<TRunDetail, TMetrics> where TRunDetail : RunDetail
2527 private readonly IRunner < TRunDetail > _runner ;
2628 private readonly IList < SuggestedPipelineRunDetail > _history ;
2729 private readonly IChannel _logger ;
30+ private Timer _maxExperimentTimeTimer ;
31+ private Timer _mainContextCanceledTimer ;
32+ private bool _experimentTimerExpired ;
33+ private MLContext _currentModelMLContext ;
34+ private Random _newContextSeedGenerator ;
2835
2936 public Experiment ( MLContext context ,
3037 TaskKind task ,
@@ -49,60 +56,125 @@ public Experiment(MLContext context,
4956 _datasetColumnInfo = datasetColumnInfo ;
5057 _runner = runner ;
5158 _logger = logger ;
59+ _experimentTimerExpired = false ;
60+ }
61+
62+ private void MaxExperimentTimeExpiredEvent ( object state )
63+ {
64+ // If at least one model was run, end experiment immediately.
65+ // Else, wait for first model to run before experiment is concluded.
66+ _experimentTimerExpired = true ;
67+ if ( _history . Any ( r => r . RunSucceeded ) )
68+ {
69+ _logger . Warning ( "Allocated time for Experiment of {0} seconds has elapsed with {1} models run. Ending experiment..." ,
70+ _experimentSettings . MaxExperimentTimeInSeconds , _history . Count ( ) ) ;
71+ _currentModelMLContext . CancelExecution ( ) ;
72+ }
73+ }
74+
75+ private void MainContextCanceledEvent ( object state )
76+ {
77+ // If the main MLContext is canceled, cancel the ongoing model training and MLContext.
78+ if ( ( _context . Model . GetEnvironment ( ) as ICancelable ) . IsCanceled )
79+ {
80+ _logger . Warning ( "Main MLContext has been canceled. Ending experiment..." ) ;
81+ // Stop timer to prevent restarting and prevent continuous calls to
82+ // MainContextCanceledEvent
83+ _mainContextCanceledTimer . Change ( Timeout . Infinite , Timeout . Infinite ) ;
84+ _currentModelMLContext . CancelExecution ( ) ;
85+ }
5286 }
5387
5488 public IList < TRunDetail > Execute ( )
5589 {
56- var stopwatch = Stopwatch . StartNew ( ) ;
5790 var iterationResults = new List < TRunDetail > ( ) ;
91+ // Create a timer for the max duration of experiment. When given time has
92+ // elapsed, MaxExperimentTimeExpiredEvent is called to interrupt training
93+ // of current model. Timer is not used if no experiment time is given, or
94+ // is not a positive number.
95+ if ( _experimentSettings . MaxExperimentTimeInSeconds > 0 )
96+ {
97+ _maxExperimentTimeTimer = new Timer (
98+ new TimerCallback ( MaxExperimentTimeExpiredEvent ) , null ,
99+ _experimentSettings . MaxExperimentTimeInSeconds * 1000 , Timeout . Infinite
100+ ) ;
101+ }
102+ // If given max duration of experiment is 0, only 1 model will be trained.
103+ // _experimentSettings.MaxExperimentTimeInSeconds is of type uint, it is
104+ // either 0 or >0.
105+ else
106+ _experimentTimerExpired = true ;
107+
108+ // Add second timer to check for the cancelation signal from the main MLContext
109+ // to the active child MLContext. This timer will propagate the cancelation
110+ // signal from the main to the child MLContexs if the main MLContext is
111+ // canceled.
112+ _mainContextCanceledTimer = new Timer ( new TimerCallback ( MainContextCanceledEvent ) , null , 1000 , 1000 ) ;
113+
114+ // Pseudo random number generator to result in deterministic runs with the provided main MLContext's seed and to
115+ // maintain variability between training iterations.
116+ int ? mainContextSeed = ( ( ISeededEnvironment ) _context . Model . GetEnvironment ( ) ) . Seed ;
117+ _newContextSeedGenerator = ( mainContextSeed . HasValue ) ? RandomUtils . Create ( mainContextSeed . Value ) : null ;
58118
59119 do
60120 {
61- var iterationStopwatch = Stopwatch . StartNew ( ) ;
62-
63- // get next pipeline
64- var getPipelineStopwatch = Stopwatch . StartNew ( ) ;
65- var pipeline = PipelineSuggester . GetNextInferredPipeline ( _context , _history , _datasetColumnInfo , _task ,
66- _optimizingMetricInfo . IsMaximizing , _experimentSettings . CacheBeforeTrainer , _logger , _trainerAllowList ) ;
67-
68- var pipelineInferenceTimeInSeconds = getPipelineStopwatch . Elapsed . TotalSeconds ;
69-
70- // break if no candidates returned, means no valid pipeline available
71- if ( pipeline == null )
72- {
73- break ;
74- }
75-
76- // evaluate pipeline
77- _logger . Trace ( $ "Evaluating pipeline { pipeline . ToString ( ) } ") ;
78- ( SuggestedPipelineRunDetail suggestedPipelineRunDetail , TRunDetail runDetail )
79- = _runner . Run ( pipeline , _modelDirectory , _history . Count + 1 ) ;
80-
81- _history . Add ( suggestedPipelineRunDetail ) ;
82- WriteIterationLog ( pipeline , suggestedPipelineRunDetail , iterationStopwatch ) ;
83-
84- runDetail . RuntimeInSeconds = iterationStopwatch . Elapsed . TotalSeconds ;
85- runDetail . PipelineInferenceTimeInSeconds = getPipelineStopwatch . Elapsed . TotalSeconds ;
86-
87- ReportProgress ( runDetail ) ;
88- iterationResults . Add ( runDetail ) ;
89-
90- // if model is perfect, break
91- if ( _metricsAgent . IsModelPerfect ( suggestedPipelineRunDetail . Score ) )
121+ try
92122 {
93- break ;
123+ var iterationStopwatch = Stopwatch . StartNew ( ) ;
124+
125+ // get next pipeline
126+ var getPipelineStopwatch = Stopwatch . StartNew ( ) ;
127+
128+ // A new MLContext is needed per model run. When max experiment time is reached, each used
129+ // context is canceled to stop further model training. The cancellation of the main MLContext
130+ // a user has instantiated is not desirable, thus additional MLContexts are used.
131+ _currentModelMLContext = _newContextSeedGenerator == null ? new MLContext ( ) : new MLContext ( _newContextSeedGenerator . Next ( ) ) ;
132+ var pipeline = PipelineSuggester . GetNextInferredPipeline ( _currentModelMLContext , _history , _datasetColumnInfo , _task ,
133+ _optimizingMetricInfo . IsMaximizing , _experimentSettings . CacheBeforeTrainer , _logger , _trainerAllowList ) ;
134+ // break if no candidates returned, means no valid pipeline available
135+ if ( pipeline == null )
136+ {
137+ break ;
138+ }
139+
140+ // evaluate pipeline
141+ _logger . Trace ( $ "Evaluating pipeline { pipeline . ToString ( ) } ") ;
142+ ( SuggestedPipelineRunDetail suggestedPipelineRunDetail , TRunDetail runDetail )
143+ = _runner . Run ( pipeline , _modelDirectory , _history . Count + 1 ) ;
144+
145+ _history . Add ( suggestedPipelineRunDetail ) ;
146+ WriteIterationLog ( pipeline , suggestedPipelineRunDetail , iterationStopwatch ) ;
147+
148+ runDetail . RuntimeInSeconds = iterationStopwatch . Elapsed . TotalSeconds ;
149+ runDetail . PipelineInferenceTimeInSeconds = getPipelineStopwatch . Elapsed . TotalSeconds ;
150+
151+ ReportProgress ( runDetail ) ;
152+ iterationResults . Add ( runDetail ) ;
153+
154+ // if model is perfect, break
155+ if ( _metricsAgent . IsModelPerfect ( suggestedPipelineRunDetail . Score ) )
156+ {
157+ break ;
158+ }
159+
160+ // If after third run, all runs have failed so far, throw exception
161+ if ( _history . Count ( ) == 3 && _history . All ( r => ! r . RunSucceeded ) )
162+ {
163+ throw new InvalidOperationException ( $ "Training failed with the exception: { _history . Last ( ) . Exception } ") ;
164+ }
94165 }
95-
96- // If after third run, all runs have failed so far, throw exception
97- if ( _history . Count ( ) == 3 && _history . All ( r => ! r . RunSucceeded ) )
166+ catch ( OperationCanceledException e )
98167 {
99- throw new InvalidOperationException ( $ "Training failed with the exception: { _history . Last ( ) . Exception } ") ;
168+ // This exception is thrown when the IHost/MLContext of the trainer is canceled due to
169+ // reaching maximum experiment time. Simply catch this exception and return finished
170+ // iteration results.
171+ _logger . Warning ( "OperationCanceledException has been caught after maximum experiment time" +
172+ "was reached, and the running MLContext was stopped. Details: {0}" , e . Message ) ;
173+ return iterationResults ;
100174 }
101-
102175 } while ( _history . Count < _experimentSettings . MaxModels &&
103176 ! _experimentSettings . CancellationToken . IsCancellationRequested &&
104- stopwatch . Elapsed . TotalSeconds < _experimentSettings . MaxExperimentTimeInSeconds ) ;
105-
177+ ! _experimentTimerExpired ) ;
106178 return iterationResults ;
107179 }
108180
0 commit comments