Skip to content

Commit a61cc72

Browse files
authored
Merge pull request #9347 from PyTorchLightning/bugfix/timer-on-train-end
fix signature in callbacks to prevent deprecation warning
1 parent 645eabe commit a61cc72

File tree

4 files changed

+13
-3
lines changed

4 files changed

+13
-3
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8+
## [1.4.6] - unreleased
9+
10+
- Fixed signature of `Timer.on_train_epoch_end` and `StochasticWeightAveraging.on_train_epoch_end` to prevent unwanted deprecation warnings ([#9347](https://github.com/PyTorchLightning/pytorch-lightning/pull/9347))
11+
12+
813
## [1.4.5] - 2021-08-31
914

1015
- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))

pytorch_lightning/callbacks/stochastic_weight_avg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
216216

217217
trainer.accumulate_grad_batches = trainer.num_training_batches
218218

219-
def on_train_epoch_end(self, trainer: "pl.Trainer", *args):
219+
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", unused: Optional = None):
220220
trainer.fit_loop._skip_backward = False
221221

222222
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):

pytorch_lightning/callbacks/timer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,9 @@ def on_train_batch_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None:
148148
return
149149
self._check_time_remaining(trainer)
150150

151-
def on_train_epoch_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None:
151+
def on_train_epoch_end(
152+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", unused: Optional = None
153+
) -> None:
152154
if self._interval != Interval.epoch or self._duration is None:
153155
return
154156
self._check_time_remaining(trainer)

tests/callbacks/test_timer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2525
from tests.helpers import BoringModel
2626
from tests.helpers.runif import RunIf
27+
from tests.helpers.utils import no_warning_call
2728

2829

2930
def test_trainer_flag(caplog):
@@ -106,7 +107,9 @@ def test_timer_stops_training(tmpdir, caplog):
106107
timer = Timer(duration=duration)
107108

108109
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1000, callbacks=[timer])
109-
with caplog.at_level(logging.INFO):
110+
with caplog.at_level(logging.INFO), no_warning_call(
111+
DeprecationWarning, match="The signature of `Callback.on_train_epoch_end` has changed in v1.3"
112+
):
110113
trainer.fit(model)
111114
assert trainer.global_step > 1
112115
assert trainer.current_epoch < 999

0 commit comments

Comments
 (0)