diff --git a/CHANGELOG.md b/CHANGELOG.md index 784a1581ee97a..db0353845e0f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -186,6 +186,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated default value of `monitor` argument in EarlyStopping callback to enforce `monitor` as a required argument ([#7907](https://github.com/PyTorchLightning/pytorch-lightning/pull/7907)) +- Deprecated the use of `CheckpointConnector.hpc_load()` in favor of `CheckpointConnector.restore()` ([#7652](https://github.com/PyTorchLightning/pytorch-lightning/pull/7652)) + + ### Removed - Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654)) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 6711ef3cb748e..8035f0c532764 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -21,9 +21,14 @@ import pytorch_lightning from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, DeviceType, rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities import ( + _OMEGACONF_AVAILABLE, + DeviceType, + rank_zero_deprecation, + rank_zero_info, + rank_zero_warn, +) from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem -from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS @@ -45,7 +50,7 @@ def hpc_resume_path(self) -> Optional[str]: dir_path_hpc = str(self.trainer.weights_save_path) max_version = self.max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_") if max_version is not None: - return f"{dir_path_hpc}/hpc_ckpt_{max_version}.ckpt" + return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt") def resume_start(self) -> None: """ @@ -129,6 +134,10 @@ def restore_model_state(self, model: LightningModule, checkpoint) -> None: # hook: give user access to checkpoint if needed. model.on_load_checkpoint(checkpoint) + # call hpc specific hook + if self.hpc_resume_path is not None: + model.on_hpc_load(self._loaded_checkpoint) + # restore model state_dict self.trainer.training_type_plugin.load_model_state_dict(checkpoint) @@ -248,6 +257,7 @@ def restore_lr_schedulers(self) -> None: # ---------------------------------- # PRIVATE OPS # ---------------------------------- + def hpc_save(self, folderpath: str, logger): # make sure the checkpoint folder exists folderpath = str(folderpath) # because the tests pass a path object @@ -365,29 +375,18 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: return checkpoint - def hpc_load(self, checkpoint_path: str, on_gpu: bool): - """ - Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc. - All restored states are listed in return value description of `dump_checkpoint`. + def hpc_load(self, checkpoint_path: str) -> None: """ + Attempts to restore the full training and model state from a HPC checkpoint file. - # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path` - checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) - - # acquire the model - model = self.trainer.lightning_module - - # restore model and datamodule state - self.restore_model_state(model, checkpoint) - - if self.trainer.root_gpu is not None: - model.cuda(self.trainer.root_gpu) - - # restore training state - self.restore_training_state(checkpoint) - - # call hpc specific hook - model.on_hpc_load(checkpoint) + .. deprecated::v1.4 + Will be removed in v1.6. Use :meth:`restore` instead. + """ + rank_zero_deprecation( + "`CheckpointConnector.hpc_load()` was deprecated in v1.4 and will be removed in v1.6." + " Use `CheckpointConnector.restore()` instead." + ) + self.restore(checkpoint_path) def max_ckpt_version_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Optional[int]: """List up files in `dir_path` with `name_key`, then yield maximum suffix number. diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index 37d8abfdf905d..23df12586d328 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -66,3 +66,16 @@ def training_step(self, batch, batch_idx): with pytest.deprecated_call(match=r"Relying on.*is deprecated in v1.2 and will be removed in v1.4"): trainer.fit(TestModel()) + + +def test_v1_4_0_deprecated_hpc_load(tmpdir): + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + ) + trainer.fit(model) + trainer.checkpoint_connector.hpc_save(tmpdir, trainer.logger) + checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(str(tmpdir)) + with pytest.deprecated_call(match=r"`CheckpointConnector.hpc_load\(\)` was deprecated in v1.4"): + trainer.checkpoint_connector.hpc_load(checkpoint_path) diff --git a/tests/helpers/pipelines.py b/tests/helpers/pipelines.py index f7a6484f6b27e..02a9e2dd0cfb2 100644 --- a/tests/helpers/pipelines.py +++ b/tests/helpers/pipelines.py @@ -91,7 +91,7 @@ def run_model_test( trainer.checkpoint_connector.hpc_save(save_dir, logger) # test HPC loading checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(save_dir) - trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu) + trainer.checkpoint_connector.restore(checkpoint_path) @torch.no_grad() diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index ed0d33f5e8c82..c4cbaeb1363c9 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -87,7 +87,7 @@ def training_epoch_end(self, outputs) -> None: trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger) # test HPC loading checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(ckpt_path) - trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu) + trainer.checkpoint_connector.restore(checkpoint_path) if on_gpu: trainer = Trainer(gpus=1, accelerator='horovod', max_epochs=1) diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 34149e2231bf5..501482d77a240 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging from unittest.mock import Mock diff --git a/tests/trainer/connectors/test_checkpoint_connector.py b/tests/trainer/connectors/test_checkpoint_connector.py new file mode 100644 index 0000000000000..6e152f5944b59 --- /dev/null +++ b/tests/trainer/connectors/test_checkpoint_connector.py @@ -0,0 +1,155 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest.mock import Mock + +import torch + +from pytorch_lightning import Trainer +from tests.helpers import BoringModel + + +class HPCHookdedModel(BoringModel): + + def __init__(self): + super().__init__() + self.hpc_save_called = 0 + self.hpc_load_called = 0 + + def on_hpc_save(self, checkpoint): + assert "state_dict" in checkpoint + self.hpc_save_called += 1 + + def on_hpc_load(self, checkpoint): + assert "state_dict" in checkpoint + self.hpc_load_called += 1 + + +def test_hpc_hook_calls(tmpdir): + model = HPCHookdedModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + checkpoint_callback=False, + logger=False, + ) + trainer.fit(model) + connector = trainer.checkpoint_connector + connector.hpc_save(tmpdir, logger=Mock()) + assert model.hpc_save_called == 1 + assert model.hpc_load_called == 0 + + # new training run, restore from hpc checkpoint file automatically + assert set(os.listdir(tmpdir)) == {"hpc_ckpt_1.ckpt"} + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + checkpoint_callback=False, + logger=False, + ) + trainer.fit(model) + assert model.hpc_save_called == 1 + assert model.hpc_load_called == 1 + + +def test_preloaded_checkpoint_lifecycle(tmpdir): + """ Tests that the preloaded checkpoint contents gets cleared from memory when it is not required anymore. """ + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + ) + trainer.fit(model) + + connector = trainer.checkpoint_connector + + assert not trainer.resume_from_checkpoint + assert not connector.resume_checkpoint_path + assert not connector._loaded_checkpoint + + connector.resume_start() + assert not connector.resume_checkpoint_path + assert not connector._loaded_checkpoint + connector.resume_end() + assert not connector.resume_checkpoint_path + assert not connector._loaded_checkpoint + + ckpt_path = trainer.checkpoint_callback.best_model_path + trainer = Trainer(default_root_dir=tmpdir, max_steps=2, resume_from_checkpoint=ckpt_path) + connector = trainer.checkpoint_connector + connector.resume_start() + assert connector.resume_checkpoint_path == ckpt_path + assert connector._loaded_checkpoint + assert isinstance(connector._loaded_checkpoint, dict) + connector.resume_end() + assert not connector.resume_checkpoint_path + assert not connector._loaded_checkpoint + + +def test_hpc_restore_attempt(tmpdir): + """ Test that restore() attempts to restore the hpc_ckpt with highest priority. """ + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + checkpoint_callback=False, + logger=False, + ) + trainer.fit(model) + + hpc_ckpt_path = tmpdir / "hpc_ckpt_3.ckpt" + trainer.save_checkpoint(hpc_ckpt_path) + assert os.listdir(tmpdir) == ["hpc_ckpt_3.ckpt"] + + # set weights to zero + for param in model.parameters(): + torch.nn.init.constant_(param, 0) + + # case 1: restore hpc first, no explicit resume path provided + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=2, + checkpoint_callback=False, + logger=False, + ) + trainer.fit(model) + + for param in model.parameters(): + assert param.abs().sum() > 0 + torch.nn.init.constant_(param, 0) + + # case 2: explicit resume path provided, restore hpc anyway + trainer = Trainer(default_root_dir=tmpdir, max_steps=3, resume_from_checkpoint="not existing") + trainer.fit(model) + + for param in model.parameters(): + assert param.abs().sum() > 0 + + +def test_hpc_max_ckpt_version(tmpdir): + """ Test that the CheckpointConnector is able to find the hpc checkpoint file with the highest version. """ + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + ) + trainer.fit(model) + trainer.save_checkpoint(tmpdir / "hpc_ckpt.ckpt") + trainer.save_checkpoint(tmpdir / "hpc_ckpt_0.ckpt") + trainer.save_checkpoint(tmpdir / "hpc_ckpt_3.ckpt") + trainer.save_checkpoint(tmpdir / "hpc_ckpt_33.ckpt") + + assert trainer.checkpoint_connector.hpc_resume_path == str(tmpdir / "hpc_ckpt_33.ckpt") + assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir) == 33 + assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir / "not" / "existing") is None