@@ -333,46 +333,77 @@ def test_early_stopping_mode_options():
333333
334334class EarlyStoppingModel (BoringModel ):
335335
336- def __init__ (self , expected_end_epoch ):
336+ def __init__ (self , expected_end_epoch : int , during_training : bool ):
337337 super ().__init__ ()
338338 self .expected_end_epoch = expected_end_epoch
339+ self .during_training = during_training
340+
341+ def training_epoch_end (self , outputs ):
342+ if not self .during_training :
343+ return
344+ losses = [8 , 4 , 2 , 3 , 4 , 5 , 8 , 10 ]
345+ loss = losses [self .current_epoch ]
346+ self .log ('abc' , torch .tensor (loss ))
347+ self .log ('cba' , torch .tensor (0 ))
339348
340349 def validation_epoch_end (self , outputs ):
350+ if self .during_training :
351+ return
341352 losses = [8 , 4 , 2 , 3 , 4 , 5 , 8 , 10 ]
342- val_loss = losses [self .current_epoch ]
343- self .log ('abc' , torch .tensor (val_loss ))
353+ loss = losses [self .current_epoch ]
354+ self .log ('abc' , torch .tensor (loss ))
344355 self .log ('cba' , torch .tensor (0 ))
345356
346357 def on_train_end (self ) -> None :
347358 assert self .trainer .current_epoch == self .expected_end_epoch , 'Early Stopping Failed'
348359
349360
350361@pytest .mark .parametrize (
351- "callbacks, expected_stop_epoch, accelerator, num_processes" ,
362+ "callbacks, expected_stop_epoch, during_training, accelerator, num_processes" ,
352363 [
353- ([EarlyStopping (monitor = 'abc' ), EarlyStopping (monitor = 'cba' , patience = 3 )], 3 , None , 1 ),
364+ ([EarlyStopping (monitor = 'abc' ), EarlyStopping (monitor = 'cba' , patience = 3 )], 3 , False , None , 1 ),
354365 ([EarlyStopping (monitor = 'cba' , patience = 3 ),
355- EarlyStopping (monitor = 'abc' )], 3 , None , 1 ),
366+ EarlyStopping (monitor = 'abc' )], 3 , False , None , 1 ),
356367 pytest .param ([EarlyStopping (monitor = 'abc' ),
357368 EarlyStopping (monitor = 'cba' , patience = 3 )],
358369 3 ,
370+ False ,
359371 'ddp_cpu' ,
360372 2 ,
361373 marks = RunIf (skip_windows = True )),
362374 pytest .param ([EarlyStopping (monitor = 'cba' , patience = 3 ),
363375 EarlyStopping (monitor = 'abc' )],
364376 3 ,
377+ False ,
378+ 'ddp_cpu' ,
379+ 2 ,
380+ marks = RunIf (skip_windows = True )),
381+ ([EarlyStopping (monitor = 'abc' , during_training = True ), EarlyStopping (monitor = 'cba' , patience = 3 , during_training = True )], 3 , True , None , 1 ),
382+ ([EarlyStopping (monitor = 'cba' , patience = 3 , during_training = True ),
383+ EarlyStopping (monitor = 'abc' , during_training = True )], 3 , True , None , 1 ),
384+ pytest .param ([EarlyStopping (monitor = 'abc' , during_training = True ),
385+ EarlyStopping (monitor = 'cba' , patience = 3 , during_training = True )],
386+ 3 ,
387+ True ,
388+ 'ddp_cpu' ,
389+ 2 ,
390+ marks = RunIf (skip_windows = True )),
391+ pytest .param ([EarlyStopping (monitor = 'cba' , patience = 3 , during_training = True ),
392+ EarlyStopping (monitor = 'abc' , during_training = True )],
393+ 3 ,
394+ True ,
365395 'ddp_cpu' ,
366396 2 ,
367397 marks = RunIf (skip_windows = True )),
398+
368399 ],
369400)
370401def test_multiple_early_stopping_callbacks (
371- tmpdir , callbacks : List [EarlyStopping ], expected_stop_epoch : int , accelerator : Optional [str ], num_processes : int
402+ tmpdir , callbacks : List [EarlyStopping ], expected_stop_epoch : int , during_training : bool , accelerator : Optional [str ], num_processes : int
372403):
373404 """Ensure when using multiple early stopping callbacks we stop if any signals we should stop."""
374405
375- model = EarlyStoppingModel (expected_stop_epoch )
406+ model = EarlyStoppingModel (expected_stop_epoch , during_training )
376407
377408 trainer = Trainer (
378409 default_root_dir = tmpdir ,
0 commit comments