From 8ce82a3a45fb74345c55af68fb945a72225a453a Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Sat, 28 Aug 2021 23:16:18 +0000 Subject: [PATCH 1/8] add on_exception callback hook --- CHANGELOG.md | 7 ++- pytorch_lightning/callbacks/base.py | 4 ++ pytorch_lightning/trainer/callback_hook.py | 5 ++ .../trainer/configuration_validator.py | 14 +++++- pytorch_lightning/trainer/trainer.py | 8 +++- pytorch_lightning/utilities/model_helpers.py | 2 + tests/deprecated_api/test_remove_1-7.py | 24 ++++++++++ tests/trainer/test_trainer.py | 46 +++++++++++++++++++ 8 files changed, 105 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d3dcedbe003a..5158617ada9a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -97,9 +97,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added validate logic for precision ([#9080](https://github.com/PyTorchLightning/pytorch-lightning/pull/9080)) -- Add support for CPU AMP autocast ([#9084](https://github.com/PyTorchLightning/pytorch-lightning/pull/9084)) +- Added support for CPU AMP autocast ([#9084](https://github.com/PyTorchLightning/pytorch-lightning/pull/9084)) +- Added `on_exception` callback hook ### Changed - Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770)) @@ -161,13 +162,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated the `TestTubeLogger` ([#9065](https://github.com/PyTorchLightning/pytorch-lightning/pull/9065)) + - Deprecated `on_{train/val/test/predict}_dataloader()` from `LightningModule` and `LightningDataModule` [#9098](https://github.com/PyTorchLightning/pytorch-lightning/pull/9098) -- - Updated deprecation of `argparse_utils.py` from removal in 1.4 to 2.0 ([#9162](https://github.com/PyTorchLightning/pytorch-lightning/pull/9162)) +- Deprecated `on_keyboard_interrupt` callback hook + ### Removed diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index fdb22a44ed307..114519f43bdff 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -267,6 +267,10 @@ def on_keyboard_interrupt(self, trainer: "pl.Trainer", pl_module: "pl.LightningM """Called when the training is interrupted by ``KeyboardInterrupt``.""" pass + def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: + """Called when the training is interrupted by any exception.""" + pass + def on_save_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] ) -> dict: diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index bbfcbb22802a8..9c35ac64c8b96 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -236,6 +236,11 @@ def on_keyboard_interrupt(self): for callback in self.callbacks: callback.on_keyboard_interrupt(self, self.lightning_module) + def on_exception(self, exception: BaseException) -> None: + """Called when the training is interrupted by any exception.""" + for callback in self.callbacks: + callback.on_exception(self, self.lightning_module, exception) + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: """Called when saving a model checkpoint.""" callback_states = {} diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index e4a3c6f8c00eb..9aa8aea9f4c48 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pytorch_lightning.callbacks.base import Callback import pytorch_lightning as pl from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn +from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature @@ -43,6 +44,8 @@ def verify_loop_configurations(self, model: "pl.LightningModule") -> None: elif self.trainer.state.fn == TrainerFn.PREDICTING: self.__verify_predict_loop_configuration(model) self.__verify_dp_batch_transfer_support(model) + # TODO(@daniellepintz): Delete _check_on_keyboard_interrupt in v1.7 + self._check_on_keyboard_interrupt() def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None: # ----------------------------------- @@ -201,3 +204,12 @@ def __check_training_step_requires_dataloader_iter(self, model: "pl.LightningMod "The model taking a `dataloader_iter` argument in your `training_step` " "is incompatible with `truncated_bptt_steps > 0`." ) + + def _check_on_keyboard_interrupt(self) -> None: + """Checks if on_keyboard_interrupt is overriden and sends a deprecation warning.""" + for callback in self.trainer.callbacks: + if is_overridden(method_name="on_keyboard_interrupt", instance=callback): + rank_zero_deprecation( + "The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7." + "Please use the `on_exception` callback hook instead." + ) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f56a84b1d294d..f936aa20dd186 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -500,13 +500,15 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: """ try: return trainer_fn(*args, **kwargs) - except KeyboardInterrupt: + # TODO(@daniellepintz): treat KeyboardInterrupt as BaseException (delete the code below) + except KeyboardInterrupt as exception: rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...") # user could press Ctrl+c many times... only shutdown once if not self.interrupted: self.state.status = TrainerStatus.INTERRUPTED self.on_keyboard_interrupt() - except BaseException: + self.on_exception(exception) + except BaseException as exception: self.state.status = TrainerStatus.INTERRUPTED if distributed_available() and self.world_size > 1: # try syncing remaing processes, kill otherwise @@ -514,8 +516,10 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: self._on_exception() # reset bookkeeping self.state.stage = None + self.on_exception(exception) raise + def fit( self, model: "pl.LightningModule", diff --git a/pytorch_lightning/utilities/model_helpers.py b/pytorch_lightning/utilities/model_helpers.py index 729a94e15da11..8596f1c67b812 100644 --- a/pytorch_lightning/utilities/model_helpers.py +++ b/pytorch_lightning/utilities/model_helpers.py @@ -41,6 +41,8 @@ def is_overridden( parent = pl.LightningModule elif isinstance(instance, pl.LightningDataModule): parent = pl.LightningDataModule + elif isinstance(instance, pl.Callback): + parent = pl.Callback if parent is None: raise ValueError("Expected a parent") diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 8c7b1a00d13d4..a04fc85026900 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -21,6 +21,7 @@ from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringModel from tests.helpers.datamodules import MNISTDataModule +from pytorch_lightning import Callback def test_v1_7_0_deprecated_lightning_module_summarize(tmpdir): @@ -116,3 +117,26 @@ def test_v1_7_0_deprecated_on_train_dataloader(tmpdir): def test_v1_7_0_test_tube_logger(_, tmpdir): with pytest.deprecated_call(match="The TestTubeLogger is deprecated since v1.5 and will be removed in v1.7"): _ = TestTubeLogger(tmpdir) + + +def test_v1_7_0_on_interrupt(tmpdir): + class HandleInterruptCallback(Callback): + def on_keyboard_interrupt(self, trainer, pl_module): + self.log("keyboard interrupt") + + model = BoringModel() + handle_interrupt_callback = HandleInterruptCallback() + + trainer = Trainer( + callbacks=[handle_interrupt_callback], + max_epochs=1, + limit_val_batches=0.1, + limit_train_batches=0.2, + progress_bar_refresh_rate=0, + logger=False, + default_root_dir=tmpdir, + ) + with pytest.deprecated_call( + match="The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7" + ): + trainer.fit(model) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5adbff1e41fd0..a7dfcc5573179 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -873,6 +873,52 @@ def on_keyboard_interrupt(self, trainer, pl_module): assert isinstance(handle_interrupt_callback.exc_info[1], KeyboardInterrupt) +def test_on_exception_hook(tmpdir): + """Test the on_exception callback hook.""" + + model = EvalModelTemplate() + + class InterruptCallback(Callback): + def __init__(self): + super().__init__() + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + raise KeyboardInterrupt + + def on_test_start(self, trainer, pl_module): + raise MisconfigurationException + + class HandleInterruptCallback(Callback): + def __init__(self): + super().__init__() + self.exception = None + + def on_exception(self, trainer, pl_module, exception): + self.exception = exception + + interrupt_callback = InterruptCallback() + handle_interrupt_callback = HandleInterruptCallback() + + trainer = Trainer( + callbacks=[interrupt_callback, handle_interrupt_callback], + max_epochs=1, + limit_val_batches=0.1, + limit_train_batches=0.2, + progress_bar_refresh_rate=0, + logger=False, + default_root_dir=tmpdir, + ) + assert not trainer.interrupted + assert handle_interrupt_callback.exception is None + trainer.fit(model) + assert trainer.interrupted + assert isinstance(handle_interrupt_callback.exception, KeyboardInterrupt) + with pytest.raises(MisconfigurationException): + trainer.test(model) + assert trainer.interrupted + assert isinstance(handle_interrupt_callback.exception, MisconfigurationException) + + @pytest.mark.parametrize( "precision", [32, pytest.param(16, marks=RunIf(min_gpus=1, amp_native=True))], From cddb35859cb4ff836169b388ac0b6c9a4d921a2c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 29 Aug 2021 05:36:26 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/configuration_validator.py | 4 ++-- pytorch_lightning/trainer/trainer.py | 1 - tests/deprecated_api/test_remove_1-7.py | 3 +-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 9aa8aea9f4c48..c3ce6030e7d2d 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.callbacks.base import Callback import pytorch_lightning as pl +from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature +from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn class ConfigValidator: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f936aa20dd186..a0f99579537d8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -519,7 +519,6 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: self.on_exception(exception) raise - def fit( self, model: "pl.LightningModule", diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index a04fc85026900..77194f2b29600 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -16,12 +16,11 @@ import pytest -from pytorch_lightning import LightningDataModule, Trainer +from pytorch_lightning import Callback, LightningDataModule, Trainer from pytorch_lightning.loggers import TestTubeLogger from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringModel from tests.helpers.datamodules import MNISTDataModule -from pytorch_lightning import Callback def test_v1_7_0_deprecated_lightning_module_summarize(tmpdir): From 7e16fe97ed3fdaa872d536e83f2ae32159804856 Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Sun, 29 Aug 2021 05:37:09 +0000 Subject: [PATCH 3/8] small fixes/lints --- CHANGELOG.md | 5 +++-- pytorch_lightning/core/hooks.py | 2 +- pytorch_lightning/trainer/configuration_validator.py | 1 - pytorch_lightning/trainer/trainer.py | 3 +-- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5158617ada9a8..2d4e6ceaff5cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -100,7 +100,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for CPU AMP autocast ([#9084](https://github.com/PyTorchLightning/pytorch-lightning/pull/9084)) -- Added `on_exception` callback hook +- Added `on_exception` callback hook ([#9183](https://github.com/PyTorchLightning/pytorch-lightning/pull/9183)) + ### Changed - Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770)) @@ -169,7 +170,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Updated deprecation of `argparse_utils.py` from removal in 1.4 to 2.0 ([#9162](https://github.com/PyTorchLightning/pytorch-lightning/pull/9162)) -- Deprecated `on_keyboard_interrupt` callback hook +- Deprecated `on_keyboard_interrupt` callback hook ([#9183](https://github.com/PyTorchLightning/pytorch-lightning/pull/9183)) ### Removed diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 220ac589f130c..479a2eae0f8bc 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -20,7 +20,7 @@ from pytorch_lightning.utilities import move_data_to_device from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS -from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn +from pytorch_lightning.utilities.warnings import rank_zero_deprecation class ModelHooks: diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 9aa8aea9f4c48..b37e2b5b51822 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.callbacks.base import Callback import pytorch_lightning as pl from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f936aa20dd186..6654e181d3e7e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -500,7 +500,7 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: """ try: return trainer_fn(*args, **kwargs) - # TODO(@daniellepintz): treat KeyboardInterrupt as BaseException (delete the code below) + # TODO(@daniellepintz): treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 except KeyboardInterrupt as exception: rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...") # user could press Ctrl+c many times... only shutdown once @@ -519,7 +519,6 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: self.on_exception(exception) raise - def fit( self, model: "pl.LightningModule", From 58318c338768c051f0d492c72cea2ec7e1f7def2 Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Sun, 29 Aug 2021 22:36:20 +0000 Subject: [PATCH 4/8] separate prs --- CHANGELOG.md | 3 --- pytorch_lightning/callbacks/base.py | 2 +- pytorch_lightning/trainer/callback_hook.py | 2 +- .../trainer/configuration_validator.py | 14 +---------- pytorch_lightning/trainer/trainer.py | 1 - pytorch_lightning/utilities/model_helpers.py | 2 -- tests/deprecated_api/test_remove_1-7.py | 25 +------------------ tests/trainer/test_trainer.py | 2 +- 8 files changed, 5 insertions(+), 46 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d4e6ceaff5cb..761187b4333a9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -163,15 +163,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated the `TestTubeLogger` ([#9065](https://github.com/PyTorchLightning/pytorch-lightning/pull/9065)) - - Deprecated `on_{train/val/test/predict}_dataloader()` from `LightningModule` and `LightningDataModule` [#9098](https://github.com/PyTorchLightning/pytorch-lightning/pull/9098) - Updated deprecation of `argparse_utils.py` from removal in 1.4 to 2.0 ([#9162](https://github.com/PyTorchLightning/pytorch-lightning/pull/9162)) -- Deprecated `on_keyboard_interrupt` callback hook ([#9183](https://github.com/PyTorchLightning/pytorch-lightning/pull/9183)) - ### Removed diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 114519f43bdff..b57c87500be46 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -268,7 +268,7 @@ def on_keyboard_interrupt(self, trainer: "pl.Trainer", pl_module: "pl.LightningM pass def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: - """Called when the training is interrupted by any exception.""" + """Called when any trainer execution is interrupted by an exception.""" pass def on_save_checkpoint( diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 9c35ac64c8b96..db0e59c4748ae 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -237,7 +237,7 @@ def on_keyboard_interrupt(self): callback.on_keyboard_interrupt(self, self.lightning_module) def on_exception(self, exception: BaseException) -> None: - """Called when the training is interrupted by any exception.""" + """Called when any trainer execution is interrupted by an exception.""" for callback in self.callbacks: callback.on_exception(self, self.lightning_module, exception) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index c3ce6030e7d2d..752cedd173be8 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytorch_lightning as pl -from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn class ConfigValidator: @@ -44,8 +43,6 @@ def verify_loop_configurations(self, model: "pl.LightningModule") -> None: elif self.trainer.state.fn == TrainerFn.PREDICTING: self.__verify_predict_loop_configuration(model) self.__verify_dp_batch_transfer_support(model) - # TODO(@daniellepintz): Delete _check_on_keyboard_interrupt in v1.7 - self._check_on_keyboard_interrupt() def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None: # ----------------------------------- @@ -204,12 +201,3 @@ def __check_training_step_requires_dataloader_iter(self, model: "pl.LightningMod "The model taking a `dataloader_iter` argument in your `training_step` " "is incompatible with `truncated_bptt_steps > 0`." ) - - def _check_on_keyboard_interrupt(self) -> None: - """Checks if on_keyboard_interrupt is overriden and sends a deprecation warning.""" - for callback in self.trainer.callbacks: - if is_overridden(method_name="on_keyboard_interrupt", instance=callback): - rank_zero_deprecation( - "The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7." - "Please use the `on_exception` callback hook instead." - ) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6654e181d3e7e..7709f34487d1d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -500,7 +500,6 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: """ try: return trainer_fn(*args, **kwargs) - # TODO(@daniellepintz): treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 except KeyboardInterrupt as exception: rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...") # user could press Ctrl+c many times... only shutdown once diff --git a/pytorch_lightning/utilities/model_helpers.py b/pytorch_lightning/utilities/model_helpers.py index 8596f1c67b812..729a94e15da11 100644 --- a/pytorch_lightning/utilities/model_helpers.py +++ b/pytorch_lightning/utilities/model_helpers.py @@ -41,8 +41,6 @@ def is_overridden( parent = pl.LightningModule elif isinstance(instance, pl.LightningDataModule): parent = pl.LightningDataModule - elif isinstance(instance, pl.Callback): - parent = pl.Callback if parent is None: raise ValueError("Expected a parent") diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 77194f2b29600..8c7b1a00d13d4 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -16,7 +16,7 @@ import pytest -from pytorch_lightning import Callback, LightningDataModule, Trainer +from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning.loggers import TestTubeLogger from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringModel @@ -116,26 +116,3 @@ def test_v1_7_0_deprecated_on_train_dataloader(tmpdir): def test_v1_7_0_test_tube_logger(_, tmpdir): with pytest.deprecated_call(match="The TestTubeLogger is deprecated since v1.5 and will be removed in v1.7"): _ = TestTubeLogger(tmpdir) - - -def test_v1_7_0_on_interrupt(tmpdir): - class HandleInterruptCallback(Callback): - def on_keyboard_interrupt(self, trainer, pl_module): - self.log("keyboard interrupt") - - model = BoringModel() - handle_interrupt_callback = HandleInterruptCallback() - - trainer = Trainer( - callbacks=[handle_interrupt_callback], - max_epochs=1, - limit_val_batches=0.1, - limit_train_batches=0.2, - progress_bar_refresh_rate=0, - logger=False, - default_root_dir=tmpdir, - ) - with pytest.deprecated_call( - match="The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7" - ): - trainer.fit(model) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a7dfcc5573179..32b57b48acf4d 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -876,7 +876,7 @@ def on_keyboard_interrupt(self, trainer, pl_module): def test_on_exception_hook(tmpdir): """Test the on_exception callback hook.""" - model = EvalModelTemplate() + model = BoringModel() class InterruptCallback(Callback): def __init__(self): From db516e3f99a2d8a66177bc7890cb9c7eed08c9b3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 29 Aug 2021 22:37:31 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/configuration_validator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 752cedd173be8..cd7acfdc8a526 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -13,10 +13,10 @@ # limitations under the License. import pytorch_lightning as pl from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature +from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn class ConfigValidator: From 45c2f3b5694512b64feb095138ff773fe5334ace Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Mon, 30 Aug 2021 19:47:13 +0000 Subject: [PATCH 6/8] update test --- tests/trainer/test_trainer.py | 45 +++++------------------------------ 1 file changed, 6 insertions(+), 39 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 32b57b48acf4d..4bf705ee7e7d7 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -834,45 +834,6 @@ def on_after_backward(self): assert not torch.isfinite(params).all() -def test_trainer_interrupted_flag(tmpdir): - """Test the flag denoting that a user interrupted training.""" - - model = EvalModelTemplate() - - class InterruptCallback(Callback): - def __init__(self): - super().__init__() - - def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): - raise KeyboardInterrupt - - class HandleInterruptCallback(Callback): - def __init__(self): - super().__init__() - self.exc_info = None - - def on_keyboard_interrupt(self, trainer, pl_module): - self.exc_info = sys.exc_info() - - interrupt_callback = InterruptCallback() - handle_interrupt_callback = HandleInterruptCallback() - - trainer = Trainer( - callbacks=[interrupt_callback, handle_interrupt_callback], - max_epochs=1, - limit_val_batches=0.1, - limit_train_batches=0.2, - progress_bar_refresh_rate=0, - logger=False, - default_root_dir=tmpdir, - ) - assert not trainer.interrupted - assert handle_interrupt_callback.exc_info is None - trainer.fit(model) - assert trainer.interrupted - assert isinstance(handle_interrupt_callback.exc_info[1], KeyboardInterrupt) - - def test_on_exception_hook(tmpdir): """Test the on_exception callback hook.""" @@ -892,10 +853,14 @@ class HandleInterruptCallback(Callback): def __init__(self): super().__init__() self.exception = None + self.exc_info = None def on_exception(self, trainer, pl_module, exception): self.exception = exception + def on_keyboard_interrupt(self, trainer, pl_module): + self.exc_info = sys.exc_info() + interrupt_callback = InterruptCallback() handle_interrupt_callback = HandleInterruptCallback() @@ -910,9 +875,11 @@ def on_exception(self, trainer, pl_module, exception): ) assert not trainer.interrupted assert handle_interrupt_callback.exception is None + assert handle_interrupt_callback.exc_info is None trainer.fit(model) assert trainer.interrupted assert isinstance(handle_interrupt_callback.exception, KeyboardInterrupt) + assert isinstance(handle_interrupt_callback.exc_info[1], KeyboardInterrupt) with pytest.raises(MisconfigurationException): trainer.test(model) assert trainer.interrupted From e3f96a5308f77aea43a4f059aa23dd26cd7f01d5 Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Mon, 30 Aug 2021 19:51:25 +0000 Subject: [PATCH 7/8] update test --- tests/trainer/test_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 4bf705ee7e7d7..c16a172962aa1 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -835,7 +835,7 @@ def on_after_backward(self): def test_on_exception_hook(tmpdir): - """Test the on_exception callback hook.""" + """Test the on_exception callback hook and the trainer interrupted flag.""" model = BoringModel() From 05f6bef2bf05603188aa61101a2ab652cdc335f5 Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Mon, 30 Aug 2021 21:05:54 +0000 Subject: [PATCH 8/8] fix failing tests --- pytorch_lightning/callbacks/lambda_function.py | 1 + .../trainer/connectors/logger_connector/fx_validator.py | 1 + tests/trainer/logging_/test_logger_connector.py | 2 ++ 3 files changed, 4 insertions(+) diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index ca9af484dbc0c..1813e7d19090f 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -75,6 +75,7 @@ def __init__( on_test_start: Optional[Callable] = None, on_test_end: Optional[Callable] = None, on_keyboard_interrupt: Optional[Callable] = None, + on_exception: Optional[Callable] = None, on_save_checkpoint: Optional[Callable] = None, on_load_checkpoint: Optional[Callable] = None, on_before_backward: Optional[Callable] = None, diff --git a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index 871ba0fa86c96..50c237db80cce 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -62,6 +62,7 @@ class FxValidator: on_predict_batch_start=None, on_predict_batch_end=None, on_keyboard_interrupt=None, + on_exception=None, on_save_checkpoint=None, on_load_checkpoint=None, setup=None, diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index ed7711b32ffda..a0aca41d1397e 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -49,6 +49,7 @@ def test_fx_validator(tmpdir): "on_init_end", "on_init_start", "on_keyboard_interrupt", + "on_exception", "on_load_checkpoint", "on_pretrain_routine_end", "on_pretrain_routine_start", @@ -91,6 +92,7 @@ def test_fx_validator(tmpdir): "on_init_end", "on_init_start", "on_keyboard_interrupt", + "on_exception", "on_load_checkpoint", "on_pretrain_routine_end", "on_pretrain_routine_start",