@@ -241,13 +241,13 @@ def on_train_epoch_start(self, epoch):
241241 self .trainer .call_hook ("on_epoch_start" )
242242 self .trainer .call_hook ("on_train_epoch_start" )
243243
244- def on_train_batch_end (self , epoch_output , epoch_end_outputs , batch , batch_idx , dataloader_idx ):
244+ def on_train_batch_end (self , epoch_output , batch_end_outputs , batch , batch_idx , dataloader_idx ):
245245 # hook
246- self .trainer .call_hook ('on_train_batch_end' , epoch_end_outputs , batch , batch_idx , dataloader_idx )
246+ self .trainer .call_hook ('on_train_batch_end' , batch_end_outputs , batch , batch_idx , dataloader_idx )
247247 self .trainer .call_hook ('on_batch_end' )
248248
249249 # figure out what to track for epoch end
250- self .track_epoch_end_reduce_metrics (epoch_output , epoch_end_outputs )
250+ self .track_epoch_end_reduce_metrics (epoch_output , batch_end_outputs )
251251
252252 # reset batch logger internals
253253 self .trainer .logger_connector .on_train_batch_end ()
@@ -259,12 +259,27 @@ def reset_train_val_dataloaders(self, model):
259259 if self .trainer .val_dataloaders is None and not self .trainer .reload_dataloaders_every_epoch :
260260 self .trainer .reset_val_dataloader (model )
261261
262- def track_epoch_end_reduce_metrics (self , epoch_output , epoch_end_outputs ):
262+ def track_epoch_end_reduce_metrics (self , epoch_output , batch_end_outputs ):
263+
263264 # track the outputs to reduce at the end of the epoch
264- for opt_idx , opt_outputs in enumerate (epoch_end_outputs ):
265+ for opt_idx , opt_outputs in enumerate (batch_end_outputs ):
266+ sample_output = opt_outputs [- 1 ]
267+
268+ # decide if we need to reduce at the end of the epoch automatically
269+ auto_reduce_tng_result = isinstance (sample_output , Result ) and sample_output .should_reduce_on_epoch_end
270+ hook_overridden = (
271+ is_overridden ("training_epoch_end" , model = self .trainer .get_model ()) or
272+ is_overridden ("on_train_epoch_end" , model = self .trainer .get_model ())
273+ )
274+
275+ # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
276+ if not (hook_overridden or auto_reduce_tng_result ):
277+ continue
278+
265279 # with 1 step (no tbptt) don't use a sequence at epoch end
266280 if isinstance (opt_outputs , list ) and len (opt_outputs ) == 1 and not isinstance (opt_outputs [0 ], Result ):
267281 opt_outputs = opt_outputs [0 ]
282+
268283 epoch_output [opt_idx ].append (opt_outputs )
269284
270285 def get_optimizers_iterable (self ):
@@ -548,17 +563,14 @@ def run_training_epoch(self):
548563 if batch_output .signal == - 1 :
549564 break
550565
551- # only track outputs when user implements training_epoch_end
552- # otherwise we will build up unnecessary memory
553- epoch_end_outputs = self .process_train_step_outputs (
566+ batch_end_outputs = self .process_train_step_outputs (
554567 batch_output .training_step_output_for_epoch_end ,
555568 self .early_stopping_accumulator ,
556569 self .checkpoint_accumulator ,
557570 )
558-
559571 # hook
560572 # TODO: add outputs to batches
561- self .on_train_batch_end (epoch_output , epoch_end_outputs , batch , batch_idx , dataloader_idx )
573+ self .on_train_batch_end (epoch_output , batch_end_outputs , batch , batch_idx , dataloader_idx )
562574
563575 # -----------------------------------------
564576 # SAVE METRICS TO LOGGERS
@@ -896,7 +908,7 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu
896908 # the training step outputs a list per optimizer. The list contains the outputs at each time step
897909 # when no TBPTT is used, then the list has 1 item per batch
898910 # when TBPTT IS used, then the list has n items (1 per time step)
899- epoch_end_outputs = []
911+ batch_end_outputs = []
900912 for optimizer_idx_outputs in all_train_step_outputs :
901913 # extract one representative sample from each time step (1 if no tbptt) and 0th optimizer
902914 if len (optimizer_idx_outputs ) == 0 :
@@ -911,14 +923,9 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu
911923 if isinstance (sample_output , dict ) and "checkpoint_on" in sample_output :
912924 checkpoint_accumulator .accumulate (sample_output ["checkpoint_on" ])
913925
914- # decide if we need to reduce at the end of the epoch automatically
915- auto_reduce_tng_result = isinstance (sample_output , Result ) and sample_output .should_reduce_on_epoch_end
916-
917- # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
918- if is_overridden ("training_epoch_end" , model = self .trainer .get_model ()) or auto_reduce_tng_result :
919- epoch_end_outputs .append (optimizer_idx_outputs )
926+ batch_end_outputs .append (optimizer_idx_outputs )
920927
921- return epoch_end_outputs
928+ return batch_end_outputs
922929
923930 def prepare_optimizers (self ):
924931 # in manual optimization we loop over all optimizers at once
0 commit comments