From cec0f2526229f8eba7bc755158f795ecdb8ad124 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 21:46:36 -0700 Subject: [PATCH 01/26] Remove outputs from on_train_epoch_end --- pytorch_lightning/trainer/callback_hook.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index fcdd8f55f6a6e..8c922a842f965 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -89,7 +89,7 @@ def on_train_epoch_start(self): for callback in self.callbacks: callback.on_train_epoch_start(self, self.lightning_module) - def on_train_epoch_end(self, outputs: EPOCH_OUTPUT): + def on_train_epoch_end(self, outputs: Optional[EPOCH_OUTPUT] = None): """Called when the epoch ends. Args: From 3ebc857fb01e46f21a665bbeda8e3084870f0e81 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 22:27:50 -0700 Subject: [PATCH 02/26] iterate --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c93232a9b4d42..6a82c6b2b60fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -213,6 +213,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `Trainer.truncated_bptt_steps` in favor of `LightningModule.truncated_bptt_steps` ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323)) +- Deprecated `outputs` in both `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#7339](https://github.com/PyTorchLightning/pytorch-lightning/pull/7339)) + + - Deprecated `LightningModule.grad_norm` in favor of `pytorch_lightning.utilities.grads.grad_norm` ([#7292](https://github.com/PyTorchLightning/pytorch-lightning/pull/7292)) From 425f68136078620812a1fb60ed26aaaa2ff3597b Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 22:29:25 -0700 Subject: [PATCH 03/26] Update callback_hook.py --- pytorch_lightning/trainer/callback_hook.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 8c922a842f965..fcdd8f55f6a6e 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -89,7 +89,7 @@ def on_train_epoch_start(self): for callback in self.callbacks: callback.on_train_epoch_start(self, self.lightning_module) - def on_train_epoch_end(self, outputs: Optional[EPOCH_OUTPUT] = None): + def on_train_epoch_end(self, outputs: EPOCH_OUTPUT): """Called when the epoch ends. Args: From acd2821d2b2c392e198af7c4b06987c09f35cb4d Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 22:58:07 -0700 Subject: [PATCH 04/26] update --- pytorch_lightning/trainer/training_loop.py | 47 ++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 790dc4c70bdeb..e840f6edbabd5 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -607,9 +607,56 @@ def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None: # capture logging self.trainer.logger_connector.cache_logged_metrics() +<<<<<<< HEAD # call train epoch end hooks self._on_train_epoch_end_hook(processed_epoch_output) self.trainer.call_hook('on_epoch_end') +======= + # call train epoch end hooks + self._on_train_epoch_end_hook(processed_epoch_output) + self.trainer.call_hook('on_epoch_end') + + def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: + # Cannot rely on Trainer.call_hook because the signatures might be different across + # lightning module and callback + # Here we need to inspect if the module accepts `outputs` in `on_train_epoch_end` + + # This implementation is copied from Trainer.call_hook + hook_name = "on_train_epoch_end" + + # set hook_name to model + reset Result obj + skip = self.trainer._reset_result_and_set_hook_fx_name(hook_name) + + # always profile hooks + with self.trainer.profiler.profile(hook_name): + + # first call trainer hook + if hasattr(self.trainer, hook_name): + trainer_hook = getattr(self.trainer, hook_name) + trainer_hook(processed_epoch_output) + + # next call hook in lightningModule + model_ref = self.trainer.lightning_module + if is_overridden(hook_name, model_ref): + hook_fx = getattr(model_ref, hook_name) + if is_param_in_hook_signature(hook_fx, "outputs"): + self.warning_cache.warn( + f"`ModelHooks.on_train_epoch_end` signature has changed in v1.3. `outputs` parameter has been deprecated." + " Support for the old signature will be removed in v1.5", DeprecationWarning + ) + model_ref.on_train_epoch_end(processed_epoch_output) + else: + model_ref.on_train_epoch_end() + + # if the PL module doesn't have the hook then call the accelerator + # used to auto-reduce things for the user with Results obj + elif hasattr(self.trainer.accelerator, hook_name): + accelerator_hook = getattr(self.trainer.accelerator, hook_name) + accelerator_hook() + + if not skip: + self.trainer._cache_logged_metrics() +>>>>>>> update def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: # We cannot rely on Trainer.call_hook because the signatures might be different across From fe6c4475623c23eb14f442ebe35999b78ce84b27 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 23:30:26 -0700 Subject: [PATCH 05/26] early stop? --- pytorch_lightning/callbacks/early_stopping.py | 2 +- pytorch_lightning/trainer/training_loop.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 242eeed808f34..f1a1789856642 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -161,7 +161,7 @@ def _should_skip_check(self, trainer) -> bool: from pytorch_lightning.trainer.states import TrainerFn return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking - def on_train_epoch_end(self, trainer, pl_module) -> None: + def on_train_epoch_end(self, trainer, pl_module, outputs) -> None: if not self._check_on_train_epoch_end or self._should_skip_check(trainer): return self._run_early_stopping_check(trainer) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index e840f6edbabd5..3cf93d87a917c 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -641,7 +641,8 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: hook_fx = getattr(model_ref, hook_name) if is_param_in_hook_signature(hook_fx, "outputs"): self.warning_cache.warn( - f"`ModelHooks.on_train_epoch_end` signature has changed in v1.3. `outputs` parameter has been deprecated." + "`ModelHooks.on_train_epoch_end` signature has changed in v1.3." + " `outputs` parameter has been deprecated." " Support for the old signature will be removed in v1.5", DeprecationWarning ) model_ref.on_train_epoch_end(processed_epoch_output) From 39987adff270a0dc309108b2e96352375d27e80a Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 23:38:45 -0700 Subject: [PATCH 06/26] fix --- pytorch_lightning/callbacks/early_stopping.py | 2 +- pytorch_lightning/trainer/training_loop.py | 9 ++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index f1a1789856642..242eeed808f34 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -161,7 +161,7 @@ def _should_skip_check(self, trainer) -> bool: from pytorch_lightning.trainer.states import TrainerFn return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking - def on_train_epoch_end(self, trainer, pl_module, outputs) -> None: + def on_train_epoch_end(self, trainer, pl_module) -> None: if not self._check_on_train_epoch_end or self._should_skip_check(trainer): return self._run_early_stopping_check(trainer) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 3cf93d87a917c..2d7a9a75c4f9f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -607,19 +607,14 @@ def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None: # capture logging self.trainer.logger_connector.cache_logged_metrics() -<<<<<<< HEAD # call train epoch end hooks self._on_train_epoch_end_hook(processed_epoch_output) self.trainer.call_hook('on_epoch_end') -======= - # call train epoch end hooks - self._on_train_epoch_end_hook(processed_epoch_output) - self.trainer.call_hook('on_epoch_end') def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: - # Cannot rely on Trainer.call_hook because the signatures might be different across + # We cannot rely on Trainer.call_hook because the signatures might be different across # lightning module and callback - # Here we need to inspect if the module accepts `outputs` in `on_train_epoch_end` + # As a result, we need to inspect if the module accepts `outputs` in `on_train_epoch_end` # This implementation is copied from Trainer.call_hook hook_name = "on_train_epoch_end" From ccceb008397250c7168c6cc8414a282bbd563ce5 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 4 May 2021 07:53:24 -0700 Subject: [PATCH 07/26] Update pytorch_lightning/trainer/training_loop.py Co-authored-by: Ethan Harris --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 2d7a9a75c4f9f..7395ca13783ea 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -636,7 +636,7 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: hook_fx = getattr(model_ref, hook_name) if is_param_in_hook_signature(hook_fx, "outputs"): self.warning_cache.warn( - "`ModelHooks.on_train_epoch_end` signature has changed in v1.3." + "The signature of `ModelHooks.on_train_epoch_end` has changed in v1.3." " `outputs` parameter has been deprecated." " Support for the old signature will be removed in v1.5", DeprecationWarning ) From dca3d2519b03c842d0499198f247c6cb8011be4a Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 4 May 2021 08:03:31 -0700 Subject: [PATCH 08/26] Update trainer.py --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e014671d9cc78..7c8f3ae3d853e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1213,7 +1213,7 @@ def _cache_logged_metrics(self): def call_hook(self, hook_name: str, *args, **kwargs) -> Any: # Note this implementation is copy/pasted into the TrainLoop class in TrainLoop._on_train_epoch_end_hook # This was done to manage the deprecation of an argument to on_train_epoch_end - # If making chnages to this function, ensure that those changes are also made to + # If making changes to this function, ensure that those changes are also made to # TrainLoop._on_train_epoch_end_hook # set hook_name to model + reset Result obj From 44d57b91eb80d00a10b209fc357b4dd6861eb89a Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 22:58:07 -0700 Subject: [PATCH 09/26] update --- pytorch_lightning/trainer/training_loop.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 7395ca13783ea..a3c8761dfa7fe 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -220,7 +220,12 @@ def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): def _should_add_batch_output_to_epoch_output(self) -> bool: # We add to the epoch outputs if # 1. The model defines training_epoch_end OR +<<<<<<< HEAD # 2. The model overrides on_train_epoch_end which has `outputs` in the signature +======= + # 2. The model overrides on_train_epoch_end which has `outputs` in the signature OR + # 3. The trainer has any callback which overrides `on_train_epoch_end` that includes `outputs` in the signature +>>>>>>> update # TODO: in v1.5 this only needs to check if training_epoch_end is overridden lightning_module = self.trainer.lightning_module if is_overridden("training_epoch_end", model=lightning_module): From b09bd59d7d181f69d6cce42b7314a59f9b2ddcdc Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 22:59:53 -0700 Subject: [PATCH 10/26] Update training_loop.py --- pytorch_lightning/trainer/training_loop.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index a3c8761dfa7fe..7395ca13783ea 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -220,12 +220,7 @@ def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): def _should_add_batch_output_to_epoch_output(self) -> bool: # We add to the epoch outputs if # 1. The model defines training_epoch_end OR -<<<<<<< HEAD # 2. The model overrides on_train_epoch_end which has `outputs` in the signature -======= - # 2. The model overrides on_train_epoch_end which has `outputs` in the signature OR - # 3. The trainer has any callback which overrides `on_train_epoch_end` that includes `outputs` in the signature ->>>>>>> update # TODO: in v1.5 this only needs to check if training_epoch_end is overridden lightning_module = self.trainer.lightning_module if is_overridden("training_epoch_end", model=lightning_module): From 0ba1ca4ba89ff6f01c37ba2177653b6b81e5dd97 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 23:30:26 -0700 Subject: [PATCH 11/26] early stop? --- pytorch_lightning/callbacks/early_stopping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 242eeed808f34..f1a1789856642 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -161,7 +161,7 @@ def _should_skip_check(self, trainer) -> bool: from pytorch_lightning.trainer.states import TrainerFn return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking - def on_train_epoch_end(self, trainer, pl_module) -> None: + def on_train_epoch_end(self, trainer, pl_module, outputs) -> None: if not self._check_on_train_epoch_end or self._should_skip_check(trainer): return self._run_early_stopping_check(trainer) From 561171c295a18b5e2ccbfb999fad07f7f8b8055d Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 23:38:45 -0700 Subject: [PATCH 12/26] fix --- pytorch_lightning/callbacks/early_stopping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index f1a1789856642..242eeed808f34 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -161,7 +161,7 @@ def _should_skip_check(self, trainer) -> bool: from pytorch_lightning.trainer.states import TrainerFn return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking - def on_train_epoch_end(self, trainer, pl_module, outputs) -> None: + def on_train_epoch_end(self, trainer, pl_module) -> None: if not self._check_on_train_epoch_end or self._should_skip_check(trainer): return self._run_early_stopping_check(trainer) From c1e4ab7cc89d1199298e048f668e43c0e95d5b47 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 21:11:16 -0700 Subject: [PATCH 13/26] Remove outputs from evaluation epoch end hooks --- pytorch_lightning/callbacks/base.py | 6 +- pytorch_lightning/core/hooks.py | 4 +- pytorch_lightning/trainer/callback_hook.py | 36 ++---------- pytorch_lightning/trainer/evaluation_loop.py | 9 +-- tests/callbacks/test_callback_hook_outputs.py | 42 -------------- tests/callbacks/test_callbacks.py | 8 +-- tests/core/test_hooks.py | 56 ------------------- .../trainer/logging_/test_logger_connector.py | 4 +- 8 files changed, 17 insertions(+), 148 deletions(-) delete mode 100644 tests/core/test_hooks.py diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 3e8a77cbfdb0a..f5922fd4f3b51 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -108,9 +108,7 @@ def on_validation_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.Lightn """Called when the val epoch begins.""" pass - def on_validation_epoch_end( - self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: EPOCH_OUTPUT - ) -> None: + def on_validation_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when the val epoch ends.""" pass @@ -118,7 +116,7 @@ def on_test_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningMod """Called when the test epoch begins.""" pass - def on_test_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: EPOCH_OUTPUT) -> None: + def on_test_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when the test epoch ends.""" pass diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index bebd1edd8e685..f7b0e82ee25ae 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -245,7 +245,7 @@ def on_validation_epoch_start(self) -> None: Called in the validation loop at the very beginning of the epoch. """ - def on_validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: + def on_validation_epoch_end(self) -> None: """ Called in the validation loop at the very end of the epoch. """ @@ -255,7 +255,7 @@ def on_test_epoch_start(self) -> None: Called in the test loop at the very beginning of the epoch. """ - def on_test_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: + def on_test_epoch_end(self) -> None: """ Called in the test loop at the very end of the epoch. """ diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index fcdd8f55f6a6e..23df26b410a03 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -111,44 +111,20 @@ def on_validation_epoch_start(self): for callback in self.callbacks: callback.on_validation_epoch_start(self, self.lightning_module) - def on_validation_epoch_end(self, outputs: EPOCH_OUTPUT): - """Called when the epoch ends. - - Args: - outputs: List of outputs on each ``validation`` epoch - """ + def on_validation_epoch_end(self): + """Called when the validation epoch ends.""" for callback in self.callbacks: - if is_param_in_hook_signature(callback.on_validation_epoch_end, "outputs"): - callback.on_validation_epoch_end(self, self.lightning_module, outputs) - else: - warning_cache.warn( - "`Callback.on_validation_epoch_end` signature has changed in v1.3." - " `outputs` parameter has been added." - " Support for the old signature will be removed in v1.5", DeprecationWarning - ) - callback.on_validation_epoch_end(self, self.lightning_module) + callback.on_validation_epoch_end(self, self.lightning_module) def on_test_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: callback.on_test_epoch_start(self, self.lightning_module) - def on_test_epoch_end(self, outputs: EPOCH_OUTPUT): - """Called when the epoch ends. - - Args: - outputs: List of outputs on each ``test`` epoch - """ + def on_test_epoch_end(self): + """Called when the test epoch ends.""" for callback in self.callbacks: - if is_param_in_hook_signature(callback.on_test_epoch_end, "outputs"): - callback.on_test_epoch_end(self, self.lightning_module, outputs) - else: - warning_cache.warn( - "`Callback.on_test_epoch_end` signature has changed in v1.3." - " `outputs` parameter has been added." - " Support for the old signature will be removed in v1.5", DeprecationWarning - ) - callback.on_test_epoch_end(self, self.lightning_module) + callback.on_test_epoch_end(self, self.lightning_module) def on_predict_epoch_start(self) -> None: """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 8201d700d39bd..f66cbfdd477d1 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -255,14 +255,7 @@ def on_evaluation_epoch_end(self, outputs: Union[List[List[Dict]], List[Dict]]) if is_overridden(hook_name, model_ref): model_hook_fx = getattr(model_ref, hook_name) - if is_param_in_hook_signature(model_hook_fx, "outputs"): - model_hook_fx(outputs) - else: - self.warning_cache.warn( - f"`ModelHooks.{hook_name}` signature has changed in v1.3. `outputs` parameter has been added." - " Support for the old signature will be removed in v1.5", DeprecationWarning - ) - model_hook_fx() + model_hook_fx() self.trainer._cache_logged_metrics() diff --git a/tests/callbacks/test_callback_hook_outputs.py b/tests/callbacks/test_callback_hook_outputs.py index b2aa20af57a94..36322482c5eba 100644 --- a/tests/callbacks/test_callback_hook_outputs.py +++ b/tests/callbacks/test_callback_hook_outputs.py @@ -65,48 +65,6 @@ def training_epoch_end(self, outputs) -> None: trainer.fit(model) -def test_on_val_epoch_end_outputs(tmpdir): - - class CB(Callback): - - def on_validation_epoch_end(self, trainer, pl_module, outputs): - if trainer.running_sanity_check: - assert len(outputs) == trainer.num_sanity_val_batches[0] - else: - assert len(outputs) == trainer.num_val_batches[0] - - model = BoringModel() - - trainer = Trainer( - callbacks=CB(), - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=1, - weights_summary=None, - ) - - trainer.fit(model) - - -def test_on_test_epoch_end_outputs(tmpdir): - - class CB(Callback): - - def on_test_epoch_end(self, trainer, pl_module, outputs): - assert len(outputs) == trainer.num_test_batches[0] - - model = BoringModel() - - trainer = Trainer( - callbacks=CB(), - default_root_dir=tmpdir, - weights_summary=None, - ) - - trainer.test(model) - - def test_free_memory_on_eval_outputs(tmpdir): class CB(Callback): diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index a30b4fe0f609b..9b048e022c45b 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -58,7 +58,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), - call.on_validation_epoch_end(trainer, model, ANY), + call.on_validation_epoch_end(trainer, model), call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_sanity_check_end(trainer, model), @@ -90,7 +90,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), - call.on_validation_epoch_end(trainer, model, ANY), + call.on_validation_epoch_end(trainer, model), call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC @@ -128,7 +128,7 @@ def test_trainer_callback_hook_system_test(tmpdir): call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_test_batch_start(trainer, model, ANY, 1, 0), call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0), - call.on_test_epoch_end(trainer, model, ANY), + call.on_test_epoch_end(trainer, model), call.on_epoch_end(trainer, model), call.on_test_end(trainer, model), call.teardown(trainer, model, 'test'), @@ -163,7 +163,7 @@ def test_trainer_callback_hook_system_validate(tmpdir): call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_validation_batch_start(trainer, model, ANY, 1, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 1, 0), - call.on_validation_epoch_end(trainer, model, ANY), + call.on_validation_epoch_end(trainer, model), call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.teardown(trainer, model, 'validate'), diff --git a/tests/core/test_hooks.py b/tests/core/test_hooks.py deleted file mode 100644 index 087f884d96feb..0000000000000 --- a/tests/core/test_hooks.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from pytorch_lightning import Trainer -from tests.helpers.boring_model import BoringModel - - -def test_on_val_epoch_end_outputs(tmpdir): - - class TestModel(BoringModel): - - def on_validation_epoch_end(self, outputs): - if trainer.running_sanity_check: - assert len(outputs) == trainer.num_sanity_val_batches[0] - else: - assert len(outputs) == trainer.num_val_batches[0] - - model = TestModel() - - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=1, - weights_summary=None, - ) - - trainer.fit(model) - - -def test_on_test_epoch_end_outputs(tmpdir): - - class TestModel(BoringModel): - - def on_test_epoch_end(self, outputs): - assert len(outputs) == trainer.num_test_batches[0] - - model = TestModel() - - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=2, - weights_summary=None, - ) - - trainer.test(model) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 06eaca6d61f2c..ab1ce3367ca37 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -681,10 +681,10 @@ def _assert_epoch_end(self, stage): def on_train_epoch_end(self): self._assert_epoch_end('train') - def on_validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): self._assert_epoch_end('val') - def on_test_epoch_end(self, outputs): + def on_test_epoch_end(self): self._assert_epoch_end('test') def _assert_called(model, stage): From b25b1ff6999e76b6b5a0f9f34dcd3e3d83a3700c Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 21:17:49 -0700 Subject: [PATCH 14/26] update --- CHANGELOG.md | 3 --- pytorch_lightning/trainer/callback_hook.py | 1 - pytorch_lightning/trainer/evaluation_loop.py | 1 - 3 files changed, 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a82c6b2b60fb..477bcb8f5f89d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -88,9 +88,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for the PyTorch 1.8.1 autograd profiler ([#6618](https://github.com/PyTorchLightning/pytorch-lightning/pull/6618)) -- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120)) - - - Added `configure_sharded_model` hook ([#6679](https://github.com/PyTorchLightning/pytorch-lightning/pull/6679)) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 23df26b410a03..08b898ab2be60 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -20,7 +20,6 @@ from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn -from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index f66cbfdd477d1..3a06c5e753fdd 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -20,7 +20,6 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache From 16899ad7a61de4a3d24b0f5c041d83aaec0cca2d Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 23:54:59 -0700 Subject: [PATCH 15/26] Update test_remove_1-5.py --- tests/deprecated_api/test_remove_1-5.py | 90 ------------------------- 1 file changed, 90 deletions(-) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index d49e191e69c8a..91b93a88f0055 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -263,96 +263,6 @@ def on_train_epoch_end(self): trainer.fit(model) -def test_v1_5_0_old_on_validation_epoch_end(tmpdir): - callback_warning_cache.clear() - - class OldSignature(Callback): - - def on_validation_epoch_end(self, trainer, pl_module): # noqa - ... - - model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature()) - - with pytest.deprecated_call(match="old signature will be removed in v1.5"): - trainer.fit(model) - - class OldSignatureModel(BoringModel): - - def on_validation_epoch_end(self): # noqa - ... - - model = OldSignatureModel() - - with pytest.deprecated_call(match="old signature will be removed in v1.5"): - trainer.fit(model) - - callback_warning_cache.clear() - - class NewSignature(Callback): - - def on_validation_epoch_end(self, trainer, pl_module, outputs): - ... - - trainer.callbacks = [NewSignature()] - with no_deprecated_call(match="`Callback.on_validation_epoch_end` signature has changed in v1.3."): - trainer.fit(model) - - class NewSignatureModel(BoringModel): - - def on_validation_epoch_end(self, outputs): - ... - - model = NewSignatureModel() - with no_deprecated_call(match="`ModelHooks.on_validation_epoch_end` signature has changed in v1.3."): - trainer.fit(model) - - -def test_v1_5_0_old_on_test_epoch_end(tmpdir): - callback_warning_cache.clear() - - class OldSignature(Callback): - - def on_test_epoch_end(self, trainer, pl_module): # noqa - ... - - model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature()) - - with pytest.deprecated_call(match="old signature will be removed in v1.5"): - trainer.test(model) - - class OldSignatureModel(BoringModel): - - def on_test_epoch_end(self): # noqa - ... - - model = OldSignatureModel() - - with pytest.deprecated_call(match="old signature will be removed in v1.5"): - trainer.test(model) - - callback_warning_cache.clear() - - class NewSignature(Callback): - - def on_test_epoch_end(self, trainer, pl_module, outputs): - ... - - trainer.callbacks = [NewSignature()] - with no_deprecated_call(match="`Callback.on_test_epoch_end` signature has changed in v1.3."): - trainer.test(model) - - class NewSignatureModel(BoringModel): - - def on_test_epoch_end(self, outputs): - ... - - model = NewSignatureModel() - with no_deprecated_call(match="`ModelHooks.on_test_epoch_end` signature has changed in v1.3."): - trainer.test(model) - - @pytest.mark.parametrize("cls", (BaseProfiler, SimpleProfiler, AdvancedProfiler, PyTorchProfiler)) def test_v1_5_0_profiler_output_filename(tmpdir, cls): filepath = str(tmpdir / "test.txt") From b490e0e3a5f268912a8226f9f0f825cd0fc4e79d Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 4 May 2021 00:19:12 -0700 Subject: [PATCH 16/26] fix lints --- pytorch_lightning/core/hooks.py | 2 +- pytorch_lightning/trainer/callback_hook.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index f7b0e82ee25ae..d311bd4f58f06 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -20,7 +20,7 @@ from torch.utils.data import DataLoader from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn -from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT +from pytorch_lightning.utilities.types import STEP_OUTPUT class ModelHooks: diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 08b898ab2be60..23df26b410a03 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -20,6 +20,7 @@ from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache From bd79ee2ce8c0ef634a5daf077f513de3cd5a1f30 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 4 May 2021 00:21:06 -0700 Subject: [PATCH 17/26] Update base.py --- pytorch_lightning/callbacks/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index f5922fd4f3b51..8283c2ddd71ec 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -22,7 +22,7 @@ from torch.optim import Optimizer import pytorch_lightning as pl -from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT +from pytorch_lightning.utilities.types import STEP_OUTPUT class Callback(abc.ABC): From af55130812c2f065ada0ba138ee655db21b5d90c Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 4 May 2021 00:52:20 -0700 Subject: [PATCH 18/26] rm-outputs --- pytorch_lightning/trainer/evaluation_loop.py | 4 ++-- pytorch_lightning/trainer/trainer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 3a06c5e753fdd..1e02dd5ad8c96 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -240,7 +240,7 @@ def store_predictions(self, output: Optional[STEP_OUTPUT], batch_idx: int, datal # track debug metrics self.trainer.dev_debugger.track_eval_loss_history(batch_idx, dataloader_idx, output) - def on_evaluation_epoch_end(self, outputs: Union[List[List[Dict]], List[Dict]]) -> None: + def on_evaluation_epoch_end(self) -> None: model_ref = self.trainer.lightning_module hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" @@ -250,7 +250,7 @@ def on_evaluation_epoch_end(self, outputs: Union[List[List[Dict]], List[Dict]]) if hasattr(self.trainer, hook_name): on_evaluation_epoch_end_hook = getattr(self.trainer, hook_name) - on_evaluation_epoch_end_hook(outputs) + on_evaluation_epoch_end_hook() if is_overridden(hook_name, model_ref): model_hook_fx = getattr(model_ref, hook_name) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7c8f3ae3d853e..4114954cfd846 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -987,7 +987,7 @@ def run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT: self.evaluation_loop.evaluation_epoch_end(outputs) # hook - self.evaluation_loop.on_evaluation_epoch_end(outputs) + self.evaluation_loop.on_evaluation_epoch_end() # update epoch-level lr_schedulers if on_epoch: From 9b5a2af391adc7df9bf150435ea890abf57feed3 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 4 May 2021 00:54:36 -0700 Subject: [PATCH 19/26] Update evaluation_loop.py --- pytorch_lightning/trainer/evaluation_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 1e02dd5ad8c96..2b5d66994547d 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union from torch.utils.data import DataLoader From 60bff2996b5ea2ade3cae039903f42aa4f488b6c Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 4 May 2021 10:44:54 -0700 Subject: [PATCH 20/26] try-save-more-memory --- pytorch_lightning/trainer/evaluation_loop.py | 7 +++++++ pytorch_lightning/trainer/trainer.py | 3 ++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 2b5d66994547d..ea623c948faf2 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -187,6 +187,13 @@ def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT output = self.trainer.call_hook('validation_step_end', *args, **kwargs) return output + def should_track_batch_outputs_for_epoch_end(self) -> bool: + model = self.trainer.lightning_module + if self.trainer.testing: + return is_overridden('test_epoch_end', model=model) + else: + return is_overridden('validation_epoch_end', model=model) + def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: # unset dataloder_idx in model self.trainer.logger_connector.evaluation_epoch_end() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4114954cfd846..706ec414c1b80 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -972,7 +972,8 @@ def run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT: dl_outputs = self.track_output_for_epoch_end(dl_outputs, output) # store batch level output per dataloader - self.evaluation_loop.outputs.append(dl_outputs) + if self.evaluation_loop.should_track_batch_outputs_for_epoch_end(): + self.evaluation_loop.outputs.append(dl_outputs) outputs = self.evaluation_loop.outputs From 611df3cc9b04b542081325c57170466c08c67459 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 4 May 2021 14:50:54 -0700 Subject: [PATCH 21/26] Update trainer.py --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 706ec414c1b80..651a3d3bbd4a8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -981,7 +981,7 @@ def run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT: self.evaluation_loop.outputs = [] # with a single dataloader don't pass a 2D list - if self.evaluation_loop.num_dataloaders == 1: + if len(outputs) > 0 and self.evaluation_loop.num_dataloaders == 1: outputs = outputs[0] # lightning module method From c0065d1550af0fdfdaf819cdf1830c5acfdb76d4 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 4 May 2021 16:23:19 -0700 Subject: [PATCH 22/26] Update trainer.py --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 651a3d3bbd4a8..1e168512a8730 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1213,7 +1213,7 @@ def _cache_logged_metrics(self): def call_hook(self, hook_name: str, *args, **kwargs) -> Any: # Note this implementation is copy/pasted into the TrainLoop class in TrainLoop._on_train_epoch_end_hook - # This was done to manage the deprecation of an argument to on_train_epoch_end + # This was done to manage the deprecation of the `outputs` argument to on_train_epoch_end # If making changes to this function, ensure that those changes are also made to # TrainLoop._on_train_epoch_end_hook From 19867c0905f438f57233ad6885e44cc0ad97f78b Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 4 May 2021 18:41:46 -0700 Subject: [PATCH 23/26] cache-at-start --- pytorch_lightning/trainer/evaluation_loop.py | 3 ++- pytorch_lightning/trainer/trainer.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index ea623c948faf2..c42eebe2848cf 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -33,6 +33,7 @@ def __init__(self, trainer: 'pl.Trainer'): self.max_batches: Optional[List[Union[int, float]]] = None self.warning_cache = WarningCache() self.num_dataloaders: Optional[int] = None + self.should_track_batch_outputs_for_epoch_end = self._should_track_batch_outputs_for_epoch_end() def on_trainer_init(self) -> None: self.trainer.num_sanity_val_batches = [] @@ -187,7 +188,7 @@ def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT output = self.trainer.call_hook('validation_step_end', *args, **kwargs) return output - def should_track_batch_outputs_for_epoch_end(self) -> bool: + def _should_track_batch_outputs_for_epoch_end(self) -> bool: model = self.trainer.lightning_module if self.trainer.testing: return is_overridden('test_epoch_end', model=model) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1e168512a8730..2a6a53a7c192c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -972,7 +972,7 @@ def run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT: dl_outputs = self.track_output_for_epoch_end(dl_outputs, output) # store batch level output per dataloader - if self.evaluation_loop.should_track_batch_outputs_for_epoch_end(): + if self.evaluation_loop.should_track_batch_outputs_for_epoch_end: self.evaluation_loop.outputs.append(dl_outputs) outputs = self.evaluation_loop.outputs From 90db4c66d99be00a498eb3121d3b8172b800760a Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 4 May 2021 18:43:23 -0700 Subject: [PATCH 24/26] Update evaluation_loop.py --- pytorch_lightning/trainer/evaluation_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index c42eebe2848cf..add4a0cbc8a75 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -33,7 +33,6 @@ def __init__(self, trainer: 'pl.Trainer'): self.max_batches: Optional[List[Union[int, float]]] = None self.warning_cache = WarningCache() self.num_dataloaders: Optional[int] = None - self.should_track_batch_outputs_for_epoch_end = self._should_track_batch_outputs_for_epoch_end() def on_trainer_init(self) -> None: self.trainer.num_sanity_val_batches = [] @@ -76,6 +75,7 @@ def should_skip_evaluation(self, max_batches: List[Union[int, float]]) -> bool: return sum(max_batches) == 0 def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: + self.should_track_batch_outputs_for_epoch_end: bool = self._should_track_batch_outputs_for_epoch_end() if self.trainer.testing: self.trainer.call_hook('on_test_start', *args, **kwargs) else: From 102aec1ab81fdcc1f02f9d396afee00494dcdf4b Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 5 May 2021 08:26:52 -0700 Subject: [PATCH 25/26] Update training_loop.py --- pytorch_lightning/trainer/training_loop.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 7395ca13783ea..b0c0d68fd0077 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -652,7 +652,6 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: if not skip: self.trainer._cache_logged_metrics() ->>>>>>> update def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: # We cannot rely on Trainer.call_hook because the signatures might be different across From c489e8b0b1b6670031a178101a1939daf813eff6 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 5 May 2021 08:28:53 -0700 Subject: [PATCH 26/26] Update training_loop.py --- pytorch_lightning/trainer/training_loop.py | 42 ---------------------- 1 file changed, 42 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index b0c0d68fd0077..790dc4c70bdeb 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -653,48 +653,6 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: if not skip: self.trainer._cache_logged_metrics() - def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: - # We cannot rely on Trainer.call_hook because the signatures might be different across - # lightning module and callback - # As a result, we need to inspect if the module accepts `outputs` in `on_train_epoch_end` - - # This implementation is copied from Trainer.call_hook - hook_name = "on_train_epoch_end" - - # set hook_name to model + reset Result obj - skip = self.trainer._reset_result_and_set_hook_fx_name(hook_name) - - # always profile hooks - with self.trainer.profiler.profile(hook_name): - - # first call trainer hook - if hasattr(self.trainer, hook_name): - trainer_hook = getattr(self.trainer, hook_name) - trainer_hook(processed_epoch_output) - - # next call hook in lightningModule - model_ref = self.trainer.lightning_module - if is_overridden(hook_name, model_ref): - hook_fx = getattr(model_ref, hook_name) - if is_param_in_hook_signature(hook_fx, "outputs"): - self.warning_cache.warn( - "The signature of `ModelHooks.on_train_epoch_end` has changed in v1.3." - " `outputs` parameter has been deprecated." - " Support for the old signature will be removed in v1.5", DeprecationWarning - ) - model_ref.on_train_epoch_end(processed_epoch_output) - else: - model_ref.on_train_epoch_end() - - # if the PL module doesn't have the hook then call the accelerator - # used to auto-reduce things for the user with Results obj - elif hasattr(self.trainer.accelerator, hook_name): - accelerator_hook = getattr(self.trainer.accelerator, hook_name) - accelerator_hook() - - if not skip: - self.trainer._cache_logged_metrics() - def run_training_batch(self, batch, batch_idx, dataloader_idx): # track grad norms grad_norm_dic = {}