Skip to content

Commit 56437e9

Browse files
tchatonBorda
authored andcommitted
[bug-fix] Trainer.test points to latest best_model_path (#5161)
* resolve bug * update code * add set -e * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Adrian Wälchli <[email protected]> * update test * Update tests/checkpointing/test_trainer_checkpoint.py Co-authored-by: Sean Naren <[email protected]> * Update tests/checkpointing/test_trainer_checkpoint.py Co-authored-by: Carlos Mocholí <[email protected]> * update on comments * resolve test * convert to set * update * add error triggering * update * update on comments * update * resolve import * update * update * Update pytorch_lightning/plugins/rpc_plugin.py Co-authored-by: Jirka Borovec <[email protected]> * update Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Sean Naren <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> (cherry picked from commit d5b3678)
1 parent 704e00e commit 56437e9

File tree

10 files changed

+111
-12
lines changed

10 files changed

+111
-12
lines changed

.drone.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ steps:
3030
MKL_THREADING_LAYER: GNU
3131

3232
commands:
33+
- set -e
3334
- python --version
3435
- pip --version
3536
- nvidia-smi

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7676

7777
### Fixed
7878

79+
- Skip restore from `resume_from_checkpoint` in while `testing` ([#5161](https://github.com/PyTorchLightning/pytorch-lightning/pull/5161))
80+
7981
- Allowed `log_momentum` for adaptive optimizers in `LearningRateMonitor` ([#5333](https://github.com/PyTorchLightning/pytorch-lightning/pull/5333))
8082

8183
- Disabled checkpointing, earlystopping and logger with `fast_dev_run` ([#5277](https://github.com/PyTorchLightning/pytorch-lightning/pull/5277))
8284

8385

84-
8586
## [1.1.2] - 2020-12-23
8687

8788
### Added

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]:
199199
"best_model_score": self.best_model_score,
200200
"best_model_path": self.best_model_path,
201201
"current_score": self.current_score,
202+
"dirpath": self.dirpath
202203
}
203204

204205
def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]):

pytorch_lightning/plugins/rpc_plugin.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@
1818

1919
from pytorch_lightning.core.lightning import LightningModule
2020
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
21-
from pytorch_lightning.utilities import _RPC_AVAILABLE
21+
from pytorch_lightning.utilities import _RPC_AVAILABLE, _module_available
2222

23+
DEFAULT_RPC_TIMEOUT_SEC = 60.
2324
if _RPC_AVAILABLE:
2425
from torch.distributed import rpc
26+
if _module_available("torch.distributed.rpc.constants") and hasattr(torch.distributed.rpc.constants, "DEFAULT_RPC_TIMEOUT_SEC"):
27+
from torch.distributed.rpc.constants import DEFAULT_RPC_TIMEOUT_SEC
2528

2629

2730
class RPCPlugin(DDPPlugin):
@@ -33,7 +36,8 @@ class RPCPlugin(DDPPlugin):
3336
that need to be addressed when using RPC communication when building custom RPC Plugins.
3437
"""
3538

36-
def __init__(self, **kwargs):
39+
def __init__(self, rpc_timeout_sec: float = DEFAULT_RPC_TIMEOUT_SEC, **kwargs):
40+
self.rpc_timeout_sec = rpc_timeout_sec
3741
self.rpc_initialized = False
3842
super().__init__(**kwargs)
3943

@@ -42,6 +46,7 @@ def init_rpc_connection(self,
4246
world_size: int) -> None:
4347
os.environ['MASTER_PORT'] = os.getenv('RPC_MASTER_PORT', '15000')
4448
rpc.init_rpc(f"worker{global_rank}", rank=global_rank, world_size=world_size)
49+
rpc._set_rpc_timeout(self.rpc_timeout_sec)
4550
self.rpc_initialized = True
4651

4752
def rpc_save_model(self,

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121

2222
import pytorch_lightning
23+
from pytorch_lightning.callbacks import ModelCheckpoint
2324
from pytorch_lightning.core.lightning import LightningModule
2425
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType, _OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn
2526
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
@@ -62,7 +63,7 @@ def restore_weights(self, model: LightningModule) -> None:
6263
rank_zero_info(f'restored hpc model from: {checkpoint_path}')
6364

6465
# 2. Attempt to restore states from `resume_from_checkpoint` file
65-
elif self.trainer.resume_from_checkpoint is not None:
66+
elif self.trainer.resume_from_checkpoint is not None and not self.trainer.testing:
6667
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu)
6768

6869
# wait for all to catch up

pytorch_lightning/trainer/training_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def setup_training(self, model: LightningModule):
184184
# if cluster resets state, the model will update with the saved weights
185185
self.trainer.model = model
186186

187-
# restore training and model before hpc is called
187+
# restore training state and model weights before hpc is called
188188
self.trainer.checkpoint_connector.restore_weights(model)
189189

190190
# on pretrain routine end

tests/checkpointing/test_model_checkpoint.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,20 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from argparse import Namespace
1415
import os
16+
from pathlib import Path
1517
import pickle
1618
import platform
1719
import re
18-
from argparse import Namespace
19-
from pathlib import Path
2020
from unittest import mock
2121
from unittest.mock import Mock
2222

2323
import cloudpickle
24+
from omegaconf import Container, OmegaConf
2425
import pytest
2526
import torch
2627
import yaml
27-
from omegaconf import Container, OmegaConf
2828

2929
import pytorch_lightning as pl
3030
import tests.base.develop_utils as tutils
@@ -34,6 +34,7 @@
3434
from pytorch_lightning.utilities.cloud_io import load as pl_load
3535
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3636
from tests.base import BoringModel
37+
import tests.base.develop_utils as tutils
3738

3839

3940
class LogInTwoMethods(BoringModel):
@@ -745,9 +746,9 @@ def assert_checkpoint_log_dir(idx):
745746
model = ExtendedBoringModel()
746747
trainer.test(model)
747748
assert not trainer.checkpoint_connector.has_trained
748-
assert trainer.global_step == epochs * limit_train_batches
749-
assert trainer.current_epoch == epochs
750-
749+
# resume_from_checkpoint is resumed when calling `.fit`
750+
assert trainer.global_step == 0
751+
assert trainer.current_epoch == 0
751752
trainer.fit(model)
752753
assert not trainer.checkpoint_connector.has_trained
753754
assert trainer.global_step == epochs * limit_train_batches
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from copy import deepcopy
15+
import os
16+
17+
import torch
18+
19+
import pytorch_lightning as pl
20+
from pytorch_lightning import seed_everything, Trainer
21+
from pytorch_lightning.callbacks import ModelCheckpoint
22+
from pytorch_lightning.utilities.cloud_io import load as pl_load
23+
from tests.base import BoringModel
24+
25+
26+
def test_finetuning_with_resume_from_checkpoint(tmpdir):
27+
"""
28+
This test validates that generated ModelCheckpoint is pointing to the right best_model_path during test
29+
"""
30+
31+
seed_everything(3)
32+
33+
checkpoint_callback = ModelCheckpoint(monitor='val_loss', dirpath=tmpdir, filename="{epoch:02d}", save_top_k=-1)
34+
35+
class ExtendedBoringModel(BoringModel):
36+
37+
def configure_optimizers(self):
38+
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.001)
39+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
40+
return [optimizer], [lr_scheduler]
41+
42+
def validation_step(self, batch, batch_idx):
43+
output = self.layer(batch)
44+
loss = self.loss(batch, output)
45+
self.log("val_loss", loss, on_epoch=True, prog_bar=True)
46+
47+
model = ExtendedBoringModel()
48+
model.validation_epoch_end = None
49+
trainer = Trainer(
50+
default_root_dir=tmpdir,
51+
max_epochs=1,
52+
limit_train_batches=12,
53+
limit_val_batches=6,
54+
limit_test_batches=12,
55+
callbacks=[checkpoint_callback],
56+
logger=False,
57+
)
58+
trainer.fit(model)
59+
assert os.listdir(tmpdir) == ['epoch=00.ckpt']
60+
61+
best_model_paths = [checkpoint_callback.best_model_path]
62+
results = []
63+
64+
for idx in range(3, 6):
65+
# load from checkpoint
66+
trainer = pl.Trainer(
67+
default_root_dir=tmpdir,
68+
max_epochs=idx,
69+
limit_train_batches=12,
70+
limit_val_batches=12,
71+
limit_test_batches=12,
72+
resume_from_checkpoint=best_model_paths[-1],
73+
progress_bar_refresh_rate=0,
74+
)
75+
trainer.fit(model)
76+
trainer.test()
77+
results.append(deepcopy(trainer.callback_metrics))
78+
best_model_paths.append(trainer.checkpoint_callback.best_model_path)
79+
80+
for idx in range(len(results) - 1):
81+
assert results[idx]["val_loss"] > results[idx + 1]["val_loss"]
82+
83+
for idx, best_model_path in enumerate(best_model_paths):
84+
if idx == 0:
85+
assert best_model_path.endswith(f"epoch=0{idx}.ckpt")
86+
else:
87+
assert f"epoch={idx + 1}" in best_model_path

tests/plugins/test_ddp_sequential_plugin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ def test_ddp_sequential_plugin_ddp_rpc_manual(tmpdir, args=None):
4747
limit_test_batches=2,
4848
gpus=2,
4949
distributed_backend="ddp",
50-
plugins=[DDPSequentialPlugin(balance=[2, 1])],
50+
plugins=[DDPSequentialPlugin(balance=[2, 1], rpc_timeout_sec=5 * 60)],
51+
enable_pl_optimizer=True,
5152
)
5253

5354
trainer.fit(model)

tests/special_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# Running special tests
15+
set -e
1516
export PL_RUNNING_SPECIAL_TESTS=1
1617
DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no"
1718
python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp

0 commit comments

Comments
 (0)