Skip to content

Commit 9410ded

Browse files
authored
Merge branch 'master' into feature/5311-flatten-dict
2 parents 6ad1b4e + d568533 commit 9410ded

File tree

12 files changed

+125
-20
lines changed

12 files changed

+125
-20
lines changed

.drone.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ steps:
3030
MKL_THREADING_LAYER: GNU
3131

3232
commands:
33+
- set -e
3334
- python --version
3435
- pip --version
3536
- nvidia-smi

CHANGELOG.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8-
## [1.1.3rc] - 2020-12-29
8+
## [1.1.3] - 2021-01-05
99

1010
### Added
1111

@@ -25,11 +25,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2525

2626
### Fixed
2727

28+
- Skip restore from `resume_from_checkpoint` in while `testing` ([#5161](https://github.com/PyTorchLightning/pytorch-lightning/pull/5161))
29+
2830
- Allowed `log_momentum` for adaptive optimizers in `LearningRateMonitor` ([#5333](https://github.com/PyTorchLightning/pytorch-lightning/pull/5333))
2931

3032
- Disabled checkpointing, earlystopping and logger with `fast_dev_run` ([#5277](https://github.com/PyTorchLightning/pytorch-lightning/pull/5277))
3133

32-
3334
- Fixed casted key to string in `_flatten_dict` ([#5354](https://github.com/PyTorchLightning/pytorch-lightning/pull/5354))
3435

3536

docs/source/transfer_learning.rst

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,22 @@ Example: Imagenet (computer Vision)
5252

5353
class ImagenetTransferLearning(LightningModule):
5454
def __init__(self):
55+
super().__init__()
56+
5557
# init a pretrained resnet
56-
num_target_classes = 10
57-
self.feature_extractor = models.resnet50(pretrained=True)
58-
self.feature_extractor.eval()
58+
backbone = models.resnet50(pretrained=True)
59+
num_filters = backbone.fc.in_features
60+
layers = list(backbone.children())[:-1]
61+
self.feature_extractor = torch.nn.Sequential(*layers)
5962
6063
# use the pretrained model to classify cifar-10 (10 image classes)
61-
self.classifier = nn.Linear(2048, num_target_classes)
64+
num_target_classes = 10
65+
self.classifier = nn.Linear(num_filters, num_target_classes)
6266

6367
def forward(self, x):
64-
representations = self.feature_extractor(x)
68+
self.feature_extractor.eval()
69+
with torch.no_grad():
70+
representations = self.feature_extractor(x).flatten(1)
6571
x = self.classifier(representations)
6672
...
6773

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]:
208208
"best_model_score": self.best_model_score,
209209
"best_model_path": self.best_model_path,
210210
"current_score": self.current_score,
211+
"dirpath": self.dirpath
211212
}
212213

213214
def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]):

pytorch_lightning/metrics/classification/precision_recall.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
207207

208208
def compute(self):
209209
"""
210-
Computes accuracy over state.
210+
Computes recall over state.
211211
"""
212212
if self.average == 'micro':
213213
return self.true_positives.sum().float() / (self.actual_positives.sum() + METRIC_EPS)

pytorch_lightning/plugins/rpc_plugin.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,19 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15-
from typing import Any, Optional
15+
from typing import Optional
1616

1717
import torch
1818

1919
from pytorch_lightning.core.lightning import LightningModule
2020
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
21-
from pytorch_lightning.utilities import RPC_AVAILABLE
21+
from pytorch_lightning.utilities import _module_available, RPC_AVAILABLE
2222

23+
DEFAULT_RPC_TIMEOUT_SEC = 60.
2324
if RPC_AVAILABLE:
2425
from torch.distributed import rpc
26+
if _module_available("torch.distributed.rpc.constants") and hasattr(torch.distributed.rpc.constants, "DEFAULT_RPC_TIMEOUT_SEC"):
27+
from torch.distributed.rpc.constants import DEFAULT_RPC_TIMEOUT_SEC
2528

2629

2730
class RPCPlugin(DDPPlugin):
@@ -33,7 +36,8 @@ class RPCPlugin(DDPPlugin):
3336
that need to be addressed when using RPC communication when building custom RPC Plugins.
3437
"""
3538

36-
def __init__(self, **kwargs):
39+
def __init__(self, rpc_timeout_sec: float = DEFAULT_RPC_TIMEOUT_SEC, **kwargs):
40+
self.rpc_timeout_sec = rpc_timeout_sec
3741
self.rpc_initialized = False
3842
super().__init__(**kwargs)
3943

@@ -42,6 +46,7 @@ def init_rpc_connection(self,
4246
world_size: int) -> None:
4347
os.environ['MASTER_PORT'] = os.getenv('RPC_MASTER_PORT', '15000')
4448
rpc.init_rpc(f"worker{global_rank}", rank=global_rank, world_size=world_size)
49+
rpc._set_rpc_timeout(self.rpc_timeout_sec)
4550
self.rpc_initialized = True
4651

4752
def rpc_save_model(self,

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import pytorch_lightning
2323
from pytorch_lightning import _logger as log
24+
from pytorch_lightning.callbacks import ModelCheckpoint
2425
from pytorch_lightning.core.lightning import LightningModule
2526
from pytorch_lightning.utilities import AMPType, APEX_AVAILABLE, OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn
2627
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
@@ -63,7 +64,7 @@ def restore_weights(self, model: LightningModule) -> None:
6364
rank_zero_info(f'restored hpc model from: {checkpoint_path}')
6465

6566
# 2. Attempt to restore states from `resume_from_checkpoint` file
66-
elif self.trainer.resume_from_checkpoint is not None:
67+
elif self.trainer.resume_from_checkpoint is not None and not self.trainer.testing:
6768
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu)
6869

6970
# wait for all to catch up

pytorch_lightning/trainer/training_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def setup_training(self, model: LightningModule):
181181
# if cluster resets state, the model will update with the saved weights
182182
self.trainer.model = model
183183

184-
# restore training and model before hpc is called
184+
# restore training state and model weights before hpc is called
185185
self.trainer.checkpoint_connector.restore_weights(model)
186186

187187
# on pretrain routine end

tests/checkpointing/test_model_checkpoint.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,20 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from argparse import Namespace
1415
import os
16+
from pathlib import Path
1517
import pickle
1618
import platform
1719
import re
18-
from argparse import Namespace
19-
from pathlib import Path
2020
from unittest import mock
2121
from unittest.mock import Mock
2222

2323
import cloudpickle
24+
from omegaconf import Container, OmegaConf
2425
import pytest
2526
import torch
2627
import yaml
27-
from omegaconf import Container, OmegaConf
2828

2929
import pytorch_lightning as pl
3030
import tests.base.develop_utils as tutils
@@ -34,6 +34,7 @@
3434
from pytorch_lightning.utilities.cloud_io import load as pl_load
3535
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3636
from tests.base import BoringModel
37+
import tests.base.develop_utils as tutils
3738

3839

3940
class LogInTwoMethods(BoringModel):
@@ -760,9 +761,9 @@ def assert_checkpoint_log_dir(idx):
760761
model = ExtendedBoringModel()
761762
trainer.test(model)
762763
assert not trainer.checkpoint_connector.has_trained
763-
assert trainer.global_step == epochs * limit_train_batches
764-
assert trainer.current_epoch == epochs
765-
764+
# resume_from_checkpoint is resumed when calling `.fit`
765+
assert trainer.global_step == 0
766+
assert trainer.current_epoch == 0
766767
trainer.fit(model)
767768
assert not trainer.checkpoint_connector.has_trained
768769
assert trainer.global_step == epochs * limit_train_batches
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from copy import deepcopy
15+
import os
16+
17+
import torch
18+
19+
import pytorch_lightning as pl
20+
from pytorch_lightning import seed_everything, Trainer
21+
from pytorch_lightning.callbacks import ModelCheckpoint
22+
from pytorch_lightning.utilities.cloud_io import load as pl_load
23+
from tests.base import BoringModel
24+
25+
26+
def test_finetuning_with_resume_from_checkpoint(tmpdir):
27+
"""
28+
This test validates that generated ModelCheckpoint is pointing to the right best_model_path during test
29+
"""
30+
31+
seed_everything(3)
32+
33+
checkpoint_callback = ModelCheckpoint(monitor='val_loss', dirpath=tmpdir, filename="{epoch:02d}", save_top_k=-1)
34+
35+
class ExtendedBoringModel(BoringModel):
36+
37+
def configure_optimizers(self):
38+
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.001)
39+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
40+
return [optimizer], [lr_scheduler]
41+
42+
def validation_step(self, batch, batch_idx):
43+
output = self.layer(batch)
44+
loss = self.loss(batch, output)
45+
self.log("val_loss", loss, on_epoch=True, prog_bar=True)
46+
47+
model = ExtendedBoringModel()
48+
model.validation_epoch_end = None
49+
trainer = Trainer(
50+
default_root_dir=tmpdir,
51+
max_epochs=1,
52+
limit_train_batches=12,
53+
limit_val_batches=6,
54+
limit_test_batches=12,
55+
callbacks=[checkpoint_callback],
56+
logger=False,
57+
)
58+
trainer.fit(model)
59+
assert os.listdir(tmpdir) == ['epoch=00.ckpt']
60+
61+
best_model_paths = [checkpoint_callback.best_model_path]
62+
results = []
63+
64+
for idx in range(3, 6):
65+
# load from checkpoint
66+
trainer = pl.Trainer(
67+
default_root_dir=tmpdir,
68+
max_epochs=idx,
69+
limit_train_batches=12,
70+
limit_val_batches=12,
71+
limit_test_batches=12,
72+
resume_from_checkpoint=best_model_paths[-1],
73+
progress_bar_refresh_rate=0,
74+
)
75+
trainer.fit(model)
76+
trainer.test()
77+
results.append(deepcopy(trainer.callback_metrics))
78+
best_model_paths.append(trainer.checkpoint_callback.best_model_path)
79+
80+
for idx in range(len(results) - 1):
81+
assert results[idx]["val_loss"] > results[idx + 1]["val_loss"]
82+
83+
for idx, best_model_path in enumerate(best_model_paths):
84+
if idx == 0:
85+
assert best_model_path.endswith(f"epoch=0{idx}.ckpt")
86+
else:
87+
assert f"epoch={idx + 1}" in best_model_path

0 commit comments

Comments
 (0)