@@ -213,11 +213,13 @@ def test_early_stopping_no_val_step(tmpdir):
213213 assert trainer .current_epoch < trainer .max_epochs - 1
214214
215215
216- @pytest .mark .parametrize ("stopping_threshold,divergence_theshold,losses,expected_epoch" , [
217- (None , None , [8 , 4 , 2 , 3 , 4 , 5 , 8 , 10 ], 5 ),
218- (2.9 , None , [9 , 8 , 7 , 6 , 5 , 6 , 4 , 3 , 2 , 1 ], 8 ),
219- (None , 15.9 , [9 , 4 , 2 , 16 , 32 , 64 ], 3 ),
220- ])
216+ @pytest .mark .parametrize (
217+ "stopping_threshold,divergence_theshold,losses,expected_epoch" , [
218+ (None , None , [8 , 4 , 2 , 3 , 4 , 5 , 8 , 10 ], 5 ),
219+ (2.9 , None , [9 , 8 , 7 , 6 , 5 , 6 , 4 , 3 , 2 , 1 ], 8 ),
220+ (None , 15.9 , [9 , 4 , 2 , 16 , 32 , 64 ], 3 ),
221+ ]
222+ )
221223def test_early_stopping_thresholds (tmpdir , stopping_threshold , divergence_theshold , losses , expected_epoch ):
222224
223225 class CurrentModel (BoringModel ):
@@ -338,7 +340,7 @@ def validation_epoch_end(self, outputs):
338340 limit_train_batches = limit_train_batches ,
339341 limit_val_batches = 2 ,
340342 min_steps = min_steps ,
341- min_epochs = min_epochs
343+ min_epochs = min_epochs ,
342344 )
343345 trainer .fit (model )
344346
@@ -359,8 +361,13 @@ def validation_epoch_end(self, outputs):
359361 by_min_epochs = min_epochs * limit_train_batches
360362
361363 # Make sure the trainer stops for the max of all minimum requirements
362- assert trainer .global_step == max (min_steps , by_early_stopping , by_min_epochs ), \
363- (trainer .global_step , max (min_steps , by_early_stopping , by_min_epochs ), step_freeze , min_steps , min_epochs )
364+ assert trainer .global_step == max (min_steps , by_early_stopping , by_min_epochs ), (
365+ trainer .global_step ,
366+ max (min_steps , by_early_stopping , by_min_epochs ),
367+ step_freeze ,
368+ min_steps ,
369+ min_epochs ,
370+ )
364371
365372 _logger .disabled = False
366373
@@ -372,53 +379,69 @@ def test_early_stopping_mode_options():
372379
373380class EarlyStoppingModel (BoringModel ):
374381
375- def __init__ (self , expected_end_epoch ):
382+ def __init__ (self , expected_end_epoch : int , early_stop_on_train : bool ):
376383 super ().__init__ ()
377384 self .expected_end_epoch = expected_end_epoch
385+ self .early_stop_on_train = early_stop_on_train
378386
379- def validation_epoch_end (self , outputs ) :
387+ def _epoch_end (self ) -> None :
380388 losses = [8 , 4 , 2 , 3 , 4 , 5 , 8 , 10 ]
381- val_loss = losses [self .current_epoch ]
382- self .log ('abc' , torch .tensor (val_loss ))
389+ loss = losses [self .current_epoch ]
390+ self .log ('abc' , torch .tensor (loss ))
383391 self .log ('cba' , torch .tensor (0 ))
384392
393+ def training_epoch_end (self , outputs ):
394+ if not self .early_stop_on_train :
395+ return
396+ self ._epoch_end ()
397+
398+ def validation_epoch_end (self , outputs ):
399+ if self .early_stop_on_train :
400+ return
401+ self ._epoch_end ()
402+
385403 def on_train_end (self ) -> None :
386404 assert self .trainer .current_epoch == self .expected_end_epoch , 'Early Stopping Failed'
387405
388406
407+ _ES_CHECK = dict (check_on_train_epoch_end = True )
408+ _ES_CHECK_P3 = dict (patience = 3 , check_on_train_epoch_end = True )
409+ _NO_WIN = dict (marks = RunIf (skip_windows = True ))
410+
411+
389412@pytest .mark .parametrize (
390- "callbacks, expected_stop_epoch, accelerator, num_processes" ,
413+ "callbacks, expected_stop_epoch, check_on_train_epoch_end, accelerator, num_processes" ,
391414 [
392- ([EarlyStopping (monitor = 'abc' ), EarlyStopping (monitor = 'cba' , patience = 3 )], 3 , None , 1 ),
393- ([EarlyStopping (monitor = 'cba' , patience = 3 ),
394- EarlyStopping (monitor = 'abc' )], 3 , None , 1 ),
395- pytest .param ([EarlyStopping (monitor = 'abc' ),
396- EarlyStopping (monitor = 'cba' , patience = 3 )],
397- 3 ,
398- 'ddp_cpu' ,
399- 2 ,
400- marks = RunIf (skip_windows = True )),
401- pytest .param ([EarlyStopping (monitor = 'cba' , patience = 3 ),
402- EarlyStopping (monitor = 'abc' )],
403- 3 ,
404- 'ddp_cpu' ,
405- 2 ,
406- marks = RunIf (skip_windows = True )),
415+ ([EarlyStopping ('abc' ), EarlyStopping ('cba' , patience = 3 )], 3 , False , None , 1 ),
416+ ([EarlyStopping ('cba' , patience = 3 ), EarlyStopping ('abc' )], 3 , False , None , 1 ),
417+ pytest .param ([EarlyStopping ('abc' ), EarlyStopping ('cba' , patience = 3 )], 3 , False , 'ddp_cpu' , 2 , ** _NO_WIN ),
418+ pytest .param ([EarlyStopping ('cba' , patience = 3 ), EarlyStopping ('abc' )], 3 , False , 'ddp_cpu' , 2 , ** _NO_WIN ),
419+ ([EarlyStopping ('abc' , ** _ES_CHECK ), EarlyStopping ('cba' , ** _ES_CHECK_P3 )], 3 , True , None , 1 ),
420+ ([EarlyStopping ('cba' , ** _ES_CHECK_P3 ), EarlyStopping ('abc' , ** _ES_CHECK )], 3 , True , None , 1 ),
421+ pytest .param ([EarlyStopping ('abc' , ** _ES_CHECK ),
422+ EarlyStopping ('cba' , ** _ES_CHECK_P3 )], 3 , True , 'ddp_cpu' , 2 , ** _NO_WIN ),
423+ pytest .param ([EarlyStopping ('cba' , ** _ES_CHECK_P3 ),
424+ EarlyStopping ('abc' , ** _ES_CHECK )], 3 , True , 'ddp_cpu' , 2 , ** _NO_WIN ),
407425 ],
408426)
409427def test_multiple_early_stopping_callbacks (
410- tmpdir , callbacks : List [EarlyStopping ], expected_stop_epoch : int , accelerator : Optional [str ], num_processes : int
428+ tmpdir ,
429+ callbacks : List [EarlyStopping ],
430+ expected_stop_epoch : int ,
431+ check_on_train_epoch_end : bool ,
432+ accelerator : Optional [str ],
433+ num_processes : int ,
411434):
412435 """Ensure when using multiple early stopping callbacks we stop if any signals we should stop."""
413436
414- model = EarlyStoppingModel (expected_stop_epoch )
437+ model = EarlyStoppingModel (expected_stop_epoch , check_on_train_epoch_end )
415438
416439 trainer = Trainer (
417440 default_root_dir = tmpdir ,
418441 callbacks = callbacks ,
419442 overfit_batches = 0.20 ,
420443 max_epochs = 20 ,
421444 accelerator = accelerator ,
422- num_processes = num_processes
445+ num_processes = num_processes ,
423446 )
424447 trainer .fit (model )
0 commit comments