From a2560f166c3c89ddeaf702302dadcc010c24c7f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 14:32:13 +0200 Subject: [PATCH 01/14] deprecate --- CHANGELOG.md | 3 ++ .../connectors/checkpoint_connector.py | 40 +++++++------------ tests/deprecated_api/test_remove_1-4.py | 13 ++++++ tests/helpers/pipelines.py | 2 +- .../data/horovod/train_default_model.py | 2 +- 5 files changed, 33 insertions(+), 27 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e9bb747b70542..c075a28c8803e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -190,6 +190,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated default value of `monitor` argument in EarlyStopping callback to enforce `monitor` as a required argument ([#7907](https://github.com/PyTorchLightning/pytorch-lightning/pull/7907)) +- Deprecated the use of `CheckpointConnector.hpc_load()` in favor of `CheckpointConnector.restore()` ([#7652](https://github.com/PyTorchLightning/pytorch-lightning/pull/7652)) + + ### Removed - Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654)) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 6711ef3cb748e..dca163be14e51 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -21,7 +21,8 @@ import pytorch_lightning from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, DeviceType, rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, DeviceType, rank_zero_info, rank_zero_warn, \ + rank_zero_deprecation from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -248,6 +249,19 @@ def restore_lr_schedulers(self) -> None: # ---------------------------------- # PRIVATE OPS # ---------------------------------- + def hpc_load(self, checkpoint_path: str): + """ + Attempts to restore the full training and model state from a HPC checkpoint file. + .. deprecated:: + `CheckpointConnector.hpc_load` was deprecated in v1.4 and will be removed in v1.6. + Use `CheckpointConnector.restore` instead. + """ + rank_zero_deprecation( + "`CheckpointConnector.hpc_load()` was deprecated in v1.4 and will be removed in v1.6." + " Use `CheckpointConnector.restore()` instead." + ) + self.restore(checkpoint_path) + def hpc_save(self, folderpath: str, logger): # make sure the checkpoint folder exists folderpath = str(folderpath) # because the tests pass a path object @@ -365,30 +379,6 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: return checkpoint - def hpc_load(self, checkpoint_path: str, on_gpu: bool): - """ - Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc. - All restored states are listed in return value description of `dump_checkpoint`. - """ - - # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path` - checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) - - # acquire the model - model = self.trainer.lightning_module - - # restore model and datamodule state - self.restore_model_state(model, checkpoint) - - if self.trainer.root_gpu is not None: - model.cuda(self.trainer.root_gpu) - - # restore training state - self.restore_training_state(checkpoint) - - # call hpc specific hook - model.on_hpc_load(checkpoint) - def max_ckpt_version_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Optional[int]: """List up files in `dir_path` with `name_key`, then yield maximum suffix number. Args: diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index 37d8abfdf905d..23df12586d328 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -66,3 +66,16 @@ def training_step(self, batch, batch_idx): with pytest.deprecated_call(match=r"Relying on.*is deprecated in v1.2 and will be removed in v1.4"): trainer.fit(TestModel()) + + +def test_v1_4_0_deprecated_hpc_load(tmpdir): + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + ) + trainer.fit(model) + trainer.checkpoint_connector.hpc_save(tmpdir, trainer.logger) + checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(str(tmpdir)) + with pytest.deprecated_call(match=r"`CheckpointConnector.hpc_load\(\)` was deprecated in v1.4"): + trainer.checkpoint_connector.hpc_load(checkpoint_path) diff --git a/tests/helpers/pipelines.py b/tests/helpers/pipelines.py index f7a6484f6b27e..02a9e2dd0cfb2 100644 --- a/tests/helpers/pipelines.py +++ b/tests/helpers/pipelines.py @@ -91,7 +91,7 @@ def run_model_test( trainer.checkpoint_connector.hpc_save(save_dir, logger) # test HPC loading checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(save_dir) - trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu) + trainer.checkpoint_connector.restore(checkpoint_path) @torch.no_grad() diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index ed0d33f5e8c82..c4cbaeb1363c9 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -87,7 +87,7 @@ def training_epoch_end(self, outputs) -> None: trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger) # test HPC loading checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(ckpt_path) - trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu) + trainer.checkpoint_connector.restore(checkpoint_path) if on_gpu: trainer = Trainer(gpus=1, accelerator='horovod', max_epochs=1) From 88f201593c02a93cb0170327da5acbc96564534a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 14:33:40 +0200 Subject: [PATCH 02/14] test --- .../connectors/test_callback_connector.py | 13 +++ .../connectors/test_checkpoint_connector.py | 107 ++++++++++++++++++ 2 files changed, 120 insertions(+) create mode 100644 tests/trainer/connectors/test_checkpoint_connector.py diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 34149e2231bf5..501482d77a240 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging from unittest.mock import Mock diff --git a/tests/trainer/connectors/test_checkpoint_connector.py b/tests/trainer/connectors/test_checkpoint_connector.py new file mode 100644 index 0000000000000..03982a474d02f --- /dev/null +++ b/tests/trainer/connectors/test_checkpoint_connector.py @@ -0,0 +1,107 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path + +import torch + +from pytorch_lightning import Trainer +from tests.helpers import BoringModel + + +def test_preloaded_checkpoint_lifecycle(tmpdir): + """ Tests that the preloaded checkpoint contents gets cleared from memory when it is not required anymore. """ + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + ) + trainer.fit(model) + + connector = trainer.checkpoint_connector + + assert not trainer.resume_from_checkpoint + assert not connector.resume_checkpoint_path + assert not connector._loaded_checkpoint + + connector.resume_start() + assert not connector.resume_checkpoint_path + assert not connector._loaded_checkpoint + connector.resume_end() + assert not connector.resume_checkpoint_path + assert not connector._loaded_checkpoint + + ckpt_path = trainer.checkpoint_callback.best_model_path + trainer = Trainer(default_root_dir=tmpdir, max_steps=2, resume_from_checkpoint=ckpt_path) + connector = trainer.checkpoint_connector + connector.resume_start() + assert connector.resume_checkpoint_path == ckpt_path + assert connector._loaded_checkpoint + assert isinstance(connector._loaded_checkpoint, dict) + connector.resume_end() + assert not connector.resume_checkpoint_path + assert not connector._loaded_checkpoint + + +def test_hpc_restore_attempt(tmpdir): + """ Test that restore() attempts to restore the hpc_ckpt with highest priority. """ + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + ) + trainer.fit(model) + + hpc_ckpt_path = tmpdir / "hpc_ckpt_3.ckpt" + trainer.save_checkpoint(hpc_ckpt_path) + assert Path(hpc_ckpt_path).exists() + + # set weights to zero + for param in model.parameters(): + torch.nn.init.constant_(param, 0) + + # case 1: restore hpc first, no explicit resume path provided + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=2, + ) + trainer.fit(model) + + for param in model.parameters(): + assert param.abs().sum() > 0 + torch.nn.init.constant_(param, 0) + + # case 2: explicit resume path provided, restore hpc anyway + trainer = Trainer(default_root_dir=tmpdir, max_steps=2, resume_from_checkpoint="not existing") + trainer.fit(model) + + for param in model.parameters(): + assert param.abs().sum() > 0 + + +def test_hpc_max_ckpt_version(tmpdir): + """ Test that the CheckpointConnector is able to find the hpc checkpoint file with the highest version. """ + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + ) + trainer.fit(model) + trainer.save_checkpoint(tmpdir / "hpc_ckpt.ckpt") + trainer.save_checkpoint(tmpdir / "hpc_ckpt_0.ckpt") + trainer.save_checkpoint(tmpdir / "hpc_ckpt_3.ckpt") + trainer.save_checkpoint(tmpdir / "hpc_ckpt_33.ckpt") + + assert trainer.checkpoint_connector.hpc_resume_path == tmpdir / "hpc_ckpt_33.ckpt" + assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir) == 33 + assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir / "not" / "existing") is None \ No newline at end of file From b0c0b0709beb916f26b947900ac3d91466ea54d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 14:44:50 +0200 Subject: [PATCH 03/14] tests --- .../connectors/test_checkpoint_connector.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/trainer/connectors/test_checkpoint_connector.py b/tests/trainer/connectors/test_checkpoint_connector.py index 03982a474d02f..fc21e132aeb31 100644 --- a/tests/trainer/connectors/test_checkpoint_connector.py +++ b/tests/trainer/connectors/test_checkpoint_connector.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from pathlib import Path import torch @@ -59,12 +60,14 @@ def test_hpc_restore_attempt(tmpdir): trainer = Trainer( default_root_dir=tmpdir, max_steps=1, + checkpoint_callback=False, + logger=False, ) trainer.fit(model) hpc_ckpt_path = tmpdir / "hpc_ckpt_3.ckpt" trainer.save_checkpoint(hpc_ckpt_path) - assert Path(hpc_ckpt_path).exists() + assert os.listdir(tmpdir) == ["hpc_ckpt_3.ckpt"] # set weights to zero for param in model.parameters(): @@ -74,6 +77,8 @@ def test_hpc_restore_attempt(tmpdir): trainer = Trainer( default_root_dir=tmpdir, max_steps=2, + checkpoint_callback=False, + logger=False, ) trainer.fit(model) @@ -82,7 +87,11 @@ def test_hpc_restore_attempt(tmpdir): torch.nn.init.constant_(param, 0) # case 2: explicit resume path provided, restore hpc anyway - trainer = Trainer(default_root_dir=tmpdir, max_steps=2, resume_from_checkpoint="not existing") + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=3, + resume_from_checkpoint="not existing" + ) trainer.fit(model) for param in model.parameters(): @@ -104,4 +113,4 @@ def test_hpc_max_ckpt_version(tmpdir): assert trainer.checkpoint_connector.hpc_resume_path == tmpdir / "hpc_ckpt_33.ckpt" assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir) == 33 - assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir / "not" / "existing") is None \ No newline at end of file + assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir / "not" / "existing") is None From 0f17119af300a49faeb59bf5a13f1da04e8c67e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 14:46:21 +0200 Subject: [PATCH 04/14] ypf --- .../trainer/connectors/checkpoint_connector.py | 9 +++++++-- tests/trainer/connectors/test_checkpoint_connector.py | 7 +------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index dca163be14e51..f203b28c09048 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -21,8 +21,13 @@ import pytorch_lightning from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, DeviceType, rank_zero_info, rank_zero_warn, \ - rank_zero_deprecation +from pytorch_lightning.utilities import ( + _OMEGACONF_AVAILABLE, + DeviceType, + rank_zero_deprecation, + rank_zero_info, + rank_zero_warn, +) from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/tests/trainer/connectors/test_checkpoint_connector.py b/tests/trainer/connectors/test_checkpoint_connector.py index fc21e132aeb31..a8c18e126733e 100644 --- a/tests/trainer/connectors/test_checkpoint_connector.py +++ b/tests/trainer/connectors/test_checkpoint_connector.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from pathlib import Path import torch @@ -87,11 +86,7 @@ def test_hpc_restore_attempt(tmpdir): torch.nn.init.constant_(param, 0) # case 2: explicit resume path provided, restore hpc anyway - trainer = Trainer( - default_root_dir=tmpdir, - max_steps=3, - resume_from_checkpoint="not existing" - ) + trainer = Trainer(default_root_dir=tmpdir, max_steps=3, resume_from_checkpoint="not existing") trainer.fit(model) for param in model.parameters(): From f62cd51bbadf44ff286f553b6655eca4c2af9b09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 15:35:13 +0200 Subject: [PATCH 05/14] test hook calls --- .../connectors/checkpoint_connector.py | 5 +++ .../connectors/test_checkpoint_connector.py | 41 +++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index f203b28c09048..9581f9e76be20 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -135,6 +135,10 @@ def restore_model_state(self, model: LightningModule, checkpoint) -> None: # hook: give user access to checkpoint if needed. model.on_load_checkpoint(checkpoint) + # call hpc specific hook + if self.hpc_resume_path is not None: + model.on_hpc_load(self._loaded_checkpoint) + # restore model state_dict self.trainer.training_type_plugin.load_model_state_dict(checkpoint) @@ -366,6 +370,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: self.trainer.precision_plugin.on_save_checkpoint(checkpoint) + # dump hyper-parameters # dump hyper-parameters if model.hparams: if hasattr(model, '_hparams_name'): diff --git a/tests/trainer/connectors/test_checkpoint_connector.py b/tests/trainer/connectors/test_checkpoint_connector.py index a8c18e126733e..c2e3d56791cf2 100644 --- a/tests/trainer/connectors/test_checkpoint_connector.py +++ b/tests/trainer/connectors/test_checkpoint_connector.py @@ -19,6 +19,47 @@ from tests.helpers import BoringModel +class HPCHookdedModel(BoringModel): + + def __init__(self): + super().__init__() + self.hpc_save_called = 0 + self.hpc_load_called = 0 + + def on_hpc_save(self, checkpoint): + assert "state_dict" in checkpoint + self.hpc_save_called += 1 + + def on_hpc_load(self, checkpoint): + assert "state_dict" in checkpoint + self.hpc_load_called += 1 + + +def test_hpc_hook_calls(tmpdir): + model = HPCHookdedModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + checkpoint_callback=False, + ) + trainer.fit(model) + connector = trainer.checkpoint_connector + connector.hpc_save(tmpdir, logger=trainer.logger) + assert model.hpc_save_called == 1 + assert model.hpc_load_called == 0 + + # new training run, restore from hpc checkpoint file automatically + assert set(os.listdir(tmpdir)) == {"hpc_ckpt_1.ckpt", "lightning_logs"} + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + checkpoint_callback=False, + ) + trainer.fit(model) + assert model.hpc_save_called == 1 + assert model.hpc_load_called == 1 + + def test_preloaded_checkpoint_lifecycle(tmpdir): """ Tests that the preloaded checkpoint contents gets cleared from memory when it is not required anymore. """ model = BoringModel() From 09dd67dbfd90308f8732a637b04a89351337ed89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 15:35:53 +0200 Subject: [PATCH 06/14] space --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 9581f9e76be20..ed00cc6c9b7ca 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -261,6 +261,7 @@ def restore_lr_schedulers(self) -> None: def hpc_load(self, checkpoint_path: str): """ Attempts to restore the full training and model state from a HPC checkpoint file. + .. deprecated:: `CheckpointConnector.hpc_load` was deprecated in v1.4 and will be removed in v1.6. Use `CheckpointConnector.restore` instead. From 1dea5be6c3d1c2efebd116595a754e2f4df335b5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 12 Jun 2021 13:37:19 +0000 Subject: [PATCH 07/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index ed00cc6c9b7ca..24424f2c19201 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -261,7 +261,7 @@ def restore_lr_schedulers(self) -> None: def hpc_load(self, checkpoint_path: str): """ Attempts to restore the full training and model state from a HPC checkpoint file. - + .. deprecated:: `CheckpointConnector.hpc_load` was deprecated in v1.4 and will be removed in v1.6. Use `CheckpointConnector.restore` instead. From 1a2c6e4dafa4c18fbbc3f7bb3c85013689619588 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 15:54:26 +0200 Subject: [PATCH 08/14] unused import --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 24424f2c19201..ed3c97dde80cb 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -29,7 +29,6 @@ rank_zero_warn, ) from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem -from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS From c4e6786086072e3b0de1d9aa0141073f47d0d02e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 12 Jun 2021 19:45:09 +0200 Subject: [PATCH 09/14] make windows test good --- tests/trainer/connectors/test_checkpoint_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/connectors/test_checkpoint_connector.py b/tests/trainer/connectors/test_checkpoint_connector.py index c2e3d56791cf2..575186f21d0b4 100644 --- a/tests/trainer/connectors/test_checkpoint_connector.py +++ b/tests/trainer/connectors/test_checkpoint_connector.py @@ -147,6 +147,6 @@ def test_hpc_max_ckpt_version(tmpdir): trainer.save_checkpoint(tmpdir / "hpc_ckpt_3.ckpt") trainer.save_checkpoint(tmpdir / "hpc_ckpt_33.ckpt") - assert trainer.checkpoint_connector.hpc_resume_path == tmpdir / "hpc_ckpt_33.ckpt" + assert trainer.checkpoint_connector.hpc_resume_path == str(tmpdir / "hpc_ckpt_33.ckpt") assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir) == 33 assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir / "not" / "existing") is None From 37766ca3240706b82a5be474f02900648b8128cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 13 Jun 2021 19:08:00 +0200 Subject: [PATCH 10/14] Update pytorch_lightning/trainer/connectors/checkpoint_connector.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index ed3c97dde80cb..299fd864500c7 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -370,7 +370,6 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: self.trainer.precision_plugin.on_save_checkpoint(checkpoint) - # dump hyper-parameters # dump hyper-parameters if model.hparams: if hasattr(model, '_hparams_name'): From 631d0eb71c96afccc9ed6fda25f0c60bfa80f971 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 13 Jun 2021 19:44:58 +0200 Subject: [PATCH 11/14] join os path --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 299fd864500c7..59346c30dfd6c 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -50,7 +50,7 @@ def hpc_resume_path(self) -> Optional[str]: dir_path_hpc = str(self.trainer.weights_save_path) max_version = self.max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_") if max_version is not None: - return f"{dir_path_hpc}/hpc_ckpt_{max_version}.ckpt" + return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt") def resume_start(self) -> None: """ From fe4632b4e2d445a4ada05409c52f6c6c20ea8386 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 13 Jun 2021 19:47:22 +0200 Subject: [PATCH 12/14] move --- .../connectors/checkpoint_connector.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 59346c30dfd6c..5bef4b198b4b6 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -257,19 +257,6 @@ def restore_lr_schedulers(self) -> None: # ---------------------------------- # PRIVATE OPS # ---------------------------------- - def hpc_load(self, checkpoint_path: str): - """ - Attempts to restore the full training and model state from a HPC checkpoint file. - - .. deprecated:: - `CheckpointConnector.hpc_load` was deprecated in v1.4 and will be removed in v1.6. - Use `CheckpointConnector.restore` instead. - """ - rank_zero_deprecation( - "`CheckpointConnector.hpc_load()` was deprecated in v1.4 and will be removed in v1.6." - " Use `CheckpointConnector.restore()` instead." - ) - self.restore(checkpoint_path) def hpc_save(self, folderpath: str, logger): # make sure the checkpoint folder exists @@ -388,6 +375,20 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: return checkpoint + def hpc_load(self, checkpoint_path: str) -> None: + """ + Attempts to restore the full training and model state from a HPC checkpoint file. + + .. deprecated:: + `CheckpointConnector.hpc_load` was deprecated in v1.4 and will be removed in v1.6. + Use `CheckpointConnector.restore` instead. + """ + rank_zero_deprecation( + "`CheckpointConnector.hpc_load()` was deprecated in v1.4 and will be removed in v1.6." + " Use `CheckpointConnector.restore()` instead." + ) + self.restore(checkpoint_path) + def max_ckpt_version_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Optional[int]: """List up files in `dir_path` with `name_key`, then yield maximum suffix number. Args: From 526d3a442795b0181f69a0f796b9a8247309f957 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 13 Jun 2021 19:49:16 +0200 Subject: [PATCH 13/14] mock logger --- tests/trainer/connectors/test_checkpoint_connector.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/trainer/connectors/test_checkpoint_connector.py b/tests/trainer/connectors/test_checkpoint_connector.py index 575186f21d0b4..6e152f5944b59 100644 --- a/tests/trainer/connectors/test_checkpoint_connector.py +++ b/tests/trainer/connectors/test_checkpoint_connector.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from unittest.mock import Mock import torch @@ -41,19 +42,21 @@ def test_hpc_hook_calls(tmpdir): default_root_dir=tmpdir, max_steps=1, checkpoint_callback=False, + logger=False, ) trainer.fit(model) connector = trainer.checkpoint_connector - connector.hpc_save(tmpdir, logger=trainer.logger) + connector.hpc_save(tmpdir, logger=Mock()) assert model.hpc_save_called == 1 assert model.hpc_load_called == 0 # new training run, restore from hpc checkpoint file automatically - assert set(os.listdir(tmpdir)) == {"hpc_ckpt_1.ckpt", "lightning_logs"} + assert set(os.listdir(tmpdir)) == {"hpc_ckpt_1.ckpt"} trainer = Trainer( default_root_dir=tmpdir, max_steps=1, checkpoint_callback=False, + logger=False, ) trainer.fit(model) assert model.hpc_save_called == 1 From d64a9dcc93e6b4a51de09a77fdcb81a65582cda4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 13 Jun 2021 20:00:15 +0200 Subject: [PATCH 14/14] shorter message --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 5bef4b198b4b6..8035f0c532764 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -379,9 +379,8 @@ def hpc_load(self, checkpoint_path: str) -> None: """ Attempts to restore the full training and model state from a HPC checkpoint file. - .. deprecated:: - `CheckpointConnector.hpc_load` was deprecated in v1.4 and will be removed in v1.6. - Use `CheckpointConnector.restore` instead. + .. deprecated::v1.4 + Will be removed in v1.6. Use :meth:`restore` instead. """ rank_zero_deprecation( "`CheckpointConnector.hpc_load()` was deprecated in v1.4 and will be removed in v1.6."