1515import os
1616import re
1717from pathlib import Path
18- from typing import Optional , Union
18+ from typing import Any , Dict , Optional , Union
1919
2020import torch
2121
@@ -49,28 +49,16 @@ def __init__(self, trainer):
4949 # used to validate checkpointing logic
5050 self .has_trained = False
5151
52- def restore_weights (self ) -> None :
53- """
54- Attempt to restore a checkpoint (e.g. weights) in this priority:
55- 1. from HPC weights
56- 2. from `resume_from_checkpoint` file
57- 3. don't restore
52+ def attempt_to_restore (self ) -> None :
53+ """Attempt to restore model/training states.
5854 """
5955 # clear cache before restore
6056 if self .trainer ._device_type == DeviceType .GPU :
6157 torch .cuda .empty_cache ()
6258
63- # 1. Attempt to restore states from HPC checkpoint
64- dir_path_hpc = str (self .trainer .weights_save_path )
65- max_suffix = self .max_ckpt_in_folder (dir_path_hpc , "hpc_ckpt_" )
66- if max_suffix is not None :
67- checkpoint_path = f'{ dir_path_hpc } /hpc_ckpt_{ max_suffix } .ckpt'
68- self .hpc_load (checkpoint_path , self .trainer ._device_type == DeviceType .GPU )
69- rank_zero_info (f'restored hpc model from: { checkpoint_path } ' )
70-
71- # 2. Attempt to restore states from `resume_from_checkpoint` file
72- elif self .trainer .resume_from_checkpoint is not None :
73- self .restore (self .trainer .resume_from_checkpoint , on_gpu = self .trainer ._device_type == DeviceType .GPU )
59+ # attempt to restore states
60+ model : LightningModule = self .trainer .get_model ()
61+ self .attempt_to_apply_checkpoint (model )
7462
7563 # wait for all to catch up
7664 self .trainer .accelerator_backend .barrier ('TrainerIOMixin.restore_weights' )
@@ -79,53 +67,95 @@ def restore_weights(self) -> None:
7967 if self .trainer ._device_type == DeviceType .GPU :
8068 torch .cuda .empty_cache ()
8169
82- def restore (self , checkpoint_path : str , on_gpu : bool ) -> bool :
83- """
84- Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore.
85- All restored states are listed in return value description of `dump_checkpoint`.
70+ def attempt_to_apply_checkpoint (self , model : LightningModule ) -> bool :
71+ """Attempt to apply checkpoint states to model/training with priority.
72+
73+ Priority:
74+ 1. from HPC weights
75+ 2. from `resume_from_checkpoint` file
76+ 3. don't apply
77+
78+ Returns:
79+ True if applied else False
8680 """
87- # Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint.
88- fs = get_filesystem (checkpoint_path )
89- if not fs .exists (checkpoint_path ):
90- rank_zero_warn ("No checkpoint file exists at `resume_from_checkpoint`. Start from scratch" )
91- return False
81+ # Design Note:
82+ # `attempt_to_restore` has responsibility to whole state restoration flow (e.g. OOM, parallel processing).
83+ # This method has responsibility to applying/assigning state value from nullable checkpoint.
9284
93- # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
94- checkpoint = pl_load (checkpoint_path , map_location = lambda storage , loc : storage )
85+ restored : bool = False
9586
96- # acquire the model
97- model = self .trainer .get_model ()
87+ # 1. Attempt to apply HPC checkpoint.
88+ dir_path_hpc = str (self .trainer .weights_save_path )
89+ max_suffix = self .max_ckpt_in_folder (dir_path_hpc , "hpc_ckpt_" )
90+ if max_suffix is not None :
91+ checkpoint_path = f'{ dir_path_hpc } /hpc_ckpt_{ max_suffix } .ckpt'
92+ checkpoint = self .restore_states (model , checkpoint_path , self .trainer ._device_type == DeviceType .GPU )
93+ model .on_hpc_load (checkpoint )
94+ restored = True
95+ rank_zero_info (f'restored hpc model from: { checkpoint_path } ' )
9896
99- # restore model and datamodule state
100- self .restore_model_state (model , checkpoint )
97+ # 2. Attempt to apply `resume_from_checkpoint` file.
98+ elif self .trainer .resume_from_checkpoint is not None :
99+ adress_checkpoint : str = self .trainer .resume_from_checkpoint
100+ if get_filesystem (adress_checkpoint ).exists (adress_checkpoint ):
101+ self .restore_states (model , adress_checkpoint , self .trainer ._device_type == DeviceType .GPU )
102+ restored = True
103+ rank_zero_info (f"States restored from the checkpoint file at { adress_checkpoint } " )
104+ else :
105+ rank_zero_warn (f"checkpoint file at { adress_checkpoint } does not exist." )
101106
102- if on_gpu :
103- model .cuda (self .trainer .root_gpu )
107+ # 3. Do not apply, start from scratch.
108+ else :
109+ rank_zero_info ("Start from scratch." )
104110
105- # restore training state
106- self .restore_training_state (checkpoint )
111+ return restored
107112
108- rank_zero_info (f"Restored states from the checkpoint file at { checkpoint_path } " )
109- return True
113+ def restore_states (
114+ self ,
115+ model : LightningModule ,
116+ checkpoint_path : str ,
117+ on_gpu : bool ,
118+ ) -> Dict [str , Any ]:
119+ """Restore all states from checkpoint in the specified path.
110120
111- def restore_model_state (self , model : LightningModule , checkpoint ) -> None :
112- """
113- Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object
121+ Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore.
122+ All restored states are listed in return value description of `dump_checkpoint`.
123+
124+ Args:
125+ on_gpu: Whether trainer is on GPU or not.
114126 """
127+ # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
128+ checkpoint : Dict [str , Any ] = pl_load (checkpoint_path , map_location = lambda storage , loc : storage )
115129
116- # restore datamodule states
130+ # restore states
117131 if self .trainer .datamodule is not None :
118132 self .trainer .datamodule .on_load_checkpoint (checkpoint )
133+ self .restore_model_state (checkpoint , model , on_gpu )
134+ self .restore_training_state (checkpoint )
135+
136+ return checkpoint
119137
138+ def restore_model_state (
139+ self ,
140+ checkpoint : Dict [str , Any ],
141+ model : LightningModule ,
142+ on_gpu : bool ,
143+ ) -> None :
144+ """Restore model state.
145+ """
120146 # hook: give user access to checkpoint if needed.
121147 model .on_load_checkpoint (checkpoint )
122148
123149 # restore model state_dict
124150 model .load_state_dict (checkpoint ['state_dict' ])
125151
126- def restore_training_state (self , checkpoint ):
127- """
128- Restore trainer state.
152+ # moves the model to the GPU
153+ if on_gpu :
154+ model .cuda (self .trainer .root_gpu )
155+
156+ def restore_training_state (self , checkpoint : Dict [str , Any ]) -> None :
157+ """Restore trainer state.
158+
129159 Model will get its change to update
130160 :param checkpoint:
131161 :return:
@@ -329,30 +359,6 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
329359
330360 return checkpoint
331361
332- def hpc_load (self , checkpoint_path : str , on_gpu : bool ):
333- """
334- Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc.
335- All restored states are listed in return value description of `dump_checkpoint`.
336- """
337-
338- # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
339- checkpoint = pl_load (checkpoint_path , map_location = lambda storage , loc : storage )
340-
341- # acquire the model
342- model = self .trainer .get_model ()
343-
344- # restore model and datamodule state
345- self .restore_model_state (model , checkpoint )
346-
347- if self .trainer .root_gpu is not None :
348- model .cuda (self .trainer .root_gpu )
349-
350- # restore training state
351- self .restore_training_state (checkpoint )
352-
353- # call hpc specific hook
354- model .on_hpc_load (checkpoint )
355-
356362 def max_ckpt_in_folder (self , dir_path : Union [str , Path ], name_key : str = 'ckpt_' ) -> Optional [int ]:
357363 """List up files in `dir_path` with `name_key`, then yield maximum suffix number.
358364
0 commit comments