From 8ce82a3a45fb74345c55af68fb945a72225a453a Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Sat, 28 Aug 2021 23:16:18 +0000 Subject: [PATCH 01/17] add on_exception callback hook --- CHANGELOG.md | 7 ++- pytorch_lightning/callbacks/base.py | 4 ++ pytorch_lightning/trainer/callback_hook.py | 5 ++ .../trainer/configuration_validator.py | 14 +++++- pytorch_lightning/trainer/trainer.py | 8 +++- pytorch_lightning/utilities/model_helpers.py | 2 + tests/deprecated_api/test_remove_1-7.py | 24 ++++++++++ tests/trainer/test_trainer.py | 46 +++++++++++++++++++ 8 files changed, 105 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d3dcedbe003a..5158617ada9a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -97,9 +97,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added validate logic for precision ([#9080](https://github.com/PyTorchLightning/pytorch-lightning/pull/9080)) -- Add support for CPU AMP autocast ([#9084](https://github.com/PyTorchLightning/pytorch-lightning/pull/9084)) +- Added support for CPU AMP autocast ([#9084](https://github.com/PyTorchLightning/pytorch-lightning/pull/9084)) +- Added `on_exception` callback hook ### Changed - Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770)) @@ -161,13 +162,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated the `TestTubeLogger` ([#9065](https://github.com/PyTorchLightning/pytorch-lightning/pull/9065)) + - Deprecated `on_{train/val/test/predict}_dataloader()` from `LightningModule` and `LightningDataModule` [#9098](https://github.com/PyTorchLightning/pytorch-lightning/pull/9098) -- - Updated deprecation of `argparse_utils.py` from removal in 1.4 to 2.0 ([#9162](https://github.com/PyTorchLightning/pytorch-lightning/pull/9162)) +- Deprecated `on_keyboard_interrupt` callback hook + ### Removed diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index fdb22a44ed307..114519f43bdff 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -267,6 +267,10 @@ def on_keyboard_interrupt(self, trainer: "pl.Trainer", pl_module: "pl.LightningM """Called when the training is interrupted by ``KeyboardInterrupt``.""" pass + def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: + """Called when the training is interrupted by any exception.""" + pass + def on_save_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] ) -> dict: diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index bbfcbb22802a8..9c35ac64c8b96 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -236,6 +236,11 @@ def on_keyboard_interrupt(self): for callback in self.callbacks: callback.on_keyboard_interrupt(self, self.lightning_module) + def on_exception(self, exception: BaseException) -> None: + """Called when the training is interrupted by any exception.""" + for callback in self.callbacks: + callback.on_exception(self, self.lightning_module, exception) + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: """Called when saving a model checkpoint.""" callback_states = {} diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index e4a3c6f8c00eb..9aa8aea9f4c48 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pytorch_lightning.callbacks.base import Callback import pytorch_lightning as pl from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn +from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature @@ -43,6 +44,8 @@ def verify_loop_configurations(self, model: "pl.LightningModule") -> None: elif self.trainer.state.fn == TrainerFn.PREDICTING: self.__verify_predict_loop_configuration(model) self.__verify_dp_batch_transfer_support(model) + # TODO(@daniellepintz): Delete _check_on_keyboard_interrupt in v1.7 + self._check_on_keyboard_interrupt() def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None: # ----------------------------------- @@ -201,3 +204,12 @@ def __check_training_step_requires_dataloader_iter(self, model: "pl.LightningMod "The model taking a `dataloader_iter` argument in your `training_step` " "is incompatible with `truncated_bptt_steps > 0`." ) + + def _check_on_keyboard_interrupt(self) -> None: + """Checks if on_keyboard_interrupt is overriden and sends a deprecation warning.""" + for callback in self.trainer.callbacks: + if is_overridden(method_name="on_keyboard_interrupt", instance=callback): + rank_zero_deprecation( + "The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7." + "Please use the `on_exception` callback hook instead." + ) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f56a84b1d294d..f936aa20dd186 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -500,13 +500,15 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: """ try: return trainer_fn(*args, **kwargs) - except KeyboardInterrupt: + # TODO(@daniellepintz): treat KeyboardInterrupt as BaseException (delete the code below) + except KeyboardInterrupt as exception: rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...") # user could press Ctrl+c many times... only shutdown once if not self.interrupted: self.state.status = TrainerStatus.INTERRUPTED self.on_keyboard_interrupt() - except BaseException: + self.on_exception(exception) + except BaseException as exception: self.state.status = TrainerStatus.INTERRUPTED if distributed_available() and self.world_size > 1: # try syncing remaing processes, kill otherwise @@ -514,8 +516,10 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: self._on_exception() # reset bookkeeping self.state.stage = None + self.on_exception(exception) raise + def fit( self, model: "pl.LightningModule", diff --git a/pytorch_lightning/utilities/model_helpers.py b/pytorch_lightning/utilities/model_helpers.py index 729a94e15da11..8596f1c67b812 100644 --- a/pytorch_lightning/utilities/model_helpers.py +++ b/pytorch_lightning/utilities/model_helpers.py @@ -41,6 +41,8 @@ def is_overridden( parent = pl.LightningModule elif isinstance(instance, pl.LightningDataModule): parent = pl.LightningDataModule + elif isinstance(instance, pl.Callback): + parent = pl.Callback if parent is None: raise ValueError("Expected a parent") diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 8c7b1a00d13d4..a04fc85026900 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -21,6 +21,7 @@ from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringModel from tests.helpers.datamodules import MNISTDataModule +from pytorch_lightning import Callback def test_v1_7_0_deprecated_lightning_module_summarize(tmpdir): @@ -116,3 +117,26 @@ def test_v1_7_0_deprecated_on_train_dataloader(tmpdir): def test_v1_7_0_test_tube_logger(_, tmpdir): with pytest.deprecated_call(match="The TestTubeLogger is deprecated since v1.5 and will be removed in v1.7"): _ = TestTubeLogger(tmpdir) + + +def test_v1_7_0_on_interrupt(tmpdir): + class HandleInterruptCallback(Callback): + def on_keyboard_interrupt(self, trainer, pl_module): + self.log("keyboard interrupt") + + model = BoringModel() + handle_interrupt_callback = HandleInterruptCallback() + + trainer = Trainer( + callbacks=[handle_interrupt_callback], + max_epochs=1, + limit_val_batches=0.1, + limit_train_batches=0.2, + progress_bar_refresh_rate=0, + logger=False, + default_root_dir=tmpdir, + ) + with pytest.deprecated_call( + match="The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7" + ): + trainer.fit(model) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5adbff1e41fd0..a7dfcc5573179 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -873,6 +873,52 @@ def on_keyboard_interrupt(self, trainer, pl_module): assert isinstance(handle_interrupt_callback.exc_info[1], KeyboardInterrupt) +def test_on_exception_hook(tmpdir): + """Test the on_exception callback hook.""" + + model = EvalModelTemplate() + + class InterruptCallback(Callback): + def __init__(self): + super().__init__() + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + raise KeyboardInterrupt + + def on_test_start(self, trainer, pl_module): + raise MisconfigurationException + + class HandleInterruptCallback(Callback): + def __init__(self): + super().__init__() + self.exception = None + + def on_exception(self, trainer, pl_module, exception): + self.exception = exception + + interrupt_callback = InterruptCallback() + handle_interrupt_callback = HandleInterruptCallback() + + trainer = Trainer( + callbacks=[interrupt_callback, handle_interrupt_callback], + max_epochs=1, + limit_val_batches=0.1, + limit_train_batches=0.2, + progress_bar_refresh_rate=0, + logger=False, + default_root_dir=tmpdir, + ) + assert not trainer.interrupted + assert handle_interrupt_callback.exception is None + trainer.fit(model) + assert trainer.interrupted + assert isinstance(handle_interrupt_callback.exception, KeyboardInterrupt) + with pytest.raises(MisconfigurationException): + trainer.test(model) + assert trainer.interrupted + assert isinstance(handle_interrupt_callback.exception, MisconfigurationException) + + @pytest.mark.parametrize( "precision", [32, pytest.param(16, marks=RunIf(min_gpus=1, amp_native=True))], From cddb35859cb4ff836169b388ac0b6c9a4d921a2c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 29 Aug 2021 05:36:26 +0000 Subject: [PATCH 02/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/configuration_validator.py | 4 ++-- pytorch_lightning/trainer/trainer.py | 1 - tests/deprecated_api/test_remove_1-7.py | 3 +-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 9aa8aea9f4c48..c3ce6030e7d2d 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.callbacks.base import Callback import pytorch_lightning as pl +from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature +from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn class ConfigValidator: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f936aa20dd186..a0f99579537d8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -519,7 +519,6 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: self.on_exception(exception) raise - def fit( self, model: "pl.LightningModule", diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index a04fc85026900..77194f2b29600 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -16,12 +16,11 @@ import pytest -from pytorch_lightning import LightningDataModule, Trainer +from pytorch_lightning import Callback, LightningDataModule, Trainer from pytorch_lightning.loggers import TestTubeLogger from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringModel from tests.helpers.datamodules import MNISTDataModule -from pytorch_lightning import Callback def test_v1_7_0_deprecated_lightning_module_summarize(tmpdir): From 7e16fe97ed3fdaa872d536e83f2ae32159804856 Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Sun, 29 Aug 2021 05:37:09 +0000 Subject: [PATCH 03/17] small fixes/lints --- CHANGELOG.md | 5 +++-- pytorch_lightning/core/hooks.py | 2 +- pytorch_lightning/trainer/configuration_validator.py | 1 - pytorch_lightning/trainer/trainer.py | 3 +-- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5158617ada9a8..2d4e6ceaff5cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -100,7 +100,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for CPU AMP autocast ([#9084](https://github.com/PyTorchLightning/pytorch-lightning/pull/9084)) -- Added `on_exception` callback hook +- Added `on_exception` callback hook ([#9183](https://github.com/PyTorchLightning/pytorch-lightning/pull/9183)) + ### Changed - Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770)) @@ -169,7 +170,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Updated deprecation of `argparse_utils.py` from removal in 1.4 to 2.0 ([#9162](https://github.com/PyTorchLightning/pytorch-lightning/pull/9162)) -- Deprecated `on_keyboard_interrupt` callback hook +- Deprecated `on_keyboard_interrupt` callback hook ([#9183](https://github.com/PyTorchLightning/pytorch-lightning/pull/9183)) ### Removed diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 220ac589f130c..479a2eae0f8bc 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -20,7 +20,7 @@ from pytorch_lightning.utilities import move_data_to_device from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS -from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn +from pytorch_lightning.utilities.warnings import rank_zero_deprecation class ModelHooks: diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 9aa8aea9f4c48..b37e2b5b51822 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.callbacks.base import Callback import pytorch_lightning as pl from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f936aa20dd186..6654e181d3e7e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -500,7 +500,7 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: """ try: return trainer_fn(*args, **kwargs) - # TODO(@daniellepintz): treat KeyboardInterrupt as BaseException (delete the code below) + # TODO(@daniellepintz): treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 except KeyboardInterrupt as exception: rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...") # user could press Ctrl+c many times... only shutdown once @@ -519,7 +519,6 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: self.on_exception(exception) raise - def fit( self, model: "pl.LightningModule", From 58318c338768c051f0d492c72cea2ec7e1f7def2 Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Sun, 29 Aug 2021 22:36:20 +0000 Subject: [PATCH 04/17] separate prs --- CHANGELOG.md | 3 --- pytorch_lightning/callbacks/base.py | 2 +- pytorch_lightning/trainer/callback_hook.py | 2 +- .../trainer/configuration_validator.py | 14 +---------- pytorch_lightning/trainer/trainer.py | 1 - pytorch_lightning/utilities/model_helpers.py | 2 -- tests/deprecated_api/test_remove_1-7.py | 25 +------------------ tests/trainer/test_trainer.py | 2 +- 8 files changed, 5 insertions(+), 46 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d4e6ceaff5cb..761187b4333a9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -163,15 +163,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated the `TestTubeLogger` ([#9065](https://github.com/PyTorchLightning/pytorch-lightning/pull/9065)) - - Deprecated `on_{train/val/test/predict}_dataloader()` from `LightningModule` and `LightningDataModule` [#9098](https://github.com/PyTorchLightning/pytorch-lightning/pull/9098) - Updated deprecation of `argparse_utils.py` from removal in 1.4 to 2.0 ([#9162](https://github.com/PyTorchLightning/pytorch-lightning/pull/9162)) -- Deprecated `on_keyboard_interrupt` callback hook ([#9183](https://github.com/PyTorchLightning/pytorch-lightning/pull/9183)) - ### Removed diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 114519f43bdff..b57c87500be46 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -268,7 +268,7 @@ def on_keyboard_interrupt(self, trainer: "pl.Trainer", pl_module: "pl.LightningM pass def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: - """Called when the training is interrupted by any exception.""" + """Called when any trainer execution is interrupted by an exception.""" pass def on_save_checkpoint( diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 9c35ac64c8b96..db0e59c4748ae 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -237,7 +237,7 @@ def on_keyboard_interrupt(self): callback.on_keyboard_interrupt(self, self.lightning_module) def on_exception(self, exception: BaseException) -> None: - """Called when the training is interrupted by any exception.""" + """Called when any trainer execution is interrupted by an exception.""" for callback in self.callbacks: callback.on_exception(self, self.lightning_module, exception) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index c3ce6030e7d2d..752cedd173be8 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytorch_lightning as pl -from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn class ConfigValidator: @@ -44,8 +43,6 @@ def verify_loop_configurations(self, model: "pl.LightningModule") -> None: elif self.trainer.state.fn == TrainerFn.PREDICTING: self.__verify_predict_loop_configuration(model) self.__verify_dp_batch_transfer_support(model) - # TODO(@daniellepintz): Delete _check_on_keyboard_interrupt in v1.7 - self._check_on_keyboard_interrupt() def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None: # ----------------------------------- @@ -204,12 +201,3 @@ def __check_training_step_requires_dataloader_iter(self, model: "pl.LightningMod "The model taking a `dataloader_iter` argument in your `training_step` " "is incompatible with `truncated_bptt_steps > 0`." ) - - def _check_on_keyboard_interrupt(self) -> None: - """Checks if on_keyboard_interrupt is overriden and sends a deprecation warning.""" - for callback in self.trainer.callbacks: - if is_overridden(method_name="on_keyboard_interrupt", instance=callback): - rank_zero_deprecation( - "The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7." - "Please use the `on_exception` callback hook instead." - ) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6654e181d3e7e..7709f34487d1d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -500,7 +500,6 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: """ try: return trainer_fn(*args, **kwargs) - # TODO(@daniellepintz): treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 except KeyboardInterrupt as exception: rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...") # user could press Ctrl+c many times... only shutdown once diff --git a/pytorch_lightning/utilities/model_helpers.py b/pytorch_lightning/utilities/model_helpers.py index 8596f1c67b812..729a94e15da11 100644 --- a/pytorch_lightning/utilities/model_helpers.py +++ b/pytorch_lightning/utilities/model_helpers.py @@ -41,8 +41,6 @@ def is_overridden( parent = pl.LightningModule elif isinstance(instance, pl.LightningDataModule): parent = pl.LightningDataModule - elif isinstance(instance, pl.Callback): - parent = pl.Callback if parent is None: raise ValueError("Expected a parent") diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 77194f2b29600..8c7b1a00d13d4 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -16,7 +16,7 @@ import pytest -from pytorch_lightning import Callback, LightningDataModule, Trainer +from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning.loggers import TestTubeLogger from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringModel @@ -116,26 +116,3 @@ def test_v1_7_0_deprecated_on_train_dataloader(tmpdir): def test_v1_7_0_test_tube_logger(_, tmpdir): with pytest.deprecated_call(match="The TestTubeLogger is deprecated since v1.5 and will be removed in v1.7"): _ = TestTubeLogger(tmpdir) - - -def test_v1_7_0_on_interrupt(tmpdir): - class HandleInterruptCallback(Callback): - def on_keyboard_interrupt(self, trainer, pl_module): - self.log("keyboard interrupt") - - model = BoringModel() - handle_interrupt_callback = HandleInterruptCallback() - - trainer = Trainer( - callbacks=[handle_interrupt_callback], - max_epochs=1, - limit_val_batches=0.1, - limit_train_batches=0.2, - progress_bar_refresh_rate=0, - logger=False, - default_root_dir=tmpdir, - ) - with pytest.deprecated_call( - match="The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7" - ): - trainer.fit(model) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a7dfcc5573179..32b57b48acf4d 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -876,7 +876,7 @@ def on_keyboard_interrupt(self, trainer, pl_module): def test_on_exception_hook(tmpdir): """Test the on_exception callback hook.""" - model = EvalModelTemplate() + model = BoringModel() class InterruptCallback(Callback): def __init__(self): From b97c66803cc5ec25fd7b90235b630c2670c0c9af Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Sun, 29 Aug 2021 22:51:01 +0000 Subject: [PATCH 05/17] deprecate on_keyboard_interrupt --- CHANGELOG.md | 2 ++ pytorch_lightning/callbacks/base.py | 7 +++++- pytorch_lightning/trainer/callback_hook.py | 7 +++++- .../trainer/configuration_validator.py | 13 +++++++++- pytorch_lightning/trainer/trainer.py | 1 + pytorch_lightning/utilities/model_helpers.py | 2 ++ tests/deprecated_api/test_remove_1-7.py | 25 ++++++++++++++++++- 7 files changed, 53 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 761187b4333a9..868c2935cd5b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -163,12 +163,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated the `TestTubeLogger` ([#9065](https://github.com/PyTorchLightning/pytorch-lightning/pull/9065)) + - Deprecated `on_{train/val/test/predict}_dataloader()` from `LightningModule` and `LightningDataModule` [#9098](https://github.com/PyTorchLightning/pytorch-lightning/pull/9098) - Updated deprecation of `argparse_utils.py` from removal in 1.4 to 2.0 ([#9162](https://github.com/PyTorchLightning/pytorch-lightning/pull/9162)) +- Deprecated `on_keyboard_interrupt` callback hook ([#9183](https://github.com/PyTorchLightning/pytorch-lightning/pull/9183)) ### Removed diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index b57c87500be46..db1357b13b89e 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -264,7 +264,12 @@ def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") pass def on_keyboard_interrupt(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """Called when the training is interrupted by ``KeyboardInterrupt``.""" + r""" + .. deprecated:: v1.5 + This callback hook was deprecated in v1.5 in favor of `on_exception` and will be removed in v1.7. + + Called when any trainer execution is interrupted by ``KeyboardInterrupt``. + """ pass def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index db0e59c4748ae..757489d5f372a 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -232,7 +232,12 @@ def on_predict_end(self) -> None: callback.on_predict_end(self, self.lightning_module) def on_keyboard_interrupt(self): - """Called when the training is interrupted by KeyboardInterrupt.""" + r""" + .. deprecated:: v1.5 + This callback hook was deprecated in v1.5 in favor of `on_exception` and will be removed in v1.7. + + Called when any trainer execution is interrupted by KeyboardInterrupt. + """ for callback in self.callbacks: callback.on_keyboard_interrupt(self, self.lightning_module) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 752cedd173be8..6bebcddbb5d26 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -13,10 +13,10 @@ # limitations under the License. import pytorch_lightning as pl from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature +from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn class ConfigValidator: @@ -43,6 +43,8 @@ def verify_loop_configurations(self, model: "pl.LightningModule") -> None: elif self.trainer.state.fn == TrainerFn.PREDICTING: self.__verify_predict_loop_configuration(model) self.__verify_dp_batch_transfer_support(model) + # TODO(@daniellepintz): Delete _check_on_keyboard_interrupt in v1.7 + self._check_on_keyboard_interrupt() def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None: # ----------------------------------- @@ -201,3 +203,12 @@ def __check_training_step_requires_dataloader_iter(self, model: "pl.LightningMod "The model taking a `dataloader_iter` argument in your `training_step` " "is incompatible with `truncated_bptt_steps > 0`." ) + + def _check_on_keyboard_interrupt(self) -> None: + """Checks if on_keyboard_interrupt is overriden and sends a deprecation warning.""" + for callback in self.trainer.callbacks: + if is_overridden(method_name="on_keyboard_interrupt", instance=callback): + rank_zero_deprecation( + "The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7." + "Please use the `on_exception` callback hook instead." + ) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7709f34487d1d..6654e181d3e7e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -500,6 +500,7 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: """ try: return trainer_fn(*args, **kwargs) + # TODO(@daniellepintz): treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 except KeyboardInterrupt as exception: rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...") # user could press Ctrl+c many times... only shutdown once diff --git a/pytorch_lightning/utilities/model_helpers.py b/pytorch_lightning/utilities/model_helpers.py index 729a94e15da11..8596f1c67b812 100644 --- a/pytorch_lightning/utilities/model_helpers.py +++ b/pytorch_lightning/utilities/model_helpers.py @@ -41,6 +41,8 @@ def is_overridden( parent = pl.LightningModule elif isinstance(instance, pl.LightningDataModule): parent = pl.LightningDataModule + elif isinstance(instance, pl.Callback): + parent = pl.Callback if parent is None: raise ValueError("Expected a parent") diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 8c7b1a00d13d4..0415a63928498 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -16,7 +16,7 @@ import pytest -from pytorch_lightning import LightningDataModule, Trainer +from pytorch_lightning import Callback, LightningDataModule, Trainer from pytorch_lightning.loggers import TestTubeLogger from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringModel @@ -116,3 +116,26 @@ def test_v1_7_0_deprecated_on_train_dataloader(tmpdir): def test_v1_7_0_test_tube_logger(_, tmpdir): with pytest.deprecated_call(match="The TestTubeLogger is deprecated since v1.5 and will be removed in v1.7"): _ = TestTubeLogger(tmpdir) + + +def test_v1_7_0_on_interrupt(tmpdir): + class HandleInterruptCallback(Callback): + def on_keyboard_interrupt(self, trainer, pl_module): + print("keyboard interrupt") + + model = BoringModel() + handle_interrupt_callback = HandleInterruptCallback() + + trainer = Trainer( + callbacks=[handle_interrupt_callback], + max_epochs=1, + limit_val_batches=0.1, + limit_train_batches=0.2, + progress_bar_refresh_rate=0, + logger=False, + default_root_dir=tmpdir, + ) + with pytest.deprecated_call( + match="The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7" + ): + trainer.fit(model) From 3c71e18808e57a79864ebf01e6b0fb6fc0287c14 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Sep 2021 19:35:45 +0000 Subject: [PATCH 06/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/deprecated_api/test_remove_1-7.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 418c222f149db..81d57377e48ea 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -141,7 +141,6 @@ def on_keyboard_interrupt(self, trainer, pl_module): trainer.fit(model) - def test_v1_7_0_process_position_trainer_constructor(tmpdir): with pytest.deprecated_call(match=r"Setting `Trainer\(process_position=5\)` is deprecated in v1.5"): _ = Trainer(process_position=5) From 99c85b6ac95c54731ea83242e13a1284c2c3c302 Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Wed, 1 Sep 2021 19:38:08 +0000 Subject: [PATCH 07/17] small fix --- pytorch_lightning/callbacks/base.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 8f8594c283c81..db1357b13b89e 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -276,10 +276,6 @@ def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", e """Called when any trainer execution is interrupted by an exception.""" pass - def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: - """Called when any trainer execution is interrupted by an exception.""" - pass - def on_save_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] ) -> dict: From b8b4de24e2e6c3288ce28fa2485b95b19bcd2d4a Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Wed, 1 Sep 2021 19:45:29 +0000 Subject: [PATCH 08/17] small fix --- pytorch_lightning/callbacks/base.py | 2 +- tests/trainer/test_trainer.py | 46 ----------------------------- 2 files changed, 1 insertion(+), 47 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index db1357b13b89e..b67b304424d8f 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -268,7 +268,7 @@ def on_keyboard_interrupt(self, trainer: "pl.Trainer", pl_module: "pl.LightningM .. deprecated:: v1.5 This callback hook was deprecated in v1.5 in favor of `on_exception` and will be removed in v1.7. - Called when any trainer execution is interrupted by ``KeyboardInterrupt``. + Called when any trainer execution is interrupted by KeyboardInterrupt. """ pass diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index b871d5375d1d0..da092aa602a25 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -886,52 +886,6 @@ def on_keyboard_interrupt(self, trainer, pl_module): assert isinstance(handle_interrupt_callback.exception, MisconfigurationException) -def test_on_exception_hook(tmpdir): - """Test the on_exception callback hook.""" - - model = BoringModel() - - class InterruptCallback(Callback): - def __init__(self): - super().__init__() - - def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): - raise KeyboardInterrupt - - def on_test_start(self, trainer, pl_module): - raise MisconfigurationException - - class HandleInterruptCallback(Callback): - def __init__(self): - super().__init__() - self.exception = None - - def on_exception(self, trainer, pl_module, exception): - self.exception = exception - - interrupt_callback = InterruptCallback() - handle_interrupt_callback = HandleInterruptCallback() - - trainer = Trainer( - callbacks=[interrupt_callback, handle_interrupt_callback], - max_epochs=1, - limit_val_batches=0.1, - limit_train_batches=0.2, - progress_bar_refresh_rate=0, - logger=False, - default_root_dir=tmpdir, - ) - assert not trainer.interrupted - assert handle_interrupt_callback.exception is None - trainer.fit(model) - assert trainer.interrupted - assert isinstance(handle_interrupt_callback.exception, KeyboardInterrupt) - with pytest.raises(MisconfigurationException): - trainer.test(model) - assert trainer.interrupted - assert isinstance(handle_interrupt_callback.exception, MisconfigurationException) - - @pytest.mark.parametrize( "precision", [32, pytest.param(16, marks=RunIf(min_gpus=1, amp_native=True))], From 7698965af2c7e090c731da7b53d8e9ca612d878f Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 1 Sep 2021 23:14:30 +0200 Subject: [PATCH 09/17] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 2 +- pytorch_lightning/trainer/configuration_validator.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e853c3119877c..67dcc20cb59b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -174,7 +174,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Updated deprecation of `argparse_utils.py` from removal in 1.4 to 2.0 ([#9162](https://github.com/PyTorchLightning/pytorch-lightning/pull/9162)) -- Deprecated `on_keyboard_interrupt` callback hook ([#9183](https://github.com/PyTorchLightning/pytorch-lightning/pull/9183)) +- Deprecated `on_keyboard_interrupt` callback hook in favor of new `on_exception` hook ([#9183](https://github.com/PyTorchLightning/pytorch-lightning/pull/9183)) - Deprecated passing `process_position` to the `Trainer` constructor in favor of adding the `ProgressBar` callback with `process_position` directly to the list of callbacks ([#9222](https://github.com/PyTorchLightning/pytorch-lightning/pull/9222)) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 6bebcddbb5d26..37a1027783f5f 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -210,5 +210,5 @@ def _check_on_keyboard_interrupt(self) -> None: if is_overridden(method_name="on_keyboard_interrupt", instance=callback): rank_zero_deprecation( "The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7." - "Please use the `on_exception` callback hook instead." + " Please use the `on_exception` callback hook instead." ) From a6a45fc70418aef5ce1bbaf2beea631572a25eb6 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 1 Sep 2021 23:18:40 +0200 Subject: [PATCH 10/17] note --- pytorch_lightning/trainer/configuration_validator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 37a1027783f5f..bd0457404ed51 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -43,7 +43,7 @@ def verify_loop_configurations(self, model: "pl.LightningModule") -> None: elif self.trainer.state.fn == TrainerFn.PREDICTING: self.__verify_predict_loop_configuration(model) self.__verify_dp_batch_transfer_support(model) - # TODO(@daniellepintz): Delete _check_on_keyboard_interrupt in v1.7 + # TODO: Delete _check_on_keyboard_interrupt in v1.7 self._check_on_keyboard_interrupt() def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None: From 3ced67df2195a8282607f0f8fabfef9c6754c86a Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Thu, 2 Sep 2021 20:20:17 +0000 Subject: [PATCH 11/17] fix failing tests --- pytorch_lightning/utilities/model_helpers.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/utilities/model_helpers.py b/pytorch_lightning/utilities/model_helpers.py index 8596f1c67b812..2c6f574da0da1 100644 --- a/pytorch_lightning/utilities/model_helpers.py +++ b/pytorch_lightning/utilities/model_helpers.py @@ -36,16 +36,6 @@ def is_overridden( # if `self.lightning_module` was passed as instance, it can be `None` return False - if parent is None: - if isinstance(instance, pl.LightningModule): - parent = pl.LightningModule - elif isinstance(instance, pl.LightningDataModule): - parent = pl.LightningDataModule - elif isinstance(instance, pl.Callback): - parent = pl.Callback - if parent is None: - raise ValueError("Expected a parent") - instance_attr = getattr(instance, method_name, None) # `functools.wraps()` support if hasattr(instance_attr, "__wrapped__"): @@ -60,6 +50,16 @@ def is_overridden( if instance_attr is None: return False + if parent is None: + if isinstance(instance, pl.LightningModule): + parent = pl.LightningModule + elif isinstance(instance, pl.LightningDataModule): + parent = pl.LightningDataModule + elif isinstance(instance, pl.Callback): + parent = pl.Callback + if parent is None: + raise ValueError("Expected a parent") + parent_attr = getattr(parent, method_name, None) if parent_attr is None: raise ValueError("The parent should define the method") From c6fdb6a8ff28ed569f9592e5f8ff1c9b58fc5cfe Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Thu, 2 Sep 2021 21:15:36 +0000 Subject: [PATCH 12/17] raise keyboardinterrupt --- pytorch_lightning/trainer/trainer.py | 3 ++- tests/trainer/test_trainer.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 252b7cebb50d0..a0bc4c6e47f47 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -513,6 +513,7 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: self.state.status = TrainerStatus.INTERRUPTED self.on_keyboard_interrupt() self.on_exception(exception) + raise exception except BaseException as exception: self.state.status = TrainerStatus.INTERRUPTED if distributed_available() and self.world_size > 1: @@ -522,7 +523,7 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: # reset bookkeeping self.state.stage = None self.on_exception(exception) - raise + raise exception def fit( self, diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index da092aa602a25..5337dca750ab4 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -876,7 +876,8 @@ def on_keyboard_interrupt(self, trainer, pl_module): assert not trainer.interrupted assert handle_interrupt_callback.exception is None assert handle_interrupt_callback.exc_info is None - trainer.fit(model) + with pytest.raises(KeyboardInterrupt): + trainer.fit(model) assert trainer.interrupted assert isinstance(handle_interrupt_callback.exception, KeyboardInterrupt) assert isinstance(handle_interrupt_callback.exc_info[1], KeyboardInterrupt) From 97b53f162013e16803b746b2e8babba76f360d54 Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Thu, 2 Sep 2021 23:26:47 +0000 Subject: [PATCH 13/17] fix tests --- cluster | 0 pytorch_lightning/trainer/trainer.py | 3 +-- pytorch_lightning/utilities/model_helpers.py | 20 ++++++++++---------- tests/callbacks/test_callbacks.py | 6 +++--- 4 files changed, 14 insertions(+), 15 deletions(-) create mode 100644 cluster diff --git a/cluster b/cluster new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a0bc4c6e47f47..2350ba302c149 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -505,7 +505,7 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: """ try: return trainer_fn(*args, **kwargs) - # TODO(@daniellepintz): treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 + # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 except KeyboardInterrupt as exception: rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...") # user could press Ctrl+c many times... only shutdown once @@ -513,7 +513,6 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: self.state.status = TrainerStatus.INTERRUPTED self.on_keyboard_interrupt() self.on_exception(exception) - raise exception except BaseException as exception: self.state.status = TrainerStatus.INTERRUPTED if distributed_available() and self.world_size > 1: diff --git a/pytorch_lightning/utilities/model_helpers.py b/pytorch_lightning/utilities/model_helpers.py index 2c6f574da0da1..8596f1c67b812 100644 --- a/pytorch_lightning/utilities/model_helpers.py +++ b/pytorch_lightning/utilities/model_helpers.py @@ -36,6 +36,16 @@ def is_overridden( # if `self.lightning_module` was passed as instance, it can be `None` return False + if parent is None: + if isinstance(instance, pl.LightningModule): + parent = pl.LightningModule + elif isinstance(instance, pl.LightningDataModule): + parent = pl.LightningDataModule + elif isinstance(instance, pl.Callback): + parent = pl.Callback + if parent is None: + raise ValueError("Expected a parent") + instance_attr = getattr(instance, method_name, None) # `functools.wraps()` support if hasattr(instance_attr, "__wrapped__"): @@ -50,16 +60,6 @@ def is_overridden( if instance_attr is None: return False - if parent is None: - if isinstance(instance, pl.LightningModule): - parent = pl.LightningModule - elif isinstance(instance, pl.LightningDataModule): - parent = pl.LightningDataModule - elif isinstance(instance, pl.Callback): - parent = pl.Callback - if parent is None: - raise ValueError("Expected a parent") - parent_attr = getattr(parent, method_name, None) if parent_attr is None: raise ValueError("The parent should define the method") diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index c363638d565d2..5803db051c659 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -26,8 +26,8 @@ def test_callbacks_configured_in_model(tmpdir): """Test the callback system with callbacks added through the model hook.""" - model_callback_mock = Mock() - trainer_callback_mock = Mock() + model_callback_mock = Mock(spec=Callback, model=Callback()) + trainer_callback_mock = Mock(spec=Callback, model=Callback()) class TestModel(BoringModel): def configure_callbacks(self): @@ -79,7 +79,7 @@ def assert_expected_calls(_trainer, model_callback, trainer_callback): def test_configure_callbacks_hook_multiple_calls(tmpdir): """Test that subsequent calls to `configure_callbacks` do not change the callbacks list.""" - model_callback_mock = Mock() + model_callback_mock = Mock(spec=Callback, model=Callback()) class TestModel(BoringModel): def configure_callbacks(self): From f98ea3fa847f1e6a31568b0269ec720d1c2dc3bd Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Thu, 2 Sep 2021 23:29:23 +0000 Subject: [PATCH 14/17] fix test --- tests/trainer/test_trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5337dca750ab4..da092aa602a25 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -876,8 +876,7 @@ def on_keyboard_interrupt(self, trainer, pl_module): assert not trainer.interrupted assert handle_interrupt_callback.exception is None assert handle_interrupt_callback.exc_info is None - with pytest.raises(KeyboardInterrupt): - trainer.fit(model) + trainer.fit(model) assert trainer.interrupted assert isinstance(handle_interrupt_callback.exception, KeyboardInterrupt) assert isinstance(handle_interrupt_callback.exc_info[1], KeyboardInterrupt) From 26f484ed942529a9a49c1d349f2cfc9d3daab5d4 Mon Sep 17 00:00:00 2001 From: Danielle Pintz <38207072+daniellepintz@users.noreply.github.com> Date: Thu, 2 Sep 2021 16:32:53 -0700 Subject: [PATCH 15/17] Delete cluster --- cluster | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 cluster diff --git a/cluster b/cluster deleted file mode 100644 index e69de29bb2d1d..0000000000000 From cef8d4dbb8e74ef7295c93f10b34c0fab12711af Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Thu, 2 Sep 2021 23:36:50 +0000 Subject: [PATCH 16/17] update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 67dcc20cb59b1..db8a7b74bc749 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -174,7 +174,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Updated deprecation of `argparse_utils.py` from removal in 1.4 to 2.0 ([#9162](https://github.com/PyTorchLightning/pytorch-lightning/pull/9162)) -- Deprecated `on_keyboard_interrupt` callback hook in favor of new `on_exception` hook ([#9183](https://github.com/PyTorchLightning/pytorch-lightning/pull/9183)) +- Deprecated `on_keyboard_interrupt` callback hook in favor of new `on_exception` hook ([#9260](https://github.com/PyTorchLightning/pytorch-lightning/pull/9260)) - Deprecated passing `process_position` to the `Trainer` constructor in favor of adding the `ProgressBar` callback with `process_position` directly to the list of callbacks ([#9222](https://github.com/PyTorchLightning/pytorch-lightning/pull/9222)) From 21b8abe84ca9f844c6f6724b9f0a98a30493e492 Mon Sep 17 00:00:00 2001 From: Danielle Pintz <38207072+daniellepintz@users.noreply.github.com> Date: Sun, 5 Sep 2021 21:24:39 -0700 Subject: [PATCH 17/17] change raise exception to raise MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2350ba302c149..b94321a4a1d70 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -522,7 +522,7 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: # reset bookkeeping self.state.stage = None self.on_exception(exception) - raise exception + raise def fit( self,