Skip to content

Commit be44f3c

Browse files
authored
Merge branch 'master' into feat-sync_step
2 parents e4b401e + ec0fb7a commit be44f3c

File tree

21 files changed

+203
-100
lines changed

21 files changed

+203
-100
lines changed

.drone.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ steps:
3030
MKL_THREADING_LAYER: GNU
3131

3232
commands:
33+
- set -e
3334
- python --version
3435
- pip --version
3536
- nvidia-smi

.github/workflows/ci_test-base.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: CI base testing
1+
name: CI basic testing
22

33
# see: https://help.github.com/en/actions/reference/events-that-trigger-workflows
44
on: # Trigger the workflow on push or pull request, but only for the master branch

CHANGELOG.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8-
## [1.1.3rc] - 2020-12-29
8+
## [1.1.3] - 2021-01-05
99

1010
### Added
1111

@@ -25,12 +25,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2525

2626
### Fixed
2727

28+
- Skip restore from `resume_from_checkpoint` in while `testing` ([#5161](https://github.com/PyTorchLightning/pytorch-lightning/pull/5161))
29+
2830
- Allowed `log_momentum` for adaptive optimizers in `LearningRateMonitor` ([#5333](https://github.com/PyTorchLightning/pytorch-lightning/pull/5333))
2931

3032
- Disabled checkpointing, earlystopping and logger with `fast_dev_run` ([#5277](https://github.com/PyTorchLightning/pytorch-lightning/pull/5277))
3133

3234

33-
3435
## [1.1.2] - 2020-12-23
3536

3637
### Added

docs/source/transfer_learning.rst

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,22 @@ Example: Imagenet (computer Vision)
5252

5353
class ImagenetTransferLearning(LightningModule):
5454
def __init__(self):
55+
super().__init__()
56+
5557
# init a pretrained resnet
56-
num_target_classes = 10
57-
self.feature_extractor = models.resnet50(pretrained=True)
58-
self.feature_extractor.eval()
58+
backbone = models.resnet50(pretrained=True)
59+
num_filters = backbone.fc.in_features
60+
layers = list(backbone.children())[:-1]
61+
self.feature_extractor = torch.nn.Sequential(*layers)
5962
6063
# use the pretrained model to classify cifar-10 (10 image classes)
61-
self.classifier = nn.Linear(2048, num_target_classes)
64+
num_target_classes = 10
65+
self.classifier = nn.Linear(num_filters, num_target_classes)
6266

6367
def forward(self, x):
64-
representations = self.feature_extractor(x)
68+
self.feature_extractor.eval()
69+
with torch.no_grad():
70+
representations = self.feature_extractor(x).flatten(1)
6571
x = self.classifier(representations)
6672
...
6773

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]:
208208
"best_model_score": self.best_model_score,
209209
"best_model_path": self.best_model_path,
210210
"current_score": self.current_score,
211+
"dirpath": self.dirpath
211212
}
212213

213214
def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]):

pytorch_lightning/loggers/__init__.py

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,40 +24,25 @@
2424
'CSVLogger',
2525
]
2626

27-
try:
28-
# needed to prevent ImportError and duplicated logs.
29-
environ["COMET_DISABLE_AUTO_LOGGING"] = "1"
27+
from pytorch_lightning.loggers.comet import _COMET_AVAILABLE, CometLogger
28+
from pytorch_lightning.loggers.mlflow import _MLFLOW_AVAILABLE, MLFlowLogger
29+
from pytorch_lightning.loggers.neptune import _NEPTUNE_AVAILABLE, NeptuneLogger
30+
from pytorch_lightning.loggers.test_tube import _TESTTUBE_AVAILABLE, TestTubeLogger
31+
from pytorch_lightning.loggers.wandb import _WANDB_AVAILABLE, WandbLogger
3032

31-
from pytorch_lightning.loggers.comet import CometLogger
32-
except ImportError: # pragma: no-cover
33-
del environ["COMET_DISABLE_AUTO_LOGGING"] # pragma: no-cover
34-
else:
33+
if _COMET_AVAILABLE:
3534
__all__.append('CometLogger')
35+
# needed to prevent ImportError and duplicated logs.
36+
environ["COMET_DISABLE_AUTO_LOGGING"] = "1"
3637

37-
try:
38-
from pytorch_lightning.loggers.mlflow import MLFlowLogger
39-
except ImportError: # pragma: no-cover
40-
pass # pragma: no-cover
41-
else:
38+
if _MLFLOW_AVAILABLE:
4239
__all__.append('MLFlowLogger')
4340

44-
try:
45-
from pytorch_lightning.loggers.neptune import NeptuneLogger
46-
except ImportError: # pragma: no-cover
47-
pass # pragma: no-cover
48-
else:
41+
if _NEPTUNE_AVAILABLE:
4942
__all__.append('NeptuneLogger')
5043

51-
try:
52-
from pytorch_lightning.loggers.test_tube import TestTubeLogger
53-
except ImportError: # pragma: no-cover
54-
pass # pragma: no-cover
55-
else:
44+
if _TESTTUBE_AVAILABLE:
5645
__all__.append('TestTubeLogger')
5746

58-
try:
59-
from pytorch_lightning.loggers.wandb import WandbLogger
60-
except ImportError: # pragma: no-cover
61-
pass # pragma: no-cover
62-
else:
47+
if _WANDB_AVAILABLE:
6348
__all__.append('WandbLogger')

pytorch_lightning/loggers/comet.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,18 @@
2121
from argparse import Namespace
2222
from typing import Any, Dict, Optional, Union
2323

24-
try:
25-
import comet_ml
24+
import torch
25+
from torch import is_tensor
2626

27-
except ModuleNotFoundError: # pragma: no-cover
28-
comet_ml = None
29-
CometExperiment = None
30-
CometExistingExperiment = None
31-
CometOfflineExperiment = None
32-
API = None
33-
generate_guid = None
34-
else:
27+
from pytorch_lightning import _logger as log
28+
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
29+
from pytorch_lightning.utilities import rank_zero_only, _module_available
30+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
31+
32+
_COMET_AVAILABLE = _module_available("comet_ml")
33+
34+
if _COMET_AVAILABLE:
35+
import comet_ml
3536
from comet_ml import ExistingExperiment as CometExistingExperiment
3637
from comet_ml import Experiment as CometExperiment
3738
from comet_ml import OfflineExperiment as CometOfflineExperiment
@@ -41,14 +42,11 @@
4142
except ImportError: # pragma: no-cover
4243
# For more information, see: https://www.comet.ml/docs/python-sdk/releases/#release-300
4344
from comet_ml.papi import API # pragma: no-cover
44-
45-
import torch
46-
from torch import is_tensor
47-
48-
from pytorch_lightning import _logger as log
49-
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
50-
from pytorch_lightning.utilities import rank_zero_only
51-
from pytorch_lightning.utilities.exceptions import MisconfigurationException
45+
else:
46+
# needed for test mocks, these tests shall be updated
47+
comet_ml = None
48+
CometExperiment, CometExistingExperiment, CometOfflineExperiment = None, None, None
49+
API = None
5250

5351

5452
class CometLogger(LightningLoggerBase):

pytorch_lightning/loggers/mlflow.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,25 @@
2121
from time import time
2222
from typing import Any, Dict, Optional, Union
2323

24-
try:
25-
import mlflow
26-
from mlflow.tracking import MlflowClient
27-
except ModuleNotFoundError: # pragma: no-cover
28-
mlflow = None
29-
MlflowClient = None
30-
3124

3225
from pytorch_lightning import _logger as log
3326
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
34-
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
27+
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn, _module_available
28+
3529

3630
LOCAL_FILE_URI_PREFIX = "file:"
3731

3832

33+
_MLFLOW_AVAILABLE = _module_available("mlflow")
34+
try:
35+
import mlflow
36+
from mlflow.tracking import MlflowClient
37+
# todo: there seems to be still some remaining import error with Conda env
38+
except ImportError:
39+
_MLFLOW_AVAILABLE = False
40+
mlflow, MlflowClient = None, None
41+
42+
3943
class MLFlowLogger(LightningLoggerBase):
4044
"""
4145
Log using `MLflow <https://mlflow.org>`_.

pytorch_lightning/loggers/neptune.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,23 @@
1717
--------------
1818
"""
1919
from argparse import Namespace
20-
from typing import Any, Dict, Iterable, List, Optional, Union
21-
22-
try:
23-
import neptune
24-
from neptune.experiments import Experiment
25-
except ImportError: # pragma: no-cover
26-
neptune = None
27-
Experiment = None
20+
from typing import Any, Dict, Iterable, Optional, Union
2821

2922
import torch
3023
from torch import is_tensor
3124

3225
from pytorch_lightning import _logger as log
3326
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
34-
from pytorch_lightning.utilities import rank_zero_only
27+
from pytorch_lightning.utilities import rank_zero_only, _module_available
28+
29+
_NEPTUNE_AVAILABLE = _module_available("neptune")
30+
31+
if _NEPTUNE_AVAILABLE:
32+
import neptune
33+
from neptune.experiments import Experiment
34+
else:
35+
# needed for test mocks, these tests shall be updated
36+
neptune, Experiment = None, None
3537

3638

3739
class NeptuneLogger(LightningLoggerBase):

pytorch_lightning/loggers/tensorboard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def save(self) -> None:
217217
hparams_file = os.path.join(dir_path, self.NAME_HPARAMS_FILE)
218218

219219
# save the metatags file if it doesn't exist
220-
if not os.path.isfile(hparams_file):
220+
if not self._fs.isfile(hparams_file):
221221
save_hparams_to_yaml(hparams_file, self.hparams)
222222

223223
@rank_zero_only

0 commit comments

Comments
 (0)