Skip to content

Commit 16feb51

Browse files
tarepantchatonawaelchliBordas-rog
authored
Refactor load in checkpoint connector (#4593)
* Refactor load step commentaries * Refactor hpc ckpt suffix acquisition * Refactor restore/hpc_load match * Refactor hpc load trial * Refactor checkpoint dir check * Refactor unneeded function nest * Refactor nested If * Refactor duplicated cache clear * Refactor attempt flow with if/elif * Fix pip8 * Refactor hook commentary Co-authored-by: chaton <[email protected]> * Fix pep8 * Refactor hpc load checkpoint path acquisition * Fix pip8 * Fix doc Co-authored-by: Adrian Wälchli <[email protected]> * Refactor None Union type with Optional Co-authored-by: chaton <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Roger Shieh <[email protected]>
1 parent 398f122 commit 16feb51

File tree

3 files changed

+66
-62
lines changed

3 files changed

+66
-62
lines changed

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 58 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@
1313
# limitations under the License.
1414

1515
import os
16+
from pathlib import Path
1617
import re
18+
from typing import Union, Optional
1719

1820
import torch
1921

2022
import pytorch_lightning
2123
from pytorch_lightning import _logger as log
2224
from pytorch_lightning.core.lightning import LightningModule
23-
from pytorch_lightning.utilities import APEX_AVAILABLE, AMPType, OMEGACONF_AVAILABLE, rank_zero_warn
25+
from pytorch_lightning.utilities import APEX_AVAILABLE, AMPType, OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn
2426
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
2527
from pytorch_lightning.utilities.cloud_io import load as pl_load
2628
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
@@ -52,16 +54,17 @@ def restore_weights(self, model: LightningModule):
5254
if self.trainer.on_gpu:
5355
torch.cuda.empty_cache()
5456

55-
# if script called from hpc resubmit, load weights
56-
did_restore_hpc_weights = self.restore_hpc_weights_if_needed(model)
57+
# 1. Attempt to restore states from HPC checkpoint
58+
dir_path_hpc = str(self.trainer.weights_save_path)
59+
max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_")
60+
if max_suffix is not None:
61+
checkpoint_path = f'{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt'
62+
self.hpc_load(checkpoint_path, self.trainer.on_gpu)
63+
rank_zero_info(f'restored hpc model from: {checkpoint_path}')
5764

58-
# clear cache after restore
59-
if self.trainer.on_gpu:
60-
torch.cuda.empty_cache()
61-
62-
if not did_restore_hpc_weights:
63-
if self.trainer.resume_from_checkpoint is not None:
64-
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu)
65+
# 2. Attempt to restore states from `resume_from_checkpoint` file
66+
elif self.trainer.resume_from_checkpoint is not None:
67+
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu)
6568

6669
# wait for all to catch up
6770
self.trainer.accelerator_backend.barrier('TrainerIOMixin.restore_weights')
@@ -72,24 +75,14 @@ def restore_weights(self, model: LightningModule):
7275

7376
def restore(self, checkpoint_path: str, on_gpu: bool):
7477
"""
75-
Load model/training states from the checkpoint file through file-read and state-restore.
76-
Also restores all training state like:
77-
- epoch
78-
- callbacks
79-
- schedulers
80-
- optimizer
81-
In detail, check return value description of `dump_checkpoint`
78+
Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore.
79+
All restored states are listed in return value description of `dump_checkpoint`.
8280
"""
8381

84-
# if on_gpu:
85-
# checkpoint = torch.load(checkpoint_path)
86-
# else:
87-
# load on CPU first
88-
# read a checkpoint dictionary object from the checkpoint file at `checkpoint_path`
82+
# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
8983
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
9084

91-
# restore states from the checkpoint dictionary object
92-
# load model state
85+
# acquire the model
9386
model = self.trainer.get_model()
9487

9588
# restore model and datamodule state
@@ -106,14 +99,14 @@ def restore_model_state(self, model: LightningModule, checkpoint) -> None:
10699
Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object
107100
"""
108101

109-
# give the datamodule a chance to load something
102+
# restore datamodule states
110103
if self.trainer.datamodule is not None:
111104
self.trainer.datamodule.on_load_checkpoint(checkpoint)
112105

113-
# give model a chance to restore something
106+
# hook: give user access to checkpoint if needed.
114107
model.on_load_checkpoint(checkpoint)
115108

116-
# restore the state_dict on the model
109+
# restore model state_dict
117110
model.load_state_dict(checkpoint['state_dict'])
118111

119112
def restore_training_state(self, checkpoint):
@@ -187,23 +180,6 @@ def restore_training_state(self, checkpoint):
187180
for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers):
188181
scheduler['scheduler'].load_state_dict(lrs_state)
189182

190-
def restore_hpc_weights_if_needed(self, model: LightningModule):
191-
"""If there is a set of hpc weights, use as signal to restore model."""
192-
did_restore = False
193-
194-
# look for hpc weights
195-
folderpath = str(self.trainer.weights_save_path)
196-
fs = get_filesystem(folderpath)
197-
if fs.exists(folderpath):
198-
files = [os.path.basename(f['name']) for f in fs.listdir(folderpath)]
199-
hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x]
200-
201-
# if hpc weights exist restore model
202-
if len(hpc_weight_paths) > 0:
203-
self.hpc_load(folderpath, self.trainer.on_gpu)
204-
did_restore = True
205-
return did_restore
206-
207183
# ----------------------------------
208184
# PRIVATE OPS
209185
# ----------------------------------
@@ -216,7 +192,8 @@ def hpc_save(self, folderpath: str, logger):
216192
# save logger to make sure we get all the metrics
217193
logger.save()
218194

219-
ckpt_number = self.max_ckpt_in_folder(folderpath) + 1
195+
max_suffix = self.max_ckpt_in_folder(folderpath)
196+
ckpt_number = (max_suffix if max_suffix is not None else 0) + 1
220197

221198
fs.makedirs(folderpath, exist_ok=True)
222199
filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt')
@@ -333,36 +310,52 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
333310

334311
return checkpoint
335312

336-
def hpc_load(self, folderpath, on_gpu):
337-
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, self.max_ckpt_in_folder(folderpath))
313+
def hpc_load(self, checkpoint_path: str, on_gpu: bool):
314+
"""
315+
Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc.
316+
All restored states are listed in return value description of `dump_checkpoint`.
317+
"""
338318

339-
# load on CPU first
340-
checkpoint = pl_load(filepath, map_location=lambda storage, loc: storage)
319+
# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
320+
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
341321

342-
# load model state
322+
# acquire the model
343323
model = self.trainer.get_model()
344324

345-
# restore states from 'PyTorch-Lightning checkpoint' dictionary object
325+
# restore model and datamodule state
346326
self.restore_model_state(model, checkpoint)
347327

348328
if self.trainer.root_gpu is not None:
349329
model.cuda(self.trainer.root_gpu)
350330

351-
# load training state (affects trainer only)
331+
# restore training state
352332
self.restore_training_state(checkpoint)
353333

354-
# call model hook
334+
# call hpc specific hook
355335
model.on_hpc_load(checkpoint)
356336

357-
log.info(f'restored hpc model from: {filepath}')
337+
def max_ckpt_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Optional[int]:
338+
"""List up files in `dir_path` with name_key, then yield maximum suffix number.
339+
340+
Args:
341+
dir_path: path of directory which may contain files whose name include `name_key`
342+
343+
Returns:
344+
None if no-corresponding-file else maximum suffix number
345+
"""
346+
347+
# check directory existence
348+
fs = get_filesystem(dir_path)
349+
if not fs.exists(dir_path):
350+
return None
358351

359-
def max_ckpt_in_folder(self, path, name_key='ckpt_'):
360-
fs = get_filesystem(path)
361-
files = [os.path.basename(f["name"]) for f in fs.listdir(path)]
352+
# check corresponding file existence
353+
files = [os.path.basename(f["name"]) for f in fs.listdir(dir_path)]
362354
files = [x for x in files if name_key in x]
363355
if len(files) == 0:
364-
return 0
356+
return None
365357

358+
# extract suffix number
366359
ckpt_vs = []
367360
for name in files:
368361
name = name.split(name_key)[-1]
@@ -371,6 +364,13 @@ def max_ckpt_in_folder(self, path, name_key='ckpt_'):
371364

372365
return max(ckpt_vs)
373366

367+
def get_max_ckpt_path_from_folder(self, folder_path: Union[str, Path]) -> str:
368+
"""Get path of maximum-epoch checkpoint in the folder."""
369+
370+
max_suffix = self.max_ckpt_in_folder(folder_path)
371+
ckpt_number = max_suffix if max_suffix is not None else 0
372+
return f'{folder_path}/hpc_ckpt_{ckpt_number}.ckpt'
373+
374374
def save_checkpoint(self, filepath, weights_only: bool = False):
375375
"""Save model/training states as a checkpoint file through state-dump and file-write.
376376

tests/base/develop_pipelines.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,11 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, wi
8686
trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = \
8787
trainer.init_optimizers(pretrained_model)
8888

89-
# test HPC loading / saving
89+
# test HPC saving
9090
trainer.checkpoint_connector.hpc_save(save_dir, logger)
91-
trainer.checkpoint_connector.hpc_load(save_dir, on_gpu=on_gpu)
91+
# test HPC loading
92+
checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(save_dir)
93+
trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu)
9294

9395

9496
def run_prediction(dataloader, trained_model, dp=False, min_acc=0.50):

tests/models/data/horovod/train_default_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,11 @@ def run_test_from_config(trainer_options):
7474
for dataloader in test_loaders:
7575
run_prediction(dataloader, pretrained_model)
7676

77-
# test HPC loading / saving
77+
# test HPC saving
7878
trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger)
79-
trainer.checkpoint_connector.hpc_load(ckpt_path, on_gpu=args.on_gpu)
79+
# test HPC loading
80+
checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(ckpt_path)
81+
trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=args.on_gpu)
8082

8183
if args.on_gpu:
8284
trainer = Trainer(gpus=1, accelerator='horovod', max_epochs=1)

0 commit comments

Comments
 (0)