|
14 | 14 |
|
15 | 15 | from abc import ABC |
16 | 16 | from copy import deepcopy |
17 | | -from typing import List |
| 17 | +from inspect import signature |
| 18 | +from typing import List, Dict, Any, Type, Callable |
18 | 19 |
|
19 | 20 | from pytorch_lightning.callbacks import Callback |
20 | 21 | from pytorch_lightning.core.lightning import LightningModule |
| 22 | +from pytorch_lightning.utilities import rank_zero_warn |
21 | 23 |
|
22 | 24 |
|
23 | 25 | class TrainerCallbackHookMixin(ABC): |
@@ -197,14 +199,29 @@ def on_keyboard_interrupt(self): |
197 | 199 | for callback in self.callbacks: |
198 | 200 | callback.on_keyboard_interrupt(self, self.lightning_module) |
199 | 201 |
|
200 | | - def on_save_checkpoint(self): |
| 202 | + @staticmethod |
| 203 | + def __is_old_signature(fn: Callable) -> bool: |
| 204 | + parameters = list(signature(fn).parameters) |
| 205 | + if len(parameters) == 2 and parameters[1] != "args": |
| 206 | + return True |
| 207 | + return False |
| 208 | + |
| 209 | + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]: |
201 | 210 | """Called when saving a model checkpoint.""" |
202 | 211 | callback_states = {} |
203 | 212 | for callback in self.callbacks: |
204 | | - callback_class = type(callback) |
205 | | - state = callback.on_save_checkpoint(self, self.lightning_module) |
| 213 | + if self.__is_old_signature(callback.on_save_checkpoint): |
| 214 | + rank_zero_warn( |
| 215 | + "`Callback.on_save_checkpoint` signature has changed in v1.3." |
| 216 | + " A `checkpoint` parameter has been added." |
| 217 | + " Support for the old signature will be removed in v1.5", |
| 218 | + DeprecationWarning |
| 219 | + ) |
| 220 | + state = callback.on_save_checkpoint(self, self.lightning_module) # noqa: parameter-unfilled |
| 221 | + else: |
| 222 | + state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint) |
206 | 223 | if state: |
207 | | - callback_states[callback_class] = state |
| 224 | + callback_states[type(callback)] = state |
208 | 225 | return callback_states |
209 | 226 |
|
210 | 227 | def on_load_checkpoint(self, checkpoint): |
|
0 commit comments