Skip to content
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added validate logic for precision ([#9080](https://github.com/PyTorchLightning/pytorch-lightning/pull/9080))


- Add support for CPU AMP autocast ([#9084](https://github.com/PyTorchLightning/pytorch-lightning/pull/9084))
- Added support for CPU AMP autocast ([#9084](https://github.com/PyTorchLightning/pytorch-lightning/pull/9084))


- Added `on_exception` callback hook ([#9183](https://github.com/PyTorchLightning/pytorch-lightning/pull/9183))

### Changed

- Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770))
Expand Down Expand Up @@ -163,7 +165,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Deprecated `on_{train/val/test/predict}_dataloader()` from `LightningModule` and `LightningDataModule` [#9098](https://github.com/PyTorchLightning/pytorch-lightning/pull/9098)

-

- Updated deprecation of `argparse_utils.py` from removal in 1.4 to 2.0 ([#9162](https://github.com/PyTorchLightning/pytorch-lightning/pull/9162))

Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,10 @@ def on_keyboard_interrupt(self, trainer: "pl.Trainer", pl_module: "pl.LightningM
"""Called when the training is interrupted by ``KeyboardInterrupt``."""
pass

def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None:
"""Called when any trainer execution is interrupted by an exception."""
pass

def on_save_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> dict:
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/callbacks/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
on_test_start: Optional[Callable] = None,
on_test_end: Optional[Callable] = None,
on_keyboard_interrupt: Optional[Callable] = None,
on_exception: Optional[Callable] = None,
on_save_checkpoint: Optional[Callable] = None,
on_load_checkpoint: Optional[Callable] = None,
on_before_backward: Optional[Callable] = None,
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from pytorch_lightning.utilities import move_data_to_device
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.warnings import rank_zero_deprecation


class ModelHooks:
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,11 @@ def on_keyboard_interrupt(self):
for callback in self.callbacks:
callback.on_keyboard_interrupt(self, self.lightning_module)

def on_exception(self, exception: BaseException) -> None:
"""Called when any trainer execution is interrupted by an exception."""
for callback in self.callbacks:
callback.on_exception(self, self.lightning_module, exception)

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]:
"""Called when saving a model checkpoint."""
callback_states = {}
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
# limitations under the License.
import pytorch_lightning as pl
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn


class ConfigValidator:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class FxValidator:
on_predict_batch_start=None,
on_predict_batch_end=None,
on_keyboard_interrupt=None,
on_exception=None,
on_save_checkpoint=None,
on_load_checkpoint=None,
setup=None,
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,20 +500,22 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs:
"""
try:
return trainer_fn(*args, **kwargs)
except KeyboardInterrupt:
except KeyboardInterrupt as exception:
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
# user could press Ctrl+c many times... only shutdown once
if not self.interrupted:
self.state.status = TrainerStatus.INTERRUPTED
self.on_keyboard_interrupt()
except BaseException:
self.on_exception(exception)
except BaseException as exception:
self.state.status = TrainerStatus.INTERRUPTED
if distributed_available() and self.world_size > 1:
# try syncing remaing processes, kill otherwise
self.training_type_plugin.reconciliate_processes(traceback.format_exc())
self._on_exception()
# reset bookkeeping
self.state.stage = None
self.on_exception(exception)
raise

def fit(
Expand Down
2 changes: 2 additions & 0 deletions tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def test_fx_validator(tmpdir):
"on_init_end",
"on_init_start",
"on_keyboard_interrupt",
"on_exception",
"on_load_checkpoint",
"on_pretrain_routine_end",
"on_pretrain_routine_start",
Expand Down Expand Up @@ -91,6 +92,7 @@ def test_fx_validator(tmpdir):
"on_init_end",
"on_init_start",
"on_keyboard_interrupt",
"on_exception",
"on_load_checkpoint",
"on_pretrain_routine_end",
"on_pretrain_routine_start",
Expand Down
19 changes: 16 additions & 3 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,10 +834,10 @@ def on_after_backward(self):
assert not torch.isfinite(params).all()


def test_trainer_interrupted_flag(tmpdir):
"""Test the flag denoting that a user interrupted training."""
def test_on_exception_hook(tmpdir):
"""Test the on_exception callback hook and the trainer interrupted flag."""

model = EvalModelTemplate()
model = BoringModel()

class InterruptCallback(Callback):
def __init__(self):
Expand All @@ -846,11 +846,18 @@ def __init__(self):
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
raise KeyboardInterrupt

def on_test_start(self, trainer, pl_module):
raise MisconfigurationException

class HandleInterruptCallback(Callback):
def __init__(self):
super().__init__()
self.exception = None
self.exc_info = None

def on_exception(self, trainer, pl_module, exception):
self.exception = exception

def on_keyboard_interrupt(self, trainer, pl_module):
self.exc_info = sys.exc_info()

Expand All @@ -867,10 +874,16 @@ def on_keyboard_interrupt(self, trainer, pl_module):
default_root_dir=tmpdir,
)
assert not trainer.interrupted
assert handle_interrupt_callback.exception is None
assert handle_interrupt_callback.exc_info is None
trainer.fit(model)
assert trainer.interrupted
assert isinstance(handle_interrupt_callback.exception, KeyboardInterrupt)
assert isinstance(handle_interrupt_callback.exc_info[1], KeyboardInterrupt)
with pytest.raises(MisconfigurationException):
trainer.test(model)
assert trainer.interrupted
assert isinstance(handle_interrupt_callback.exception, MisconfigurationException)


@pytest.mark.parametrize(
Expand Down