`__
@@ -103,8 +107,9 @@ def __init__(
self._log_model = log_model
self._prefix = prefix
self._kwargs = kwargs
- # logging multiple Trainer on a single W&B run (k-fold, etc)
+ # logging multiple Trainer on a single W&B run (k-fold, resuming, etc)
self._step_offset = 0
+ self.warning_cache = WarningCache()
def __getstate__(self):
state = self.__dict__.copy()
@@ -134,6 +139,8 @@ def experiment(self) -> Run:
self._experiment = wandb.init(
name=self._name, dir=self._save_dir, project=self._project, anonymous=self._anonymous,
id=self._id, resume='allow', **self._kwargs) if wandb.run is None else wandb.run
+ # offset logging step when resuming a run
+ self._step_offset = self._experiment.step
# save checkpoints in wandb dir to upload on W&B servers
if self._log_model:
self._save_dir = self._experiment.dir
@@ -154,6 +161,10 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'
metrics = self._add_prefix(metrics)
+ if step is not None and step + self._step_offset < self.experiment.step:
+ self.warning_cache.warn(
+ 'Trying to log at a previous step. Use `commit=False` when logging metrics manually.'
+ )
self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None)
@property
diff --git a/pytorch_lightning/setup_tools.py b/pytorch_lightning/setup_tools.py
index 26a607a2955b8..07f5545df8a54 100644
--- a/pytorch_lightning/setup_tools.py
+++ b/pytorch_lightning/setup_tools.py
@@ -14,12 +14,12 @@
# limitations under the License.
import os
import re
-import warnings
from typing import Iterable, List
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen
+import warnings
-from pytorch_lightning import PROJECT_ROOT, __homepage__, __version__
+from pytorch_lightning import __homepage__, __version__, _PROJECT_ROOT
_PATH_BADGES = os.path.join('.', 'docs', 'source', '_images', 'badges')
# badge to download
@@ -37,7 +37,7 @@
def _load_requirements(path_dir: str , file_name: str = 'requirements.txt', comment_char: str = '#') -> List[str]:
"""Load requirements from a file
- >>> _load_requirements(PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
+ >>> _load_requirements(_PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
['numpy...', 'torch...', ...]
"""
with open(os.path.join(path_dir, file_name), 'r') as file:
@@ -155,7 +155,7 @@ def _download_badge(url_badge: str, badge_name: str, target_dir: str) -> str:
def _load_long_description(path_dir: str) -> str:
"""Load readme as decribtion
- >>> _load_long_description(PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
+ >>> _load_long_description(_PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
'...'
"""
path_readme = os.path.join(path_dir, "README.md")
diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py
index 28025859814cc..6d206f3dd929e 100644
--- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py
+++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py
@@ -91,11 +91,13 @@ def check_dataloader_idx(self, result: Result) -> bool:
random_key = list(result.keys())[-1]
return result["meta"][random_key]["dataloader_idx"] is not None
- def get_latest_from_func_name(self, latest_result, func_name: str, *args, **kwargs) -> Dict:
+ def get_latest_from_func_name(self, latest_result_opt, func_name: str, *args, **kwargs) -> Dict:
results = {}
- add_dataloader_idx = self.check_dataloader_idx(latest_result)
- func = getattr(latest_result, func_name)
- results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs))
+ for opt_idx in latest_result_opt:
+ latest_result = latest_result_opt[opt_idx]
+ add_dataloader_idx = self.check_dataloader_idx(latest_result)
+ func = getattr(latest_result, func_name)
+ results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs))
return results
def run_latest_batch_metrics_with_func_name(self, func_name, *args, **kwargs) -> List[Dict]:
@@ -156,6 +158,7 @@ def append(self, result, dataloader_idx: Optional[int] = None, extra_info: Optio
assert isinstance(result, Result)
if dataloader_idx is None:
dataloader_idx = 0
+
if extra_info is None:
extra_info = {}
@@ -166,6 +169,7 @@ def append(self, result, dataloader_idx: Optional[int] = None, extra_info: Optio
if dataloader_idx not in self._internals:
self._internals[dataloader_idx] = {}
self._internals_reduced[dataloader_idx] = defaultdict(dict)
+ self._latest_ref[dataloader_idx] = {}
# extract infos
opt_idx = extra_info["opt_idx"]
@@ -173,7 +177,7 @@ def append(self, result, dataloader_idx: Optional[int] = None, extra_info: Optio
self._append_to_structure(self._internals[dataloader_idx], opt_idx, batch_idx, result)
- self._latest_ref[dataloader_idx] = result
+ self._latest_ref[dataloader_idx][opt_idx] = result
# [dataloader_idx] is a list
else:
@@ -181,7 +185,11 @@ def append(self, result, dataloader_idx: Optional[int] = None, extra_info: Optio
self._internals.setdefault(dataloader_idx, [])
self._internals[dataloader_idx].append(result)
- self._latest_ref[dataloader_idx] = result
+ if dataloader_idx not in self._latest_ref:
+ self._latest_ref[dataloader_idx] = {}
+ self._latest_ref[dataloader_idx][0] = {}
+
+ self._latest_ref[dataloader_idx][0] = result
def auto_reduce_results_on_epoch_end(self) -> None:
"""
@@ -206,13 +214,9 @@ def auto_reduce_results_on_epoch_end(self) -> None:
# TODO: How to start training in middle of epoch
opt_outputs = epoch_metrics[opt_idx]
- num_batch_idx = len(self._internals[dl_idx][num_opt_idx]) - 1
- assert num_batch_idx >= 0
- batch_indexes = self._internals[dl_idx][num_opt_idx].keys()
-
# reduce across time first
time_reduced_outputs = []
- for batch_idx in batch_indexes:
+ for batch_idx in opt_outputs.keys():
tbptt_outs = opt_outputs[batch_idx]
tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs)
if len(tbptt_outs) > 1:
diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py
index db51fb8014de0..04fa3f4cc842b 100644
--- a/pytorch_lightning/trainer/supporters.py
+++ b/pytorch_lightning/trainer/supporters.py
@@ -56,7 +56,7 @@ def __init__(self, window_length: int):
def reset(self) -> None:
"""Empty the accumulator."""
- self = TensorRunningAccum(self.window_length)
+ self.__init__(self.window_length)
def last(self):
"""Get the last added element."""
diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py
index 014e0a62679dd..06cdc43674d1b 100644
--- a/pytorch_lightning/trainer/trainer.py
+++ b/pytorch_lightning/trainer/trainer.py
@@ -133,7 +133,7 @@ def __init__(
distributed_backend: Optional[str] = None,
automatic_optimization: Optional[bool] = None,
move_metrics_to_cpu: bool = False,
- enable_pl_optimizer: bool = True,
+ enable_pl_optimizer: bool = False,
multiple_trainloader_mode: str = 'max_size_cycle',
):
r"""
diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py
index 9724f05247c00..be5d781939c04 100644
--- a/pytorch_lightning/utilities/distributed.py
+++ b/pytorch_lightning/utilities/distributed.py
@@ -15,14 +15,14 @@
import os
import warnings
from functools import wraps
+from typing import Any, Optional, Union
import torch
+
from pytorch_lightning import _logger as log
-from typing import Union, Optional, Any
if torch.distributed.is_available():
- from torch.distributed import ReduceOp
- from torch.distributed import group
+ from torch.distributed import ReduceOp, group
else:
class ReduceOp:
SUM = None
@@ -145,15 +145,14 @@ def sync_ddp(
if group is None:
group = torch.distributed.group.WORLD
- if reduce_op is None:
- reduce_op = torch.distributed.ReduceOp.SUM
- elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"):
- reduce_op = torch.distributed.ReduceOp.SUM
+ op = reduce_op if isinstance(reduce_op, ReduceOp) else ReduceOp.SUM
+
+ if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
divide_by_world_size = True
# sync all processes before reduction
torch.distributed.barrier(group=group)
- torch.distributed.all_reduce(result, op=reduce_op, group=group, async_op=False)
+ torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
if divide_by_world_size:
result = result / torch.distributed.get_world_size(group)
@@ -207,6 +206,6 @@ def all_gather_ddp_if_available(
if sync_grads:
return AllGatherGrad.apply(tensor, group)
else:
- with torch.no_grad:
+ with torch.no_grad():
return AllGatherGrad.apply(tensor, group)
return tensor
diff --git a/requirements/examples.txt b/requirements/examples.txt
index 6e48778cb222a..c87d10a39346f 100644
--- a/requirements/examples.txt
+++ b/requirements/examples.txt
@@ -1,2 +1,2 @@
torchvision>=0.4.1
-gym>=0.17.0
+gym>=0.17.0
\ No newline at end of file
diff --git a/requirements/test.txt b/requirements/test.txt
index 3cb538a98d7c8..632f40e0287b4 100644
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -17,3 +17,4 @@ pre-commit>=1.0
cloudpickle>=1.3
nltk>=3.3
+pandas # needed in benchmarks
diff --git a/setup.cfg b/setup.cfg
index 4475fb11266d0..7b685fb8dc0e5 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -99,6 +99,10 @@ max-line-length = 120
files = pytorch_lightning, pl_examples, benchmarks, tests
disallow_untyped_defs = True
ignore_missing_imports = True
+show_error_codes = True
+warn_redundant_casts = True
+warn_unused_configs = True
+warn_unused_ignores = True
# todo: add proper typing to this module...
[mypy-pytorch_lightning.callbacks.*]
diff --git a/tests/__init__.py b/tests/__init__.py
index 981d685430da9..e0ec83a2efbca 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -1,18 +1,31 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import os
import numpy as np
-TEST_ROOT = os.path.dirname(__file__)
-PROJECT_ROOT = os.path.dirname(TEST_ROOT)
-TEMP_PATH = os.path.join(PROJECT_ROOT, 'test_temp')
+_TEST_ROOT = os.path.dirname(__file__)
+_PROJECT_ROOT = os.path.dirname(_TEST_ROOT)
+_TEMP_PATH = os.path.join(_PROJECT_ROOT, 'test_temp')
# todo: this setting `PYTHONPATH` may not be used by other evns like Conda for import packages
-if PROJECT_ROOT not in os.getenv('PYTHONPATH', ""):
+if _PROJECT_ROOT not in os.getenv('PYTHONPATH', ""):
splitter = ":" if os.environ.get("PYTHONPATH", "") else ""
- os.environ['PYTHONPATH'] = f'{PROJECT_ROOT}{splitter}{os.environ.get("PYTHONPATH", "")}'
+ os.environ['PYTHONPATH'] = f'{_PROJECT_ROOT}{splitter}{os.environ.get("PYTHONPATH", "")}'
# generate a list of random seeds for each test
RANDOM_PORTS = list(np.random.randint(12000, 19000, 1000))
-if not os.path.isdir(TEMP_PATH):
- os.mkdir(TEMP_PATH)
+if not os.path.isdir(_TEMP_PATH):
+ os.mkdir(_TEMP_PATH)
diff --git a/tests/base/datasets.py b/tests/base/datasets.py
index 854d69b54eaf8..33d3801c432ab 100644
--- a/tests/base/datasets.py
+++ b/tests/base/datasets.py
@@ -22,10 +22,10 @@
from torch import Tensor
from torch.utils.data import Dataset
-from tests import PROJECT_ROOT
+from tests import _PROJECT_ROOT
#: local path to test datasets
-PATH_DATASETS = os.path.join(PROJECT_ROOT, 'Datasets')
+PATH_DATASETS = os.path.join(_PROJECT_ROOT, 'Datasets')
class MNIST(Dataset):
@@ -63,8 +63,13 @@ class MNIST(Dataset):
TEST_FILE_NAME = 'test.pt'
cache_folder_name = 'complete'
- def __init__(self, root: str = PATH_DATASETS, train: bool = True,
- normalize: tuple = (0.5, 1.0), download: bool = True):
+ def __init__(
+ self,
+ root: str = PATH_DATASETS,
+ train: bool = True,
+ normalize: tuple = (0.5, 1.0),
+ download: bool = True,
+ ):
super().__init__()
self.root = root
self.train = train # training set or test set
diff --git a/tests/base/develop_utils.py b/tests/base/develop_utils.py
index 3db8eb022288a..6eb19d3c4b1e4 100644
--- a/tests/base/develop_utils.py
+++ b/tests/base/develop_utils.py
@@ -19,7 +19,7 @@
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger, TestTubeLogger
-from tests import TEMP_PATH, RANDOM_PORTS
+from tests import _TEMP_PATH, RANDOM_PORTS
from tests.base.model_template import EvalModelTemplate
@@ -63,7 +63,7 @@ def get_data_path(expt_logger, path_dir=None):
if hasattr(expt_logger, 'save_dir') and expt_logger.save_dir:
path_dir = expt_logger.save_dir
else:
- path_dir = TEMP_PATH
+ path_dir = _TEMP_PATH
path_expt = os.path.join(path_dir, name, 'version_%s' % version)
# try if the new sub-folder exists, typical case for test-tube
diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py
index c00c712bb3b13..070bb4e9f6989 100644
--- a/tests/callbacks/test_callbacks.py
+++ b/tests/callbacks/test_callbacks.py
@@ -33,6 +33,8 @@ def test_trainer_callback_system(torch_save):
limit_train_batches=3,
limit_test_batches=2,
progress_bar_refresh_rate=0,
+ # todo: enabled since internally we wrap the model for optimizer step, this should be fixed
+ enable_pl_optimizer=True
)
# no call yet
diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py
index 27f484c63d87c..1f3e44f58173e 100644
--- a/tests/checkpointing/test_model_checkpoint.py
+++ b/tests/checkpointing/test_model_checkpoint.py
@@ -905,3 +905,42 @@ def __init__(self, hparams):
else:
# make sure it's not AttributeDict
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type
+
+
+@pytest.mark.parametrize('max_epochs', [3, 4])
+@pytest.mark.parametrize(
+ 'save_top_k, expected',
+ [
+ (1, ['curr_epoch.ckpt']),
+ (2, ['curr_epoch.ckpt', 'curr_epoch-v0.ckpt']),
+ ]
+)
+def test_model_checkpoint_file_already_exists(tmpdir, max_epochs, save_top_k, expected):
+ """
+ Test that version is added to filename if required and it already exists in dirpath.
+ """
+ model_checkpoint = ModelCheckpoint(
+ dirpath=tmpdir,
+ filename='curr_epoch',
+ save_top_k=save_top_k,
+ monitor='epoch',
+ mode='max',
+ )
+ trainer = Trainer(
+ default_root_dir=tmpdir,
+ callbacks=[model_checkpoint],
+ max_epochs=max_epochs,
+ limit_train_batches=2,
+ limit_val_batches=2,
+ logger=None,
+ weights_summary=None,
+ progress_bar_refresh_rate=0,
+ )
+
+ model = BoringModel()
+ trainer.fit(model)
+ ckpt_files = os.listdir(tmpdir)
+ assert set(ckpt_files) == set(expected)
+
+ epochs_in_ckpt_files = [pl_load(os.path.join(tmpdir, f))['epoch'] - 1 for f in ckpt_files]
+ assert sorted(epochs_in_ckpt_files) == list(range(max_epochs - save_top_k, max_epochs))
diff --git a/tests/collect_env_details.py b/tests/collect_env_details.py
index 1d443795d2876..2b8c4b3fafeed 100644
--- a/tests/collect_env_details.py
+++ b/tests/collect_env_details.py
@@ -1,3 +1,16 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""Diagnose your system and show basic information
This server mainly to get detail info for better bug reporting.
diff --git a/tests/conftest.py b/tests/conftest.py
index ad4b7169456a8..c6a14a99b2478 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,7 +1,21 @@
-import sys
-import threading
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from functools import partial, wraps
from http.server import SimpleHTTPRequestHandler
+import sys
+import threading
import pytest
import torch.multiprocessing as mp
diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py
index ea40814b18861..0aec3b22f74a9 100644
--- a/tests/loggers/test_all.py
+++ b/tests/loggers/test_all.py
@@ -74,7 +74,9 @@ def test_loggers_fit_test_all(tmpdir, monkeypatch):
with mock.patch('pytorch_lightning.loggers.test_tube.Experiment'):
_test_loggers_fit_test(tmpdir, TestTubeLogger)
- with mock.patch('pytorch_lightning.loggers.wandb.wandb'):
+ with mock.patch('pytorch_lightning.loggers.wandb.wandb') as wandb:
+ wandb.run = None
+ wandb.init().step = 0
_test_loggers_fit_test(tmpdir, WandbLogger)
@@ -366,7 +368,9 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
logger.experiment.log.assert_called_once_with({"tmp-test": 1.0}, global_step=0)
# WandB
- with mock.patch('pytorch_lightning.loggers.wandb.wandb'):
+ with mock.patch('pytorch_lightning.loggers.wandb.wandb') as wandb:
logger = _instantiate_logger(WandbLogger, save_idr=tmpdir, prefix=prefix)
+ wandb.run = None
+ wandb.init().step = 0
logger.log_metrics({"test": 1.0}, step=0)
logger.experiment.log.assert_called_once_with({'tmp-test': 1.0}, step=0)
diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py
index fa503f5d8eeb1..398ee45ef4aa0 100644
--- a/tests/loggers/test_wandb.py
+++ b/tests/loggers/test_wandb.py
@@ -22,8 +22,14 @@
from tests.base import EvalModelTemplate
+def get_warnings(recwarn):
+ warnings_text = '\n'.join(str(w.message) for w in recwarn.list)
+ recwarn.clear()
+ return warnings_text
+
+
@mock.patch('pytorch_lightning.loggers.wandb.wandb')
-def test_wandb_logger_init(wandb):
+def test_wandb_logger_init(wandb, recwarn):
"""Verify that basic functionality of wandb logger works.
Wandb doesn't work well with pytest so we have to mock it out here."""
@@ -34,6 +40,9 @@ def test_wandb_logger_init(wandb):
wandb.init.assert_called_once()
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None)
+ # mock wandb step
+ wandb.init().step = 0
+
# test wandb.init not called if there is a W&B run
wandb.init().log.reset_mock()
wandb.init.reset_mock()
@@ -49,15 +58,28 @@ def test_wandb_logger_init(wandb):
logger.log_metrics({'acc': 1.0}, step=3)
wandb.init().log.assert_called_with({'acc': 1.0}, step=6)
+ # log hyper parameters
logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]})
wandb.init().config.update.assert_called_once_with(
{'test': 'None', 'nested/a': 1, 'b': [2, 3, 4]},
allow_val_change=True,
)
+ # watch a model
logger.watch('model', 'log', 10)
wandb.init().watch.assert_called_once_with('model', log='log', log_freq=10)
+ # verify warning for logging at a previous step
+ assert 'Trying to log at a previous step' not in get_warnings(recwarn)
+ # current step from wandb should be 6 (last logged step)
+ logger.experiment.step = 6
+ # logging at step 2 should raise a warning (step_offset is still 3)
+ logger.log_metrics({'acc': 1.0}, step=2)
+ assert 'Trying to log at a previous step' in get_warnings(recwarn)
+ # logging again at step 2 should not display again the same warning
+ logger.log_metrics({'acc': 1.0}, step=2)
+ assert 'Trying to log at a previous step' not in get_warnings(recwarn)
+
assert logger.name == wandb.init().project_name()
assert logger.version == wandb.init().id
@@ -71,6 +93,7 @@ def test_wandb_pickle(wandb, tmpdir):
class Experiment:
""" """
id = 'the_id'
+ step = 0
def project_name(self):
return 'the_project_name'
@@ -108,8 +131,11 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):
assert logger.name is None
# mock return values of experiment
+ wandb.run = None
+ wandb.init().step = 0
logger.experiment.id = '1'
logger.experiment.project_name.return_value = 'project'
+ logger.experiment.step = 0
for _ in range(2):
_ = logger.experiment
diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py
index cd61da7c008bc..169552ce1bd75 100644
--- a/tests/models/test_gpu.py
+++ b/tests/models/test_gpu.py
@@ -47,7 +47,7 @@ def test_multi_gpu_none_backend(tmpdir):
tpipes.run_model_test(trainer_options, model)
-@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
+@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.parametrize('gpus', [1, [0], [1]])
def test_single_gpu_model(tmpdir, gpus):
"""Make sure single GPU works (DP mode)."""
diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py
index 3a2ae8750443f..2f11c7df5f26f 100644
--- a/tests/models/test_horovod.py
+++ b/tests/models/test_horovod.py
@@ -44,9 +44,9 @@
from horovod.common.util import nccl_built
nccl_built()
except (ImportError, ModuleNotFoundError, AttributeError):
- HOROVOD_NCCL_AVAILABLE = False
+ _HOROVOD_NCCL_AVAILABLE = False
finally:
- HOROVOD_NCCL_AVAILABLE = True
+ _HOROVOD_NCCL_AVAILABLE = True
def _run_horovod(trainer_options, on_gpu=False):
@@ -105,7 +105,7 @@ def test_horovod_cpu_implicit(enable_pl_optimizer, tmpdir):
@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
-@pytest.mark.skipif(not HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support")
+@pytest.mark.skipif(not _HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_horovod_multi_gpu(tmpdir):
"""Test Horovod with multi-GPU support."""
@@ -125,7 +125,7 @@ def test_horovod_multi_gpu(tmpdir):
@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
-@pytest.mark.skipif(not HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support")
+@pytest.mark.skipif(not _HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
def test_horovod_apex(tmpdir):
@@ -149,7 +149,7 @@ def test_horovod_apex(tmpdir):
@pytest.mark.skip(reason="Skip till Horovod fixes integration with Native torch.cuda.amp")
@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
-@pytest.mark.skipif(not HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support")
+@pytest.mark.skipif(not _HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="test requires torch.cuda.amp")
def test_horovod_amp(tmpdir):
@@ -172,7 +172,7 @@ def test_horovod_amp(tmpdir):
@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
-@pytest.mark.skipif(not HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support")
+@pytest.mark.skipif(not _HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_horovod_transfer_batch_to_gpu(tmpdir):
class TestTrainingStepModel(EvalModelTemplate):
diff --git a/tests/special_tests.sh b/tests/special_tests.sh
index f7cb581951783..950e3776bbc7f 100644
--- a/tests/special_tests.sh
+++ b/tests/special_tests.sh
@@ -19,4 +19,4 @@ python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual_amp
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic
-# python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance
+python ${DEFAULTS} tests/trainer/logging_tests/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp
diff --git a/tests/test_profiler.py b/tests/test_profiler.py
index 3bce379c1115c..91a8631a73287 100644
--- a/tests/test_profiler.py
+++ b/tests/test_profiler.py
@@ -1,6 +1,20 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import os
-import time
from pathlib import Path
+import time
import numpy as np
import pytest
diff --git a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py
index d5a985489a909..f418db2bd72a5 100644
--- a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py
+++ b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py
@@ -26,8 +26,8 @@
from torch.utils.data import Dataset
import pytorch_lightning as pl
-from pytorch_lightning import Trainer, callbacks
-from pytorch_lightning.callbacks import ModelCheckpoint
+from pytorch_lightning import callbacks, Trainer
+from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset
from tests.base.deterministic_model import DeterministicModel
@@ -687,6 +687,7 @@ class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
acc = self.step(batch[0])
self.log('foo', torch.tensor(fake_result), on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum')
+ self.log('foo_2', 2, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum')
return acc
def validation_step(self, batch, batch_idx):
@@ -706,9 +707,46 @@ def validation_step(self, batch, batch_idx):
trainer.fit(model)
assert trainer.logged_metrics['foo'] == fake_result
+ assert trainer.logged_metrics['foo_2'] == 2
assert trainer.logged_metrics['bar'] == fake_result
+@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
+@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
+ reason="test should be run outside of pytest")
+def test_logging_sync_dist_true_ddp(tmpdir):
+ """
+ Tests to ensure that the sync_dist flag works with ddp
+ """
+ class TestLoggingSyncDistModel(BoringModel):
+ def training_step(self, batch, batch_idx):
+ acc = self.step(batch[0])
+ self.log('foo', 1, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='SUM')
+ return acc
+
+ def validation_step(self, batch, batch_idx):
+ self.training_step_called = True
+ output = self.layer(batch)
+ loss = self.loss(batch, output)
+ self.log('bar', 2, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='AVG')
+ return {"x": loss}
+
+ model = TestLoggingSyncDistModel()
+ trainer = Trainer(
+ default_root_dir=tmpdir,
+ limit_train_batches=1,
+ limit_val_batches=1,
+ max_epochs=2,
+ weights_summary=None,
+ accelerator="ddp",
+ gpus=2,
+ )
+ trainer.fit(model)
+
+ assert trainer.logged_metrics['foo'] == 2
+ assert trainer.logged_metrics['bar'] == 2
+
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_logging_sync_dist_true_gpu(tmpdir):
"""
@@ -818,3 +856,47 @@ def on_train_epoch_end(self, trainer, pl_module, outputs):
'on_epoch_end': 5,
'on_train_epoch_end': 6}
assert trainer.callback_metrics == expected
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
+def test_metric_are_properly_reduced(tmpdir):
+ class TestingModel(BoringModel):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ self.val_acc = pl.metrics.Accuracy()
+
+ def training_step(self, batch, batch_idx):
+ output = super().training_step(batch, batch_idx)
+ self.log("train_loss", output["loss"])
+ return output
+
+ def validation_step(self, batch, batch_idx):
+ preds = torch.tensor([[0.9, 0.1]], device=self.device)
+ targets = torch.tensor([1], device=self.device)
+ if batch_idx < 8:
+ preds = torch.tensor([[0.1, 0.9]], device=self.device)
+ self.val_acc(preds, targets)
+ self.log('val_acc', self.val_acc, on_step=True, on_epoch=True)
+ return super().validation_step(batch, batch_idx)
+
+ early_stop = EarlyStopping(monitor='val_acc', mode='max')
+
+ checkpoint = ModelCheckpoint(
+ monitor='val_acc',
+ save_last=True,
+ save_top_k=2,
+ mode='max',
+ )
+
+ model = TestingModel()
+ trainer = Trainer(
+ default_root_dir=tmpdir,
+ gpus=1,
+ max_epochs=2,
+ limit_train_batches=5,
+ limit_val_batches=32,
+ callbacks=[early_stop, checkpoint])
+ trainer.fit(model)
+
+ assert trainer.callback_metrics["val_acc"] == 8 / 32.
+ assert "train_loss" in trainer.callback_metrics
diff --git a/tests/trainer/optimization/test_multiple_optimizers.py b/tests/trainer/optimization/test_multiple_optimizers.py
new file mode 100644
index 0000000000000..78b6f8f7ff84a
--- /dev/null
+++ b/tests/trainer/optimization/test_multiple_optimizers.py
@@ -0,0 +1,63 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Tests to ensure that the behaviours related to multiple optimizers works
+"""
+import torch
+
+import pytorch_lightning as pl
+from tests.base.boring_model import BoringModel
+
+
+def test_unbalanced_logging_with_multiple_optimizers(tmpdir):
+ """
+ This tests ensures reduction works in un-balanced logging settings
+ """
+ class TestModel(BoringModel):
+
+ loss_1 = []
+ loss_2 = []
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ output = self.layer(batch)
+ loss = self.loss(batch, output)
+ if optimizer_idx == 0 and self.trainer.global_step > 10:
+ self.log("loss_1", loss, on_epoch=True, prog_bar=True)
+ self.loss_1.append(loss.detach().clone())
+ elif optimizer_idx == 1:
+ self.log("loss_2", loss, on_epoch=True, prog_bar=True)
+ self.loss_2.append(loss.detach().clone())
+ return {"loss": loss}
+
+ def configure_optimizers(self):
+ optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.001)
+ optimizer2 = torch.optim.SGD(self.layer.parameters(), lr=0.001)
+ return [optimizer, optimizer2]
+
+ model = TestModel()
+ model.training_epoch_end = None
+
+ # Initialize a trainer
+ trainer = pl.Trainer(
+ default_root_dir=tmpdir,
+ max_epochs=1,
+ )
+
+ trainer.fit(model)
+
+ assert torch.equal(trainer.callback_metrics["loss_2_step"], model.loss_2[-1])
+ assert torch.equal(trainer.callback_metrics["loss_1_step"], model.loss_1[-1])
+ # test loss are properly reduced
+ assert torch.abs(trainer.callback_metrics["loss_2_epoch"] - torch.FloatTensor(model.loss_2).mean()) < 1e-6
+ assert torch.abs(trainer.callback_metrics["loss_1_epoch"] - torch.FloatTensor(model.loss_1).mean()) < 1e-6
diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py
index 6195d7ddeb0b0..1a1203e8f2dd6 100644
--- a/tests/trainer/test_supporters.py
+++ b/tests/trainer/test_supporters.py
@@ -17,10 +17,32 @@
import torch
from torch.utils.data import TensorDataset
-from pytorch_lightning.trainer.supporters import CycleIterator, CombinedLoader, CombinedDataset, CombinedLoaderIterator
+from pytorch_lightning.trainer.supporters import (
+ CycleIterator, CombinedLoader, CombinedDataset, CombinedLoaderIterator, TensorRunningAccum)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
+def test_tensor_running_accum_reset():
+ """ Test that reset would set all attributes to the initialization state """
+
+ window_length = 10
+
+ accum = TensorRunningAccum(window_length=window_length)
+ assert accum.last() is None
+ assert accum.mean() is None
+
+ accum.append(torch.tensor(1.5))
+ assert accum.last() == torch.tensor(1.5)
+ assert accum.mean() == torch.tensor(1.5)
+
+ accum.reset()
+ assert accum.window_length == window_length
+ assert accum.memory is None
+ assert accum.current_idx == 0
+ assert accum.last_idx is None
+ assert not accum.rotated
+
+
def test_cycle_iterator():
"""Test the cycling function of `CycleIterator`"""
iterator = CycleIterator(range(100), 1000)