@@ -94,11 +94,9 @@ def training_epoch_end(self, outputs):
9494
9595def test_training_epoch_end_metrics_collection_on_override (tmpdir ):
9696 """ Test that batch end metrics are collected when training_epoch_end is overridden at the end of an epoch. """
97- num_epochs = 1
9897
9998 class LoggingCallback (pl .Callback ):
100-
101- def on_train_epoch_end (self , trainer , pl_module ):
99+ def on_train_epoch_start (self , trainer , pl_module ):
102100 self .len_outputs = 0
103101
104102 def on_train_epoch_end (self , trainer , pl_module , outputs ):
@@ -110,7 +108,6 @@ def on_train_epoch_start(self):
110108 self .num_train_batches = 0
111109
112110 def training_epoch_end (self , outputs ): # Overridden
113- pass
114111 return
115112
116113 def on_train_batch_end (self , outputs , batch , batch_idx , dataloader_idx ):
@@ -129,19 +126,20 @@ def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
129126
130127 callback = LoggingCallback ()
131128 trainer = Trainer (
132- max_epochs = num_epochs ,
129+ max_epochs = 1 ,
133130 default_root_dir = tmpdir ,
134131 overfit_batches = 2 ,
135132 callbacks = [callback ],
136133 )
137134
138- result = trainer .fit (overridden_model )
135+ trainer .fit (overridden_model )
136+ # outputs from on_train_batch_end should be accessible in on_train_epoch_end hook
137+ # if training_epoch_end is overridden
139138 assert callback .len_outputs == overridden_model .num_train_batches
140- # outputs from on_train_batch_end should be accessible in on_train_epoch_end hook if training_epoch_end is overridden
141139
142- result = trainer .fit (not_overridden_model )
143- assert callback .len_outputs == 0
140+ trainer .fit (not_overridden_model )
144141 # outputs from on_train_batch_end should be empty
142+ assert callback .len_outputs == 0
145143
146144
147145@pytest .mark .skipif (not torch .cuda .is_available (), reason = "test requires GPU machine" )
0 commit comments