Skip to content

Commit 986eec2

Browse files
authored
Merge ba013f0 into 46617d9
2 parents 46617d9 + ba013f0 commit 986eec2

File tree

13 files changed

+144
-38
lines changed

13 files changed

+144
-38
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1212
- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))
1313

1414

15+
- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))
16+
17+
1518
### Changed
1619

1720

pytorch_lightning/callbacks/base.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"""
1818

1919
import abc
20-
from typing import Any
20+
from typing import Any, Dict
2121

2222
from pytorch_lightning.core.lightning import LightningModule
2323

@@ -177,12 +177,26 @@ def on_keyboard_interrupt(self, trainer, pl_module: LightningModule) -> None:
177177
"""Called when the training is interrupted by ``KeyboardInterrupt``."""
178178
pass
179179

180-
def on_save_checkpoint(self, trainer, pl_module: LightningModule) -> None:
181-
"""Called when saving a model checkpoint, use to persist state."""
180+
def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]) -> dict:
181+
"""
182+
Called when saving a model checkpoint, use to persist state.
183+
184+
Args:
185+
trainer: the current Trainer instance.
186+
pl_module: the current LightningModule instance.
187+
checkpoint: the checkpoint dictionary that will be saved.
188+
189+
Returns:
190+
The callback state.
191+
"""
182192
pass
183193

184-
def on_load_checkpoint(self, checkpointed_state) -> None:
185-
"""Called when loading a model checkpoint, use to reload state."""
194+
def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None:
195+
"""Called when loading a model checkpoint, use to reload state.
196+
197+
Args:
198+
callback_state: the callback state returned by ``on_save_checkpoint``.
199+
"""
186200
pass
187201

188202
def on_after_backward(self, trainer, pl_module: LightningModule) -> None:

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Monitor a metric and stop training when it stops improving.
1919
2020
"""
21+
from typing import Any, Dict
2122

2223
import numpy as np
2324
import torch
@@ -140,19 +141,19 @@ def _validate_condition_metric(self, logs):
140141
def monitor_op(self):
141142
return self.mode_dict[self.mode]
142143

143-
def on_save_checkpoint(self, trainer, pl_module):
144+
def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
144145
return {
145146
'wait_count': self.wait_count,
146147
'stopped_epoch': self.stopped_epoch,
147148
'best_score': self.best_score,
148149
'patience': self.patience
149150
}
150151

151-
def on_load_checkpoint(self, checkpointed_state):
152-
self.wait_count = checkpointed_state['wait_count']
153-
self.stopped_epoch = checkpointed_state['stopped_epoch']
154-
self.best_score = checkpointed_state['best_score']
155-
self.patience = checkpointed_state['patience']
152+
def on_load_checkpoint(self, callback_state: Dict[str, Any]):
153+
self.wait_count = callback_state['wait_count']
154+
self.stopped_epoch = callback_state['stopped_epoch']
155+
self.best_score = callback_state['best_score']
156+
self.patience = callback_state['patience']
156157

157158
def on_validation_end(self, trainer, pl_module):
158159
if trainer.running_sanity_check:

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def on_validation_end(self, trainer, pl_module):
191191
"""
192192
self.save_checkpoint(trainer, pl_module)
193193

194-
def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]:
194+
def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
195195
return {
196196
"monitor": self.monitor,
197197
"best_model_score": self.best_model_score,
@@ -200,9 +200,9 @@ def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]:
200200
"dirpath": self.dirpath
201201
}
202202

203-
def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]):
204-
self.best_model_score = checkpointed_state["best_model_score"]
205-
self.best_model_path = checkpointed_state["best_model_path"]
203+
def on_load_checkpoint(self, callback_state: Dict[str, Any]):
204+
self.best_model_score = callback_state["best_model_score"]
205+
self.best_model_path = callback_state["best_model_path"]
206206

207207
def save_checkpoint(self, trainer, pl_module):
208208
"""

pytorch_lightning/trainer/callback_hook.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414

1515
from abc import ABC
1616
from copy import deepcopy
17-
from typing import List
17+
from inspect import signature
18+
from typing import List, Dict, Any, Type, Callable
1819

1920
from pytorch_lightning.callbacks import Callback
2021
from pytorch_lightning.core.lightning import LightningModule
22+
from pytorch_lightning.utilities import rank_zero_warn
2123

2224

2325
class TrainerCallbackHookMixin(ABC):
@@ -197,14 +199,29 @@ def on_keyboard_interrupt(self):
197199
for callback in self.callbacks:
198200
callback.on_keyboard_interrupt(self, self.lightning_module)
199201

200-
def on_save_checkpoint(self):
202+
@staticmethod
203+
def __is_old_signature(fn: Callable) -> bool:
204+
parameters = list(signature(fn).parameters)
205+
if len(parameters) == 2 and parameters[1] != "args":
206+
return True
207+
return False
208+
209+
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]:
201210
"""Called when saving a model checkpoint."""
202211
callback_states = {}
203212
for callback in self.callbacks:
204-
callback_class = type(callback)
205-
state = callback.on_save_checkpoint(self, self.lightning_module)
213+
if self.__is_old_signature(callback.on_save_checkpoint):
214+
rank_zero_warn(
215+
"`Callback.on_save_checkpoint` signature has changed in v1.3."
216+
" A `checkpoint` parameter has been added."
217+
" Support for the old signature will be removed in v1.5",
218+
DeprecationWarning
219+
)
220+
state = callback.on_save_checkpoint(self, self.lightning_module) # noqa: parameter-unfilled
221+
else:
222+
state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)
206223
if state:
207-
callback_states[callback_class] = state
224+
callback_states[type(callback)] = state
208225
return callback_states
209226

210227
def on_load_checkpoint(self, checkpoint):

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -270,17 +270,18 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
270270
if not has_reached_max_steps:
271271
current_epoch += 1
272272

273+
model = self.trainer.lightning_module
274+
273275
checkpoint = {
274276
'epoch': current_epoch,
275277
'global_step': global_step,
276278
'pytorch-lightning_version': pytorch_lightning.__version__,
279+
'state_dict': model.state_dict(),
277280
}
278281

279282
if not weights_only:
280-
281283
# dump callbacks
282-
callback_states = self.trainer.on_save_checkpoint()
283-
checkpoint['callbacks'] = callback_states
284+
checkpoint['callbacks'] = self.trainer.on_save_checkpoint(checkpoint)
284285

285286
optimizer_states = []
286287
for i, optimizer in enumerate(self.trainer.optimizers):
@@ -305,12 +306,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
305306
elif self.trainer.amp_backend == AMPType.APEX:
306307
checkpoint['amp_scaling_state'] = amp.state_dict()
307308

308-
# add the hyper_parameters and state_dict from the model
309-
model = self.trainer.lightning_module
310-
311-
# dump the module_arguments and state_dict from the model
312-
checkpoint['state_dict'] = model.state_dict()
313-
309+
# dump hyper-parameters
314310
if model.hparams:
315311
if hasattr(model, '_hparams_name'):
316312
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name

tests/callbacks/test_callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_trainer_callback_system(torch_save, tmpdir):
9898
call.on_validation_epoch_end(trainer, model),
9999
call.on_epoch_end(trainer, model),
100100
call.on_validation_end(trainer, model),
101-
call.on_save_checkpoint(trainer, model),
101+
call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC
102102
call.on_train_end(trainer, model),
103103
call.on_fit_end(trainer, model),
104104
call.teardown(trainer, model, 'fit'),

tests/callbacks/test_early_stopping.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ def __init__(self, expected_state, *args, **kwargs):
3939

4040
def on_train_start(self, trainer, pl_module):
4141
if self.expected_state:
42-
assert self.on_save_checkpoint(trainer, pl_module) == self.expected_state
42+
assert self.on_save_checkpoint(trainer, pl_module, {}) == self.expected_state
4343

4444
def on_validation_end(self, trainer, pl_module):
4545
super().on_validation_end(trainer, pl_module)
46-
self.saved_states.append(self.on_save_checkpoint(trainer, pl_module).copy())
46+
self.saved_states.append(self.on_save_checkpoint(trainer, pl_module, {}).copy())
4747

4848

4949
def test_resume_early_stopping_from_checkpoint(tmpdir):

tests/checkpointing/test_model_checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,9 +346,9 @@ def __init__(self, expected_count, *args, **kwargs):
346346
def on_train_start(self, trainer, pl_module):
347347
torch.save = Mock(wraps=torch.save)
348348

349-
def on_save_checkpoint(self, trainer, pl_module):
349+
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
350350
# expect all ranks to run but only rank 0 will actually write the checkpoint file
351-
super().on_save_checkpoint(trainer, pl_module)
351+
super().on_save_checkpoint(trainer, pl_module, checkpoint)
352352
self.on_save_checkpoint_count += 1
353353

354354
def on_train_end(self, trainer, pl_module):

tests/deprecated_api/test_remove_1-4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""Test deprecated functionality which will be removed in vX.Y.Z"""
14+
"""Test deprecated functionality which will be removed in v1.4.0"""
1515
import sys
1616

1717
import pytest
@@ -243,5 +243,5 @@ def training_step(self, batch, batch_idx):
243243

244244
trainer = Trainer(default_root_dir=tmpdir, checkpoint_callback=True, max_epochs=1)
245245

246-
with pytest.warns(DeprecationWarning, match=r"Relying on.*is deprecated in v1.2 and will be removed in v1.4"):
246+
with pytest.deprecated_call(match=r"Relying on.*is deprecated in v1.2 and will be removed in v1.4"):
247247
trainer.fit(TestModel())

0 commit comments

Comments
 (0)