diff --git a/CHANGELOG.md b/CHANGELOG.md index e7d747cc29c48..db8a7b74bc749 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -167,12 +167,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated the `TestTubeLogger` ([#9065](https://github.com/PyTorchLightning/pytorch-lightning/pull/9065)) + - Deprecated `on_{train/val/test/predict}_dataloader()` from `LightningModule` and `LightningDataModule` [#9098](https://github.com/PyTorchLightning/pytorch-lightning/pull/9098) - Updated deprecation of `argparse_utils.py` from removal in 1.4 to 2.0 ([#9162](https://github.com/PyTorchLightning/pytorch-lightning/pull/9162)) +- Deprecated `on_keyboard_interrupt` callback hook in favor of new `on_exception` hook ([#9260](https://github.com/PyTorchLightning/pytorch-lightning/pull/9260)) + + - Deprecated passing `process_position` to the `Trainer` constructor in favor of adding the `ProgressBar` callback with `process_position` directly to the list of callbacks ([#9222](https://github.com/PyTorchLightning/pytorch-lightning/pull/9222)) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index b57c87500be46..b67b304424d8f 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -264,7 +264,12 @@ def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") pass def on_keyboard_interrupt(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """Called when the training is interrupted by ``KeyboardInterrupt``.""" + r""" + .. deprecated:: v1.5 + This callback hook was deprecated in v1.5 in favor of `on_exception` and will be removed in v1.7. + + Called when any trainer execution is interrupted by KeyboardInterrupt. + """ pass def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index db0e59c4748ae..757489d5f372a 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -232,7 +232,12 @@ def on_predict_end(self) -> None: callback.on_predict_end(self, self.lightning_module) def on_keyboard_interrupt(self): - """Called when the training is interrupted by KeyboardInterrupt.""" + r""" + .. deprecated:: v1.5 + This callback hook was deprecated in v1.5 in favor of `on_exception` and will be removed in v1.7. + + Called when any trainer execution is interrupted by KeyboardInterrupt. + """ for callback in self.callbacks: callback.on_keyboard_interrupt(self, self.lightning_module) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index cd7acfdc8a526..bd0457404ed51 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -43,6 +43,8 @@ def verify_loop_configurations(self, model: "pl.LightningModule") -> None: elif self.trainer.state.fn == TrainerFn.PREDICTING: self.__verify_predict_loop_configuration(model) self.__verify_dp_batch_transfer_support(model) + # TODO: Delete _check_on_keyboard_interrupt in v1.7 + self._check_on_keyboard_interrupt() def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None: # ----------------------------------- @@ -201,3 +203,12 @@ def __check_training_step_requires_dataloader_iter(self, model: "pl.LightningMod "The model taking a `dataloader_iter` argument in your `training_step` " "is incompatible with `truncated_bptt_steps > 0`." ) + + def _check_on_keyboard_interrupt(self) -> None: + """Checks if on_keyboard_interrupt is overriden and sends a deprecation warning.""" + for callback in self.trainer.callbacks: + if is_overridden(method_name="on_keyboard_interrupt", instance=callback): + rank_zero_deprecation( + "The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7." + " Please use the `on_exception` callback hook instead." + ) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3e84c725b9663..b94321a4a1d70 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -505,6 +505,7 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: """ try: return trainer_fn(*args, **kwargs) + # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 except KeyboardInterrupt as exception: rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...") # user could press Ctrl+c many times... only shutdown once diff --git a/pytorch_lightning/utilities/model_helpers.py b/pytorch_lightning/utilities/model_helpers.py index 729a94e15da11..8596f1c67b812 100644 --- a/pytorch_lightning/utilities/model_helpers.py +++ b/pytorch_lightning/utilities/model_helpers.py @@ -41,6 +41,8 @@ def is_overridden( parent = pl.LightningModule elif isinstance(instance, pl.LightningDataModule): parent = pl.LightningDataModule + elif isinstance(instance, pl.Callback): + parent = pl.Callback if parent is None: raise ValueError("Expected a parent") diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index c363638d565d2..5803db051c659 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -26,8 +26,8 @@ def test_callbacks_configured_in_model(tmpdir): """Test the callback system with callbacks added through the model hook.""" - model_callback_mock = Mock() - trainer_callback_mock = Mock() + model_callback_mock = Mock(spec=Callback, model=Callback()) + trainer_callback_mock = Mock(spec=Callback, model=Callback()) class TestModel(BoringModel): def configure_callbacks(self): @@ -79,7 +79,7 @@ def assert_expected_calls(_trainer, model_callback, trainer_callback): def test_configure_callbacks_hook_multiple_calls(tmpdir): """Test that subsequent calls to `configure_callbacks` do not change the callbacks list.""" - model_callback_mock = Mock() + model_callback_mock = Mock(spec=Callback, model=Callback()) class TestModel(BoringModel): def configure_callbacks(self): diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 188b7f4a4a3fa..81d57377e48ea 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -16,7 +16,7 @@ import pytest -from pytorch_lightning import LightningDataModule, Trainer +from pytorch_lightning import Callback, LightningDataModule, Trainer from pytorch_lightning.loggers import TestTubeLogger from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringModel @@ -118,6 +118,29 @@ def test_v1_7_0_test_tube_logger(_, tmpdir): _ = TestTubeLogger(tmpdir) +def test_v1_7_0_on_interrupt(tmpdir): + class HandleInterruptCallback(Callback): + def on_keyboard_interrupt(self, trainer, pl_module): + print("keyboard interrupt") + + model = BoringModel() + handle_interrupt_callback = HandleInterruptCallback() + + trainer = Trainer( + callbacks=[handle_interrupt_callback], + max_epochs=1, + limit_val_batches=0.1, + limit_train_batches=0.2, + progress_bar_refresh_rate=0, + logger=False, + default_root_dir=tmpdir, + ) + with pytest.deprecated_call( + match="The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7" + ): + trainer.fit(model) + + def test_v1_7_0_process_position_trainer_constructor(tmpdir): with pytest.deprecated_call(match=r"Setting `Trainer\(process_position=5\)` is deprecated in v1.5"): _ = Trainer(process_position=5)