From 02f5f85833e81e7d002ce008eabd1144cf60d922 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 16 Dec 2020 20:35:51 +0100 Subject: [PATCH 01/21] resolve bug --- .../callbacks/model_checkpoint.py | 9 +- .../trainer/connectors/callback_connector.py | 8 ++ .../connectors/checkpoint_connector.py | 10 +- .../checkpointing/test_trainer_checkpoint.py | 96 +++++++++++++++++++ 4 files changed, 116 insertions(+), 7 deletions(-) create mode 100644 tests/checkpointing/test_trainer_checkpoint.py diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 4ac800f456c06..e2231a6de9fda 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -20,10 +20,10 @@ """ -import os -import re from copy import deepcopy +import os from pathlib import Path +import re from typing import Any, Dict, Optional, Union import numpy as np @@ -32,8 +32,8 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.plugins.rpc_plugin import RPCPlugin +from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -207,11 +207,14 @@ def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]: "best_model_score": self.best_model_score, "best_model_path": self.best_model_path, "current_score": self.current_score, + "dirpath": self.dirpath } def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]): self.best_model_score = checkpointed_state["best_model_score"] self.best_model_path = checkpointed_state["best_model_path"] + if "dirpath" in checkpointed_state: + self.dirpath = checkpointed_state["dirpath"] def save_checkpoint(self, trainer, pl_module): """ diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 9a8e12c9419ab..25af1f13bc025 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -53,6 +53,14 @@ def on_trainer_init( progress_bar_refresh_rate, process_position ) + def resolve_resume_from_checkpoint(self): + if not self._trainer_has_checkpoint_callbacks(): + return self.trainer.resume_from_checkpoint + checkpoint_callbacks = self.trainer.checkpoint_callbacks[0] + if os.path.exists(checkpoint_callbacks.best_model_path): + return checkpoint_callbacks.best_model_path + return self.trainer.resume_from_checkpoint + def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpoint, bool]): if isinstance(checkpoint_callback, ModelCheckpoint): # TODO: deprecated, remove this block in v1.3.0 diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 429bddd88b77e..76f2b40396270 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -15,18 +15,19 @@ import os from pathlib import Path import re -from typing import Union, Optional +from typing import Optional, Union import torch import pytorch_lightning from pytorch_lightning import _logger as log +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities import APEX_AVAILABLE, AMPType, OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities import AMPType, APEX_AVAILABLE, OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load -from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS if APEX_AVAILABLE: from apex import amp @@ -64,7 +65,8 @@ def restore_weights(self, model: LightningModule): # 2. Attempt to restore states from `resume_from_checkpoint` file elif self.trainer.resume_from_checkpoint is not None: - self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu) + resume_from_checkpoint = self.trainer.callback_connector.resolve_resume_from_checkpoint() + self.restore(resume_from_checkpoint, on_gpu=self.trainer.on_gpu) # wait for all to catch up self.trainer.accelerator_backend.barrier('TrainerIOMixin.restore_weights') diff --git a/tests/checkpointing/test_trainer_checkpoint.py b/tests/checkpointing/test_trainer_checkpoint.py new file mode 100644 index 0000000000000..a95ef37853d45 --- /dev/null +++ b/tests/checkpointing/test_trainer_checkpoint.py @@ -0,0 +1,96 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from argparse import Namespace +from copy import deepcopy +import os +from pathlib import Path +import pickle +import platform +import re +from unittest import mock +from unittest.mock import Mock + +import cloudpickle +from omegaconf import Container, OmegaConf +import pytest +import torch +import yaml + +import pytorch_lightning as pl +from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.callbacks import Callback, ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.base import BoringModel +import tests.base.develop_utils as tutils + + +def test_finetunning_with_resume_from_checkpoint(tmpdir): + """ + This test validates that generated ModelCheckpoint is pointing to the right best_model_path during test + """ + + seed_everything(3) + + checkpoint_callback = ModelCheckpoint(monitor='val_loss', dirpath=tmpdir, filename="{epoch:02d}", save_top_k=-1) + + class ExtendedBoringModel(BoringModel): + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.001) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + def validation_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log("val_loss", loss, on_epoch=True, prog_bar=True) + + model = ExtendedBoringModel() + model.validation_epoch_end = None + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=12, + limit_val_batches=6, + limit_test_batches=12, + callbacks=[checkpoint_callback], + ) + trainer.fit(model) + assert os.listdir(tmpdir) == ['epoch=00.ckpt', 'lightning_logs'] + + best_model_paths = [deepcopy(checkpoint_callback.best_model_path)] + results = [] + + for idx in range(3, 6): + # load from checkpoint + trainer = pl.Trainer( + default_root_dir=tmpdir, + max_epochs=idx, + limit_train_batches=12, + limit_val_batches=12, + limit_test_batches=12, + resume_from_checkpoint=best_model_paths[-1], + progress_bar_refresh_rate=0, + ) + trainer.fit(model) + results.append(trainer.test()[0]) + best_model_paths.append(deepcopy(trainer.callbacks[0].best_model_path)) + + for idx in range(len(results) - 1): + assert results[idx]["val_loss"] > results[idx + 1]["val_loss"] + + for idx, best_model_path in enumerate(best_model_paths[1:]): + assert f"epoch={idx + 2}" in best_model_path From b60a0addb6137d06f3faac22a73af9e034949fb1 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sun, 20 Dec 2020 20:27:20 +0100 Subject: [PATCH 02/21] update code --- pytorch_lightning/callbacks/model_checkpoint.py | 4 +--- .../trainer/connectors/callback_connector.py | 7 ++++++- tests/checkpointing/test_model_checkpoint.py | 11 +++++------ 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 524db51c1fd41..b2e15cdf88a63 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -21,10 +21,8 @@ """ -import numbers -import os -import re from copy import deepcopy +import numbers import os from pathlib import Path import re diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 25af1f13bc025..8ef726557a54f 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -58,7 +58,12 @@ def resolve_resume_from_checkpoint(self): return self.trainer.resume_from_checkpoint checkpoint_callbacks = self.trainer.checkpoint_callbacks[0] if os.path.exists(checkpoint_callbacks.best_model_path): - return checkpoint_callbacks.best_model_path + resume_from_checkpoint_options = [ + checkpoint_callbacks.best_model_path, + self.trainer.resume_from_checkpoint + ] + resume_from_checkpoint_options.sort() + return resume_from_checkpoint_options[-1] return self.trainer.resume_from_checkpoint def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpoint, bool]): diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 106c34030051e..7c8da9a143f99 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -11,29 +11,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from argparse import Namespace import os +from pathlib import Path import pickle import platform import re -from argparse import Namespace -from pathlib import Path from unittest import mock from unittest.mock import Mock import cloudpickle +from omegaconf import Container, OmegaConf import pytest import torch import yaml -from omegaconf import Container, OmegaConf import pytorch_lightning as pl -import tests.base.develop_utils as tutils -from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import BoringModel +import tests.base.develop_utils as tutils class LogInTwoMethods(BoringModel): @@ -762,7 +762,6 @@ def assert_checkpoint_log_dir(idx): assert not trainer.checkpoint_connector.has_trained assert trainer.global_step == epochs * limit_train_batches assert trainer.current_epoch == epochs - trainer.fit(model) assert not trainer.checkpoint_connector.has_trained assert trainer.global_step == epochs * limit_train_batches From c42149e1e7669b3f3db04652bc0c1c412a9878ec Mon Sep 17 00:00:00 2001 From: tchaton Date: Sun, 20 Dec 2020 20:31:07 +0100 Subject: [PATCH 03/21] add set -e --- tests/special_tests.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/special_tests.sh b/tests/special_tests.sh index 950e3776bbc7f..8d67cce28b39f 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # Running special tests +set -e export PL_RUNNING_SPECIAL_TESTS=1 DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no" python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp From 9e433aaefe7d7519dec95ee7ab416338d2e778a6 Mon Sep 17 00:00:00 2001 From: chaton Date: Mon, 21 Dec 2020 11:54:59 +0100 Subject: [PATCH 04/21] Update pytorch_lightning/callbacks/model_checkpoint.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/callbacks/model_checkpoint.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 9647ea4d6ad9b..ea47be70ca901 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -215,8 +215,7 @@ def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]: def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]): self.best_model_score = checkpointed_state["best_model_score"] self.best_model_path = checkpointed_state["best_model_path"] - if "dirpath" in checkpointed_state: - self.dirpath = checkpointed_state["dirpath"] + self.dirpath = checkpointed_state.get("dirpath", self.dirpath) def save_checkpoint(self, trainer, pl_module): """ From 17cb6a14734116a5830ee85c76210f10c3c9439f Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 23 Dec 2020 08:49:59 +0100 Subject: [PATCH 05/21] update test --- tests/checkpointing/test_trainer_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/checkpointing/test_trainer_checkpoint.py b/tests/checkpointing/test_trainer_checkpoint.py index a95ef37853d45..ffcc50b9bc771 100644 --- a/tests/checkpointing/test_trainer_checkpoint.py +++ b/tests/checkpointing/test_trainer_checkpoint.py @@ -69,7 +69,7 @@ def validation_step(self, batch, batch_idx): callbacks=[checkpoint_callback], ) trainer.fit(model) - assert os.listdir(tmpdir) == ['epoch=00.ckpt', 'lightning_logs'] + assert set(os.listdir(tmpdir)) == {'epoch=00.ckpt', 'lightning_logs'} best_model_paths = [deepcopy(checkpoint_callback.best_model_path)] results = [] From 014f79cb525cad5d2be579ef064372119bdaf910 Mon Sep 17 00:00:00 2001 From: chaton Date: Wed, 23 Dec 2020 12:42:39 +0100 Subject: [PATCH 06/21] Update tests/checkpointing/test_trainer_checkpoint.py Co-authored-by: Sean Naren --- tests/checkpointing/test_trainer_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/checkpointing/test_trainer_checkpoint.py b/tests/checkpointing/test_trainer_checkpoint.py index ffcc50b9bc771..98f44ecf4578a 100644 --- a/tests/checkpointing/test_trainer_checkpoint.py +++ b/tests/checkpointing/test_trainer_checkpoint.py @@ -37,7 +37,7 @@ import tests.base.develop_utils as tutils -def test_finetunning_with_resume_from_checkpoint(tmpdir): +def test_finetuning_with_resume_from_checkpoint(tmpdir): """ This test validates that generated ModelCheckpoint is pointing to the right best_model_path during test """ From f09af65b6b28be66973c71c03719069f480f4a56 Mon Sep 17 00:00:00 2001 From: chaton Date: Mon, 28 Dec 2020 12:02:28 +0100 Subject: [PATCH 07/21] Update tests/checkpointing/test_trainer_checkpoint.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- tests/checkpointing/test_trainer_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/checkpointing/test_trainer_checkpoint.py b/tests/checkpointing/test_trainer_checkpoint.py index 98f44ecf4578a..13531e64757f1 100644 --- a/tests/checkpointing/test_trainer_checkpoint.py +++ b/tests/checkpointing/test_trainer_checkpoint.py @@ -69,7 +69,7 @@ def validation_step(self, batch, batch_idx): callbacks=[checkpoint_callback], ) trainer.fit(model) - assert set(os.listdir(tmpdir)) == {'epoch=00.ckpt', 'lightning_logs'} + assert os.listdir(tmpdir) == ['epoch=00.ckpt'] best_model_paths = [deepcopy(checkpoint_callback.best_model_path)] results = [] From a2a9fa0764bf45394136d255a6af2e293169f7bb Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 28 Dec 2020 12:26:19 +0100 Subject: [PATCH 08/21] update on comments --- .../callbacks/model_checkpoint.py | 5 ++++- .../trainer/connectors/callback_connector.py | 20 ++++++++----------- .../connectors/checkpoint_connector.py | 4 ++-- pytorch_lightning/trainer/training_loop.py | 2 +- tests/checkpointing/test_model_checkpoint.py | 2 +- .../checkpointing/test_trainer_checkpoint.py | 13 +++++++----- 6 files changed, 24 insertions(+), 22 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 629d5c4e22b6d..1fe350038d514 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -214,7 +214,10 @@ def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]: def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]): self.best_model_score = checkpointed_state["best_model_score"] self.best_model_path = checkpointed_state["best_model_path"] - self.dirpath = checkpointed_state.get("dirpath", self.dirpath) + dirpath = checkpointed_state.get("dirpath", self.dirpath) + # If the dirpath exists, checkpoints should be grouped together. + if dirpath is not None and os.path.exists(dirpath): + self.dirpath = dirpath def save_checkpoint(self, trainer, pl_module): """ diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 8ef726557a54f..5d98e66eabd4f 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -53,18 +53,14 @@ def on_trainer_init( progress_bar_refresh_rate, process_position ) - def resolve_resume_from_checkpoint(self): - if not self._trainer_has_checkpoint_callbacks(): - return self.trainer.resume_from_checkpoint - checkpoint_callbacks = self.trainer.checkpoint_callbacks[0] - if os.path.exists(checkpoint_callbacks.best_model_path): - resume_from_checkpoint_options = [ - checkpoint_callbacks.best_model_path, - self.trainer.resume_from_checkpoint - ] - resume_from_checkpoint_options.sort() - return resume_from_checkpoint_options[-1] - return self.trainer.resume_from_checkpoint + def resolve_checkpoint_path(self): + if self.trainer.testing: + checkpoint_callbacks = self.trainer.checkpoint_callbacks[0] + checkpoint_path = checkpoint_callbacks.best_model_path + else: + checkpoint_path = self.trainer.resume_from_checkpoint + return self.trainer.resume_from_checkpoint if checkpoint_path == '' \ + else checkpoint_path def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpoint, bool]): if isinstance(checkpoint_callback, ModelCheckpoint): diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 76f2b40396270..59df2c3f39160 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -65,8 +65,8 @@ def restore_weights(self, model: LightningModule): # 2. Attempt to restore states from `resume_from_checkpoint` file elif self.trainer.resume_from_checkpoint is not None: - resume_from_checkpoint = self.trainer.callback_connector.resolve_resume_from_checkpoint() - self.restore(resume_from_checkpoint, on_gpu=self.trainer.on_gpu) + checkpoint_path = self.trainer.callback_connector.resolve_checkpoint_path() + self.restore(checkpoint_path, on_gpu=self.trainer.on_gpu) # wait for all to catch up self.trainer.accelerator_backend.barrier('TrainerIOMixin.restore_weights') diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index fe4525006ebb9..65eab8ba6f92a 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -171,7 +171,7 @@ def setup_training(self, model: LightningModule): # if cluster resets state, the model will update with the saved weights self.trainer.model = model - # restore training and model before hpc is called + # restore model weights before hpc is called self.trainer.checkpoint_connector.restore_weights(model) # on pretrain routine end diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 7c8da9a143f99..82b671d3b36dd 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -680,7 +680,7 @@ def validation_step(self, batch, batch_idx): ) trainer.fit(model) trainer.test(model, verbose=False) - assert set(os.listdir(tmpdir)) == {'epoch=00.ckpt', 'lightning_logs'} + assert os.listdir(tmpdir) == ['epoch=00.ckpt', 'lightning_logs'] assert set(os.listdir(tmpdir.join("lightning_logs"))) == {f'version_{i}' for i in range(4)} diff --git a/tests/checkpointing/test_trainer_checkpoint.py b/tests/checkpointing/test_trainer_checkpoint.py index 13531e64757f1..88378351bad4a 100644 --- a/tests/checkpointing/test_trainer_checkpoint.py +++ b/tests/checkpointing/test_trainer_checkpoint.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from argparse import Namespace -from copy import deepcopy import os from pathlib import Path import pickle @@ -67,11 +66,12 @@ def validation_step(self, batch, batch_idx): limit_val_batches=6, limit_test_batches=12, callbacks=[checkpoint_callback], + logger=False, ) trainer.fit(model) assert os.listdir(tmpdir) == ['epoch=00.ckpt'] - best_model_paths = [deepcopy(checkpoint_callback.best_model_path)] + best_model_paths = [checkpoint_callback.best_model_path] results = [] for idx in range(3, 6): @@ -87,10 +87,13 @@ def validation_step(self, batch, batch_idx): ) trainer.fit(model) results.append(trainer.test()[0]) - best_model_paths.append(deepcopy(trainer.callbacks[0].best_model_path)) + best_model_paths.append(trainer.callbacks[0].best_model_path) for idx in range(len(results) - 1): assert results[idx]["val_loss"] > results[idx + 1]["val_loss"] - for idx, best_model_path in enumerate(best_model_paths[1:]): - assert f"epoch={idx + 2}" in best_model_path + for idx, best_model_path in enumerate(best_model_paths): + if idx == 0: + best_model_path.endswith(f"epoch={idx}.ckpt") + else: + best_model_path.endswith(f"epoch={idx + 2}.ckpt") From bd3d2dac85724fb6eca182d99b196f3a5ff9f1c5 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 28 Dec 2020 16:17:52 +0100 Subject: [PATCH 09/21] resolve test --- tests/checkpointing/test_trainer_checkpoint.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/checkpointing/test_trainer_checkpoint.py b/tests/checkpointing/test_trainer_checkpoint.py index 88378351bad4a..2120e8f0ae71a 100644 --- a/tests/checkpointing/test_trainer_checkpoint.py +++ b/tests/checkpointing/test_trainer_checkpoint.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from argparse import Namespace +from copy import deepcopy import os from pathlib import Path import pickle @@ -86,7 +87,8 @@ def validation_step(self, batch, batch_idx): progress_bar_refresh_rate=0, ) trainer.fit(model) - results.append(trainer.test()[0]) + trainer.test() + results.append(deepcopy(trainer.callback_metrics)) best_model_paths.append(trainer.callbacks[0].best_model_path) for idx in range(len(results) - 1): From 64e3b5b23bf8ae7002a3ea4aedb22f961f2bf273 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 28 Dec 2020 17:09:26 +0100 Subject: [PATCH 10/21] convert to set --- tests/checkpointing/test_model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 82b671d3b36dd..7c8da9a143f99 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -680,7 +680,7 @@ def validation_step(self, batch, batch_idx): ) trainer.fit(model) trainer.test(model, verbose=False) - assert os.listdir(tmpdir) == ['epoch=00.ckpt', 'lightning_logs'] + assert set(os.listdir(tmpdir)) == {'epoch=00.ckpt', 'lightning_logs'} assert set(os.listdir(tmpdir.join("lightning_logs"))) == {f'version_{i}' for i in range(4)} From ed84588cf9a966afc99571ab93857718bc389fd2 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 4 Jan 2021 13:49:56 +0100 Subject: [PATCH 11/21] update --- pytorch_lightning/callbacks/model_checkpoint.py | 4 ---- .../trainer/connectors/callback_connector.py | 9 --------- .../trainer/connectors/checkpoint_connector.py | 5 ++--- tests/checkpointing/test_model_checkpoint.py | 5 +++-- tests/checkpointing/test_trainer_checkpoint.py | 4 ++-- 5 files changed, 7 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 1fe350038d514..6ff2846aaf9f2 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -214,10 +214,6 @@ def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]: def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]): self.best_model_score = checkpointed_state["best_model_score"] self.best_model_path = checkpointed_state["best_model_path"] - dirpath = checkpointed_state.get("dirpath", self.dirpath) - # If the dirpath exists, checkpoints should be grouped together. - if dirpath is not None and os.path.exists(dirpath): - self.dirpath = dirpath def save_checkpoint(self, trainer, pl_module): """ diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 5d98e66eabd4f..9a8e12c9419ab 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -53,15 +53,6 @@ def on_trainer_init( progress_bar_refresh_rate, process_position ) - def resolve_checkpoint_path(self): - if self.trainer.testing: - checkpoint_callbacks = self.trainer.checkpoint_callbacks[0] - checkpoint_path = checkpoint_callbacks.best_model_path - else: - checkpoint_path = self.trainer.resume_from_checkpoint - return self.trainer.resume_from_checkpoint if checkpoint_path == '' \ - else checkpoint_path - def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpoint, bool]): if isinstance(checkpoint_callback, ModelCheckpoint): # TODO: deprecated, remove this block in v1.3.0 diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 59df2c3f39160..cececec98c51a 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -64,9 +64,8 @@ def restore_weights(self, model: LightningModule): rank_zero_info(f'restored hpc model from: {checkpoint_path}') # 2. Attempt to restore states from `resume_from_checkpoint` file - elif self.trainer.resume_from_checkpoint is not None: - checkpoint_path = self.trainer.callback_connector.resolve_checkpoint_path() - self.restore(checkpoint_path, on_gpu=self.trainer.on_gpu) + elif self.trainer.resume_from_checkpoint is not None and not self.trainer.testing: + self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu) # wait for all to catch up self.trainer.accelerator_backend.barrier('TrainerIOMixin.restore_weights') diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 7c8da9a143f99..de61536d97cd0 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -760,8 +760,9 @@ def assert_checkpoint_log_dir(idx): model = ExtendedBoringModel() trainer.test(model) assert not trainer.checkpoint_connector.has_trained - assert trainer.global_step == epochs * limit_train_batches - assert trainer.current_epoch == epochs + # resume_from_checkpoint is resumed when calling `.fit` + assert trainer.global_step == 0 + assert trainer.current_epoch == 0 trainer.fit(model) assert not trainer.checkpoint_connector.has_trained assert trainer.global_step == epochs * limit_train_batches diff --git a/tests/checkpointing/test_trainer_checkpoint.py b/tests/checkpointing/test_trainer_checkpoint.py index 2120e8f0ae71a..e0e46646f6026 100644 --- a/tests/checkpointing/test_trainer_checkpoint.py +++ b/tests/checkpointing/test_trainer_checkpoint.py @@ -96,6 +96,6 @@ def validation_step(self, batch, batch_idx): for idx, best_model_path in enumerate(best_model_paths): if idx == 0: - best_model_path.endswith(f"epoch={idx}.ckpt") + assert best_model_path.endswith(f"epoch=0{idx}.ckpt") else: - best_model_path.endswith(f"epoch={idx + 2}.ckpt") + assert f"epoch={idx + 1}" in best_model_path From 41355d5e2dc0720bad9d602705d75565b155cab2 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 4 Jan 2021 14:01:16 +0100 Subject: [PATCH 12/21] add error triggering --- .drone.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.drone.yml b/.drone.yml index b0b6c3df1b699..472861852cae7 100644 --- a/.drone.yml +++ b/.drone.yml @@ -30,6 +30,7 @@ steps: MKL_THREADING_LAYER: GNU commands: + - set -e - python --version - pip --version - nvidia-smi From 53455af8a7751ae7a31a4428ab99d3a91e6d29d3 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 4 Jan 2021 17:18:20 +0000 Subject: [PATCH 13/21] update --- pytorch_lightning/plugins/rpc_plugin.py | 9 +++++++-- tests/plugins/test_ddp_sequential_plugin.py | 4 +++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/rpc_plugin.py b/pytorch_lightning/plugins/rpc_plugin.py index 492bddaff0c77..40439e35df5d7 100644 --- a/pytorch_lightning/plugins/rpc_plugin.py +++ b/pytorch_lightning/plugins/rpc_plugin.py @@ -33,7 +33,8 @@ class RPCPlugin(DDPPlugin): that need to be addressed when using RPC communication when building custom RPC Plugins. """ - def __init__(self, **kwargs): + def __init__(self, rpc_timeout_sec: float = rpc.constants.DEFAULT_RPC_TIMEOUT_SEC, **kwargs): + self.rpc_timeout_sec = rpc_timeout_sec self.rpc_initialized = False super().__init__(**kwargs) @@ -41,7 +42,11 @@ def init_rpc_connection(self, global_rank: int, world_size: int) -> None: os.environ['MASTER_PORT'] = os.getenv('RPC_MASTER_PORT', '15000') - rpc.init_rpc(f"worker{global_rank}", rank=global_rank, world_size=world_size) + rpc_backend_options = rpc.TensorPipeRpcBackendOptions( + rpc_timeout=self.rpc_timeout_sec + ) + rpc.init_rpc(f"worker{global_rank}", rank=global_rank, world_size=world_size) + rpc._set_rpc_timeout(self.rpc_timeout_sec) self.rpc_initialized = True def rpc_save_model(self, diff --git a/tests/plugins/test_ddp_sequential_plugin.py b/tests/plugins/test_ddp_sequential_plugin.py index 23b0b9128b349..9271c8d68a33f 100644 --- a/tests/plugins/test_ddp_sequential_plugin.py +++ b/tests/plugins/test_ddp_sequential_plugin.py @@ -47,7 +47,8 @@ def test_ddp_sequential_plugin_ddp_rpc_manual(tmpdir, args=None): limit_test_batches=2, gpus=2, distributed_backend="ddp", - plugins=[DDPSequentialPlugin(balance=[2, 1])], + plugins=[DDPSequentialPlugin(balance=[2, 1], rpc_timeout_sec=5 * 60)], + enable_pl_optimizer=True, ) trainer.fit(model) @@ -163,6 +164,7 @@ def step(self, x): def training_step(self, batch, batch_idx): opt = self.optimizers() + print(opt) output = self.sequential_module(batch) loss = self.loss(output) self.log("train_loss", loss, on_epoch=True, prog_bar=True) From fa8d9521f2538aae89fbc7c2813f60e432c098a6 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 4 Jan 2021 18:46:18 +0100 Subject: [PATCH 14/21] update on comments --- pytorch_lightning/plugins/rpc_plugin.py | 2 +- pytorch_lightning/trainer/training_loop.py | 2 +- tests/checkpointing/test_trainer_checkpoint.py | 18 ++---------------- 3 files changed, 4 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/plugins/rpc_plugin.py b/pytorch_lightning/plugins/rpc_plugin.py index 492bddaff0c77..776ac17c3d4eb 100644 --- a/pytorch_lightning/plugins/rpc_plugin.py +++ b/pytorch_lightning/plugins/rpc_plugin.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Optional +from typing import Optional import torch diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 65eab8ba6f92a..0ec6c3c3f3146 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -171,7 +171,7 @@ def setup_training(self, model: LightningModule): # if cluster resets state, the model will update with the saved weights self.trainer.model = model - # restore model weights before hpc is called + # restore training state and model weights before hpc is called self.trainer.checkpoint_connector.restore_weights(model) # on pretrain routine end diff --git a/tests/checkpointing/test_trainer_checkpoint.py b/tests/checkpointing/test_trainer_checkpoint.py index e0e46646f6026..9e93a8c297481 100644 --- a/tests/checkpointing/test_trainer_checkpoint.py +++ b/tests/checkpointing/test_trainer_checkpoint.py @@ -11,30 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from argparse import Namespace from copy import deepcopy import os -from pathlib import Path -import pickle -import platform -import re -from unittest import mock -from unittest.mock import Mock -import cloudpickle -from omegaconf import Container, OmegaConf -import pytest import torch -import yaml import pytorch_lightning as pl from pytorch_lightning import seed_everything, Trainer -from pytorch_lightning.callbacks import Callback, ModelCheckpoint -from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.utilities.cloud_io import load as pl_load -from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import BoringModel -import tests.base.develop_utils as tutils def test_finetuning_with_resume_from_checkpoint(tmpdir): @@ -89,7 +75,7 @@ def validation_step(self, batch, batch_idx): trainer.fit(model) trainer.test() results.append(deepcopy(trainer.callback_metrics)) - best_model_paths.append(trainer.callbacks[0].best_model_path) + best_model_paths.append(trainer.checkpoint_callback.best_model_path) for idx in range(len(results) - 1): assert results[idx]["val_loss"] > results[idx + 1]["val_loss"] From ef75de54a689807f73f7a57a0458b82e3994edd4 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 4 Jan 2021 18:48:20 +0100 Subject: [PATCH 15/21] update --- pytorch_lightning/plugins/rpc_plugin.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/rpc_plugin.py b/pytorch_lightning/plugins/rpc_plugin.py index 7cf06ed80e786..4e0db2bc5027e 100644 --- a/pytorch_lightning/plugins/rpc_plugin.py +++ b/pytorch_lightning/plugins/rpc_plugin.py @@ -22,6 +22,9 @@ if RPC_AVAILABLE: from torch.distributed import rpc + from torch.distributed.rpc.constants import DEFAULT_RPC_TIMEOUT_SEC +else: + DEFAULT_RPC_TIMEOUT_SEC = 60. class RPCPlugin(DDPPlugin): @@ -33,7 +36,7 @@ class RPCPlugin(DDPPlugin): that need to be addressed when using RPC communication when building custom RPC Plugins. """ - def __init__(self, rpc_timeout_sec: float = rpc.constants.DEFAULT_RPC_TIMEOUT_SEC, **kwargs): + def __init__(self, rpc_timeout_sec: float = DEFAULT_RPC_TIMEOUT_SEC, **kwargs): self.rpc_timeout_sec = rpc_timeout_sec self.rpc_initialized = False super().__init__(**kwargs) @@ -44,9 +47,9 @@ def init_rpc_connection(self, os.environ['MASTER_PORT'] = os.getenv('RPC_MASTER_PORT', '15000') rpc_backend_options = rpc.TensorPipeRpcBackendOptions( rpc_timeout=self.rpc_timeout_sec - ) - rpc.init_rpc(f"worker{global_rank}", rank=global_rank, world_size=world_size) - rpc._set_rpc_timeout(self.rpc_timeout_sec) + ) + rpc.init_rpc(f"worker{global_rank}", rank=global_rank, world_size=world_size) + rpc._set_rpc_timeout(self.rpc_timeout_sec) self.rpc_initialized = True def rpc_save_model(self, From d85662d094977866e8d858b087473a77ccc44ab3 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 4 Jan 2021 19:31:12 +0100 Subject: [PATCH 16/21] resolve import --- pytorch_lightning/plugins/rpc_plugin.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/plugins/rpc_plugin.py b/pytorch_lightning/plugins/rpc_plugin.py index 4e0db2bc5027e..1be833f4b86fe 100644 --- a/pytorch_lightning/plugins/rpc_plugin.py +++ b/pytorch_lightning/plugins/rpc_plugin.py @@ -20,11 +20,13 @@ from pytorch_lightning.plugins.ddp_plugin import DDPPlugin from pytorch_lightning.utilities import RPC_AVAILABLE +DEFAULT_RPC_TIMEOUT_SEC = 60. if RPC_AVAILABLE: from torch.distributed import rpc - from torch.distributed.rpc.constants import DEFAULT_RPC_TIMEOUT_SEC -else: - DEFAULT_RPC_TIMEOUT_SEC = 60. + try: + from torch.distributed.rpc.constants import DEFAULT_RPC_TIMEOUT_SEC + except ModuleNotFoundError: + DEFAULT_RPC_TIMEOUT_SEC = 60. class RPCPlugin(DDPPlugin): @@ -45,10 +47,7 @@ def init_rpc_connection(self, global_rank: int, world_size: int) -> None: os.environ['MASTER_PORT'] = os.getenv('RPC_MASTER_PORT', '15000') - rpc_backend_options = rpc.TensorPipeRpcBackendOptions( - rpc_timeout=self.rpc_timeout_sec - ) - rpc.init_rpc(f"worker{global_rank}", rank=global_rank, world_size=world_size) + rpc.init_rpc(f"worker{global_rank}", rank=global_rank, world_size=world_size) rpc._set_rpc_timeout(self.rpc_timeout_sec) self.rpc_initialized = True From 6c4948c85c2652f0127636f90c17125861a3a360 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 4 Jan 2021 19:56:52 +0100 Subject: [PATCH 17/21] update --- pytorch_lightning/plugins/rpc_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/rpc_plugin.py b/pytorch_lightning/plugins/rpc_plugin.py index 1be833f4b86fe..e476cba365710 100644 --- a/pytorch_lightning/plugins/rpc_plugin.py +++ b/pytorch_lightning/plugins/rpc_plugin.py @@ -25,7 +25,7 @@ from torch.distributed import rpc try: from torch.distributed.rpc.constants import DEFAULT_RPC_TIMEOUT_SEC - except ModuleNotFoundError: + except (ModuleNotFoundError, ImportError): DEFAULT_RPC_TIMEOUT_SEC = 60. From 31977620edf471fba4a61227239c537569fddbc1 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 4 Jan 2021 20:32:21 +0100 Subject: [PATCH 18/21] update --- tests/plugins/test_ddp_sequential_plugin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/plugins/test_ddp_sequential_plugin.py b/tests/plugins/test_ddp_sequential_plugin.py index 9271c8d68a33f..8b21c36e73065 100644 --- a/tests/plugins/test_ddp_sequential_plugin.py +++ b/tests/plugins/test_ddp_sequential_plugin.py @@ -164,7 +164,6 @@ def step(self, x): def training_step(self, batch, batch_idx): opt = self.optimizers() - print(opt) output = self.sequential_module(batch) loss = self.loss(output) self.log("train_loss", loss, on_epoch=True, prog_bar=True) From b8f64bf772ca339c8589060c37b6cd1fa49b9ef6 Mon Sep 17 00:00:00 2001 From: chaton Date: Mon, 4 Jan 2021 21:13:02 +0100 Subject: [PATCH 19/21] Update pytorch_lightning/plugins/rpc_plugin.py Co-authored-by: Jirka Borovec --- pytorch_lightning/plugins/rpc_plugin.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/rpc_plugin.py b/pytorch_lightning/plugins/rpc_plugin.py index e476cba365710..7ebbac32560cf 100644 --- a/pytorch_lightning/plugins/rpc_plugin.py +++ b/pytorch_lightning/plugins/rpc_plugin.py @@ -23,10 +23,8 @@ DEFAULT_RPC_TIMEOUT_SEC = 60. if RPC_AVAILABLE: from torch.distributed import rpc - try: + if _module_available("torch.distributed.rpc.constants") and hasattr(torch.distributed.rpc.constants, "DEFAULT_RPC_TIMEOUT_SEC"): from torch.distributed.rpc.constants import DEFAULT_RPC_TIMEOUT_SEC - except (ModuleNotFoundError, ImportError): - DEFAULT_RPC_TIMEOUT_SEC = 60. class RPCPlugin(DDPPlugin): From 34986bbb08a15e638fece8b2ef61016ca7f2111b Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 4 Jan 2021 21:40:30 +0100 Subject: [PATCH 20/21] update --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b9b705459510..6084dfe1c993f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). -## [1.1.3rc] - 2020-12-29 +## [1.1.3] - 2021-01-05 ### Added @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Skip restore from `resume_from_checkpoint` in while `testing` ([#5161](https://github.com/PyTorchLightning/pytorch-lightning/pull/5161)) ## [1.1.2] - 2020-12-23 From 9252a0612715f7984e5b27ba2cbfe4092900147d Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 4 Jan 2021 22:38:08 +0100 Subject: [PATCH 21/21] add _module_available --- pytorch_lightning/plugins/rpc_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/rpc_plugin.py b/pytorch_lightning/plugins/rpc_plugin.py index 7ebbac32560cf..a1464f3c70e0b 100644 --- a/pytorch_lightning/plugins/rpc_plugin.py +++ b/pytorch_lightning/plugins/rpc_plugin.py @@ -18,7 +18,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.ddp_plugin import DDPPlugin -from pytorch_lightning.utilities import RPC_AVAILABLE +from pytorch_lightning.utilities import _module_available, RPC_AVAILABLE DEFAULT_RPC_TIMEOUT_SEC = 60. if RPC_AVAILABLE: