1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- from __future__ import annotations
14+ """Contains migration functions to upgrade legacy checkpoints to the format of the current Lightning version.
15+
16+ When Lightning loads a checkpoint, these migrations will be applied on the loaded checkpoint dictionary sequentially,
17+ see :func:`migrate_checkpoint`.
18+ """
1519
1620import sys
17- import threading
21+ from distutils . version import LooseVersion
1822from types import ModuleType , TracebackType
23+ from typing import Any , Dict , Optional , Type
1924
25+ import pytorch_lightning as pl
2026import pytorch_lightning .utilities .argparse
2127
22- # Create a global lock to ensure no race condition with deleting sys modules
23- _lock = threading .Lock ()
28+ _CHECKPOINT = Dict [str , Any ]
2429
2530
2631class pl_legacy_patch :
2732 """Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for
2833 unpickling old checkpoints. The following patches apply.
2934
3035 1. ``pytorch_lightning.utilities.argparse._gpus_arg_default``: Applies to all checkpoints saved prior to
31- version 1.2.8. See: https://github.com/Lightning-AI/ lightning/pull/6898
36+ version 1.2.8. See: https://github.com/PyTorchLightning/pytorch- lightning/pull/6898
3237 2. ``pytorch_lightning.utilities.argparse_utils``: A module that was deprecated in 1.2 and removed in 1.4,
3338 but still needs to be available for import for legacy checkpoints.
3439
@@ -38,20 +43,156 @@ class pl_legacy_patch:
3843 torch.load("path/to/legacy/checkpoint.ckpt")
3944 """
4045
41- def __enter__ (self ) -> None :
42- _lock .acquire ()
46+ def __enter__ (self ) -> "pl_legacy_patch" :
4347 # `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse`
4448 legacy_argparse_module = ModuleType ("pytorch_lightning.utilities.argparse_utils" )
4549 sys .modules ["pytorch_lightning.utilities.argparse_utils" ] = legacy_argparse_module
4650
4751 # `_gpus_arg_default` used to be imported from these locations
4852 legacy_argparse_module ._gpus_arg_default = lambda x : x
4953 pytorch_lightning .utilities .argparse ._gpus_arg_default = lambda x : x
54+ return self
5055
5156 def __exit__ (
52- self , exc_type : type [BaseException ] | None , exc_value : BaseException | None , exc_traceback : TracebackType | None
57+ self ,
58+ exc_type : Optional [Type [BaseException ]],
59+ exc_value : Optional [BaseException ],
60+ exc_traceback : Optional [TracebackType ],
5361 ) -> None :
5462 if hasattr (pytorch_lightning .utilities .argparse , "_gpus_arg_default" ):
5563 delattr (pytorch_lightning .utilities .argparse , "_gpus_arg_default" )
5664 del sys .modules ["pytorch_lightning.utilities.argparse_utils" ]
57- _lock .release ()
65+
66+
67+ def get_version (checkpoint : _CHECKPOINT ) -> str :
68+ """Get the version of a Lightning checkpoint."""
69+ return checkpoint ["pytorch-lightning_version" ]
70+
71+
72+ def set_version (checkpoint : _CHECKPOINT , version : str ) -> None :
73+ """Set the version of a Lightning checkpoint."""
74+ checkpoint ["pytorch-lightning_version" ] = version
75+
76+
77+ def should_upgrade (checkpoint : _CHECKPOINT , target : str ) -> bool :
78+ """Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target."""
79+ return LooseVersion (get_version (checkpoint )) < LooseVersion (target )
80+
81+
82+ def migrate_checkpoint (checkpoint : _CHECKPOINT ) -> _CHECKPOINT :
83+ """Applies all migrations below in order."""
84+ if should_upgrade (checkpoint , "0.10.0" ):
85+ _migrate_model_checkpoint_early_stopping (checkpoint )
86+ if should_upgrade (checkpoint , "1.6.0" ):
87+ _migrate_loop_global_step_to_progress_tracking (checkpoint )
88+ _migrate_loop_current_epoch_to_progress_tracking (checkpoint )
89+
90+ set_version (checkpoint , pl .__version__ )
91+
92+ # TODO: If any migrations apply, log a message. Suggest to run upgrade_checkpoint script to convert
93+ # checkpoints permanently
94+ return checkpoint
95+
96+
97+ def _migrate_model_checkpoint_early_stopping (checkpoint : _CHECKPOINT ) -> _CHECKPOINT :
98+ """The checkpoint and early stopping keys were renamed.
99+
100+ Version: 0.10.0
101+ Commit:
102+ """
103+ from pytorch_lightning .callbacks .early_stopping import EarlyStopping
104+ from pytorch_lightning .callbacks .model_checkpoint import ModelCheckpoint
105+
106+ keys_mapping = {
107+ "checkpoint_callback_best_model_score" : (ModelCheckpoint , "best_model_score" ),
108+ "checkpoint_callback_best_model_path" : (ModelCheckpoint , "best_model_path" ),
109+ "checkpoint_callback_best" : (ModelCheckpoint , "best_model_score" ),
110+ "early_stop_callback_wait" : (EarlyStopping , "wait_count" ),
111+ "early_stop_callback_patience" : (EarlyStopping , "patience" ),
112+ }
113+ checkpoint ["callbacks" ] = checkpoint .get ("callbacks" ) or {}
114+
115+ for key , new_path in keys_mapping .items ():
116+ if key in checkpoint :
117+ value = checkpoint [key ]
118+ callback_type , callback_key = new_path
119+ checkpoint ["callbacks" ][callback_type ] = checkpoint ["callbacks" ].get (callback_type ) or {}
120+ checkpoint ["callbacks" ][callback_type ][callback_key ] = value
121+ del checkpoint [key ]
122+ return checkpoint
123+
124+
125+ def _migrate_loop_global_step_to_progress_tracking (checkpoint : _CHECKPOINT ) -> _CHECKPOINT :
126+ """Set the `global_step` value for checkpoints before v1.6 without the progress tracking state.
127+ It will be overwritten by the loop's state if it was also saved.
128+
129+ Version: 1.6.0
130+ Commit:
131+ """
132+ global_step = checkpoint ["global_step" ]
133+ checkpoint .setdefault ("loops" , {"fit_loop" : _FIT_LOOP_INITIAL_STATE_1_6_0 })
134+ checkpoint ["loops" ].setdefault ("fit_loop" , _FIT_LOOP_INITIAL_STATE_1_6_0 )
135+ # for automatic optimization
136+ optim_progress = checkpoint ["loops" ]["fit_loop" ]["epoch_loop.batch_loop.optimizer_loop.optim_progress" ]
137+ optim_progress ["optimizer" ]["step" ]["total" ]["completed" ] = global_step
138+ # for manual optimization
139+ optim_step_progress = checkpoint ["loops" ]["fit_loop" ]["epoch_loop.batch_loop.manual_loop.optim_step_progress" ]
140+ optim_step_progress ["total" ]["completed" ] = global_step
141+ return checkpoint
142+
143+
144+ def _migrate_loop_current_epoch_to_progress_tracking (checkpoint : _CHECKPOINT ) -> _CHECKPOINT :
145+ """Set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state.
146+ It will be overwritten by the loop's state if it was also saved.
147+
148+ Version: 1.6.0
149+ Commit:
150+ """
151+ epoch = checkpoint ["epoch" ]
152+ checkpoint .setdefault ("loops" , {"fit_loop" : _FIT_LOOP_INITIAL_STATE_1_6_0 })
153+ checkpoint ["loops" ].setdefault ("fit_loop" , _FIT_LOOP_INITIAL_STATE_1_6_0 )
154+ checkpoint ["loops" ]["fit_loop" ]["epoch_progress" ]["current" ]["completed" ] = epoch
155+
156+
157+ _FIT_LOOP_INITIAL_STATE_1_6_0 = {
158+ "epoch_loop.batch_loop.manual_loop.optim_step_progress" : {
159+ "current" : {"completed" : 0 , "ready" : 0 },
160+ "total" : {"completed" : 0 , "ready" : 0 },
161+ },
162+ "epoch_loop.batch_loop.manual_loop.state_dict" : {},
163+ "epoch_loop.batch_loop.optimizer_loop.optim_progress" : {
164+ "optimizer" : {
165+ "step" : {"current" : {"completed" : 0 , "ready" : 0 }, "total" : {"completed" : 0 , "ready" : 0 }},
166+ "zero_grad" : {
167+ "current" : {"completed" : 0 , "ready" : 0 , "started" : 0 },
168+ "total" : {"completed" : 0 , "ready" : 0 , "started" : 0 },
169+ },
170+ },
171+ "optimizer_position" : 0 ,
172+ },
173+ "epoch_loop.batch_loop.optimizer_loop.state_dict" : {},
174+ "epoch_loop.batch_loop.state_dict" : {},
175+ "epoch_loop.batch_progress" : {
176+ "current" : {"completed" : 0 , "processed" : 0 , "ready" : 0 , "started" : 0 },
177+ "is_last_batch" : False ,
178+ "total" : {"completed" : 0 , "processed" : 0 , "ready" : 0 , "started" : 0 },
179+ },
180+ "epoch_loop.scheduler_progress" : {"current" : {"completed" : 0 , "ready" : 0 }, "total" : {"completed" : 0 , "ready" : 0 }},
181+ "epoch_loop.state_dict" : {"_batches_that_stepped" : 0 },
182+ "epoch_loop.val_loop.dataloader_progress" : {
183+ "current" : {"completed" : 0 , "ready" : 0 },
184+ "total" : {"completed" : 0 , "ready" : 0 },
185+ },
186+ "epoch_loop.val_loop.epoch_loop.batch_progress" : {
187+ "current" : {"completed" : 0 , "processed" : 0 , "ready" : 0 , "started" : 0 },
188+ "is_last_batch" : False ,
189+ "total" : {"completed" : 0 , "processed" : 0 , "ready" : 0 , "started" : 0 },
190+ },
191+ "epoch_loop.val_loop.epoch_loop.state_dict" : {},
192+ "epoch_loop.val_loop.state_dict" : {},
193+ "epoch_progress" : {
194+ "current" : {"completed" : 0 , "processed" : 0 , "ready" : 0 , "started" : 0 },
195+ "total" : {"completed" : 0 , "processed" : 0 , "ready" : 0 , "started" : 0 },
196+ },
197+ "state_dict" : {},
198+ }
0 commit comments