diff --git a/CHANGELOG.md b/CHANGELOG.md
index 8a20ee5914854..e4a2d76c1fdd0 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -4,6 +4,18 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
+## [1.2.7] - 2021-04-06
+
+### Fixed
+
+- Fixed resolve a bug with omegaconf and xm.save ([#6741](https://github.com/PyTorchLightning/pytorch-lightning/pull/6741))
+- Fixed an issue with IterableDataset when __len__ is not defined ([#6828](https://github.com/PyTorchLightning/pytorch-lightning/pull/6828))
+- Sanitize None params during pruning ([#6836](https://github.com/PyTorchLightning/pytorch-lightning/pull/6836))
+- Enforce an epoch scheduler interval when using SWA ([#6588](https://github.com/PyTorchLightning/pytorch-lightning/pull/6588))
+- Fixed TPU Colab hang issue, post training ([#6816](https://github.com/PyTorchLightning/pytorch-lightning/pull/6816))
+- Fixed a bug where `TensorBoardLogger` would give a warning and not log correctly to a symbolic link `save_dir` ([#6730](https://github.com/PyTorchLightning/pytorch-lightning/pull/6730))
+
+
## [1.2.6] - 2021-03-30
### Changed
diff --git a/README.md b/README.md
index 9d085e2631d89..d658953cb8014 100644
--- a/README.md
+++ b/README.md
@@ -91,19 +91,6 @@ Lightning is rigurously tested across multiple GPUs, TPUs CPUs and against major
-
- Bleeding edge build status (1.2)
-
-
-
- 
- 
- 
- 
- 
-
-
-
---
## How To Use
@@ -132,22 +119,22 @@ pip install pytorch-lightning
conda install pytorch-lightning -c conda-forge
```
- #### Install stable - future 1.1.x
+ #### Install stable 1.2.x
- the actual status of 1.1 [stable] is following:
+ the actual status of 1.2 [stable] is following:
- 
- 
- 
- 
- 
+ 
+ 
+ 
+ 
+ 
Install future release from the source
```bash
- pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@release/1.1.x --upgrade
+ pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@release/1.2.x --upgrade
```
- #### Install bleeding-edge - future 1.2
+ #### Install bleeding-edge - future 1.3
Install nightly from the source (no guarantees)
```bash
@@ -356,27 +343,27 @@ class LitAutoEncoder(pl.LightningModule):
- [MNIST on TPUs](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/06-mnist-tpu-training.ipynb)
###### Contrastive Learning
-- [BYOL](https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#byol)
-- [CPC v2](https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#cpc-v2)
-- [Moco v2](https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#moco-v2)
-- [SIMCLR](https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#simclr)
+- [BYOL](https://lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#byol)
+- [CPC v2](https://lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#cpc-v2)
+- [Moco v2](https://lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#moco-v2)
+- [SIMCLR](https://lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#simclr)
###### NLP
- [BERT](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/04-transformers-text-classification.ipynb)
-- [GPT-2](https://pytorch-lightning-bolts.readthedocs.io/en/latest/convolutional.html#gpt-2)
+- [GPT-2](https://lightning-bolts.readthedocs.io/en/latest/convolutional.html#gpt-2)
###### Reinforcement Learning
-- [DQN](https://pytorch-lightning-bolts.readthedocs.io/en/latest/reinforce_learn.html#dqn-models)
-- [Dueling-DQN](https://pytorch-lightning-bolts.readthedocs.io/en/latest/reinforce_learn.html#dueling-dqn)
-- [Reinforce](https://pytorch-lightning-bolts.readthedocs.io/en/latest/reinforce_learn.html#reinforce)
+- [DQN](https://lightning-bolts.readthedocs.io/en/latest/reinforce_learn.html#dqn-models)
+- [Dueling-DQN](https://lightning-bolts.readthedocs.io/en/latest/reinforce_learn.html#dueling-dqn)
+- [Reinforce](https://lightning-bolts.readthedocs.io/en/latest/reinforce_learn.html#reinforce)
###### Vision
- [GAN](https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/03-basic-gan.ipynb)
###### Classic ML
-- [Logistic Regression](https://pytorch-lightning-bolts.readthedocs.io/en/latest/classic_ml.html#logistic-regression)
-- [Linear Regression](https://pytorch-lightning-bolts.readthedocs.io/en/latest/classic_ml.html#linear-regression)
+- [Logistic Regression](https://lightning-bolts.readthedocs.io/en/latest/classic_ml.html#logistic-regression)
+- [Linear Regression](https://lightning-bolts.readthedocs.io/en/latest/classic_ml.html#linear-regression)
---
diff --git a/docs/source/advanced/tpu.rst b/docs/source/advanced/tpu.rst
index b9688ce425b5f..09a614f31c854 100644
--- a/docs/source/advanced/tpu.rst
+++ b/docs/source/advanced/tpu.rst
@@ -64,8 +64,7 @@ To get a TPU on colab, follow these steps:
.. code-block::
- !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
- !python pytorch-xla-env-setup.py --version 1.7 --apt-packages libomp5 libopenblas-dev
+ !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl
5. Once the above is done, install PyTorch Lightning (v 0.7.0+).
diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst
index 7f0df33a351e4..da8ae5971aea6 100644
--- a/docs/source/common/lightning_module.rst
+++ b/docs/source/common/lightning_module.rst
@@ -907,30 +907,6 @@ use_amp
~~~~~~~
True if using Automatic Mixed Precision (AMP)
-------------
-
-use_ddp
-~~~~~~~
-True if using ddp
-
-------------
-
-use_ddp2
-~~~~~~~~
-True if using ddp2
-
-------------
-
-use_dp
-~~~~~~
-True if using dp
-
-------------
-
-use_tpu
-~~~~~~~
-True if using TPUs
-
--------------
automatic_optimization
diff --git a/docs/source/ecosystem/bolts.rst b/docs/source/ecosystem/bolts.rst
index f3a4ab9c858be..c10097fa4bd05 100644
--- a/docs/source/ecosystem/bolts.rst
+++ b/docs/source/ecosystem/bolts.rst
@@ -1,11 +1,11 @@
Bolts
=====
-`PyTorch Lightning Bolts `_, is our official collection
+`PyTorch Lightning Bolts `_, is our official collection
of prebuilt models across many research domains.
.. code-block:: bash
- pip install pytorch-lightning-bolts
+ pip install lightning-bolts
In bolts we have:
diff --git a/docs/source/extensions/callbacks.rst b/docs/source/extensions/callbacks.rst
index 73691c6dd76f5..dd46e910ff541 100644
--- a/docs/source/extensions/callbacks.rst
+++ b/docs/source/extensions/callbacks.rst
@@ -71,10 +71,10 @@ Examples
--------
You can do pretty much anything with callbacks.
-- `Add a MLP to fine-tune self-supervised networks `_.
-- `Find how to modify an image input to trick the classification result `_.
-- `Interpolate the latent space of any variational model `_.
-- `Log images to Tensorboard for any model `_.
+- `Add a MLP to fine-tune self-supervised networks `_.
+- `Find how to modify an image input to trick the classification result `_.
+- `Interpolate the latent space of any variational model `_.
+- `Log images to Tensorboard for any model `_.
--------------
@@ -85,7 +85,7 @@ Lightning has a few built-in callbacks.
.. note::
For a richer collection of callbacks, check out our
- `bolts library `_.
+ `bolts library `_.
.. currentmodule:: pytorch_lightning.callbacks
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 81011cbf14724..1432badf2038f 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -79,14 +79,14 @@ PyTorch Lightning Documentation
ecosystem/pytorch_ecoystem
ecosystem/community_examples
- Autoencoder
- BYOL
- DQN
- GAN
- GPT-2
- Image-GPT
- SimCLR
- VAE
+ Autoencoder
+ BYOL
+ DQN
+ GAN
+ GPT-2
+ Image-GPT
+ SimCLR
+ VAE
.. toctree::
:maxdepth: 1
diff --git a/docs/source/starter/introduction_guide.rst b/docs/source/starter/introduction_guide.rst
index 2ee31304299e0..5625140cc12cf 100644
--- a/docs/source/starter/introduction_guide.rst
+++ b/docs/source/starter/introduction_guide.rst
@@ -572,9 +572,7 @@ Next, install the required xla library (adds support for PyTorch on TPUs)
.. code-block:: shell
- !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
-
- !python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev
+ !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl
In distributed training (multiple GPUs and multiple TPU cores) each GPU or TPU core will run a copy
of this program. This means that without taking any care you will download the dataset N times which
diff --git a/notebooks/07-cifar10-baseline.ipynb b/notebooks/07-cifar10-baseline.ipynb
index 9f3209a8bbc02..8e9394d653846 100644
--- a/notebooks/07-cifar10-baseline.ipynb
+++ b/notebooks/07-cifar10-baseline.ipynb
@@ -61,7 +61,7 @@
"id": "ziAQCrE-TYWG"
},
"source": [
- "! pip install pytorch-lightning pytorch-lightning-bolts -qU"
+ "! pip install pytorch-lightning lightning-bolts -qU"
],
"execution_count": null,
"outputs": []
diff --git a/pl_examples/README.md b/pl_examples/README.md
index bed553322edf3..30a891f6b9bfc 100644
--- a/pl_examples/README.md
+++ b/pl_examples/README.md
@@ -1,6 +1,6 @@
# Examples
Our most robust examples showing all sorts of implementations
-can be found in our sister library [PyTorch-Lightning-Bolts](https://pytorch-lightning-bolts.readthedocs.io/en/latest/convolutional.html#gpt-2).
+can be found in our sister library [PyTorch-Lightning-Bolts](https://lightning-bolts.readthedocs.io/en/latest/convolutional.html#gpt-2).
---
@@ -15,5 +15,5 @@ In this folder we add 3 simple examples:
## Domain examples
This folder contains older examples. You should instead use the examples
-in [PyTorch-Lightning-Bolts](https://pytorch-lightning-bolts.readthedocs.io/en/latest/convolutional.html#gpt-2)
+in [PyTorch-Lightning-Bolts](https://lightning-bolts.readthedocs.io/en/latest/convolutional.html#gpt-2)
for advanced use cases.
diff --git a/pl_examples/basic_examples/conv_sequential_example.py b/pl_examples/basic_examples/conv_sequential_example.py
index 6cfb6109f04fc..db59f52b103b2 100644
--- a/pl_examples/basic_examples/conv_sequential_example.py
+++ b/pl_examples/basic_examples/conv_sequential_example.py
@@ -202,7 +202,7 @@ def instantiate_datamodule(args):
if __name__ == "__main__":
cli_lightning_logo()
- assert _BOLTS_AVAILABLE, "Bolts is required for this example, install it via pip install pytorch-lightning-bolts"
+ assert _BOLTS_AVAILABLE, "Bolts is required for this example, install it via `pip install lightning-bolts`"
assert _FAIRSCALE_PIPE_AVAILABLE, "FairScale and PyTorch 1.6 is required for this example."
parser = ArgumentParser(description="Pipe Example")
diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py
index b25e5e06e8b86..f37e3bb31cc5e 100644
--- a/pytorch_lightning/callbacks/finetuning.py
+++ b/pytorch_lightning/callbacks/finetuning.py
@@ -77,7 +77,7 @@ def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx):
# When `current_epoch` is 10, feature_extractor will start training.
if current_epoch == self._unfreeze_at_epoch:
self.unfreeze_and_add_param_group(
- module=pl_module.feature_extractor,
+ modules=pl_module.feature_extractor,
optimizer=optimizer,
train_bn=True,
)
diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py
index 46331e004c1c7..649243f7600ba 100644
--- a/pytorch_lightning/callbacks/progress.py
+++ b/pytorch_lightning/callbacks/progress.py
@@ -146,9 +146,10 @@ def total_val_batches(self) -> int:
validation dataloader is of infinite size.
"""
total_val_batches = 0
- if not self.trainer.disable_validation:
- is_val_epoch = (self.trainer.current_epoch) % self.trainer.check_val_every_n_epoch == 0
+ if self.trainer.enable_validation:
+ is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
total_val_batches = sum(self.trainer.num_val_batches) if is_val_epoch else 0
+
return total_val_batches
@property
diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py
index 3f82ab3565403..36622af0edaff 100644
--- a/pytorch_lightning/callbacks/pruning.py
+++ b/pytorch_lightning/callbacks/pruning.py
@@ -422,7 +422,9 @@ def sanitize_parameters_to_prune(
current_modules = [m for m in pl_module.modules() if not isinstance(m, _MODULE_CONTAINERS)]
if parameters_to_prune is None:
- parameters_to_prune = [(m, p) for p in parameters for m in current_modules if hasattr(m, p)]
+ parameters_to_prune = [
+ (m, p) for p in parameters for m in current_modules if getattr(m, p, None) is not None
+ ]
elif (
isinstance(parameters_to_prune, (list, tuple)) and len(parameters_to_prune) > 0
and all(len(p) == 2 for p in parameters_to_prune)
diff --git a/pytorch_lightning/callbacks/swa.py b/pytorch_lightning/callbacks/swa.py
index c8cf367cb4d5e..cc4bbd516a87c 100644
--- a/pytorch_lightning/callbacks/swa.py
+++ b/pytorch_lightning/callbacks/swa.py
@@ -189,14 +189,15 @@ def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningMo
anneal_strategy=self._annealing_strategy,
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1
)
+ _scheduler_config = _get_default_scheduler_config()
+ assert _scheduler_config["interval"] == "epoch" and _scheduler_config["frequency"] == 1
+ _scheduler_config["scheduler"] = self._swa_scheduler
if trainer.lr_schedulers:
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}")
- trainer.lr_schedulers[0]["scheduler"] = self._swa_scheduler
+ trainer.lr_schedulers[0] = _scheduler_config
else:
- _scheduler_config = _get_default_scheduler_config()
- _scheduler_config["scheduler"] = self._swa_scheduler
trainer.lr_schedulers.append(_scheduler_config)
self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)
diff --git a/pytorch_lightning/info.py b/pytorch_lightning/info.py
index 5b383b78e8a41..fbabb2b0bb231 100644
--- a/pytorch_lightning/info.py
+++ b/pytorch_lightning/info.py
@@ -1,7 +1,7 @@
import time
_this_year = time.strftime("%Y")
-__version__ = '1.2.6'
+__version__ = '1.2.7'
__author__ = 'William Falcon et al.'
__author_email__ = 'waf2107@columbia.edu'
__license__ = 'Apache-2.0'
diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py
index 007f898a27cc7..e6ece8c8cffb1 100644
--- a/pytorch_lightning/plugins/training_type/ddp.py
+++ b/pytorch_lightning/plugins/training_type/ddp.py
@@ -257,12 +257,12 @@ def pre_dispatch(self):
self.dist.rank = self.global_rank
self.dist.device = self.root_device
- if self.sync_batchnorm:
- self.model = self.configure_sync_batchnorm(self.model)
-
# move the model to the correct device
self.model_to_device()
+ if self.sync_batchnorm:
+ self.model = self.configure_sync_batchnorm(self.model)
+
self.configure_ddp()
self.barrier()
diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py
index fdb88a3c5cba5..dcd6443b0e6fd 100644
--- a/pytorch_lightning/plugins/training_type/ddp_spawn.py
+++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py
@@ -148,12 +148,12 @@ def new_process(self, process_idx, trainer, mp_queue):
self.dist.rank = self.global_rank
self.dist.device = self.root_device
- if self.sync_batchnorm:
- self.model = self.configure_sync_batchnorm(self.model)
-
# move the model to the correct device
self.model_to_device()
+ if self.sync_batchnorm:
+ self.model = self.configure_sync_batchnorm(self.model)
+
self.configure_ddp()
self.barrier()
diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py
index 3ddfd98128787..1ef2c7676ae72 100644
--- a/pytorch_lightning/plugins/training_type/single_tpu.py
+++ b/pytorch_lightning/plugins/training_type/single_tpu.py
@@ -1,12 +1,20 @@
-import os
-from typing import Optional, Union
-
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import torch
-from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
-from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
-from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
+from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.apply_func import move_data_to_device
if _TPU_AVAILABLE:
@@ -15,14 +23,15 @@
class SingleTPUPlugin(SingleDevicePlugin):
- def __init__(self, device: Union[torch.device, int]):
- if isinstance(device, int):
- device = xm.xla_device(device)
+ def __init__(self, device: int):
+
+ device = xm.xla_device(device)
super().__init__(device)
self.tpu_local_core_rank = 0
self.tpu_global_core_rank = 0
+ @property
def on_tpu(self) -> bool:
return True
@@ -31,6 +40,10 @@ def connect(self, model: torch.nn.Module) -> torch.nn.Module:
self.model_to_device()
return self._model
+ @property
+ def is_distributed(self) -> bool:
+ return False
+
def model_to_device(self) -> None:
self._model.to(self.root_device)
@@ -41,21 +54,6 @@ def pre_dispatch(self) -> None:
self.tpu_local_core_rank = xm.get_local_ordinal()
self.tpu_global_core_rank = xm.get_ordinal()
- def post_dispatch(self) -> None:
- model = self.lightning_module
-
- if on_colab_kaggle():
- rank_zero_warn("cleaning up... please do not interrupt")
- self.save_spawn_weights(model)
-
- def save_spawn_weights(self, model: LightningModule) -> Optional[str]:
- """
- Dump a temporary checkpoint after ddp ends to get weights out of the process
- """
- path = os.path.join(model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt")
- model.trainer.save_checkpoint(path)
- return path
-
def on_save(self, checkpoint: dict) -> dict:
"""
Move XLA tensors to CPU before saving
@@ -63,7 +61,3 @@ def on_save(self, checkpoint: dict) -> dict:
https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors
"""
return move_data_to_device(checkpoint, torch.device("cpu"))
-
- @property
- def is_distributed(self):
- return False
diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py
index 09603f9a22bc2..0f55100bf1ab9 100644
--- a/pytorch_lightning/plugins/training_type/tpu_spawn.py
+++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py
@@ -1,40 +1,35 @@
import io
import os
import re
+import time
from typing import Any, Dict, Iterable, List, Optional, Union
import torch
import torch.multiprocessing as mp
-from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
-from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
-from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
+from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn
+from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
- import torch_xla.distributed.parallel_loader as xla_pl
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.core.xla_model import rendezvous
- from torch_xla.distributed.parallel_loader import ParallelLoader
+ from torch_xla.distributed.parallel_loader import MpDeviceLoader
else:
- xm, xla_pl, xmp, ParallelLoader, rendezvous = [None] * 5
+ xm, xmp, MpDeviceLoader, rendezvous = [None] * 4
+
+if _OMEGACONF_AVAILABLE:
+ from omegaconf import DictConfig, ListConfig, OmegaConf
class TPUSpawnPlugin(DDPSpawnPlugin):
- def __init__(
- self,
- parallel_devices: Optional[List[torch.device]] = None,
- num_nodes: int = 1,
- **kwargs: Dict[str, Any]
- ) -> None:
- super().__init__(
- parallel_devices, num_nodes=num_nodes, cluster_environment=None, sync_batchnorm=False, **kwargs
- )
+ def __init__(self, parallel_devices: Optional[List[int]] = None, **kwargs: Dict[str, Any]) -> None:
+ super().__init__(parallel_devices, num_nodes=1, cluster_environment=None, sync_batchnorm=False)
self.tpu_local_core_rank = 0
self.start_method = None
@@ -56,10 +51,9 @@ def distributed_sampler_kwargs(self) -> dict:
def is_distributed(self):
return self.world_size != 1
- def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> ParallelLoader:
+ def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> MpDeviceLoader:
device = xm.xla_device()
- dataloader = xla_pl.ParallelLoader(dataloader, [device])
- dataloader = dataloader.per_device_loader(device)
+ dataloader = MpDeviceLoader(dataloader, device)
return dataloader
def configure_ddp(self) -> None:
@@ -99,16 +93,14 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
results = trainer.train_or_test_or_predict()
- self.__save_end_of_training_weights(self.lightning_module)
self.transfer_distrib_spawn_state_on_fit_end(results)
+ # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
self.barrier("end-process")
- def __save_end_of_training_weights(self, model: LightningModule) -> None:
- # when training ends on these platforms dump weights to get out of the main process
- if on_colab_kaggle():
- rank_zero_warn("cleaning up... please do not interrupt")
- self.save_spawn_weights(model)
+ # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
+ if self.global_rank == 0:
+ time.sleep(2)
def model_to_device(self) -> None:
self._model.to(xm.xla_device())
@@ -137,16 +129,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
self.mp_queue.put(results)
def save(self, state_dict: Dict, path: str) -> None:
- """
- Saving with ``xm.save`` can be unstable and miss the rendez-vous after ``torch.save``.
- The rendez-vous doesn't affect directly saving.
- We can ignore the ``RuntimeError`` to reduce friction with TPUs.
- """
- try:
- xm.save(state_dict, path)
- except RuntimeError as e:
- if "Failed to meet rendezvous" not in str(e):
- raise e
+ xm.save(state_dict, path)
def broadcast(self, obj: object, src: int = 0) -> object:
buffer = io.BytesIO()
@@ -158,37 +141,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
obj = torch.load(buffer)
return obj
- def load_spawn_weights(self, original_model: LightningModule) -> LightningModule:
- """
- Load the temp weights saved in the process
- To recover the trained model from the ddp process we load the saved weights
- """
-
- loaded_model = original_model
-
- if self.is_global_zero:
- # load weights saved in ddp
- path = os.path.join(original_model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt")
- loaded_model = original_model.__class__.load_from_checkpoint(path)
-
- # copy loaded weights to old model
- original_model.load_state_dict(loaded_model.state_dict())
-
- # remove ddp weights
- os.remove(path)
-
- return loaded_model
-
- def save_spawn_weights(self, model: LightningModule) -> Optional[str]:
- """
- Dump a temporary checkpoint after ddp ends to get weights out of the process
- """
- if model.trainer.is_global_zero:
- path = os.path.join(model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt")
- model.trainer.save_checkpoint(path)
- return path
-
- def reduce_decision(self, decision: bool) -> bool:
+ def reduce_boolean_decision(self, decision: bool) -> bool:
decision = torch.tensor(int(decision), device=self.device)
decision = self.reduce(decision, "sum")
decision = bool(decision == self.world_size)
@@ -212,40 +165,6 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[
return output
- def post_dispatch(self) -> None:
- # TODO: Check if trainer references can be resolved otherwise
- model = self.lightning_module
-
- # restore main state with best weights
- best_path = self.mp_queue.get()
- last_path = self.mp_queue.get()
- self._results = self.mp_queue.get()
-
- # transfer back the best path to the trainer
- if self.lightning_module.trainer.checkpoint_callback is not None:
- self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path
- # todo, pass also bets score
-
- # load last weights
- if last_path and not self.lightning_module.trainer.testing:
- ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
- model.load_state_dict(ckpt)
-
- self._model = model
-
- # when training completes, load the weights back in main process
- self.__load_weights_on_main_process()
-
- def __load_weights_on_main_process(self) -> None:
- model = self.lightning_module
-
- # load weights if not interrupted
- # TODO: check for trainer reference
- if on_colab_kaggle() and not model.trainer.testing:
- self.load_spawn_weights(model)
-
- self._model = model
-
def _close_logger(self, trainer) -> None:
if trainer.logger is not None:
trainer.logger.finalize("success")
@@ -284,14 +203,15 @@ def test_step(self, *args, **kwargs):
def predict(self, *args, **kwargs):
return self.lightning_module.predict(*args, **kwargs)
- def save_checkpoint(self, filepath, weights_only: bool = False):
+ def save_checkpoint(self, filepath: str, weights_only: bool = False) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
-
Args:
filepath: write-target file's path
weights_only: saving model weights only
"""
# dump states as a checkpoint dictionary object
- _checkpoint = self.lightning_module.trainer.checkpoint_connector.dump_checkpoint(weights_only)
+ checkpoint = self.lightning_module.trainer.checkpoint_connector.dump_checkpoint(weights_only)
# Todo: TypeError: 'mappingproxy' object does not support item assignment
- self.save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath)
+ if _OMEGACONF_AVAILABLE:
+ checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)
+ self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath)
diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py
index 83eddfed6c4dc..c53db011d837a 100644
--- a/pytorch_lightning/trainer/connectors/accelerator_connector.py
+++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py
@@ -251,7 +251,7 @@ def use_dp(self) -> bool:
def use_ddp(self) -> bool:
return self._distrib_type in (
DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED,
- DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED
+ DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED, DistributedType.TPU_SPAWN
)
@property
@@ -291,7 +291,8 @@ def parallel_devices(self) -> Union[List[torch.device], int]:
elif self.on_tpu:
# explicitly don't make a tpu device here!
# https://github.com/PyTorchLightning/pytorch-lightning/issues/3169
- devices = [i for i in self.parallel_device_ids]
+ if isinstance(self.tpu_cores, int):
+ devices = list(range(self.tpu_cores))
else:
devices = [torch.device("cpu")] * self.num_processes
return devices
@@ -369,6 +370,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic
use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN
use_ddp_cpu_spawn = self.use_ddp and self.on_cpu
+ use_tpu_spawn = self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self.is_using_torchelastic
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED
@@ -379,7 +381,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
if os.environ.get("PL_IN_DDP_SUBPROCESS", False):
use_torchelastic_ddp = False
- if self.on_tpu:
+ if use_tpu_spawn:
ddp_plugin_cls = TPUSpawnPlugin
elif use_ddp_sharded:
ddp_plugin_cls = DDPShardedPlugin
@@ -402,11 +404,8 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
plugin = DataParallelPlugin(parallel_devices=self.parallel_devices)
elif self.use_horovod:
plugin = HorovodPlugin(parallel_devices=self.parallel_devices)
- elif self.on_tpu:
- if isinstance(self.tpu_cores, list):
- plugin = SingleTPUPlugin(self.tpu_id)
- else:
- plugin = TPUSpawnPlugin(parallel_devices=list(range(self.tpu_cores)))
+ elif self.on_tpu and isinstance(self.tpu_cores, list):
+ plugin = SingleTPUPlugin(self.tpu_id)
else:
single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids)
plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu"))
@@ -507,6 +506,8 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
# special case with TPUs
elif self.distributed_backend == 'tpu' or self.tpu_cores is not None:
self._device_type = DeviceType.TPU
+ if isinstance(self.tpu_cores, int):
+ self._distrib_type = DistributedType.TPU_SPAWN
elif self.distributed_backend and self._distrib_type is None:
self._distrib_type = DistributedType(self.distributed_backend)
@@ -515,9 +516,9 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
if self.num_gpus > 0 and not _on_cpu:
self._device_type = DeviceType.GPU
- _distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
+ _gpu_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
# DP and DDP2 cannot run without GPU
- if self.num_gpus == 0 and self._distrib_type in _distrib_types and not _on_cpu:
+ if self.num_gpus == 0 and self._distrib_type in _gpu_distrib_types and not _on_cpu:
rank_zero_warn(
'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.'
)
diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py
index 82bb858ef6c53..e90837bf980ed 100644
--- a/pytorch_lightning/trainer/trainer.py
+++ b/pytorch_lightning/trainer/trainer.py
@@ -898,7 +898,7 @@ def test(
self._set_running_stage(RunningStage.TESTING, model or self.lightning_module)
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
- if test_dataloaders and datamodule:
+ if test_dataloaders is not None and datamodule:
raise MisconfigurationException(
'You cannot pass test_dataloaders to trainer.test if you supply a datamodule'
)
@@ -1008,7 +1008,7 @@ def predict(
self._set_running_stage(RunningStage.PREDICTING, model)
- if dataloaders and datamodule:
+ if dataloaders is not None and datamodule:
raise MisconfigurationException(
'You cannot pass dataloaders to trainer.predict if you supply a datamodule.'
)
diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py
index c5256c6ddc65f..d46db5de1ddc8 100644
--- a/pytorch_lightning/tuner/tuning.py
+++ b/pytorch_lightning/tuner/tuning.py
@@ -60,7 +60,13 @@ def tune(self, model, train_dataloader, val_dataloaders, datamodule):
# Run learning rate finder:
if self.trainer.auto_lr_find:
- self.lr_find(model, update_attr=True)
+ self.lr_find(
+ model,
+ update_attr=True,
+ train_dataloader=train_dataloader,
+ val_dataloaders=val_dataloaders,
+ datamodule=datamodule,
+ )
def scale_batch_size(
self,
diff --git a/pytorch_lightning/utilities/cloud_io.py b/pytorch_lightning/utilities/cloud_io.py
index e94934020107d..c179d0d0d0bf8 100644
--- a/pytorch_lightning/utilities/cloud_io.py
+++ b/pytorch_lightning/utilities/cloud_io.py
@@ -12,15 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import os
import io
from distutils.version import LooseVersion
from pathlib import Path
from typing import IO, Union
import fsspec
+from fsspec.implementations.local import LocalFileSystem
+
import torch
+class _LightningLocalFileSystem(LocalFileSystem):
+ """Extension of ``fsspec.implementations.local.LocalFileSystem`` where ``LightningLocalFileSystem.isdir`` behaves
+ the same as ``os.isdir``.
+
+ To be removed when https://github.com/intake/filesystem_spec/issues/591 is fixed.
+ """
+
+ def isdir(self, path: str) -> bool:
+ return os.path.isdir(path) # follows symlinks
+
+
def load(path_or_url: Union[str, IO, Path], map_location=None):
if not isinstance(path_or_url, (str, Path)):
# any sort of BytesIO or similiar
@@ -39,7 +53,7 @@ def get_filesystem(path: Union[str, Path]):
return fsspec.filesystem(path.split(":", 1)[0])
else:
# use local filesystem
- return fsspec.filesystem("file")
+ return _LightningLocalFileSystem()
def atomic_save(checkpoint, filepath: str):
diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py
index 3e4add4fb68d1..eb912d1dc3aae 100644
--- a/pytorch_lightning/utilities/enums.py
+++ b/pytorch_lightning/utilities/enums.py
@@ -13,7 +13,7 @@
# limitations under the License.
"""Enumerated utilities"""
from enum import Enum
-from typing import Union
+from typing import List, Union
class LightningEnum(str, Enum):
@@ -58,10 +58,23 @@ class DistributedType(LightningEnum):
>>> DistributedType.DDP2 in ('ddp2', )
True
"""
+
+ @staticmethod
+ def interactive_compatible_types() -> List['DistributedType']:
+ """Returns a list containing interactive compatible DistributeTypes"""
+ return [
+ DistributedType.DP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED_SPAWN, DistributedType.TPU_SPAWN
+ ]
+
+ def is_interactive_compatible(self) -> bool:
+ """Returns whether self is interactive compatible"""
+ return self in DistributedType.interactive_compatible_types()
+
DP = 'dp'
DDP = 'ddp'
DDP2 = 'ddp2'
DDP_SPAWN = 'ddp_spawn'
+ TPU_SPAWN = 'tpu_spawn'
DEEPSPEED = 'deepspeed'
HOROVOD = 'horovod'
DDP_SHARDED = 'ddp_sharded'
diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py
index 294d3d2c5ec40..49ec176d4cdbb 100644
--- a/pytorch_lightning/utilities/xla_device.py
+++ b/pytorch_lightning/utilities/xla_device.py
@@ -17,13 +17,10 @@
import traceback
from multiprocessing import Process, Queue
-import torch.multiprocessing as mp
-
from pytorch_lightning.utilities.imports import _XLA_AVAILABLE
if _XLA_AVAILABLE:
import torch_xla.core.xla_model as xm
- import torch_xla.distributed.xla_multiprocessing as xmp
#: define waiting time got checking TPU available in sec
TPU_CHECK_TIMEOUT = 25
@@ -64,23 +61,13 @@ class XLADeviceUtils:
@pl_multi_process
def _is_device_tpu() -> bool:
"""
- Check if device is TPU
+ Check if TPU devices are available
Return:
- A boolean value indicating if the xla device is a TPU device or not
+ A boolean value indicating if TPU devices are available
"""
- def _fn(_: int, mp_queue):
- try:
- device = xm.xla_device()
- mp_queue.put(device.type == 'xla')
- except Exception:
- mp_queue.put(False)
-
- smp = mp.get_context("spawn")
- queue = smp.SimpleQueue()
- xmp.spawn(_fn, args=(queue, ), nprocs=1)
- return queue.get()
+ return len(xm.get_xla_supported_devices("TPU")) > 0
@staticmethod
def xla_available() -> bool:
diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py
index 484b09e27bc0d..e995a7c658101 100644
--- a/tests/callbacks/test_pruning.py
+++ b/tests/callbacks/test_pruning.py
@@ -36,7 +36,7 @@ def __init__(self):
self.layer = Sequential(
OrderedDict([
("mlp_1", nn.Linear(32, 32)),
- ("mlp_2", nn.Linear(32, 32)),
+ ("mlp_2", nn.Linear(32, 32, bias=False)),
("mlp_3", nn.Linear(32, 2)),
])
)
@@ -85,7 +85,10 @@ def train_with_pruning_callback(
if parameters_to_prune:
pruning_kwargs["parameters_to_prune"] = [(model.layer.mlp_1, "weight"), (model.layer.mlp_2, "weight")]
else:
- pruning_kwargs["parameter_names"] = ["weight"]
+ if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured"):
+ pruning_kwargs["parameter_names"] = ["weight"]
+ else:
+ pruning_kwargs["parameter_names"] = ["weight", "bias"]
if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured"):
pruning_kwargs["pruning_dim"] = 0
if pruning_fn == "ln_structured":
@@ -250,14 +253,14 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent):
actual = [m for m in actual if m.startswith("Applied")]
assert actual == [
"Applied `L1Unstructured`. Pruned: 0/1122 (0.00%) -> 544/1122 (48.48%)",
- "Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 506 (49.41%)", # noqa: E501
- "Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 38 (59.38%)", # noqa: E501
+ "Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 500 (48.83%)", # noqa: E501
+ "Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 44 (68.75%)", # noqa: E501
"Applied `RandomUnstructured`. Pruned: 544/1122 (48.48%) -> 680/1122 (60.61%)",
- "Applied `RandomUnstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.25. Pruned: 506 (49.41%) -> 633 (61.82%)", # noqa: E501
- "Applied `RandomUnstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.25. Pruned: 38 (59.38%) -> 47 (73.44%)", # noqa: E501
+ "Applied `RandomUnstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.25. Pruned: 500 (48.83%) -> 635 (62.01%)", # noqa: E501
+ "Applied `RandomUnstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.25. Pruned: 44 (68.75%) -> 45 (70.31%)", # noqa: E501
"Applied `L1Unstructured`. Pruned: 680/1122 (60.61%) -> 884/1122 (78.79%)",
- "Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 633 (61.82%) -> 828 (80.86%)", # noqa: E501
- "Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 47 (73.44%) -> 56 (87.50%)", # noqa: E501
+ "Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 635 (62.01%) -> 830 (81.05%)", # noqa: E501
+ "Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 45 (70.31%) -> 54 (84.38%)", # noqa: E501
]
filepath = str(tmpdir / "foo.ckpt")
diff --git a/tests/callbacks/test_swa.py b/tests/callbacks/test_swa.py
index ea8e368e39542..eb4c8f1536a22 100644
--- a/tests/callbacks/test_swa.py
+++ b/tests/callbacks/test_swa.py
@@ -24,19 +24,22 @@
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
+from tests.helpers.runif import RunIf
if _TORCH_GREATER_EQUAL_1_6:
from pytorch_lightning.callbacks import StochasticWeightAveraging
+ from torch.optim.swa_utils import SWALR
class SwaTestModel(BoringModel):
- def __init__(self, batchnorm: bool = True):
+ def __init__(self, batchnorm: bool = True, interval: str = "epoch"):
super().__init__()
layers = [nn.Linear(32, 32)]
if batchnorm:
layers.append(nn.BatchNorm1d(32))
layers += [nn.ReLU(), nn.Linear(32, 2)]
self.layer = nn.Sequential(*layers)
+ self.interval = interval
def training_step(self, batch, batch_idx):
output = self.forward(batch)
@@ -46,6 +49,14 @@ def training_step(self, batch, batch_idx):
def train_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=2)
+ def configure_optimizers(self):
+ optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
+ return {
+ "optimizer": optimizer,
+ "scheduler": torch.optim.lr_scheduler.StepLR(optimizer, step_size=1),
+ "interval": self.interval,
+ }
+
class SwaTestCallback(StochasticWeightAveraging):
update_parameters_calls: int = 0
transfer_weights_calls: int = 0
@@ -61,6 +72,10 @@ def transfer_weights(self, *args, **kwargs):
def on_train_epoch_start(self, trainer, *args):
super().on_train_epoch_start(trainer, *args)
assert trainer.train_loop._skip_backward == (trainer.current_epoch > self.swa_end)
+ if self.swa_start <= trainer.current_epoch:
+ assert isinstance(trainer.lr_schedulers[0]["scheduler"], SWALR)
+ assert trainer.lr_schedulers[0]["interval"] == "epoch"
+ assert trainer.lr_schedulers[0]["frequency"] == 1
def on_train_epoch_end(self, trainer, *args):
super().on_train_epoch_end(trainer, *args)
@@ -89,8 +104,8 @@ def on_train_end(self, trainer, pl_module):
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
-def train_with_swa(tmpdir, batchnorm=True, accelerator=None, gpus=None, num_processes=1):
- model = SwaTestModel(batchnorm=batchnorm)
+def train_with_swa(tmpdir, batchnorm=True, accelerator=None, gpus=None, num_processes=1, interval="epoch"):
+ model = SwaTestModel(batchnorm=batchnorm, interval=interval)
swa_start = 2
max_epochs = 5
swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1)
@@ -147,7 +162,13 @@ def test_swa_callback(tmpdir, batchnorm):
train_with_swa(tmpdir, batchnorm=batchnorm)
-@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason="SWA available from PyTorch 1.6.0")
+@RunIf(min_torch="1.6.0")
+@pytest.mark.parametrize("interval", ("epoch", "step"))
+def test_swa_callback_scheduler_step(tmpdir, interval: bool):
+ train_with_swa(tmpdir, interval=interval)
+
+
+@RunIf(min_torch="1.6.0")
def test_swa_raises():
with pytest.raises(MisconfigurationException, match=">0 integer or a float between 0 and 1"):
StochasticWeightAveraging(swa_epoch_start=0, swa_lrs=0.1)
diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py
index e5e3f231d3ac7..316390e61d9f2 100644
--- a/tests/loggers/test_tensorboard.py
+++ b/tests/loggers/test_tensorboard.py
@@ -303,3 +303,22 @@ def test_tensorboard_save_hparams_to_yaml_once(tmpdir):
hparams_file = "hparams.yaml"
assert os.path.isfile(os.path.join(trainer.log_dir, hparams_file))
assert not os.path.isfile(os.path.join(tmpdir, hparams_file))
+
+
+@mock.patch('pytorch_lightning.loggers.tensorboard.log')
+def test_tensorboard_with_symlink(log, tmpdir):
+ """
+ Tests a specific failure case when tensorboard logger is used with empty name, symbolic link ``save_dir``, and
+ relative paths.
+ """
+ os.chdir(tmpdir) # need to use relative paths
+ source = os.path.join('.', 'lightning_logs')
+ dest = os.path.join('.', 'sym_lightning_logs')
+
+ os.makedirs(source, exist_ok=True)
+ os.symlink(source, dest)
+
+ logger = TensorBoardLogger(save_dir=dest, name='')
+ _ = logger.version
+
+ log.warning.assert_not_called()
diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py
index 2befc5bd7dbd2..24c0b615b95bb 100644
--- a/tests/models/test_tpu.py
+++ b/tests/models/test_tpu.py
@@ -122,7 +122,7 @@ def test_model_16bit_tpu_cores_1(tmpdir):
progress_bar_refresh_rate=0,
max_epochs=2,
tpu_cores=1,
- limit_train_batches=8,
+ limit_train_batches=0.7,
limit_val_batches=2,
)
@@ -210,8 +210,8 @@ def test_tpu_grad_norm(tmpdir):
progress_bar_refresh_rate=0,
max_epochs=4,
tpu_cores=1,
- limit_train_batches=4,
- limit_val_batches=4,
+ limit_train_batches=10,
+ limit_val_batches=10,
gradient_clip_val=0.5,
)
diff --git a/tests/test_profiler.py b/tests/test_profiler.py
index 667e153a9edd4..6abcf17a04893 100644
--- a/tests/test_profiler.py
+++ b/tests/test_profiler.py
@@ -20,6 +20,7 @@
import pytest
from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler
+from tests.helpers.runif import RunIf
PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005
@@ -165,6 +166,7 @@ def test_advanced_profiler_overhead(advanced_profiler, n_iter=5):
assert average_duration < PROFILER_OVERHEAD_MAX_TOLERANCE
+@RunIf(max_torch="1.8.1")
def test_advanced_profiler_describe(tmpdir, advanced_profiler):
"""
ensure the profiler won't fail when reporting the summary
diff --git a/tests/trainer/flags/test_check_val_every_n_epoch.py b/tests/trainer/flags/test_check_val_every_n_epoch.py
new file mode 100644
index 0000000000000..f7f1403ecdbfd
--- /dev/null
+++ b/tests/trainer/flags/test_check_val_every_n_epoch.py
@@ -0,0 +1,53 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pytest
+
+from pytorch_lightning.trainer import Trainer
+from pytorch_lightning.trainer.states import TrainerState
+from tests.helpers import BoringModel
+
+
+@pytest.mark.parametrize(
+ 'max_epochs,expected_val_loop_calls,expected_val_batches', [
+ (1, 0, [0]),
+ (4, 2, [0, 2, 0, 2]),
+ (5, 2, [0, 2, 0, 2, 0]),
+ ]
+)
+def test_check_val_every_n_epoch(tmpdir, max_epochs, expected_val_loop_calls, expected_val_batches):
+
+ class TestModel(BoringModel):
+ val_epoch_calls = 0
+ val_batches = []
+
+ def on_train_epoch_end(self, *args, **kwargs):
+ self.val_batches.append(self.trainer.progress_bar_callback.total_val_batches)
+
+ def on_validation_epoch_start(self) -> None:
+ self.val_epoch_calls += 1
+
+ model = TestModel()
+ trainer = Trainer(
+ default_root_dir=tmpdir,
+ max_epochs=max_epochs,
+ num_sanity_val_steps=0,
+ limit_val_batches=2,
+ check_val_every_n_epoch=2,
+ logger=False,
+ )
+ trainer.fit(model)
+ assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
+
+ assert model.val_epoch_calls == expected_val_loop_calls
+ assert model.val_batches == expected_val_batches
diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py
index bca8e5dcc531b..69d199a76dfff 100644
--- a/tests/trainer/test_dataloaders.py
+++ b/tests/trainer/test_dataloaders.py
@@ -703,28 +703,39 @@ def test_warning_with_few_workers_multi_loader(mock, tmpdir, ckpt_path):
def test_warning_with_iterable_dataset_and_len(tmpdir):
""" Tests that a warning message is shown when an IterableDataset defines `__len__`. """
- model = EvalModelTemplate()
+ model = BoringModel()
original_dataset = model.train_dataloader().dataset
- class IterableWithLen(IterableDataset):
+ class IterableWithoutLen(IterableDataset):
def __iter__(self):
return iter(original_dataset)
+ class IterableWithLen(IterableWithoutLen):
+
def __len__(self):
return len(original_dataset)
+ # with __len__ defined
dataloader = DataLoader(IterableWithLen(), batch_size=16)
assert has_len(dataloader)
assert has_iterable_dataset(dataloader)
- trainer = Trainer(
- default_root_dir=tmpdir,
- max_steps=3,
- )
+ trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.test(model, test_dataloaders=[dataloader])
+ with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
+ trainer.predict(model, dataloaders=[dataloader])
+
+ # without __len__ defined
+ dataloader = DataLoader(IterableWithoutLen(), batch_size=16)
+ assert not has_len(dataloader)
+ assert has_iterable_dataset(dataloader)
+ trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
+ trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
+ trainer.test(model, test_dataloaders=dataloader)
+ trainer.predict(model, dataloaders=dataloader)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py
index fd2b48a3fa140..306d38d2d651b 100644
--- a/tests/trainer/test_trainer.py
+++ b/tests/trainer/test_trainer.py
@@ -42,6 +42,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.helpers import BoringModel, RandomDataset
+from tests.helpers.runif import RunIf
@pytest.fixture
@@ -1499,6 +1500,7 @@ def test_trainer_predict_ddp_cpu(tmpdir):
predict(tmpdir, "ddp_cpu", 0, 2)
+@RunIf(max_torch="1.8.1")
def test_pytorch_profiler_describe(pytorch_profiler):
"""Ensure the profiler won't fail when reporting the summary."""
with pytorch_profiler.profile("test_step"):