|
| 1 | +import pytorch_lightning as pl |
| 2 | +from pytorch_lightning.utilities.migration.base import set_version, should_upgrade |
| 3 | + |
| 4 | + |
| 5 | +# v0.10.0 |
| 6 | +def migrate_model_checkpoint_early_stopping(checkpoint: dict) -> dict: |
| 7 | + from pytorch_lightning.callbacks.early_stopping import EarlyStopping |
| 8 | + from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint |
| 9 | + keys_mapping = { |
| 10 | + "checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"), |
| 11 | + "checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"), |
| 12 | + "checkpoint_callback_best": (ModelCheckpoint, "best_model_score"), |
| 13 | + "early_stop_callback_wait": (EarlyStopping, "wait_count"), |
| 14 | + "early_stop_callback_patience": (EarlyStopping, "patience"), |
| 15 | + } |
| 16 | + checkpoint["callbacks"] = checkpoint.get("callbacks") or {} |
| 17 | + |
| 18 | + for key, new_path in keys_mapping.items(): |
| 19 | + if key in checkpoint: |
| 20 | + value = checkpoint[key] |
| 21 | + callback_type, callback_key = new_path |
| 22 | + checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {} |
| 23 | + checkpoint["callbacks"][callback_type][callback_key] = value |
| 24 | + del checkpoint[key] |
| 25 | + return checkpoint |
| 26 | + |
| 27 | + |
| 28 | +# v1.3.1 |
| 29 | +def migrate_callback_state_identifiers(checkpoint): |
| 30 | + if "callbacks" not in checkpoint: |
| 31 | + return |
| 32 | + callbacks = checkpoint["callbacks"] |
| 33 | + checkpoint["callbacks"] = dict((callback_type.__name__, state) for callback_type, state in callbacks.items()) |
| 34 | + return checkpoint |
| 35 | + |
| 36 | + |
| 37 | +def migrate_checkpoint(checkpoint: dict): |
| 38 | + """ Applies all the above migrations in order. """ |
| 39 | + if should_upgrade(checkpoint, "0.10.0"): |
| 40 | + migrate_model_checkpoint_early_stopping(checkpoint) |
| 41 | + if should_upgrade(checkpoint, "1.3.0"): |
| 42 | + migrate_callback_state_identifiers(checkpoint) |
| 43 | + set_version(checkpoint, "1.3.0") |
| 44 | + set_version(checkpoint, pl.__version__) |
| 45 | + return checkpoint |
0 commit comments