Skip to content

Commit 912fd31

Browse files
daniellepintzpre-commit-ci[bot]Bordaawaelchlicarmocca
authored
Deprecate on_keyboard_interrupt callback hook (#9260)
* add on_exception callback hook * deprecate on_keyboard_interrupt * Apply suggestions from code review * raise keyboard interrupt * Delete cluster * update changelog Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 49c0485 commit 912fd31

File tree

8 files changed

+57
-6
lines changed

8 files changed

+57
-6
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
171171

172172
- Deprecated the `TestTubeLogger` ([#9065](https://github.com/PyTorchLightning/pytorch-lightning/pull/9065))
173173

174+
174175
- Deprecated `on_{train/val/test/predict}_dataloader()` from `LightningModule` and `LightningDataModule` [#9098](https://github.com/PyTorchLightning/pytorch-lightning/pull/9098)
175176

176177

177178
- Updated deprecation of `argparse_utils.py` from removal in 1.4 to 2.0 ([#9162](https://github.com/PyTorchLightning/pytorch-lightning/pull/9162))
178179

179180

181+
- Deprecated `on_keyboard_interrupt` callback hook in favor of new `on_exception` hook ([#9260](https://github.com/PyTorchLightning/pytorch-lightning/pull/9260))
182+
183+
180184
- Deprecated passing `process_position` to the `Trainer` constructor in favor of adding the `ProgressBar` callback with `process_position` directly to the list of callbacks ([#9222](https://github.com/PyTorchLightning/pytorch-lightning/pull/9222))
181185

182186

pytorch_lightning/callbacks/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,12 @@ def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule")
264264
pass
265265

266266
def on_keyboard_interrupt(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
267-
"""Called when the training is interrupted by ``KeyboardInterrupt``."""
267+
r"""
268+
.. deprecated:: v1.5
269+
This callback hook was deprecated in v1.5 in favor of `on_exception` and will be removed in v1.7.
270+
271+
Called when any trainer execution is interrupted by KeyboardInterrupt.
272+
"""
268273
pass
269274

270275
def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None:

pytorch_lightning/trainer/callback_hook.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,12 @@ def on_predict_end(self) -> None:
232232
callback.on_predict_end(self, self.lightning_module)
233233

234234
def on_keyboard_interrupt(self):
235-
"""Called when the training is interrupted by KeyboardInterrupt."""
235+
r"""
236+
.. deprecated:: v1.5
237+
This callback hook was deprecated in v1.5 in favor of `on_exception` and will be removed in v1.7.
238+
239+
Called when any trainer execution is interrupted by KeyboardInterrupt.
240+
"""
236241
for callback in self.callbacks:
237242
callback.on_keyboard_interrupt(self, self.lightning_module)
238243

pytorch_lightning/trainer/configuration_validator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def verify_loop_configurations(self, model: "pl.LightningModule") -> None:
4343
elif self.trainer.state.fn == TrainerFn.PREDICTING:
4444
self.__verify_predict_loop_configuration(model)
4545
self.__verify_dp_batch_transfer_support(model)
46+
# TODO: Delete _check_on_keyboard_interrupt in v1.7
47+
self._check_on_keyboard_interrupt()
4648

4749
def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None:
4850
# -----------------------------------
@@ -201,3 +203,12 @@ def __check_training_step_requires_dataloader_iter(self, model: "pl.LightningMod
201203
"The model taking a `dataloader_iter` argument in your `training_step` "
202204
"is incompatible with `truncated_bptt_steps > 0`."
203205
)
206+
207+
def _check_on_keyboard_interrupt(self) -> None:
208+
"""Checks if on_keyboard_interrupt is overriden and sends a deprecation warning."""
209+
for callback in self.trainer.callbacks:
210+
if is_overridden(method_name="on_keyboard_interrupt", instance=callback):
211+
rank_zero_deprecation(
212+
"The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7."
213+
" Please use the `on_exception` callback hook instead."
214+
)

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,7 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs:
509509
"""
510510
try:
511511
return trainer_fn(*args, **kwargs)
512+
# TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
512513
except KeyboardInterrupt as exception:
513514
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
514515
# user could press Ctrl+c many times... only shutdown once

pytorch_lightning/utilities/model_helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def is_overridden(
4141
parent = pl.LightningModule
4242
elif isinstance(instance, pl.LightningDataModule):
4343
parent = pl.LightningDataModule
44+
elif isinstance(instance, pl.Callback):
45+
parent = pl.Callback
4446
if parent is None:
4547
raise ValueError("Expected a parent")
4648

tests/callbacks/test_callbacks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
def test_callbacks_configured_in_model(tmpdir):
2727
"""Test the callback system with callbacks added through the model hook."""
2828

29-
model_callback_mock = Mock()
30-
trainer_callback_mock = Mock()
29+
model_callback_mock = Mock(spec=Callback, model=Callback())
30+
trainer_callback_mock = Mock(spec=Callback, model=Callback())
3131

3232
class TestModel(BoringModel):
3333
def configure_callbacks(self):
@@ -79,7 +79,7 @@ def assert_expected_calls(_trainer, model_callback, trainer_callback):
7979

8080
def test_configure_callbacks_hook_multiple_calls(tmpdir):
8181
"""Test that subsequent calls to `configure_callbacks` do not change the callbacks list."""
82-
model_callback_mock = Mock()
82+
model_callback_mock = Mock(spec=Callback, model=Callback())
8383

8484
class TestModel(BoringModel):
8585
def configure_callbacks(self):

tests/deprecated_api/test_remove_1-7.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import pytest
1818

19-
from pytorch_lightning import LightningDataModule, Trainer
19+
from pytorch_lightning import Callback, LightningDataModule, Trainer
2020
from pytorch_lightning.loggers import TestTubeLogger
2121
from tests.deprecated_api import _soft_unimport_module
2222
from tests.helpers import BoringModel
@@ -118,6 +118,29 @@ def test_v1_7_0_test_tube_logger(_, tmpdir):
118118
_ = TestTubeLogger(tmpdir)
119119

120120

121+
def test_v1_7_0_on_interrupt(tmpdir):
122+
class HandleInterruptCallback(Callback):
123+
def on_keyboard_interrupt(self, trainer, pl_module):
124+
print("keyboard interrupt")
125+
126+
model = BoringModel()
127+
handle_interrupt_callback = HandleInterruptCallback()
128+
129+
trainer = Trainer(
130+
callbacks=[handle_interrupt_callback],
131+
max_epochs=1,
132+
limit_val_batches=0.1,
133+
limit_train_batches=0.2,
134+
progress_bar_refresh_rate=0,
135+
logger=False,
136+
default_root_dir=tmpdir,
137+
)
138+
with pytest.deprecated_call(
139+
match="The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7"
140+
):
141+
trainer.fit(model)
142+
143+
121144
def test_v1_7_0_process_position_trainer_constructor(tmpdir):
122145
with pytest.deprecated_call(match=r"Setting `Trainer\(process_position=5\)` is deprecated in v1.5"):
123146
_ = Trainer(process_position=5)

0 commit comments

Comments
 (0)