Skip to content

Commit 6b7b404

Browse files
awaelchlipre-commit-ci[bot]carmocca
authored
deprecate hpc_load() and integrate it with restore() (#7955)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 20a5e09 commit 6b7b404

File tree

7 files changed

+209
-26
lines changed

7 files changed

+209
-26
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
186186
- Deprecated default value of `monitor` argument in EarlyStopping callback to enforce `monitor` as a required argument ([#7907](https://github.com/PyTorchLightning/pytorch-lightning/pull/7907))
187187

188188

189+
- Deprecated the use of `CheckpointConnector.hpc_load()` in favor of `CheckpointConnector.restore()` ([#7652](https://github.com/PyTorchLightning/pytorch-lightning/pull/7652))
190+
191+
189192
### Removed
190193

191194
- Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654))

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,14 @@
2121

2222
import pytorch_lightning
2323
from pytorch_lightning.core.lightning import LightningModule
24-
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, DeviceType, rank_zero_info, rank_zero_warn
24+
from pytorch_lightning.utilities import (
25+
_OMEGACONF_AVAILABLE,
26+
DeviceType,
27+
rank_zero_deprecation,
28+
rank_zero_info,
29+
rank_zero_warn,
30+
)
2531
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
26-
from pytorch_lightning.utilities.cloud_io import load as pl_load
2732
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2833
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
2934

@@ -45,7 +50,7 @@ def hpc_resume_path(self) -> Optional[str]:
4550
dir_path_hpc = str(self.trainer.weights_save_path)
4651
max_version = self.max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_")
4752
if max_version is not None:
48-
return f"{dir_path_hpc}/hpc_ckpt_{max_version}.ckpt"
53+
return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt")
4954

5055
def resume_start(self) -> None:
5156
"""
@@ -129,6 +134,10 @@ def restore_model_state(self, model: LightningModule, checkpoint) -> None:
129134
# hook: give user access to checkpoint if needed.
130135
model.on_load_checkpoint(checkpoint)
131136

137+
# call hpc specific hook
138+
if self.hpc_resume_path is not None:
139+
model.on_hpc_load(self._loaded_checkpoint)
140+
132141
# restore model state_dict
133142
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)
134143

@@ -248,6 +257,7 @@ def restore_lr_schedulers(self) -> None:
248257
# ----------------------------------
249258
# PRIVATE OPS
250259
# ----------------------------------
260+
251261
def hpc_save(self, folderpath: str, logger):
252262
# make sure the checkpoint folder exists
253263
folderpath = str(folderpath) # because the tests pass a path object
@@ -365,29 +375,18 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
365375

366376
return checkpoint
367377

368-
def hpc_load(self, checkpoint_path: str, on_gpu: bool):
369-
"""
370-
Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc.
371-
All restored states are listed in return value description of `dump_checkpoint`.
378+
def hpc_load(self, checkpoint_path: str) -> None:
372379
"""
380+
Attempts to restore the full training and model state from a HPC checkpoint file.
373381
374-
# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
375-
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
376-
377-
# acquire the model
378-
model = self.trainer.lightning_module
379-
380-
# restore model and datamodule state
381-
self.restore_model_state(model, checkpoint)
382-
383-
if self.trainer.root_gpu is not None:
384-
model.cuda(self.trainer.root_gpu)
385-
386-
# restore training state
387-
self.restore_training_state(checkpoint)
388-
389-
# call hpc specific hook
390-
model.on_hpc_load(checkpoint)
382+
.. deprecated::v1.4
383+
Will be removed in v1.6. Use :meth:`restore` instead.
384+
"""
385+
rank_zero_deprecation(
386+
"`CheckpointConnector.hpc_load()` was deprecated in v1.4 and will be removed in v1.6."
387+
" Use `CheckpointConnector.restore()` instead."
388+
)
389+
self.restore(checkpoint_path)
391390

392391
def max_ckpt_version_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Optional[int]:
393392
"""List up files in `dir_path` with `name_key`, then yield maximum suffix number.

tests/deprecated_api/test_remove_1-4.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,16 @@ def training_step(self, batch, batch_idx):
6666

6767
with pytest.deprecated_call(match=r"Relying on.*is deprecated in v1.2 and will be removed in v1.4"):
6868
trainer.fit(TestModel())
69+
70+
71+
def test_v1_4_0_deprecated_hpc_load(tmpdir):
72+
model = BoringModel()
73+
trainer = Trainer(
74+
default_root_dir=tmpdir,
75+
max_steps=1,
76+
)
77+
trainer.fit(model)
78+
trainer.checkpoint_connector.hpc_save(tmpdir, trainer.logger)
79+
checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(str(tmpdir))
80+
with pytest.deprecated_call(match=r"`CheckpointConnector.hpc_load\(\)` was deprecated in v1.4"):
81+
trainer.checkpoint_connector.hpc_load(checkpoint_path)

tests/helpers/pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def run_model_test(
9191
trainer.checkpoint_connector.hpc_save(save_dir, logger)
9292
# test HPC loading
9393
checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(save_dir)
94-
trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu)
94+
trainer.checkpoint_connector.restore(checkpoint_path)
9595

9696

9797
@torch.no_grad()

tests/models/data/horovod/train_default_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def training_epoch_end(self, outputs) -> None:
8787
trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger)
8888
# test HPC loading
8989
checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(ckpt_path)
90-
trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu)
90+
trainer.checkpoint_connector.restore(checkpoint_path)
9191

9292
if on_gpu:
9393
trainer = Trainer(gpus=1, accelerator='horovod', max_epochs=1)

tests/trainer/connectors/test_callback_connector.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
114
import logging
215
from unittest.mock import Mock
316

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
from unittest.mock import Mock
16+
17+
import torch
18+
19+
from pytorch_lightning import Trainer
20+
from tests.helpers import BoringModel
21+
22+
23+
class HPCHookdedModel(BoringModel):
24+
25+
def __init__(self):
26+
super().__init__()
27+
self.hpc_save_called = 0
28+
self.hpc_load_called = 0
29+
30+
def on_hpc_save(self, checkpoint):
31+
assert "state_dict" in checkpoint
32+
self.hpc_save_called += 1
33+
34+
def on_hpc_load(self, checkpoint):
35+
assert "state_dict" in checkpoint
36+
self.hpc_load_called += 1
37+
38+
39+
def test_hpc_hook_calls(tmpdir):
40+
model = HPCHookdedModel()
41+
trainer = Trainer(
42+
default_root_dir=tmpdir,
43+
max_steps=1,
44+
checkpoint_callback=False,
45+
logger=False,
46+
)
47+
trainer.fit(model)
48+
connector = trainer.checkpoint_connector
49+
connector.hpc_save(tmpdir, logger=Mock())
50+
assert model.hpc_save_called == 1
51+
assert model.hpc_load_called == 0
52+
53+
# new training run, restore from hpc checkpoint file automatically
54+
assert set(os.listdir(tmpdir)) == {"hpc_ckpt_1.ckpt"}
55+
trainer = Trainer(
56+
default_root_dir=tmpdir,
57+
max_steps=1,
58+
checkpoint_callback=False,
59+
logger=False,
60+
)
61+
trainer.fit(model)
62+
assert model.hpc_save_called == 1
63+
assert model.hpc_load_called == 1
64+
65+
66+
def test_preloaded_checkpoint_lifecycle(tmpdir):
67+
""" Tests that the preloaded checkpoint contents gets cleared from memory when it is not required anymore. """
68+
model = BoringModel()
69+
trainer = Trainer(
70+
default_root_dir=tmpdir,
71+
max_steps=1,
72+
)
73+
trainer.fit(model)
74+
75+
connector = trainer.checkpoint_connector
76+
77+
assert not trainer.resume_from_checkpoint
78+
assert not connector.resume_checkpoint_path
79+
assert not connector._loaded_checkpoint
80+
81+
connector.resume_start()
82+
assert not connector.resume_checkpoint_path
83+
assert not connector._loaded_checkpoint
84+
connector.resume_end()
85+
assert not connector.resume_checkpoint_path
86+
assert not connector._loaded_checkpoint
87+
88+
ckpt_path = trainer.checkpoint_callback.best_model_path
89+
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, resume_from_checkpoint=ckpt_path)
90+
connector = trainer.checkpoint_connector
91+
connector.resume_start()
92+
assert connector.resume_checkpoint_path == ckpt_path
93+
assert connector._loaded_checkpoint
94+
assert isinstance(connector._loaded_checkpoint, dict)
95+
connector.resume_end()
96+
assert not connector.resume_checkpoint_path
97+
assert not connector._loaded_checkpoint
98+
99+
100+
def test_hpc_restore_attempt(tmpdir):
101+
""" Test that restore() attempts to restore the hpc_ckpt with highest priority. """
102+
model = BoringModel()
103+
trainer = Trainer(
104+
default_root_dir=tmpdir,
105+
max_steps=1,
106+
checkpoint_callback=False,
107+
logger=False,
108+
)
109+
trainer.fit(model)
110+
111+
hpc_ckpt_path = tmpdir / "hpc_ckpt_3.ckpt"
112+
trainer.save_checkpoint(hpc_ckpt_path)
113+
assert os.listdir(tmpdir) == ["hpc_ckpt_3.ckpt"]
114+
115+
# set weights to zero
116+
for param in model.parameters():
117+
torch.nn.init.constant_(param, 0)
118+
119+
# case 1: restore hpc first, no explicit resume path provided
120+
trainer = Trainer(
121+
default_root_dir=tmpdir,
122+
max_steps=2,
123+
checkpoint_callback=False,
124+
logger=False,
125+
)
126+
trainer.fit(model)
127+
128+
for param in model.parameters():
129+
assert param.abs().sum() > 0
130+
torch.nn.init.constant_(param, 0)
131+
132+
# case 2: explicit resume path provided, restore hpc anyway
133+
trainer = Trainer(default_root_dir=tmpdir, max_steps=3, resume_from_checkpoint="not existing")
134+
trainer.fit(model)
135+
136+
for param in model.parameters():
137+
assert param.abs().sum() > 0
138+
139+
140+
def test_hpc_max_ckpt_version(tmpdir):
141+
""" Test that the CheckpointConnector is able to find the hpc checkpoint file with the highest version. """
142+
model = BoringModel()
143+
trainer = Trainer(
144+
default_root_dir=tmpdir,
145+
max_steps=1,
146+
)
147+
trainer.fit(model)
148+
trainer.save_checkpoint(tmpdir / "hpc_ckpt.ckpt")
149+
trainer.save_checkpoint(tmpdir / "hpc_ckpt_0.ckpt")
150+
trainer.save_checkpoint(tmpdir / "hpc_ckpt_3.ckpt")
151+
trainer.save_checkpoint(tmpdir / "hpc_ckpt_33.ckpt")
152+
153+
assert trainer.checkpoint_connector.hpc_resume_path == str(tmpdir / "hpc_ckpt_33.ckpt")
154+
assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir) == 33
155+
assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir / "not" / "existing") is None

0 commit comments

Comments
 (0)