@@ -863,19 +863,18 @@ def test_metric_are_properly_reduced(tmpdir):
863863 class TestingModel (BoringModel ):
864864 def __init__ (self , * args , ** kwargs ):
865865 super ().__init__ ()
866- self .train_acc = pl .metrics .Accuracy ()
867866 self .val_acc = pl .metrics .Accuracy ()
868867
869868 def training_step (self , batch , batch_idx ):
870- self . train_acc ( torch . rand ( 1 , 3 , device = self . device ), torch . randint ( 0 , 2 , ( 1 ,), device = self . device ) )
871- self .log ('train_acc' , self . train_acc , on_step = True , on_epoch = True )
872- return super (). training_step ( batch , batch_idx )
869+ output = super (). training_step ( batch , batch_idx )
870+ self .log ("train_loss" , output [ "loss" ] )
871+ return output
873872
874873 def validation_step (self , batch , batch_idx ):
875- preds = torch .tensor (0 , device = self .device )
876- targets = torch .tensor (1 , device = self .device )
874+ preds = torch .tensor ([[ 0.9 , 0.1 ]] , device = self .device )
875+ targets = torch .tensor ([ 1 ] , device = self .device )
877876 if batch_idx < 8 :
878- targets = preds
877+ preds = torch . tensor ([[ 0.1 , 0.9 ]], device = self . device )
879878 self .val_acc (preds , targets )
880879 self .log ('val_acc' , self .val_acc , on_step = True , on_epoch = True )
881880 return super ().validation_step (batch , batch_idx )
@@ -900,4 +899,4 @@ def validation_step(self, batch, batch_idx):
900899 trainer .fit (model )
901900
902901 assert trainer .callback_metrics ["val_acc" ] == 8 / 32.
903- assert "train_acc " in trainer .callback_metrics
902+ assert "train_loss " in trainer .callback_metrics
0 commit comments