diff --git a/CHANGELOG.md b/CHANGELOG.md index 7872af715d68a..a80153191416a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,7 +60,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - -- +- Removed the `outputs` argument in both the `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#8587](https://github.com/PyTorchLightning/pytorch-lightning/pull/8587)) - diff --git a/docs/source/starter/new-project.rst b/docs/source/starter/new-project.rst index 0d541d6f26300..dd81419f63b60 100644 --- a/docs/source/starter/new-project.rst +++ b/docs/source/starter/new-project.rst @@ -602,7 +602,7 @@ Here's an example adding a not-so-fancy learning rate decay rule: group = [param_group['lr'] for param_group in optimizer.param_groups] self.old_lrs.append(group) - def on_train_epoch_end(self, trainer, pl_module, outputs): + def on_train_epoch_end(self, trainer, pl_module): for opt_idx, optimizer in enumerate(trainer.optimizers): old_lr_group = self.old_lrs[opt_idx] new_lr_group = [] diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index a492be314df26..38f5e42703ab7 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -94,9 +94,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo """Called when the train epoch begins.""" pass - def on_train_epoch_end( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", unused: Optional = None - ) -> None: + def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when the train epoch ends. To access all batch outputs at the end of the epoch, either: diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index e69c9e1765eb3..21e0d4c54c632 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -399,7 +399,7 @@ def _run_pruning(self, current_epoch: int) -> None: ): self.apply_lottery_ticket_hypothesis() - def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: LightningModule) -> None: # type: ignore + def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: LightningModule) -> None: if self._prune_on_train_epoch_end: rank_zero_debug("`ModelPruning.on_train_epoch_end`. Applying pruning") self._run_pruning(pl_module.current_epoch) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index a30f699c70cfd..1724d29eba99a 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -234,7 +234,7 @@ def on_train_epoch_start(self) -> None: Called in the training loop at the very beginning of the epoch. """ - def on_train_epoch_end(self, unused: Optional = None) -> None: + def on_train_epoch_end(self) -> None: """ Called in the training loop at the very end of the epoch. diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 1869c5fd43167..9463e04ba1a03 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -22,7 +22,6 @@ from pytorch_lightning.trainer.progress import Progress, SchedulerProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache @@ -227,7 +226,7 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]: self.trainer.fit_loop.epoch_progress.increment_processed() # call train epoch end hooks - self._on_train_epoch_end_hook(processed_outputs) + self.trainer.call_hook("on_train_epoch_end") self.trainer.call_hook("on_epoch_end") self.trainer.logger_connector.on_epoch_end() @@ -250,47 +249,6 @@ def _run_validation(self): with torch.no_grad(): self.val_loop.run() - def _on_train_epoch_end_hook(self, processed_epoch_output: List[List[STEP_OUTPUT]]) -> None: - """Runs ``on_train_epoch_end hook``.""" - # We cannot rely on Trainer.call_hook because the signatures might be different across - # lightning module and callback - # As a result, we need to inspect if the module accepts `outputs` in `on_train_epoch_end` - - # This implementation is copied from Trainer.call_hook - hook_name = "on_train_epoch_end" - prev_fx_name = self.trainer.lightning_module._current_fx_name - self.trainer.lightning_module._current_fx_name = hook_name - - # always profile hooks - with self.trainer.profiler.profile(hook_name): - - # first call trainer hook - if hasattr(self.trainer, hook_name): - trainer_hook = getattr(self.trainer, hook_name) - trainer_hook(processed_epoch_output) - - # next call hook in lightningModule - model_ref = self.trainer.lightning_module - if is_overridden(hook_name, model_ref): - hook_fx = getattr(model_ref, hook_name) - if is_param_in_hook_signature(hook_fx, "outputs"): - self._warning_cache.deprecation( - "The signature of `ModelHooks.on_train_epoch_end` has changed in v1.3." - " `outputs` parameter has been deprecated." - " Support for the old signature will be removed in v1.5" - ) - model_ref.on_train_epoch_end(processed_epoch_output) - else: - model_ref.on_train_epoch_end() - - # call the accelerator hook - if hasattr(self.trainer.accelerator, hook_name): - accelerator_hook = getattr(self.trainer.accelerator, hook_name) - accelerator_hook() - - # restore current_fx when nested context - self.trainer.lightning_module._current_fx_name = prev_fx_name - def _accumulated_batches_reached(self) -> bool: """Determine if accumulation will be finished by the end of the current batch.""" return self.batch_progress.current.ready % self.trainer.accumulate_grad_batches == 0 @@ -313,7 +271,7 @@ def _track_epoch_end_reduce_metrics( self, epoch_output: List[List[STEP_OUTPUT]], batch_end_outputs: STEP_OUTPUT ) -> None: """Adds the batch outputs to the epoch outputs and prepares reduction""" - hook_overridden = self._should_add_batch_output_to_epoch_output() + hook_overridden = is_overridden("training_epoch_end", self.trainer.lightning_module) if not hook_overridden: return @@ -329,24 +287,6 @@ def _track_epoch_end_reduce_metrics( epoch_output[opt_idx].append(opt_outputs) - def _should_add_batch_output_to_epoch_output(self) -> bool: - """ - We add to the epoch outputs if - 1. The model defines training_epoch_end OR - 2. The model overrides on_train_epoch_end which has `outputs` in the signature - """ - # TODO: in v1.5 this only needs to check if training_epoch_end is overridden - lightning_module = self.trainer.lightning_module - if is_overridden("training_epoch_end", lightning_module): - return True - - if is_overridden("on_train_epoch_end", lightning_module): - model_hook_fx = getattr(lightning_module, "on_train_epoch_end") - if is_param_in_hook_signature(model_hook_fx, "outputs"): - return True - - return False - @staticmethod def _prepare_outputs( outputs: List[List[List["ResultCollection"]]], batch_mode: bool diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index ffcac8f9073f6..18e22371b1c1a 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -22,11 +22,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn -from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT -from pytorch_lightning.utilities.warnings import WarningCache - -warning_cache = WarningCache() +from pytorch_lightning.utilities.types import STEP_OUTPUT class TrainerCallbackHookMixin(ABC): @@ -91,22 +87,10 @@ def on_train_epoch_start(self): for callback in self.callbacks: callback.on_train_epoch_start(self, self.lightning_module) - def on_train_epoch_end(self, outputs: EPOCH_OUTPUT): - """Called when the epoch ends. - - Args: - outputs: List of outputs on each ``train`` epoch - """ + def on_train_epoch_end(self): + """Called when the epoch ends.""" for callback in self.callbacks: - if is_param_in_hook_signature(callback.on_train_epoch_end, "outputs"): - warning_cache.deprecation( - "The signature of `Callback.on_train_epoch_end` has changed in v1.3." - " `outputs` parameter has been removed." - " Support for the old signature will be removed in v1.5" - ) - callback.on_train_epoch_end(self, self.lightning_module, outputs) - else: - callback.on_train_epoch_end(self, self.lightning_module) + callback.on_train_epoch_end(self, self.lightning_module) def on_validation_epoch_start(self): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5f3d18ebc4e66..3bd374b86e512 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1200,10 +1200,6 @@ def _call_teardown_hook(self, model: "pl.LightningModule") -> None: model._metric_attributes = None def call_hook(self, hook_name: str, *args, **kwargs) -> Any: - # Note this implementation is copy/pasted into the TrainLoop class in TrainingEpochLoop._on_train_epoch_end_hook - # This was done to manage the deprecation of the `outputs` argument to on_train_epoch_end - # If making changes to this function, ensure that those changes are also made to - # TrainingEpochLoop._on_train_epoch_end_hook if self.lightning_module: prev_fx_name = self.lightning_module._current_fx_name self.lightning_module._current_fx_name = hook_name diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 4647ee993b484..70b87d2b26ca7 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -27,7 +27,6 @@ from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.plugins import DeepSpeedPlugin from pytorch_lightning.profiler import AdvancedProfiler, BaseProfiler, PyTorchProfiler, SimpleProfiler -from pytorch_lightning.trainer.callback_hook import warning_cache as callback_warning_cache from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.imports import _compare_version from tests.deprecated_api import no_deprecated_call @@ -194,49 +193,6 @@ def test_v1_5_0_model_checkpoint_period(tmpdir): ModelCheckpoint(dirpath=tmpdir, period=1) -def test_v1_5_0_old_on_train_epoch_end(tmpdir): - callback_warning_cache.clear() - - class OldSignature(Callback): - def on_train_epoch_end(self, trainer, pl_module, outputs): # noqa - ... - - class OldSignatureModel(BoringModel): - def on_train_epoch_end(self, outputs): # noqa - ... - - model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature()) - - with pytest.deprecated_call(match="old signature will be removed in v1.5"): - trainer.fit(model) - - callback_warning_cache.clear() - - model = OldSignatureModel() - - with pytest.deprecated_call(match="old signature will be removed in v1.5"): - trainer.fit(model) - - trainer.fit_loop.epoch_loop._warning_cache.clear() - - class NewSignature(Callback): - def on_train_epoch_end(self, trainer, pl_module): - ... - - trainer.callbacks = [NewSignature()] - with no_deprecated_call(match="`Callback.on_train_epoch_end` signature has changed in v1.3."): - trainer.fit(model) - - class NewSignatureModel(BoringModel): - def on_train_epoch_end(self): - ... - - model = NewSignatureModel() - with no_deprecated_call(match="`ModelHooks.on_train_epoch_end` signature has changed in v1.3."): - trainer.fit(model) - - @pytest.mark.parametrize("cls", (BaseProfiler, SimpleProfiler, AdvancedProfiler, PyTorchProfiler)) def test_v1_5_0_profiler_output_filename(tmpdir, cls): filepath = str(tmpdir / "test.txt") diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 608639d2459df..9b565e27aac62 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -534,11 +534,11 @@ def training_step(self, batch, batch_idx): dict(name="train", args=(True,)), dict(name="on_validation_model_train"), dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)), - dict(name="Callback.on_train_epoch_end", args=(trainer, model, [dict(loss=ANY)] * train_batches)), + dict(name="Callback.on_train_epoch_end", args=(trainer, model)), # `ModelCheckpoint.save_checkpoint` is called here from `Callback.on_train_epoch_end` dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)), dict(name="on_save_checkpoint", args=(saved_ckpt,)), - dict(name="on_train_epoch_end", args=([dict(loss=ANY)] * train_batches,)), + dict(name="on_train_epoch_end"), dict(name="Callback.on_epoch_end", args=(trainer, model)), dict(name="on_epoch_end"), dict(name="Callback.on_train_end", args=(trainer, model)), @@ -635,10 +635,10 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): # TODO: wrong current epoch after reload *model._train_batch(trainer, model, train_batches, current_epoch=1), dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)), - dict(name="Callback.on_train_epoch_end", args=(trainer, model, [dict(loss=ANY)] * train_batches)), + dict(name="Callback.on_train_epoch_end", args=(trainer, model)), dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)), dict(name="on_save_checkpoint", args=(saved_ckpt,)), - dict(name="on_train_epoch_end", args=([dict(loss=ANY)] * train_batches,)), + dict(name="on_train_epoch_end"), dict(name="Callback.on_epoch_end", args=(trainer, model)), dict(name="on_epoch_end"), dict(name="Callback.on_train_end", args=(trainer, model)), diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 61aff4a0e1eab..eee5cfe422a94 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -551,7 +551,7 @@ def on_batch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module): self.log("on_epoch_end", 5) - def on_train_epoch_end(self, trainer, pl_module, outputs): + def on_train_epoch_end(self, trainer, pl_module): self.log("on_train_epoch_end", 6) model = BoringModel()