@@ -404,71 +404,24 @@ def on_train_end(self) -> None:
404404 assert self .trainer .current_epoch == self .expected_end_epoch , 'Early Stopping Failed'
405405
406406
407+ _ES_CHECK = dict (check_on_train_epoch_end = True )
408+ _ES_CHECK_P3 = dict (patience = 3 , check_on_train_epoch_end = True )
409+ _NO_WIN = dict (marks = RunIf (skip_windows = True ))
410+
411+
407412@pytest .mark .parametrize (
408413 "callbacks, expected_stop_epoch, check_on_train_epoch_end, accelerator, num_processes" ,
409414 [
410- ([EarlyStopping (monitor = 'abc' ), EarlyStopping (monitor = 'cba' , patience = 3 )], 3 , False , None , 1 ),
411- ([EarlyStopping (monitor = 'cba' , patience = 3 ),
412- EarlyStopping (monitor = 'abc' )], 3 , False , None , 1 ),
413- pytest .param (
414- [EarlyStopping (monitor = 'abc' ), EarlyStopping (monitor = 'cba' , patience = 3 )],
415- 3 ,
416- False ,
417- 'ddp_cpu' ,
418- 2 ,
419- marks = RunIf (skip_windows = True ),
420- ),
421- pytest .param (
422- [EarlyStopping (monitor = 'cba' , patience = 3 ),
423- EarlyStopping (monitor = 'abc' )],
424- 3 ,
425- False ,
426- 'ddp_cpu' ,
427- 2 ,
428- marks = RunIf (skip_windows = True ),
429- ),
430- (
431- [
432- EarlyStopping (monitor = 'abc' , check_on_train_epoch_end = True ),
433- EarlyStopping (monitor = 'cba' , patience = 3 , check_on_train_epoch_end = True ),
434- ],
435- 3 ,
436- True ,
437- None ,
438- 1 ,
439- ),
440- (
441- [
442- EarlyStopping (monitor = 'cba' , patience = 3 , check_on_train_epoch_end = True ),
443- EarlyStopping (monitor = 'abc' , check_on_train_epoch_end = True ),
444- ],
445- 3 ,
446- True ,
447- None ,
448- 1 ,
449- ),
450- pytest .param (
451- [
452- EarlyStopping (monitor = 'abc' , check_on_train_epoch_end = True ),
453- EarlyStopping (monitor = 'cba' , patience = 3 , check_on_train_epoch_end = True ),
454- ],
455- 3 ,
456- True ,
457- 'ddp_cpu' ,
458- 2 ,
459- marks = RunIf (skip_windows = True ),
460- ),
461- pytest .param (
462- [
463- EarlyStopping (monitor = 'cba' , patience = 3 , check_on_train_epoch_end = True ),
464- EarlyStopping (monitor = 'abc' , check_on_train_epoch_end = True ),
465- ],
466- 3 ,
467- True ,
468- 'ddp_cpu' ,
469- 2 ,
470- marks = RunIf (skip_windows = True ),
471- ),
415+ ([EarlyStopping ('abc' ), EarlyStopping ('cba' , patience = 3 )], 3 , False , None , 1 ),
416+ ([EarlyStopping ('cba' , patience = 3 ), EarlyStopping ('abc' )], 3 , False , None , 1 ),
417+ pytest .param ([EarlyStopping ('abc' ), EarlyStopping ('cba' , patience = 3 )], 3 , False , 'ddp_cpu' , 2 , ** _NO_WIN ),
418+ pytest .param ([EarlyStopping ('cba' , patience = 3 ), EarlyStopping ('abc' )], 3 , False , 'ddp_cpu' , 2 , ** _NO_WIN ),
419+ ([EarlyStopping ('abc' , ** _ES_CHECK ), EarlyStopping ('cba' , ** _ES_CHECK_P3 )], 3 , True , None , 1 ),
420+ ([EarlyStopping ('cba' , ** _ES_CHECK_P3 ), EarlyStopping ('abc' , ** _ES_CHECK )], 3 , True , None , 1 ),
421+ pytest .param ([EarlyStopping ('abc' , ** _ES_CHECK ),
422+ EarlyStopping ('cba' , ** _ES_CHECK_P3 )], 3 , True , 'ddp_cpu' , 2 , ** _NO_WIN ),
423+ pytest .param ([EarlyStopping ('cba' , ** _ES_CHECK_P3 ),
424+ EarlyStopping ('abc' , ** _ES_CHECK )], 3 , True , 'ddp_cpu' , 2 , ** _NO_WIN ),
472425 ],
473426)
474427def test_multiple_early_stopping_callbacks (
0 commit comments