Skip to content

Commit 73a1e19

Browse files
authored
Merge d10fba8 into d2c2e50
2 parents d2c2e50 + d10fba8 commit 73a1e19

File tree

8 files changed

+111
-92
lines changed

8 files changed

+111
-92
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
110110
- Changed the default value for the `progress_bar_refresh_rate` Trainer argument in Google COLAB notebooks to 20 ([#5516](https://github.com/PyTorchLightning/pytorch-lightning/pull/5516))
111111

112112

113+
- Refactored `hpc_load` and entangled logics in `CheckpointConnector` ([#5371](https://github.com/PyTorchLightning/pytorch-lightning/pull/5371))
114+
115+
113116
- Made `LightningModule.global_rank`, `LightningModule.local_rank` and `LightningModule.logger` read-only properties ([#5730](https://github.com/PyTorchLightning/pytorch-lightning/pull/5730))
114117

115118

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 76 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import os
1616
import re
1717
from pathlib import Path
18-
from typing import Optional, Union
18+
from typing import Any, Dict, Optional, Union
1919

2020
import torch
2121

@@ -49,28 +49,16 @@ def __init__(self, trainer):
4949
# used to validate checkpointing logic
5050
self.has_trained = False
5151

52-
def restore_weights(self) -> None:
53-
"""
54-
Attempt to restore a checkpoint (e.g. weights) in this priority:
55-
1. from HPC weights
56-
2. from `resume_from_checkpoint` file
57-
3. don't restore
52+
def attempt_to_restore(self) -> None:
53+
"""Attempt to restore model/training states.
5854
"""
5955
# clear cache before restore
6056
if self.trainer._device_type == DeviceType.GPU:
6157
torch.cuda.empty_cache()
6258

63-
# 1. Attempt to restore states from HPC checkpoint
64-
dir_path_hpc = str(self.trainer.weights_save_path)
65-
max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_")
66-
if max_suffix is not None:
67-
checkpoint_path = f'{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt'
68-
self.hpc_load(checkpoint_path, self.trainer._device_type == DeviceType.GPU)
69-
rank_zero_info(f'restored hpc model from: {checkpoint_path}')
70-
71-
# 2. Attempt to restore states from `resume_from_checkpoint` file
72-
elif self.trainer.resume_from_checkpoint is not None:
73-
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer._device_type == DeviceType.GPU)
59+
# attempt to restore states
60+
model: LightningModule = self.trainer.get_model()
61+
self.attempt_to_apply_checkpoint(model)
7462

7563
# wait for all to catch up
7664
self.trainer.accelerator_backend.barrier('TrainerIOMixin.restore_weights')
@@ -79,53 +67,95 @@ def restore_weights(self) -> None:
7967
if self.trainer._device_type == DeviceType.GPU:
8068
torch.cuda.empty_cache()
8169

82-
def restore(self, checkpoint_path: str, on_gpu: bool) -> bool:
83-
"""
84-
Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore.
85-
All restored states are listed in return value description of `dump_checkpoint`.
70+
def attempt_to_apply_checkpoint(self, model: LightningModule) -> bool:
71+
"""Attempt to apply checkpoint states to model/training with priority.
72+
73+
Priority:
74+
1. from HPC weights
75+
2. from `resume_from_checkpoint` file
76+
3. don't apply
77+
78+
Returns:
79+
True if applied else False
8680
"""
87-
# Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint.
88-
fs = get_filesystem(checkpoint_path)
89-
if not fs.exists(checkpoint_path):
90-
rank_zero_warn("No checkpoint file exists at `resume_from_checkpoint`. Start from scratch")
91-
return False
81+
# Design Note:
82+
# `attempt_to_restore` has responsibility to whole state restoration flow (e.g. OOM, parallel processing).
83+
# This method has responsibility to applying/assigning state value from nullable checkpoint.
9284

93-
# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
94-
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
85+
restored: bool = False
9586

96-
# acquire the model
97-
model = self.trainer.get_model()
87+
# 1. Attempt to apply HPC checkpoint.
88+
dir_path_hpc = str(self.trainer.weights_save_path)
89+
max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_")
90+
if max_suffix is not None:
91+
checkpoint_path = f'{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt'
92+
checkpoint = self.restore_states(model, checkpoint_path, self.trainer._device_type == DeviceType.GPU)
93+
model.on_hpc_load(checkpoint)
94+
restored = True
95+
rank_zero_info(f'restored hpc model from: {checkpoint_path}')
9896

99-
# restore model and datamodule state
100-
self.restore_model_state(model, checkpoint)
97+
# 2. Attempt to apply `resume_from_checkpoint` file.
98+
elif self.trainer.resume_from_checkpoint is not None:
99+
adress_checkpoint: str = self.trainer.resume_from_checkpoint
100+
if get_filesystem(adress_checkpoint).exists(adress_checkpoint):
101+
self.restore_states(model, adress_checkpoint, self.trainer._device_type == DeviceType.GPU)
102+
restored = True
103+
rank_zero_info(f"States restored from the checkpoint file at {adress_checkpoint}")
104+
else:
105+
rank_zero_warn(f"checkpoint file at {adress_checkpoint} does not exist.")
101106

102-
if on_gpu:
103-
model.cuda(self.trainer.root_gpu)
107+
# 3. Do not apply, start from scratch.
108+
else:
109+
rank_zero_info("Start from scratch.")
104110

105-
# restore training state
106-
self.restore_training_state(checkpoint)
111+
return restored
107112

108-
rank_zero_info(f"Restored states from the checkpoint file at {checkpoint_path}")
109-
return True
113+
def restore_states(
114+
self,
115+
model: LightningModule,
116+
checkpoint_path: str,
117+
on_gpu: bool,
118+
) -> Dict[str, Any]:
119+
"""Restore all states from checkpoint in the specified path.
110120
111-
def restore_model_state(self, model: LightningModule, checkpoint) -> None:
112-
"""
113-
Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object
121+
Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore.
122+
All restored states are listed in return value description of `dump_checkpoint`.
123+
124+
Args:
125+
on_gpu: Whether trainer is on GPU or not.
114126
"""
127+
# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
128+
checkpoint: Dict[str, Any] = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
115129

116-
# restore datamodule states
130+
# restore states
117131
if self.trainer.datamodule is not None:
118132
self.trainer.datamodule.on_load_checkpoint(checkpoint)
133+
self.restore_model_state(checkpoint, model, on_gpu)
134+
self.restore_training_state(checkpoint)
135+
136+
return checkpoint
119137

138+
def restore_model_state(
139+
self,
140+
checkpoint: Dict[str, Any],
141+
model: LightningModule,
142+
on_gpu: bool,
143+
) -> None:
144+
"""Restore model state.
145+
"""
120146
# hook: give user access to checkpoint if needed.
121147
model.on_load_checkpoint(checkpoint)
122148

123149
# restore model state_dict
124150
model.load_state_dict(checkpoint['state_dict'])
125151

126-
def restore_training_state(self, checkpoint):
127-
"""
128-
Restore trainer state.
152+
# moves the model to the GPU
153+
if on_gpu:
154+
model.cuda(self.trainer.root_gpu)
155+
156+
def restore_training_state(self, checkpoint: Dict[str, Any]) -> None:
157+
"""Restore trainer state.
158+
129159
Model will get its change to update
130160
:param checkpoint:
131161
:return:
@@ -329,30 +359,6 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
329359

330360
return checkpoint
331361

332-
def hpc_load(self, checkpoint_path: str, on_gpu: bool):
333-
"""
334-
Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc.
335-
All restored states are listed in return value description of `dump_checkpoint`.
336-
"""
337-
338-
# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
339-
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
340-
341-
# acquire the model
342-
model = self.trainer.get_model()
343-
344-
# restore model and datamodule state
345-
self.restore_model_state(model, checkpoint)
346-
347-
if self.trainer.root_gpu is not None:
348-
model.cuda(self.trainer.root_gpu)
349-
350-
# restore training state
351-
self.restore_training_state(checkpoint)
352-
353-
# call hpc specific hook
354-
model.on_hpc_load(checkpoint)
355-
356362
def max_ckpt_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Optional[int]:
357363
"""List up files in `dir_path` with `name_key`, then yield maximum suffix number.
358364

pytorch_lightning/trainer/training_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ def setup_training(self):
158158
if self.trainer.is_global_zero:
159159
ref_model.summarize(mode=self.trainer.weights_summary)
160160

161-
# restore training state and model weights before hpc is called
162-
self.trainer.checkpoint_connector.restore_weights()
161+
# restore model/training states before hpc is called
162+
self.trainer.checkpoint_connector.attempt_to_restore()
163163

164164
# on pretrain routine end
165165
self.trainer.on_pretrain_routine_end(ref_model)

pytorch_lightning/tuner/batch_size_scaling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,9 @@ def scale_batch_size(trainer,
113113
garbage_collection_cuda()
114114
log.info(f'Finished batch size finder, will continue with full run using batch size {new_size}')
115115

116-
# Restore initial state of model
116+
# Restore initial state of model from temporary checkpoint, which is deleted after restore.
117117
if trainer.is_global_zero:
118-
trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer._device_type == DeviceType.GPU)
118+
trainer.checkpoint_connector.restore_states(model, str(save_path), trainer._device_type == DeviceType.GPU)
119119
fs = get_filesystem(str(save_path))
120120
if fs.exists(save_path):
121121
fs.rm(save_path)

pytorch_lightning/tuner/lr_finder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,9 @@ def lr_find(
190190
'loss': trainer.callbacks[0].losses})
191191
lr_finder._total_batch_idx = trainer.total_batch_idx # for debug purpose
192192

193-
# Reset model state
193+
# Restore initial state of model from temporary checkpoint, which is deleted after restore.
194194
if trainer.is_global_zero:
195-
trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer._device_type == DeviceType.GPU)
195+
trainer.checkpoint_connector.restore_states(model, str(save_path), trainer._device_type == DeviceType.GPU)
196196
fs = get_filesystem(str(save_path))
197197
if fs.exists(save_path):
198198
fs.rm(save_path)

tests/base/develop_pipelines.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch
1515

1616
from pytorch_lightning import Trainer
17+
from pytorch_lightning.core.lightning import LightningModule
1718
from pytorch_lightning.trainer.states import TrainerState
1819
from pytorch_lightning.utilities import DistributedType
1920
from tests.base import BoringModel
@@ -50,7 +51,7 @@ def run_model_test_without_loggers(trainer_options, model, min_acc: float = 0.50
5051
trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers()
5152

5253

53-
def run_model_test(trainer_options, model, on_gpu: bool = True, version=None,
54+
def run_model_test(trainer_options, model: LightningModule, on_gpu: bool = True, version=None,
5455
with_hpc: bool = True, min_acc: float = 0.25):
5556

5657
reset_seed()
@@ -93,7 +94,8 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None,
9394
trainer.checkpoint_connector.hpc_save(save_dir, logger)
9495
# test HPC loading
9596
checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(save_dir)
96-
trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu)
97+
checkpoint = trainer.checkpoint_connector.restore_states(model, checkpoint_path, trainer.root_gpu)
98+
trainer.get_model().on_hpc_load(checkpoint)
9799

98100

99101
def run_prediction(trained_model, dataloader, dp=False, min_acc=0.25):

tests/models/data/horovod/train_default_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def run_test_from_config(trainer_options):
7979
trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger)
8080
# test HPC loading
8181
checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(ckpt_path)
82-
trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=args.on_gpu)
82+
checkpoint = trainer.checkpoint_connector.restore_states(model, checkpoint_path, trainer.root_gpu)
83+
trainer.get_model().on_hpc_load(checkpoint)
8384

8485
if args.on_gpu:
8586
trainer = Trainer(gpus=1, accelerator='horovod', max_epochs=1)

tests/models/test_restore.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import os
1717
import pickle
1818
from copy import deepcopy
19+
from pathlib import Path
20+
from typing import Optional
1921

2022
import cloudpickle
2123
import pytest
@@ -70,23 +72,28 @@ def test_model_properties_resume_from_checkpoint(tmpdir):
7072
trainer.fit(model)
7173

7274

73-
def test_try_resume_from_non_existing_checkpoint(tmpdir):
75+
def test_try_resume_from_non_existing_checkpoint(tmpdir: Path):
7476
""" Test that trying to resume from non-existing `resume_from_checkpoint` fail without error."""
7577
model = BoringModel()
76-
checkpoint_cb = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True)
77-
trainer = Trainer(
78-
default_root_dir=tmpdir,
79-
max_epochs=1,
80-
logger=False,
81-
callbacks=[checkpoint_cb],
82-
limit_train_batches=0.1,
83-
limit_val_batches=0.1,
84-
)
78+
79+
def gen_trainer(name_ckpt: Optional[str]) -> Trainer:
80+
path_dir_saved = tmpdir
81+
path_file_loaded = None if name_ckpt is None else str(tmpdir / name_ckpt)
82+
checkpoint_cb = ModelCheckpoint(dirpath=path_dir_saved, monitor="early_stop_on", save_last=True)
83+
return Trainer(
84+
resume_from_checkpoint=path_file_loaded,
85+
max_epochs=1,
86+
logger=False,
87+
callbacks=[checkpoint_cb],
88+
limit_train_batches=0.1,
89+
limit_val_batches=0.1,
90+
)
91+
8592
# Generate checkpoint `last.ckpt` with BoringModel
86-
trainer.fit(model)
93+
gen_trainer(None).fit(model)
8794
# `True` if resume/restore successfully else `False`
88-
assert trainer.checkpoint_connector.restore(str(tmpdir / "last.ckpt"), trainer.on_gpu)
89-
assert not trainer.checkpoint_connector.restore(str(tmpdir / "last_non_existing.ckpt"), trainer.on_gpu)
95+
assert gen_trainer("last.ckpt").checkpoint_connector.attempt_to_apply_checkpoint(model)
96+
assert not gen_trainer("last_non_existing.ckpt").checkpoint_connector.attempt_to_apply_checkpoint(model)
9097

9198

9299
class CaptureCallbacksBeforeTraining(Callback):

0 commit comments

Comments
 (0)