@@ -778,24 +778,29 @@ def test_metric_are_properly_reduced(tmpdir):
778778 class TestingModel (BoringModel ):
779779 def __init__ (self , * args , ** kwargs ):
780780 super ().__init__ ()
781- self .acc = pl .metrics .Accuracy ()
781+ self .train_acc = pl .metrics .Accuracy ()
782+ self .val_acc = pl .metrics .Accuracy ()
782783
783784 def training_step (self , batch , batch_idx ):
784- self .acc (torch .rand (1 , 3 , device = self .device ), torch .randint (0 , 2 , (1 ,), device = self .device ))
785- self .log ('train_acc' , self .acc , on_step = True , on_epoch = True )
785+ self .train_acc (torch .rand (1 , 3 , device = self .device ), torch .randint (0 , 2 , (1 ,), device = self .device ))
786+ self .log ('train_acc' , self .train_acc , on_step = True , on_epoch = True )
786787 return super ().training_step (batch , batch_idx )
787788
788789 def validation_step (self , batch , batch_idx ):
789- self .acc (torch .rand (1 , 3 , device = self .device ), torch .randint (0 , 2 , (1 ,), device = self .device ))
790- self .log ('val_acc' , self .acc , on_step = True , on_epoch = True )
790+ preds = torch .tensor (0 , device = self .device )
791+ targets = torch .tensor (1 , device = self .device )
792+ if batch_idx < 8 :
793+ targets = preds
794+ self .val_acc (preds , targets )
795+ self .log ('val_acc' , self .val_acc , on_step = True , on_epoch = True )
791796 return super ().validation_step (batch , batch_idx )
792797
793798 early_stop = EarlyStopping (monitor = 'val_acc' , mode = 'max' )
794799
795800 checkpoint = ModelCheckpoint (
796801 monitor = 'val_acc' ,
797802 save_last = True ,
798- save_top_k = 5 ,
803+ save_top_k = 2 ,
799804 mode = 'max' ,
800805 )
801806
@@ -804,8 +809,10 @@ def validation_step(self, batch, batch_idx):
804809 default_root_dir = tmpdir ,
805810 gpus = 1 ,
806811 max_epochs = 2 ,
812+ limit_train_batches = 5 ,
813+ limit_val_batches = 32 ,
807814 callbacks = [early_stop , checkpoint ])
808815 trainer .fit (model )
809816
810- assert "val_acc" in trainer . callback_metrics
817+ assert trainer . callback_metrics [ "val_acc" ] == 8 / 32.
811818 assert "train_acc" in trainer .callback_metrics
0 commit comments