@@ -249,7 +249,8 @@ public override NasBertTransformer Fit(IDataView input)
249249 for ( int i = 0 ; i < Option . MaxEpoch ; i ++ )
250250 {
251251 ch . Trace ( $ "Starting epoch { i } ") ;
252- trainer . Train ( input ) ;
252+ Host . CheckAlive ( ) ;
253+ trainer . Train ( Host , input ) ;
253254 ch . Trace ( $ "Finished epoch { i } ") ;
254255 if ( Option . ValidationSet != null )
255256 trainer . Validate ( pch , ch , i ) ;
@@ -423,7 +424,7 @@ private bool ValidateStep(DataViewRowCursor cursor,
423424 return cursorValid ;
424425 }
425426
426- public void Train ( IDataView input )
427+ public void Train ( IHost host , IDataView input )
427428 {
428429 // Get the cursor and the correct columns based on the inputs
429430 DataViewRowCursor cursor = default ;
@@ -443,14 +444,15 @@ public void Train(IDataView input)
443444 var cursorValid = true ;
444445 while ( cursorValid )
445446 {
446- cursorValid = TrainStep ( cursor , sentence1Getter , sentence2Getter , labelGetter , ref inputTensors , ref targets ) ;
447+ cursorValid = TrainStep ( host , cursor , sentence1Getter , sentence2Getter , labelGetter , ref inputTensors , ref targets ) ;
447448 }
448449 }
449450
450- private bool TrainStep ( DataViewRowCursor cursor ,
451- ValueGetter < ReadOnlyMemory < char > > sentence1Getter ,
452- ValueGetter < ReadOnlyMemory < char > > sentence2Getter ,
453- ValueGetter < TLabelCol > labelGetter ,
451+ private bool TrainStep ( IHost host ,
452+ DataViewRowCursor cursor ,
453+ ValueGetter < ReadOnlyMemory < char > > sentence1Getter ,
454+ ValueGetter < ReadOnlyMemory < char > > sentence2Getter ,
455+ ValueGetter < TLabelCol > labelGetter ,
454456 ref List < Tensor > inputTensors ,
455457 ref List < TTargetsCol > targets )
456458 {
@@ -461,6 +463,7 @@ private bool TrainStep(DataViewRowCursor cursor,
461463 var cursorValid = true ;
462464 for ( int i = 0 ; i < Parent . Option . BatchSize && cursorValid ; i ++ )
463465 {
466+ host . CheckAlive ( ) ;
464467 cursorValid = cursor . MoveNext ( ) ;
465468 if ( cursorValid )
466469 {
@@ -479,7 +482,7 @@ private bool TrainStep(DataViewRowCursor cursor,
479482 }
480483
481484 Updates ++ ;
482-
485+ host . CheckAlive ( ) ;
483486 torch . random . manual_seed ( 1 + Updates ) ;
484487 torch . cuda . manual_seed ( 1 + Updates ) ;
485488 Model . train ( ) ;
@@ -497,8 +500,10 @@ private bool TrainStep(DataViewRowCursor cursor,
497500 loss = torch . nn . MSELoss ( reduction : Parent . Option . Reduction ) . forward ( logits , targetsTensor ) ;
498501 logits = logits . squeeze ( ) ;
499502 }
500-
503+ host . CheckAlive ( ) ;
501504 loss . backward ( ) ;
505+
506+ host . CheckAlive ( ) ;
502507 OptimizeStep ( ) ;
503508
504509 return cursorValid ;
0 commit comments