1313# limitations under the License.
1414
1515import os
16+ from pathlib import Path
1617import re
18+ from typing import Union , Optional
1719
1820import torch
1921
2022import pytorch_lightning
2123from pytorch_lightning import _logger as log
2224from pytorch_lightning .core .lightning import LightningModule
23- from pytorch_lightning .utilities import APEX_AVAILABLE , AMPType , OMEGACONF_AVAILABLE , rank_zero_warn
25+ from pytorch_lightning .utilities import APEX_AVAILABLE , AMPType , OMEGACONF_AVAILABLE , rank_zero_info , rank_zero_warn
2426from pytorch_lightning .utilities .cloud_io import atomic_save , get_filesystem
2527from pytorch_lightning .utilities .cloud_io import load as pl_load
2628from pytorch_lightning .utilities .upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
@@ -52,16 +54,17 @@ def restore_weights(self, model: LightningModule):
5254 if self .trainer .on_gpu :
5355 torch .cuda .empty_cache ()
5456
55- # if script called from hpc resubmit, load weights
56- did_restore_hpc_weights = self .restore_hpc_weights_if_needed (model )
57+ # 1. Attempt to restore states from HPC checkpoint
58+ dir_path_hpc = str (self .trainer .weights_save_path )
59+ max_suffix = self .max_ckpt_in_folder (dir_path_hpc , "hpc_ckpt_" )
60+ if max_suffix is not None :
61+ checkpoint_path = f'{ dir_path_hpc } /hpc_ckpt_{ max_suffix } .ckpt'
62+ self .hpc_load (checkpoint_path , self .trainer .on_gpu )
63+ rank_zero_info (f'restored hpc model from: { checkpoint_path } ' )
5764
58- # clear cache after restore
59- if self .trainer .on_gpu :
60- torch .cuda .empty_cache ()
61-
62- if not did_restore_hpc_weights :
63- if self .trainer .resume_from_checkpoint is not None :
64- self .restore (self .trainer .resume_from_checkpoint , on_gpu = self .trainer .on_gpu )
65+ # 2. Attempt to restore states from `resume_from_checkpoint` file
66+ elif self .trainer .resume_from_checkpoint is not None :
67+ self .restore (self .trainer .resume_from_checkpoint , on_gpu = self .trainer .on_gpu )
6568
6669 # wait for all to catch up
6770 self .trainer .accelerator_backend .barrier ('TrainerIOMixin.restore_weights' )
@@ -72,24 +75,14 @@ def restore_weights(self, model: LightningModule):
7275
7376 def restore (self , checkpoint_path : str , on_gpu : bool ):
7477 """
75- Load model/training states from the checkpoint file through file-read and state-restore.
76- Also restores all training state like:
77- - epoch
78- - callbacks
79- - schedulers
80- - optimizer
81- In detail, check return value description of `dump_checkpoint`
78+ Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore.
79+ All restored states are listed in return value description of `dump_checkpoint`.
8280 """
8381
84- # if on_gpu:
85- # checkpoint = torch.load(checkpoint_path)
86- # else:
87- # load on CPU first
88- # read a checkpoint dictionary object from the checkpoint file at `checkpoint_path`
82+ # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
8983 checkpoint = pl_load (checkpoint_path , map_location = lambda storage , loc : storage )
9084
91- # restore states from the checkpoint dictionary object
92- # load model state
85+ # acquire the model
9386 model = self .trainer .get_model ()
9487
9588 # restore model and datamodule state
@@ -106,14 +99,14 @@ def restore_model_state(self, model: LightningModule, checkpoint) -> None:
10699 Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object
107100 """
108101
109- # give the datamodule a chance to load something
102+ # restore datamodule states
110103 if self .trainer .datamodule is not None :
111104 self .trainer .datamodule .on_load_checkpoint (checkpoint )
112105
113- # give model a chance to restore something
106+ # hook: give user access to checkpoint if needed.
114107 model .on_load_checkpoint (checkpoint )
115108
116- # restore the state_dict on the model
109+ # restore model state_dict
117110 model .load_state_dict (checkpoint ['state_dict' ])
118111
119112 def restore_training_state (self , checkpoint ):
@@ -187,23 +180,6 @@ def restore_training_state(self, checkpoint):
187180 for scheduler , lrs_state in zip (self .trainer .lr_schedulers , lr_schedulers ):
188181 scheduler ['scheduler' ].load_state_dict (lrs_state )
189182
190- def restore_hpc_weights_if_needed (self , model : LightningModule ):
191- """If there is a set of hpc weights, use as signal to restore model."""
192- did_restore = False
193-
194- # look for hpc weights
195- folderpath = str (self .trainer .weights_save_path )
196- fs = get_filesystem (folderpath )
197- if fs .exists (folderpath ):
198- files = [os .path .basename (f ['name' ]) for f in fs .listdir (folderpath )]
199- hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x ]
200-
201- # if hpc weights exist restore model
202- if len (hpc_weight_paths ) > 0 :
203- self .hpc_load (folderpath , self .trainer .on_gpu )
204- did_restore = True
205- return did_restore
206-
207183 # ----------------------------------
208184 # PRIVATE OPS
209185 # ----------------------------------
@@ -216,7 +192,8 @@ def hpc_save(self, folderpath: str, logger):
216192 # save logger to make sure we get all the metrics
217193 logger .save ()
218194
219- ckpt_number = self .max_ckpt_in_folder (folderpath ) + 1
195+ max_suffix = self .max_ckpt_in_folder (folderpath )
196+ ckpt_number = (max_suffix if max_suffix is not None else 0 ) + 1
220197
221198 fs .makedirs (folderpath , exist_ok = True )
222199 filepath = os .path .join (folderpath , f'hpc_ckpt_{ ckpt_number } .ckpt' )
@@ -333,36 +310,52 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
333310
334311 return checkpoint
335312
336- def hpc_load (self , folderpath , on_gpu ):
337- filepath = '{}/hpc_ckpt_{}.ckpt' .format (folderpath , self .max_ckpt_in_folder (folderpath ))
313+ def hpc_load (self , checkpoint_path : str , on_gpu : bool ):
314+ """
315+ Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc.
316+ All restored states are listed in return value description of `dump_checkpoint`.
317+ """
338318
339- # load on CPU first
340- checkpoint = pl_load (filepath , map_location = lambda storage , loc : storage )
319+ # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
320+ checkpoint = pl_load (checkpoint_path , map_location = lambda storage , loc : storage )
341321
342- # load model state
322+ # acquire the model
343323 model = self .trainer .get_model ()
344324
345- # restore states from 'PyTorch-Lightning checkpoint' dictionary object
325+ # restore model and datamodule state
346326 self .restore_model_state (model , checkpoint )
347327
348328 if self .trainer .root_gpu is not None :
349329 model .cuda (self .trainer .root_gpu )
350330
351- # load training state (affects trainer only)
331+ # restore training state
352332 self .restore_training_state (checkpoint )
353333
354- # call model hook
334+ # call hpc specific hook
355335 model .on_hpc_load (checkpoint )
356336
357- log .info (f'restored hpc model from: { filepath } ' )
337+ def max_ckpt_in_folder (self , dir_path : Union [str , Path ], name_key : str = 'ckpt_' ) -> Optional [int ]:
338+ """List up files in `dir_path` with name_key, then yield maximum suffix number.
339+
340+ Args:
341+ dir_path: path of directory which may contain files whose name include `name_key`
342+
343+ Returns:
344+ None if no-corresponding-file else maximum suffix number
345+ """
346+
347+ # check directory existence
348+ fs = get_filesystem (dir_path )
349+ if not fs .exists (dir_path ):
350+ return None
358351
359- def max_ckpt_in_folder (self , path , name_key = 'ckpt_' ):
360- fs = get_filesystem (path )
361- files = [os .path .basename (f ["name" ]) for f in fs .listdir (path )]
352+ # check corresponding file existence
353+ files = [os .path .basename (f ["name" ]) for f in fs .listdir (dir_path )]
362354 files = [x for x in files if name_key in x ]
363355 if len (files ) == 0 :
364- return 0
356+ return None
365357
358+ # extract suffix number
366359 ckpt_vs = []
367360 for name in files :
368361 name = name .split (name_key )[- 1 ]
@@ -371,6 +364,13 @@ def max_ckpt_in_folder(self, path, name_key='ckpt_'):
371364
372365 return max (ckpt_vs )
373366
367+ def get_max_ckpt_path_from_folder (self , folder_path : Union [str , Path ]) -> str :
368+ """Get path of maximum-epoch checkpoint in the folder."""
369+
370+ max_suffix = self .max_ckpt_in_folder (folder_path )
371+ ckpt_number = max_suffix if max_suffix is not None else 0
372+ return f'{ folder_path } /hpc_ckpt_{ ckpt_number } .ckpt'
373+
374374 def save_checkpoint (self , filepath , weights_only : bool = False ):
375375 """Save model/training states as a checkpoint file through state-dump and file-write.
376376
0 commit comments