Skip to content

Commit b046f18

Browse files
authored
Merge branch 'master' into multi_opt
2 parents ab4efc7 + d568533 commit b046f18

File tree

15 files changed

+168
-35
lines changed

15 files changed

+168
-35
lines changed

.drone.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ steps:
3030
MKL_THREADING_LAYER: GNU
3131

3232
commands:
33+
- set -e
3334
- python --version
3435
- pip --version
3536
- nvidia-smi

CHANGELOG.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8-
## [1.1.3rc] - 2020-12-29
8+
## [1.1.3] - 2021-01-05
99

1010
### Added
1111

@@ -25,12 +25,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2525

2626
### Fixed
2727

28+
- Skip restore from `resume_from_checkpoint` in while `testing` ([#5161](https://github.com/PyTorchLightning/pytorch-lightning/pull/5161))
29+
2830
- Allowed `log_momentum` for adaptive optimizers in `LearningRateMonitor` ([#5333](https://github.com/PyTorchLightning/pytorch-lightning/pull/5333))
2931

3032
- Disabled checkpointing, earlystopping and logger with `fast_dev_run` ([#5277](https://github.com/PyTorchLightning/pytorch-lightning/pull/5277))
3133

3234

33-
3435
## [1.1.2] - 2020-12-23
3536

3637
### Added

docs/source/transfer_learning.rst

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,22 @@ Example: Imagenet (computer Vision)
5252

5353
class ImagenetTransferLearning(LightningModule):
5454
def __init__(self):
55+
super().__init__()
56+
5557
# init a pretrained resnet
56-
num_target_classes = 10
57-
self.feature_extractor = models.resnet50(pretrained=True)
58-
self.feature_extractor.eval()
58+
backbone = models.resnet50(pretrained=True)
59+
num_filters = backbone.fc.in_features
60+
layers = list(backbone.children())[:-1]
61+
self.feature_extractor = torch.nn.Sequential(*layers)
5962
6063
# use the pretrained model to classify cifar-10 (10 image classes)
61-
self.classifier = nn.Linear(2048, num_target_classes)
64+
num_target_classes = 10
65+
self.classifier = nn.Linear(num_filters, num_target_classes)
6266

6367
def forward(self, x):
64-
representations = self.feature_extractor(x)
68+
self.feature_extractor.eval()
69+
with torch.no_grad():
70+
representations = self.feature_extractor(x).flatten(1)
6571
x = self.classifier(representations)
6672
...
6773

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]:
208208
"best_model_score": self.best_model_score,
209209
"best_model_path": self.best_model_path,
210210
"current_score": self.current_score,
211+
"dirpath": self.dirpath
211212
}
212213

213214
def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]):

pytorch_lightning/core/lightning.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414

1515
"""nn.Module with additional great features."""
1616

17-
from abc import ABC
18-
from argparse import Namespace
1917
import collections
2018
import copy
2119
import inspect
2220
import os
23-
from pathlib import Path
2421
import re
2522
import tempfile
23+
from abc import ABC
24+
from argparse import Namespace
25+
from pathlib import Path
2626
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
2727

2828
import torch
@@ -1331,9 +1331,17 @@ def tbptt_split_batch(self, batch, split_size):
13311331

13321332
return splits
13331333

1334-
def summarize(self, mode: str = ModelSummary.MODE_DEFAULT) -> ModelSummary:
1335-
model_summary = ModelSummary(self, mode=mode)
1336-
log.info("\n" + str(model_summary))
1334+
def summarize(self, mode: Optional[str] = ModelSummary.MODE_DEFAULT) -> Optional[ModelSummary]:
1335+
model_summary = None
1336+
1337+
if mode in ModelSummary.MODES:
1338+
model_summary = ModelSummary(self, mode=mode)
1339+
log.info("\n" + str(model_summary))
1340+
elif mode is not None:
1341+
raise MisconfigurationException(
1342+
f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}"
1343+
)
1344+
13371345
return model_summary
13381346

13391347
def freeze(self) -> None:

pytorch_lightning/metrics/classification/precision_recall.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
207207

208208
def compute(self):
209209
"""
210-
Computes accuracy over state.
210+
Computes recall over state.
211211
"""
212212
if self.average == 'micro':
213213
return self.true_positives.sum().float() / (self.actual_positives.sum() + METRIC_EPS)

pytorch_lightning/plugins/rpc_plugin.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,19 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15-
from typing import Any, Optional
15+
from typing import Optional
1616

1717
import torch
1818

1919
from pytorch_lightning.core.lightning import LightningModule
2020
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
21-
from pytorch_lightning.utilities import RPC_AVAILABLE
21+
from pytorch_lightning.utilities import _module_available, RPC_AVAILABLE
2222

23+
DEFAULT_RPC_TIMEOUT_SEC = 60.
2324
if RPC_AVAILABLE:
2425
from torch.distributed import rpc
26+
if _module_available("torch.distributed.rpc.constants") and hasattr(torch.distributed.rpc.constants, "DEFAULT_RPC_TIMEOUT_SEC"):
27+
from torch.distributed.rpc.constants import DEFAULT_RPC_TIMEOUT_SEC
2528

2629

2730
class RPCPlugin(DDPPlugin):
@@ -33,7 +36,8 @@ class RPCPlugin(DDPPlugin):
3336
that need to be addressed when using RPC communication when building custom RPC Plugins.
3437
"""
3538

36-
def __init__(self, **kwargs):
39+
def __init__(self, rpc_timeout_sec: float = DEFAULT_RPC_TIMEOUT_SEC, **kwargs):
40+
self.rpc_timeout_sec = rpc_timeout_sec
3741
self.rpc_initialized = False
3842
super().__init__(**kwargs)
3943

@@ -42,6 +46,7 @@ def init_rpc_connection(self,
4246
world_size: int) -> None:
4347
os.environ['MASTER_PORT'] = os.getenv('RPC_MASTER_PORT', '15000')
4448
rpc.init_rpc(f"worker{global_rank}", rank=global_rank, world_size=world_size)
49+
rpc._set_rpc_timeout(self.rpc_timeout_sec)
4550
self.rpc_initialized = True
4651

4752
def rpc_save_model(self,

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import pytorch_lightning
2323
from pytorch_lightning import _logger as log
24+
from pytorch_lightning.callbacks import ModelCheckpoint
2425
from pytorch_lightning.core.lightning import LightningModule
2526
from pytorch_lightning.utilities import AMPType, APEX_AVAILABLE, OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn
2627
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
@@ -63,7 +64,7 @@ def restore_weights(self, model: LightningModule) -> None:
6364
rank_zero_info(f'restored hpc model from: {checkpoint_path}')
6465

6566
# 2. Attempt to restore states from `resume_from_checkpoint` file
66-
elif self.trainer.resume_from_checkpoint is not None:
67+
elif self.trainer.resume_from_checkpoint is not None and not self.trainer.testing:
6768
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu)
6869

6970
# wait for all to catch up

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,6 @@ def __init__(
311311
self.plugin_connector = PluginConnector(self)
312312

313313
# training state
314-
self.weights_summary = weights_summary
315314
self.model = None
316315
self.shown_warnings = set()
317316

@@ -374,7 +373,8 @@ def __init__(
374373
max_steps,
375374
min_steps,
376375
num_sanity_val_steps,
377-
automatic_optimization
376+
automatic_optimization,
377+
weights_summary,
378378
)
379379
self.evaluation_loop.on_trainer_init()
380380

pytorch_lightning/trainer/training_loop.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,14 @@ def __init__(self, trainer):
4949
self._cur_grad_norm_dict = None
5050

5151
def on_trainer_init(
52-
self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps, automatic_optimization
52+
self,
53+
max_epochs,
54+
min_epochs,
55+
max_steps,
56+
min_steps,
57+
num_sanity_val_steps,
58+
automatic_optimization,
59+
weights_summary,
5360
):
5461
self.trainer.global_step = 0
5562
self.trainer.current_epoch = 0
@@ -73,6 +80,12 @@ def on_trainer_init(
7380
else:
7481
self.trainer.num_sanity_val_steps = num_sanity_val_steps
7582

83+
self.trainer.weights_summary = weights_summary
84+
if weights_summary is not None and weights_summary not in ModelSummary.MODES:
85+
raise MisconfigurationException(
86+
f"`weights_summary` can be None, {', '.join(ModelSummary.MODES)}, got {weights_summary}"
87+
)
88+
7689
@property
7790
def num_optimizers(self):
7891
num_optimizers = len(self.get_optimizers_iterable())
@@ -161,17 +174,14 @@ def setup_training(self, model: LightningModule):
161174
ref_model.on_pretrain_routine_start()
162175

163176
# print model summary
164-
if self.trainer.is_global_zero and self.trainer.weights_summary is not None and not self.trainer.testing:
165-
if self.trainer.weights_summary in ModelSummary.MODES:
166-
ref_model.summarize(mode=self.trainer.weights_summary)
167-
else:
168-
raise MisconfigurationException("weights_summary can be None, " + ", ".join(ModelSummary.MODES))
177+
if self.trainer.is_global_zero and not self.trainer.testing:
178+
ref_model.summarize(mode=self.trainer.weights_summary)
169179

170180
# track model now.
171181
# if cluster resets state, the model will update with the saved weights
172182
self.trainer.model = model
173183

174-
# restore training and model before hpc is called
184+
# restore training state and model weights before hpc is called
175185
self.trainer.checkpoint_connector.restore_weights(model)
176186

177187
# on pretrain routine end

0 commit comments

Comments
 (0)