2222from pytorch_lightning .trainer .progress import Progress , SchedulerProgress
2323from pytorch_lightning .utilities .exceptions import MisconfigurationException
2424from pytorch_lightning .utilities .model_helpers import is_overridden
25- from pytorch_lightning .utilities .signature_utils import is_param_in_hook_signature
2625from pytorch_lightning .utilities .types import STEP_OUTPUT
2726from pytorch_lightning .utilities .warnings import WarningCache
2827
@@ -227,7 +226,7 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]:
227226 self .trainer .fit_loop .epoch_progress .increment_processed ()
228227
229228 # call train epoch end hooks
230- self ._on_train_epoch_end_hook ( processed_outputs )
229+ self .trainer . call_hook ( "on_train_epoch_end" )
231230 self .trainer .call_hook ("on_epoch_end" )
232231 self .trainer .logger_connector .on_epoch_end ()
233232
@@ -250,47 +249,6 @@ def _run_validation(self):
250249 with torch .no_grad ():
251250 self .val_loop .run ()
252251
253- def _on_train_epoch_end_hook (self , processed_epoch_output : List [List [STEP_OUTPUT ]]) -> None :
254- """Runs ``on_train_epoch_end hook``."""
255- # We cannot rely on Trainer.call_hook because the signatures might be different across
256- # lightning module and callback
257- # As a result, we need to inspect if the module accepts `outputs` in `on_train_epoch_end`
258-
259- # This implementation is copied from Trainer.call_hook
260- hook_name = "on_train_epoch_end"
261- prev_fx_name = self .trainer .lightning_module ._current_fx_name
262- self .trainer .lightning_module ._current_fx_name = hook_name
263-
264- # always profile hooks
265- with self .trainer .profiler .profile (hook_name ):
266-
267- # first call trainer hook
268- if hasattr (self .trainer , hook_name ):
269- trainer_hook = getattr (self .trainer , hook_name )
270- trainer_hook (processed_epoch_output )
271-
272- # next call hook in lightningModule
273- model_ref = self .trainer .lightning_module
274- if is_overridden (hook_name , model_ref ):
275- hook_fx = getattr (model_ref , hook_name )
276- if is_param_in_hook_signature (hook_fx , "outputs" ):
277- self ._warning_cache .deprecation (
278- "The signature of `ModelHooks.on_train_epoch_end` has changed in v1.3."
279- " `outputs` parameter has been deprecated."
280- " Support for the old signature will be removed in v1.5"
281- )
282- model_ref .on_train_epoch_end (processed_epoch_output )
283- else :
284- model_ref .on_train_epoch_end ()
285-
286- # call the accelerator hook
287- if hasattr (self .trainer .accelerator , hook_name ):
288- accelerator_hook = getattr (self .trainer .accelerator , hook_name )
289- accelerator_hook ()
290-
291- # restore current_fx when nested context
292- self .trainer .lightning_module ._current_fx_name = prev_fx_name
293-
294252 def _accumulated_batches_reached (self ) -> bool :
295253 """Determine if accumulation will be finished by the end of the current batch."""
296254 return self .batch_progress .current .ready % self .trainer .accumulate_grad_batches == 0
@@ -313,7 +271,7 @@ def _track_epoch_end_reduce_metrics(
313271 self , epoch_output : List [List [STEP_OUTPUT ]], batch_end_outputs : STEP_OUTPUT
314272 ) -> None :
315273 """Adds the batch outputs to the epoch outputs and prepares reduction"""
316- hook_overridden = self ._should_add_batch_output_to_epoch_output ( )
274+ hook_overridden = is_overridden ( "training_epoch_end" , self .trainer . lightning_module )
317275 if not hook_overridden :
318276 return
319277
@@ -329,24 +287,6 @@ def _track_epoch_end_reduce_metrics(
329287
330288 epoch_output [opt_idx ].append (opt_outputs )
331289
332- def _should_add_batch_output_to_epoch_output (self ) -> bool :
333- """
334- We add to the epoch outputs if
335- 1. The model defines training_epoch_end OR
336- 2. The model overrides on_train_epoch_end which has `outputs` in the signature
337- """
338- # TODO: in v1.5 this only needs to check if training_epoch_end is overridden
339- lightning_module = self .trainer .lightning_module
340- if is_overridden ("training_epoch_end" , lightning_module ):
341- return True
342-
343- if is_overridden ("on_train_epoch_end" , lightning_module ):
344- model_hook_fx = getattr (lightning_module , "on_train_epoch_end" )
345- if is_param_in_hook_signature (model_hook_fx , "outputs" ):
346- return True
347-
348- return False
349-
350290 @staticmethod
351291 def _prepare_outputs (
352292 outputs : List [List [List ["ResultCollection" ]]], batch_mode : bool
0 commit comments