diff --git a/.github/workflows/ci_test-conda.yml b/.github/workflows/ci_test-conda.yml index 7c8aae53ab06e..419580b71cd10 100644 --- a/.github/workflows/ci_test-conda.yml +++ b/.github/workflows/ci_test-conda.yml @@ -27,11 +27,13 @@ jobs: run: | conda info conda list + # adjust versions according installed Torch version + python ./requirements/adjust_versions.py requirements/extra.txt + python ./requirements/adjust_versions.py requirements/examples.txt pip install --requirement requirements/devel.txt --upgrade-strategy only-if-needed pip list - name: Pull checkpoints from S3 - # todo: consider adding coma caching, but ATM all models have less then 100KB run: | # enter legacy and update checkpoints from S3 cd legacy @@ -39,12 +41,6 @@ jobs: unzip -o checkpoints.zip ls -l checkpoints/ - # todo: require proper fix in docker image - - name: Hotfix dependency - run: | - pip install torchtext==0.6.0 -U - shell: bash - - name: Tests run: | # NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003 diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index f08c277b71064..dd29777d9940c 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -104,20 +104,17 @@ jobs: HOROVOD_WITHOUT_MXNET: 1 HOROVOD_WITHOUT_TENSORFLOW: 1 run: | - # python -m pip install --upgrade --user pip - pip install --requirement requirements.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet --upgrade - pip install --requirement ./requirements/devel.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet --upgrade python --version pip --version + # python -m pip install --upgrade --user pip + pip install --requirement requirements.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade + # adjust versions according installed Torch version + python ./requirements/adjust_versions.py requirements/extra.txt + python ./requirements/adjust_versions.py requirements/examples.txt + pip install --requirement ./requirements/devel.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade pip list shell: bash - # todo: require proper fix in docker image - - name: Hotfix dependency - run: | - pip install torchtext==0.6.0 -U - shell: bash - - name: Reinstall Horovod if necessary if: runner.os != 'windows' env: @@ -143,10 +140,9 @@ jobs: # NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003 coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml - # todo: put this back just when TorchVision can download datasets - #- name: Examples - # run: | - # python -m pytest pl_examples -v --durations=10 + - name: Examples + run: | + python -m pytest pl_examples -v --durations=10 - name: Upload pytest test results uses: actions/upload-artifact@v2 diff --git a/.github/workflows/docs-checks.yml b/.github/workflows/docs-checks.yml index 347f20196d974..5ee4f23b4b3cc 100644 --- a/.github/workflows/docs-checks.yml +++ b/.github/workflows/docs-checks.yml @@ -41,6 +41,8 @@ jobs: - name: Install dependencies run: | + python --version + pip --version # remove Horovod from requirements python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if not line.startswith('horovod')] ; open(fname, 'w').writelines(lines)" # python -m pip install --upgrade --user pip @@ -48,8 +50,6 @@ jobs: pip install --requirement requirements/extra.txt pip install --requirement requirements/loggers.txt pip install --requirement requirements/docs.txt - python --version - pip --version pip list shell: bash @@ -84,12 +84,12 @@ jobs: - name: Install dependencies run: | - pip install --requirement requirements.txt --upgrade-strategy only-if-needed --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet + python --version + pip --version + # pip install --requirement requirements.txt --upgrade-strategy only-if-needed --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet pip install --requirement requirements/docs.txt # install Texlive, see https://linuxconfig.org/how-to-install-latex-on-ubuntu-20-04-focal-fossa-linux sudo apt-get update && sudo apt-get install -y texlive-latex-extra dvipng texlive-pictures - python --version - pip --version pip list shell: bash diff --git a/.github/workflows/events-nightly.yml b/.github/workflows/events-nightly.yml index 809d75408d43e..24d8ce4002e5d 100644 --- a/.github/workflows/events-nightly.yml +++ b/.github/workflows/events-nightly.yml @@ -102,8 +102,6 @@ jobs: id: extend - name: Publish CUDA to Docker Hub - # ToDo: extend also building for Nightly from pip - if: matrix.pytorch_version < 1.8 # publish master/release uses: docker/build-push-action@v2 with: diff --git a/CHANGELOG.md b/CHANGELOG.md index d1f5b60592006..7b2aa324beaf3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,24 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [1.2.3] - 2021-03-09 + + +### Fixed + +- Fixed `ModelPruning(make_pruning_permanent=True)` pruning buffers getting removed when saved during training ([#6073](https://github.com/PyTorchLightning/pytorch-lightning/pull/6073)) +- Fixed when `_stable_1d_sort` to work when `n >= N` ([#6177](https://github.com/PyTorchLightning/pytorch-lightning/pull/6177)) +- Fixed `AttributeError` when `logger=None` on TPU ([#6221](https://github.com/PyTorchLightning/pytorch-lightning/pull/6221)) +- Fixed PyTorch Profiler with `emit_nvtx` ([#6260](https://github.com/PyTorchLightning/pytorch-lightning/pull/6260)) +- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272)) +- Fixed `SingleTPU` calling `all_gather` ([#6296](https://github.com/PyTorchLightning/pytorch-lightning/pull/6296)) +- Ensure we check deepspeed/sharded in multinode DDP ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297) +- Check `LightningOptimizer` doesn't delete optimizer hooks ([#6305](https://github.com/PyTorchLightning/pytorch-lightning/pull/6305) +- Resolve memory leak for evaluation ([#6326](https://github.com/PyTorchLightning/pytorch-lightning/pull/6326) +- Ensure that clip gradients is only called if the value is greater than 0 ([#6330](https://github.com/PyTorchLightning/pytorch-lightning/pull/6330) +- Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372)) + + ## [1.2.2] - 2021-03-02 ### Added diff --git a/MANIFEST.in b/MANIFEST.in index 31e6c22ab953f..b1e7613831fe8 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -46,7 +46,7 @@ recursive-include docs/source/_static/images/general/ pl_overview* tf_* tutorial # Include the Requirements recursive-include requirements *.txt -recursive-exclude requirements *.sh +recursive-exclude requirements *.sh *.py include requirements.txt include pyproject.toml diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 6d67afc31f2e4..6dfddda0295fe 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -71,11 +71,6 @@ jobs: python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu >= 2, f'GPU: {mgpu}'" displayName: 'Env details' - # todo: require proper fix in docker image - - bash: | - pip install torchtext==0.7 -U - displayName: 'HotFix' - - bash: | wget https://pl-public-data.s3.amazonaws.com/legacy/checkpoints.zip -P legacy/ unzip -o legacy/checkpoints.zip -d legacy/ @@ -100,10 +95,12 @@ jobs: python -m pytest benchmarks -v --maxfail=2 --durations=0 displayName: 'Testing: benchmarks' - # todo: put this back just when TorchVision can download datasets - #- bash: | - # python -m pytest pl_examples -v --maxfail=2 --durations=0 - # python setup.py install --user --quiet - # bash pl_examples/run_ddp-example.sh - # pip uninstall -y pytorch-lightning - # displayName: 'Examples' + - bash: | + python -m pytest pl_examples -v --maxfail=2 --durations=0 + python setup.py install --user --quiet + bash pl_examples/run_ddp-example.sh + cd pl_examples/basic_examples + bash submit_ddp_job.sh + bash submit_ddp2_job.sh + pip uninstall -y pytorch-lightning + displayName: 'Examples' diff --git a/dockers/base-conda/Dockerfile b/dockers/base-conda/Dockerfile index b586aa8d8b293..585aa1768ffd7 100644 --- a/dockers/base-conda/Dockerfile +++ b/dockers/base-conda/Dockerfile @@ -98,10 +98,12 @@ ENV \ COPY ./requirements/extra.txt requirements-extra.txt COPY ./requirements/test.txt requirements-test.txt +COPY ./requirements/adjust_versions.py requirements_adjust_versions.py RUN \ pip list | grep torch && \ python -c "import torch; print(torch.__version__)" && \ + python requirements_adjust_versions.py requirements-extra.txt && \ # Install remaining requirements pip install -r requirements-extra.txt --no-cache-dir && \ pip install -r requirements-test.txt --no-cache-dir && \ diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile index ee481670b1a5b..843e47ca91289 100644 --- a/dockers/base-cuda/Dockerfile +++ b/dockers/base-cuda/Dockerfile @@ -94,12 +94,14 @@ RUN \ # eventualy use pre-release #pip install "torch==${PYTORCH_VERSION}.*" --pre && \ # set particular PyTorch version - python -c "import re ; fname = 'requirements.txt' ; req = re.sub(r'torch[>=]+[\d\.]+', 'torch==${PYTORCH_VERSION}.*', open(fname).read()) ; open(fname, 'w').write(req)" && \ + python ./requirements/adjust_versions.py requirements.txt ${PYTORCH_VERSION} && \ + python ./requirements/adjust_versions.py requirements/extra.txt ${PYTORCH_VERSION} && \ + python ./requirements/adjust_versions.py requirements/examples.txt ${PYTORCH_VERSION} && \ # Install all requirements # todo: find a way how to install nightly PT version # --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${cuda_ver[0]}${cuda_ver[1]}/torch_nightly.html pip install -r requirements/devel.txt --no-cache-dir && \ - rm -rf requirements* + rm -rf requirements.* requirements/ RUN \ # install DALI, needed for examples @@ -113,7 +115,7 @@ RUN \ RUN \ # install DeepSpeed from source. - # todo: swap to pypi release once DeepSpeed releases a new version. + # todo: swap to pypi release once DeepSpeed releases a new version >= 0.3.10 pip install deepspeed@git+https://github.com/microsoft/DeepSpeed@ec8b1cb RUN \ diff --git a/dockers/base-xla/Dockerfile b/dockers/base-xla/Dockerfile index ce612f158753a..7f7e74bba75a6 100644 --- a/dockers/base-xla/Dockerfile +++ b/dockers/base-xla/Dockerfile @@ -104,6 +104,7 @@ RUN \ python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'fairscale' not in line] ; open(fname, 'w').writelines(lines)" && \ # drop TorchVision as it was installed with XLA python -c "fname = 'requirements/examples.txt' ; lines = [line for line in open(fname).readlines() if not line.startswith('torchvision')] ; open(fname, 'w').writelines(lines)" && \ + python ./requirements/adjust_versions.py ./requirements/extra.txt && \ pip install --requirement ./requirements/devel.txt --no-cache-dir && \ cd .. && \ rm -rf pytorch-lightning && \ diff --git a/dockers/release/Dockerfile b/dockers/release/Dockerfile index c0778f858830a..3584ee02746e3 100644 --- a/dockers/release/Dockerfile +++ b/dockers/release/Dockerfile @@ -27,14 +27,15 @@ COPY ./ ./pytorch-lightning/ RUN \ # Disable cache #conda install "pip>20.1" && \ - #pip config set global.cache-dir false && \ - if [ -z $LIGHTNING_VERSION ] ; then \ - pip install ./pytorch-lightning --no-cache-dir ; \ + if [ ! -z "$LIGHTNING_VERSION" ] ; then \ rm -rf pytorch-lightning ; \ - else \ - rm -rf pytorch-lightning ; \ - pip install https://github.com/PyTorchLightning/pytorch-lightning/archive/${LIGHTNING_VERSION}.zip --no-cache-dir ; \ - fi + wget https://github.com/PyTorchLightning/pytorch-lightning/archive/${LIGHTNING_VERSION}.zip --progress=bar:force:noscroll ; \ + unzip ${LIGHTNING_VERSION}.zip ; \ + mv pytorch-lightning-*/ pytorch-lightning ; \ + rm *.zip ; \ + fi && \ + pip install ./pytorch-lightning["extra"] --no-cache-dir && \ + rm -rf pytorch-lightning RUN python --version && \ pip --version && \ diff --git a/docs/source/advanced/multiple_loaders.rst b/docs/source/advanced/multiple_loaders.rst index fb1aa33f80462..2e3e3201b2181 100644 --- a/docs/source/advanced/multiple_loaders.rst +++ b/docs/source/advanced/multiple_loaders.rst @@ -16,6 +16,8 @@ Lightning supports multiple dataloaders in a few ways. ---------- +.. _multiple-training-dataloaders: + Multiple training dataloaders ----------------------------- For training, the usual way to use multiple dataloaders is to create a ``DataLoader`` class @@ -86,6 +88,27 @@ For more details please have a look at :attr:`~pytorch_lightning.trainer.trainer return loaders +Furthermore, Lightning also supports that nested lists and dicts (or a combination) can +be returned + +.. testcode:: + + class LitModel(LightningModule): + + def train_dataloader(self): + + loader_a = torch.utils.data.DataLoader(range(8), batch_size=4) + loader_b = torch.utils.data.DataLoader(range(16), batch_size=4) + loader_c = torch.utils.data.DataLoader(range(32), batch_size=4) + loader_c = torch.utils.data.DataLoader(range(64), batch_size=4) + + # pass loaders as a nested dict. This will create batches like this: + # {'loader_a_b': {'a': batch from loader a, 'b': batch from loader b}, + # 'loader_c_d': {'c': batch from loader c, 'd': batch from loader d}} + loaders = {'loaders_a_b': {'a': loader_a, 'b': loader_b}, + 'loaders_c_d': {'c': loader_c, 'd': loader_d}} + return loaders + ---------- Test/Val dataloaders diff --git a/notebooks/06-mnist-tpu-training.ipynb b/notebooks/06-mnist-tpu-training.ipynb index 359d262dfd880..67b8d331ccd24 100644 --- a/notebooks/06-mnist-tpu-training.ipynb +++ b/notebooks/06-mnist-tpu-training.ipynb @@ -80,7 +80,7 @@ "id": "AYGWh10lRaF1" }, "source": [ - "! pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl" + "! pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp37-cp37m-linux_x86_64.whl" ], "execution_count": null, "outputs": [] diff --git a/pl_examples/__init__.py b/pl_examples/__init__.py index 6ad0a4dfc0624..ffd60f9ed71af 100644 --- a/pl_examples/__init__.py +++ b/pl_examples/__init__.py @@ -1,14 +1,30 @@ import os +from urllib.error import HTTPError + +from six.moves import urllib from pytorch_lightning.utilities import _module_available +# TorchVision hotfix https://github.com/pytorch/vision/issues/1938 +opener = urllib.request.build_opener() +opener.addheaders = [('User-agent', 'Mozilla/5.0')] +urllib.request.install_opener(opener) + _EXAMPLES_ROOT = os.path.dirname(__file__) _PACKAGE_ROOT = os.path.dirname(_EXAMPLES_ROOT) _DATASETS_PATH = os.path.join(_PACKAGE_ROOT, 'Datasets') _TORCHVISION_AVAILABLE = _module_available("torchvision") +_TORCHVISION_MNIST_AVAILABLE = True _DALI_AVAILABLE = _module_available("nvidia.dali") +if _TORCHVISION_AVAILABLE: + try: + from torchvision.datasets.mnist import MNIST + MNIST(_DATASETS_PATH, download=True) + except HTTPError: + _TORCHVISION_MNIST_AVAILABLE = False + LIGHTNING_LOGO = """ #### ########### diff --git a/pl_examples/basic_examples/autoencoder.py b/pl_examples/basic_examples/autoencoder.py index c60c4faec4acd..b3188a21b7f04 100644 --- a/pl_examples/basic_examples/autoencoder.py +++ b/pl_examples/basic_examples/autoencoder.py @@ -20,9 +20,9 @@ from torch.utils.data import DataLoader, random_split import pytorch_lightning as pl -from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, cli_lightning_logo +from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo -if _TORCHVISION_AVAILABLE: +if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: from torchvision import transforms from torchvision.datasets.mnist import MNIST else: diff --git a/pl_examples/basic_examples/backbone_image_classifier.py b/pl_examples/basic_examples/backbone_image_classifier.py index ad50da18ff3fd..01a5dca0de3c7 100644 --- a/pl_examples/basic_examples/backbone_image_classifier.py +++ b/pl_examples/basic_examples/backbone_image_classifier.py @@ -19,9 +19,9 @@ from torch.utils.data import DataLoader, random_split import pytorch_lightning as pl -from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, cli_lightning_logo +from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo -if _TORCHVISION_AVAILABLE: +if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: from torchvision import transforms from torchvision.datasets.mnist import MNIST else: diff --git a/pl_examples/basic_examples/dali_image_classifier.py b/pl_examples/basic_examples/dali_image_classifier.py index d6e64d2b3de14..b4bf1407a9b26 100644 --- a/pl_examples/basic_examples/dali_image_classifier.py +++ b/pl_examples/basic_examples/dali_image_classifier.py @@ -23,9 +23,15 @@ from torch.utils.data import random_split import pytorch_lightning as pl -from pl_examples import _DALI_AVAILABLE, _DATASETS_PATH, _TORCHVISION_AVAILABLE, cli_lightning_logo - -if _TORCHVISION_AVAILABLE: +from pl_examples import ( + _DALI_AVAILABLE, + _DATASETS_PATH, + _TORCHVISION_AVAILABLE, + _TORCHVISION_MNIST_AVAILABLE, + cli_lightning_logo, +) + +if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: from torchvision import transforms from torchvision.datasets.mnist import MNIST else: diff --git a/pl_examples/basic_examples/mnist_datamodule.py b/pl_examples/basic_examples/mnist_datamodule.py index 46acc5a3a2a14..a50f67cdab301 100644 --- a/pl_examples/basic_examples/mnist_datamodule.py +++ b/pl_examples/basic_examples/mnist_datamodule.py @@ -17,10 +17,10 @@ from torch.utils.data import DataLoader, random_split -from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE +from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE from pytorch_lightning import LightningDataModule -if _TORCHVISION_AVAILABLE: +if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import MNIST else: diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index 2aa2a9f73db8b..285fba8b93f1b 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -26,15 +26,19 @@ import torch import torch.nn as nn import torch.nn.functional as F # noqa -import torchvision -import torchvision.transforms as transforms from torch.utils.data import DataLoader -from torchvision.datasets import MNIST -from pl_examples import cli_lightning_logo +from pl_examples import _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo from pytorch_lightning.core import LightningDataModule, LightningModule from pytorch_lightning.trainer import Trainer +if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: + import torchvision + import torchvision.transforms as transforms + from torchvision.datasets import MNIST +else: + from tests.helpers.datasets import MNIST + class Generator(nn.Module): """ diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index 715e00e7b9d9c..abd09e53980c9 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -5,7 +5,7 @@ import time _this_year = time.strftime("%Y") -__version__ = '1.2.2' +__version__ = '1.2.3' __author__ = 'William Falcon et al.' __author_email__ = 'waf2107@columbia.edu' __license__ = 'Apache-2.0' diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 8f63bc7b86b11..d285a197c49fe 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -28,7 +28,7 @@ def setup(self, trainer, model): return super().setup(trainer, model) def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs): - xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs}) + xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs}) def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): """ @@ -40,4 +40,7 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s Return: A tensor of shape (world_size, batch, ...) """ - return xm.all_gather(tensor, group=group, sync_grads=sync_grads) + # todo: Add support for backward with all_gather + if torch.distributed.is_initialized(): + return xm.all_gather(tensor, group=group, sync_grads=sync_grads) + return tensor diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index ee130a700ae68..72c32a8f5b738 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -18,7 +18,7 @@ import inspect from copy import deepcopy from functools import partial -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn.utils.prune as pytorch_prune @@ -27,7 +27,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException _PYTORCH_PRUNING_FUNCTIONS = { @@ -246,14 +246,18 @@ def _create_pruning_fn(self, pruning_fn: str, **kwargs) -> Union[Callable, pytor def _wrap_pruning_fn(pruning_fn, **kwargs): return partial(pruning_fn, **kwargs) - def make_pruning_permanent(self): - """ Makes ``parameters_to_prune`` current pruning permanent. """ - for module, param_name in self._parameters_to_prune: - try: - pytorch_prune.remove(module, param_name) - except ValueError: - # pruning already made permanent - pass + def make_pruning_permanent(self, pl_module: LightningModule): + """ + Removes pruning buffers from any pruned modules + + Adapted from https://github.com/pytorch/pytorch/blob/1.7.1/torch/nn/utils/prune.py#L1176-L1180 + """ + for _, module in pl_module.named_modules(): + for k in list(module._forward_pre_hooks): + hook = module._forward_pre_hooks[k] + if isinstance(hook, pytorch_prune.BasePruningMethod): + hook.remove(module) + del module._forward_pre_hooks[k] def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, tensor_name: str): trained = getattr(module, tensor_name) @@ -351,7 +355,7 @@ def _log_sparsity_stats( f" {curr_mask_zeros} ({curr_mask_zeros / curr_mask_size:.2%})" ) - def on_before_accelerator_backend_setup(self, trainer, pl_module): + def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModule): parameters_to_prune = self.sanitize_parameters_to_prune( pl_module, self._parameters_to_prune, parameter_names=self._parameter_names ) @@ -367,7 +371,7 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module): self._original_layers.setdefault(id_, {"data": deepcopy(module), "names": []}) self._original_layers[id_]["names"].append((i, name)) - def on_train_epoch_end(self, trainer, pl_module, *args): + def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs): current_epoch = trainer.current_epoch prune = self._apply_pruning(current_epoch) if isinstance(self._apply_pruning, Callable) else self._apply_pruning amount = self.amount(current_epoch) if isinstance(self.amount, Callable) else self.amount @@ -381,13 +385,20 @@ def on_train_epoch_end(self, trainer, pl_module, *args): ): self.apply_lottery_ticket_hypothesis() - def on_train_end(self, *args): + def on_train_end(self, trainer, pl_module: LightningModule): if self._make_pruning_permanent: - self.make_pruning_permanent() + rank_zero_debug("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint.") + self.make_pruning_permanent(pl_module) - def on_save_checkpoint(self, *args): + def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]): if self._make_pruning_permanent: - self.make_pruning_permanent() + rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint.") + prev_device = pl_module.device + # prune a copy so training can continue with the same buffers + copy = deepcopy(pl_module.to("cpu")) + self.make_pruning_permanent(copy) + checkpoint["state_dict"] = copy.state_dict() + pl_module.to(prev_device) @staticmethod def sanitize_parameters_to_prune( diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index e0b33c1219e8b..604803365298c 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -383,12 +383,14 @@ def prepare_data(self): model.test_dataloader() """ - def train_dataloader(self) -> DataLoader: + def train_dataloader(self) -> Any: """ - Implement a PyTorch DataLoader for training. + Implement one or more PyTorch DataLoaders for training. Return: - Single PyTorch :class:`~torch.utils.data.DataLoader`. + Either a single PyTorch :class:`~torch.utils.data.DataLoader` or a collection of these + (list, dict, nested lists and dicts). In the case of multiple dataloaders, please see + this :ref:`page ` The dataloader you return will not be called every epoch unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``. @@ -414,6 +416,7 @@ def train_dataloader(self) -> DataLoader: Example:: + # single dataloader def train_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) @@ -426,6 +429,32 @@ def train_dataloader(self): ) return loader + # multiple dataloaders, return as list + def train_dataloader(self): + mnist = MNIST(...) + cifar = CIFAR(...) + mnist_loader = torch.utils.data.DataLoader( + dataset=mnist, batch_size=self.batch_size, shuffle=True + ) + cifar_loader = torch.utils.data.DataLoader( + dataset=cifar, batch_size=self.batch_size, shuffle=True + ) + # each batch will be a list of tensors: [batch_mnist, batch_cifar] + return [mnist_loader, cifar_loader] + + # multiple dataloader, return as dict + def train_dataloader(self): + mnist = MNIST(...) + cifar = CIFAR(...) + mnist_loader = torch.utils.data.DataLoader( + dataset=mnist, batch_size=self.batch_size, shuffle=True + ) + cifar_loader = torch.utils.data.DataLoader( + dataset=cifar, batch_size=self.batch_size, shuffle=True + ) + # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar} + return {'mnist': mnist_loader, 'cifar': cifar_loader} + """ rank_zero_warn("`train_dataloader` must be implemented to be used with the Lightning Trainer") diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 8b6548f438756..162e17ca47bf5 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -38,7 +38,7 @@ class LightningOptimizer: def __init__(self, optimizer: Optimizer): - self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k != 'step'} + self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ('step', "__del__")} # For Horovod if hasattr(optimizer, "skip_synchronize"): diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index 7a4dc726ea555..7abf260a822ef 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -283,6 +283,7 @@ def _stable_1d_sort(x: torch, N: int = 2049): n = x.numel() if N - n > 0: x_max = x.max() - x_pad = torch.cat([x, (x_max + 1) * torch.ones(2049 - n, dtype=x.dtype, device=x.device)], 0) - x_sort = x_pad.sort() - return x_sort.values[:n], x_sort.indices[:n] + x = torch.cat([x, (x_max + 1) * torch.ones(N - n, dtype=x.dtype, device=x.device)], 0) + x_sort = x.sort() + i = min(N, n) + return x_sort.values[:i], x_sort.indices[:i] diff --git a/pytorch_lightning/plugins/precision/sharded_native_amp.py b/pytorch_lightning/plugins/precision/sharded_native_amp.py index b3b01fc720d2b..8ade1396a174c 100644 --- a/pytorch_lightning/plugins/precision/sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -31,6 +31,9 @@ def __init__(self): super().__init__() self.scaler = ShardedGradScaler() - def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)): + def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None: + if clip_val <= 0: + return + optimizer = cast(OSS, optimizer) optimizer.clip_grad_norm(clip_val, norm_type=norm_type) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 351d945675a0c..13585f8f368f4 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -15,6 +15,7 @@ from typing import Any, List, Optional, Union import torch +import torch.distributed as torch_distrib from torch.optim.lr_scheduler import _LRScheduler, Optimizer from pytorch_lightning.core.optimizer import LightningOptimizer @@ -116,7 +117,8 @@ def start_predicting(self, trainer): hvd.join() def barrier(self, *args, **kwargs): - hvd.join() + if torch_distrib.is_initialized(): + hvd.join() def broadcast(self, obj: object, src: int = 0) -> object: obj = hvd.broadcast_object(obj, src) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 692a4426a6ad6..371649057909b 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -4,6 +4,7 @@ from typing import Any, Dict, Iterable, List, Optional, Union import torch +import torch.distributed as torch_distrib import torch.multiprocessing as mp from pytorch_lightning.core.lightning import LightningModule @@ -112,7 +113,8 @@ def model_to_device(self) -> None: self._model.to(xm.xla_device()) def barrier(self, name: Optional[str] = None) -> None: - rendezvous(f"pl.Trainer.{name}") + if torch_distrib.is_initialized(): + rendezvous(f"pl.Trainer.{name}") def transfer_distrib_spawn_state_on_fit_end(self, results): # TODO: is there a better way than accessing callback through model -> trainer -> callback? @@ -126,7 +128,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): # TODO: is there a better way than accessing trainer through model -> trainer? if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0: last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - xm.save(self.lightning_module.state_dict(), last_path) + self.save(self.lightning_module.state_dict(), last_path) if self.global_rank == 0: # todo, pass complete checkpoint as state dictionary @@ -134,6 +136,18 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): self.mp_queue.put(last_path) self.mp_queue.put(results) + def save(self, state_dict: Dict, path: str) -> None: + """ + Saving with ``xm.save`` can be unstable and miss the rendez-vous after ``torch.save``. + The rendez-vous doesn't affect directly saving. + We can ignore the ``RuntimeError`` to reduce friction with TPUs. + """ + try: + xm.save(state_dict, path) + except RuntimeError as e: + if "Failed to meet rendezvous" not in str(e): + raise e + def broadcast(self, obj: object, src: int = 0) -> object: buffer = io.BytesIO() torch.save(obj, buffer) @@ -234,7 +248,7 @@ def __load_weights_on_main_process(self) -> None: self._model = model def _close_logger(self, trainer) -> None: - if hasattr(trainer, "logger"): + if trainer.logger is not None: trainer.logger.finalize("success") @property @@ -281,4 +295,4 @@ def save_checkpoint(self, filepath, weights_only: bool = False): # dump states as a checkpoint dictionary object _checkpoint = self.lightning_module.trainer.checkpoint_connector.dump_checkpoint(weights_only) # Todo: TypeError: 'mappingproxy' object does not support item assignment - xm.save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath) + self.save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath) diff --git a/pytorch_lightning/profiler/__init__.py b/pytorch_lightning/profiler/__init__.py index 546ed45e18263..e09a5ea11a084 100644 --- a/pytorch_lightning/profiler/__init__.py +++ b/pytorch_lightning/profiler/__init__.py @@ -194,13 +194,8 @@ def custom_processing_step(self, data): """ -from pytorch_lightning.profiler.profilers import ( - AdvancedProfiler, - BaseProfiler, - PassThroughProfiler, - PyTorchProfiler, - SimpleProfiler, -) +from pytorch_lightning.profiler.profilers import AdvancedProfiler, BaseProfiler, PassThroughProfiler, SimpleProfiler +from pytorch_lightning.profiler.pytorch import PyTorchProfiler __all__ = [ 'BaseProfiler', diff --git a/pytorch_lightning/profiler/profilers.py b/pytorch_lightning/profiler/profilers.py index 2089943aa7593..24cf9af8e5802 100644 --- a/pytorch_lightning/profiler/profilers.py +++ b/pytorch_lightning/profiler/profilers.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Profiler to check if there are any bottlenecks in your code.""" - import cProfile -import inspect import io import os import pstats @@ -22,16 +20,12 @@ from abc import ABC, abstractmethod from collections import defaultdict from contextlib import contextmanager -from typing import List, Optional, Union +from typing import Optional, Union import numpy as np -import torch from pytorch_lightning import _logger as log -from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.cloud_io import get_filesystem -from pytorch_lightning.utilities.distributed import rank_zero_warn -from pytorch_lightning.utilities.exceptions import MisconfigurationException class BaseProfiler(ABC): @@ -285,261 +279,3 @@ def __del__(self): """Close profiler's stream.""" if self.output_file: self.output_file.close() - - -class PyTorchProfiler(BaseProfiler): - - PROFILED_FUNCTIONS = ("training_step_and_backward", "validation_step", "test_step") - AVAILABLE_SORT_KEYS = ( - "cpu_time", - "cuda_time", - "cpu_time_total", - "cuda_time_total", - "cpu_memory_usage", - "cuda_memory_usage", - "self_cpu_memory_usage", - "self_cuda_memory_usage", - "count", - ) - - def __init__( - self, - output_filename: Optional[str] = None, - enabled: bool = True, - use_cuda: bool = False, - record_shapes: bool = False, - profile_memory: bool = False, - group_by_input_shapes: bool = False, - with_stack: bool = False, - use_kineto: bool = False, - use_cpu: bool = True, - emit_nvtx: bool = False, - export_to_chrome: bool = False, - path_to_export_trace: str = None, - row_limit: int = 20, - sort_by_key: Optional[str] = None, - profiled_functions: Optional[List] = None, - local_rank: Optional[int] = None, - ): - """ - This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of - different operators inside your model - both on the CPU and GPU - - Args: - - output_filename: optionally save profile results to file instead of printing - to std out when training is finished. When using ``ddp``, - each rank will stream the profiled operation to their own file - with the extension ``_{rank}.txt`` - - enabled: Setting this to False makes this context manager a no-op. - - use_cuda: Enables timing of CUDA events as well using the cudaEvent API. - Adds approximately 4us of overhead to each tensor operation. - - record_shapes: If shapes recording is set, information about input dimensions will be collected. - - profile_memory: Whether to report memory usage, default: True (Introduced in PyTorch 1.6.0) - - group_by_input_shapes: Include operator input shapes and group calls by shape. - - with_stack: record source information (file and line number) for the ops (Introduced in PyTorch 1.7.0) - - use_kineto: experimental support for Kineto profiler (Introduced in PyTorch 1.8.0) - - use_cpu: use_kineto=True and can be used to lower the overhead - for GPU-only profiling (Introduced in PyTorch 1.8.0) - - emit_nvtx: Context manager that makes every autograd operation emit an NVTX range - Run:: - - nvprof --profile-from-start off -o trace_name.prof -- - - To visualize, you can either use:: - - nvvp trace_name.prof - torch.autograd.profiler.load_nvprof(path) - - export_to_chrome: Wether to export the sequence of profiled operators for Chrome. - It will generate a ``.json`` file which can be read by Chrome. - - path_to_export_trace: Directory path to export ``.json`` traces when using ``export_to_chrome=True``. - By default, it will be save where the file being is being run. - - row_limit: Limit the number of rows in a table, `0` is a special value that - removes the limit completely. - - sort_by_key: Keys to sort out profiled table - - profiled_functions: list of profiled functions which will create a context manager on. - Any other will be pass through. - - local_rank: When running in distributed setting, local_rank is used for each process - to write to their own file if `output_fname` is provided. - """ - - self.profiled_actions = {} - self.enabled = enabled - self.profiled_functions = profiled_functions or self.PROFILED_FUNCTIONS - self.use_cuda = use_cuda - self.record_shapes = record_shapes - self.profile_memory = profile_memory - self.sort_by_key = sort_by_key or ("cuda_time_total" if self.use_cuda else "cpu_time_total") - self.with_stack = with_stack - self.group_by_input_shapes = group_by_input_shapes and record_shapes - self.use_kineto = use_kineto - self.use_cpu = use_cpu - self.row_limit = row_limit - self.emit_nvtx = emit_nvtx - self.export_to_chrome = export_to_chrome - self.path_to_export_trace = path_to_export_trace - - if export_to_chrome and path_to_export_trace is None: - rank_zero_warn( - "The exported trace would be save locally as `path_to_export_trace` is empty." - " Note: Each functions will generate its own traced file." - ) - - if self.sort_by_key not in self.AVAILABLE_SORT_KEYS: - raise MisconfigurationException( - f"Found sort_by_key: {sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. " - ) - - self.profiled_actions = {} - self.context_names = {} - self.running_stack = [] - self.profiler = None - - self.output_fname = output_filename - self.output_file = None - if local_rank is not None: - self.on_train_start(local_rank=local_rank) - self.on_train_start = super().on_train_start - - def on_train_start(self, local_rank: Optional[str] = None): - self.local_rank = local_rank - - # when logging to `log.info`, only perform profiling on rank 0 - if local_rank != 0 and self.output_fname is None: - self.wrap_functions_into_rank_zero_only() - - if self.output_fname: - if local_rank is not None: - if '.txt' not in self.output_fname: - raise MisconfigurationException("Log file should be .txt file.") - - self.output_fname = self.output_fname.replace(".txt", f"_{self.local_rank}.txt") - - fs = get_filesystem(self.output_fname) - self.output_file = fs.open(self.output_fname, "w") - - streaming_out = [self.output_file.write] if self.output_file else [log.info] - super().__init__(output_streams=streaming_out) - - def wrap_functions_into_rank_zero_only(self): - self.start = rank_zero_only(self.start) - self.stop = rank_zero_only(self.stop) - self.summary = rank_zero_only(self.summary) - self.describe = rank_zero_only(self.describe) - - def start(self, action_name: str) -> None: - if action_name not in self.profiled_functions: - return - - if len(self.running_stack) > 0: - self._stop(self.running_stack[-1]) - self.running_stack.append(action_name) - - self.context_names[action_name] = "/".join(self.running_stack) - - self._start(action_name) - - def _start(self, action_name: str) -> None: - if self.emit_nvtx: - self._create_profiler(action_name, torch.cuda.profiler.profile, enter=False) - self._create_profiler(action_name, torch.autograd.profiler.emit_nvtx) - else: - self._create_profiler(action_name, torch.autograd.profiler.profile) - - def _create_profiler(self, action_name, profiler, enter=True): - init_args = inspect.signature(profiler.__init__).parameters - profiler_args = {k: v for k, v in vars(self).items() if k in init_args} - pr = profiler(**profiler_args) - if enter: - pr = pr.__enter__() - self.profiler = pr - - def _stop(self, action_name: str) -> None: - if self.profiler is None: - return - - self.profiler.__exit__(exc_type=None, exc_val=None, exc_tb=None) - - function_events = self.profiler.function_events - self.profiler = None - for name in self.running_stack: - if name not in self.profiled_actions: - self.profiled_actions[name] = function_events - else: - self.profiled_actions[name] += function_events - - def stop(self, action_name: str) -> None: - if action_name not in self.profiled_functions: - return - - if len(self.running_stack) == 0 or self.running_stack[-1] != action_name: - raise ValueError( # pragma: no-cover - f"Attempting to stop recording an action ({action_name}) which was never started." - ) - self._stop(action_name) - self.running_stack.pop() - # restore running profiler - if len(self.running_stack) > 0: - self._start(self.running_stack[-1]) - - def summary(self) -> str: - recorded_stats = {} - output_string = '' - local_rank = '0' if self.local_rank is None else self.local_rank - - if not self.enabled: - return output_string - - for action_name, function_events in self.profiled_actions.items(): - - # next line is a workaround for a pytorch issue (fixed on master, still present - # on 1.7). Without it the code fails with `AssertionError: There is already a CPU - # parent event for detach` - function_events.populate_cpu_children = lambda: None - - if self.export_to_chrome: - filename = f"{action_name}_{local_rank}_trace.json" - path_to_trace = filename if self.path_to_export_trace is None \ - else os.path.join(self.path_to_export_trace, filename) - function_events.export_chrome_trace(path_to_trace) - - if self.emit_nvtx: - return output_string - - else: - data = function_events.key_averages(group_by_input_shapes=self.group_by_input_shapes) - table = data.table(sort_by=self.sort_by_key, row_limit=self.row_limit) - recorded_stats[action_name] = table - - # log to standard out - output_string = f"{os.linesep}Profiler Report{os.linesep}" - for action, stats in recorded_stats.items(): - output_string += (f"{os.linesep}Profile stats for: {action} rank: {local_rank} {os.linesep}{stats}") - - return output_string - - def describe(self): - """Logs a profile report after the conclusion of the training run.""" - super().describe() - if self.output_file: - self.output_file.flush() - - def __del__(self): - """Close profiler's stream.""" - if self.output_file: - self.output_file.close() diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py new file mode 100644 index 0000000000000..88a33a3d367f8 --- /dev/null +++ b/pytorch_lightning/profiler/pytorch.py @@ -0,0 +1,303 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Profiler to check if there are any bottlenecks in your code.""" + +import inspect +import logging +import os +from typing import List, Optional + +import torch + +from pytorch_lightning.profiler.profilers import BaseProfiler +from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.distributed import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +log = logging.getLogger(__name__) + + +class PyTorchProfiler(BaseProfiler): + + PROFILED_FUNCTIONS = ("training_step_and_backward", "validation_step", "test_step") + AVAILABLE_SORT_KEYS = ( + "cpu_time", + "cuda_time", + "cpu_time_total", + "cuda_time_total", + "cpu_memory_usage", + "cuda_memory_usage", + "self_cpu_memory_usage", + "self_cuda_memory_usage", + "count", + ) + + def __init__( + self, + output_filename: Optional[str] = None, + enabled: bool = True, + use_cuda: bool = False, + record_shapes: bool = False, + profile_memory: bool = False, + group_by_input_shapes: bool = False, + with_stack: bool = False, + use_kineto: bool = False, + use_cpu: bool = True, + emit_nvtx: bool = False, + export_to_chrome: bool = False, + path_to_export_trace: str = None, + row_limit: int = 20, + sort_by_key: Optional[str] = None, + profiled_functions: Optional[List] = None, + local_rank: Optional[int] = None, + ): + """ + This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of + different operators inside your model - both on the CPU and GPU + + Args: + + output_filename: optionally save profile results to file instead of printing + to std out when training is finished. When using ``ddp``, + each rank will stream the profiled operation to their own file + with the extension ``_{rank}.txt`` + + enabled: Setting this to False makes this context manager a no-op. + + use_cuda: Enables timing of CUDA events as well using the cudaEvent API. + Adds approximately 4us of overhead to each tensor operation. + + record_shapes: If shapes recording is set, information about input dimensions will be collected. + + profile_memory: Whether to report memory usage, default: True (Introduced in PyTorch 1.6.0) + + group_by_input_shapes: Include operator input shapes and group calls by shape. + + with_stack: record source information (file and line number) for the ops (Introduced in PyTorch 1.7.0) + + use_kineto: experimental support for Kineto profiler (Introduced in PyTorch 1.8.0) + + use_cpu: use_kineto=True and can be used to lower the overhead + for GPU-only profiling (Introduced in PyTorch 1.8.0) + + emit_nvtx: Context manager that makes every autograd operation emit an NVTX range + Run:: + + nvprof --profile-from-start off -o trace_name.prof -- + + To visualize, you can either use:: + + nvvp trace_name.prof + torch.autograd.profiler.load_nvprof(path) + + export_to_chrome: Wether to export the sequence of profiled operators for Chrome. + It will generate a ``.json`` file which can be read by Chrome. + + path_to_export_trace: Directory path to export ``.json`` traces when using ``export_to_chrome=True``. + By default, it will be save where the file being is being run. + + row_limit: Limit the number of rows in a table, `0` is a special value that + removes the limit completely. + + sort_by_key: Keys to sort out profiled table + + profiled_functions: list of profiled functions which will create a context manager on. + Any other will be pass through. + + local_rank: When running in distributed setting, local_rank is used for each process + to write to their own file if `output_fname` is provided. + + Raises: + MisconfigurationException: + If arg ``sort_by_key`` is not present in ``AVAILABLE_SORT_KEYS``, or + if log file is not a ``.txt`` file. + ValueError: + If you attempt to stop recording an action which was never started. + """ + + self.profiled_actions = {} + self.enabled = enabled + self.profiled_functions = profiled_functions or self.PROFILED_FUNCTIONS + self.use_cuda = use_cuda + self.record_shapes = record_shapes + self.profile_memory = profile_memory + self.sort_by_key = sort_by_key or ("cuda_time_total" if self.use_cuda else "cpu_time_total") + self.with_stack = with_stack + self.group_by_input_shapes = group_by_input_shapes and record_shapes + self.use_kineto = use_kineto + self.use_cpu = use_cpu + self.row_limit = row_limit + self.emit_nvtx = emit_nvtx + self.export_to_chrome = export_to_chrome + self.path_to_export_trace = path_to_export_trace + + if export_to_chrome and path_to_export_trace is None: + rank_zero_warn( + "The exported trace would be save locally as `path_to_export_trace` is empty." + " Note: Each functions will generate its own traced file." + ) + + if self.sort_by_key not in self.AVAILABLE_SORT_KEYS: + raise MisconfigurationException( + f"Found sort_by_key: {sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. " + ) + + self.profiled_actions = {} + self.context_names = {} + self.running_stack = [] + self.profiler = None + + self.output_fname = output_filename + self.output_file = None + if local_rank is not None: + self.on_train_start(local_rank=local_rank) + self.on_train_start = super().on_train_start + + def on_train_start(self, local_rank: Optional[str] = None): + self.local_rank = local_rank + + # when logging to `log.info`, only perform profiling on rank 0 + if local_rank != 0 and self.output_fname is None: + self.wrap_functions_into_rank_zero_only() + + if self.output_fname: + if local_rank is not None: + if '.txt' not in self.output_fname: + raise MisconfigurationException("Log file should be .txt file.") + + self.output_fname = self.output_fname.replace(".txt", f"_{self.local_rank}.txt") + + fs = get_filesystem(self.output_fname) + self.output_file = fs.open(self.output_fname, "w") + + streaming_out = [self.output_file.write] if self.output_file else [log.info] + super().__init__(output_streams=streaming_out) + + def wrap_functions_into_rank_zero_only(self): + self.start = rank_zero_only(self.start) + self.stop = rank_zero_only(self.stop) + self.summary = rank_zero_only(self.summary) + self.describe = rank_zero_only(self.describe) + + def start(self, action_name: str) -> None: + if action_name not in self.profiled_functions: + return + + if len(self.running_stack) > 0: + self._stop(self.running_stack[-1]) + self.running_stack.append(action_name) + + self.context_names[action_name] = "/".join(self.running_stack) + + self._start(action_name) + + def _start(self, action_name: str) -> None: + if self.emit_nvtx: + self._parent_profiler = self._create_profiler(action_name, torch.cuda.profiler.profile, enter=True) + self._create_profiler(action_name, torch.autograd.profiler.emit_nvtx) + else: + self._create_profiler(action_name, torch.autograd.profiler.profile) + + def _create_profiler(self, action_name, profiler, enter=True): + init_args = inspect.signature(profiler.__init__).parameters + profiler_args = {k: v for k, v in vars(self).items() if k in init_args} + pr = profiler(**profiler_args) + if enter: + out_pr = pr.__enter__() + if out_pr is not None: + pr = out_pr + self.profiler = pr + return self.profiler + + def _stop(self, action_name: str) -> None: + if self.profiler is None: + return + + self.profiler.__exit__(exc_type=None, exc_val=None, exc_tb=None) + + if isinstance(self.profiler, torch.autograd.profiler.emit_nvtx): + # when running ``emit_nvtx``, PyTorch requires 2 context manager. + # The parent_profiler is being closed too. + self._parent_profiler.__exit__(None, None, None) + return + + function_events = self.profiler.function_events + self.profiler = None + for name in self.running_stack: + if name not in self.profiled_actions: + self.profiled_actions[name] = function_events + else: + self.profiled_actions[name] += function_events + + def stop(self, action_name: str) -> None: + if action_name not in self.profiled_functions: + return + + if len(self.running_stack) == 0 or self.running_stack[-1] != action_name: + raise ValueError( # pragma: no-cover + f"Attempting to stop recording an action ({action_name}) which was never started." + ) + self._stop(action_name) + self.running_stack.pop() + # restore running profiler + if len(self.running_stack) > 0: + self._start(self.running_stack[-1]) + + def summary(self) -> str: + recorded_stats = {} + output_string = '' + local_rank = '0' if self.local_rank is None else self.local_rank + + if not self.enabled: + return output_string + + for action_name, function_events in self.profiled_actions.items(): + + # next line is a workaround for a pytorch issue (fixed on master, still present + # on 1.7). Without it the code fails with `AssertionError: There is already a CPU + # parent event for detach` + function_events.populate_cpu_children = lambda: None + + if self.export_to_chrome: + filename = f"{action_name}_{local_rank}_trace.json" + path_to_trace = filename if self.path_to_export_trace is None \ + else os.path.join(self.path_to_export_trace, filename) + function_events.export_chrome_trace(path_to_trace) + + if self.emit_nvtx: + return output_string + + else: + data = function_events.key_averages(group_by_input_shapes=self.group_by_input_shapes) + table = data.table(sort_by=self.sort_by_key, row_limit=self.row_limit) + recorded_stats[action_name] = table + + # log to standard out + output_string = f"{os.linesep}Profiler Report{os.linesep}" + for action, stats in recorded_stats.items(): + output_string += (f"{os.linesep}Profile stats for: {action} rank: {local_rank} {os.linesep}{stats}") + + return output_string + + def describe(self): + """Logs a profile report after the conclusion of the training run.""" + super().describe() + if self.output_file: + self.output_file.flush() + + def __del__(self): + """Close profiler's stream.""" + if self.output_file: + self.output_file.close() diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index fb6a1d4ab8442..59d406b0479c6 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -494,7 +494,7 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): # define the max CPU available self.num_processes = os.cpu_count() # special case with TPUs - elif self.distributed_backend == 'tpu': + elif self.distributed_backend == 'tpu' or self.tpu_cores is not None: self._device_type = DeviceType.TPU elif self.distributed_backend and self._distrib_type is None: self._distrib_type = DistributedType(self.distributed_backend) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 087741aa69c2b..e1b3688ef36e6 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -203,6 +203,10 @@ def __run_eval_epoch_end(self, num_dataloaders): # with a single dataloader don't pass an array outputs = self.outputs + + # free memory + self.outputs = [] + eval_results = outputs if num_dataloaders == 1: eval_results = outputs[0] diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index ea881b796e825..a247fb92cd22f 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import ABC -from typing import List, Optional, Tuple, Dict, Any +from typing import Any, Dict, List, Optional, Tuple import torch from torch import optim @@ -27,7 +27,10 @@ class TrainerOptimizersMixin(ABC): + _lightning_optimizers: Optional[List[LightningOptimizer]] + def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]: + self._lightning_optimizers = None optim_conf = model.configure_optimizers() if optim_conf is None: rank_zero_warn( diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 58ce7f09ea2a2..e123c1af5a5d0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -15,7 +15,7 @@ import warnings from itertools import count from pathlib import Path -from typing import Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Union import torch from torch.utils.data import DataLoader @@ -56,7 +56,7 @@ from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.tuning import Tuner -from pytorch_lightning.utilities import DeviceType, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.enums import LightningEnum @@ -425,7 +425,7 @@ def setup_trainer(self, model: LightningModule): def fit( self, model: LightningModule, - train_dataloader: Optional[DataLoader] = None, + train_dataloader: Any = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[LightningDataModule] = None, ): @@ -437,8 +437,9 @@ def fit( model: Model to fit. - train_dataloader: A Pytorch DataLoader with training samples. If the model has - a predefined train_dataloader method this will be skipped. + train_dataloader: Either a single PyTorch DataLoader or a collection of these + (list, dict, nested lists and dicts). In the case of multiple dataloaders, please + see this :ref:`page ` val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped @@ -948,8 +949,8 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): f'specify a path for a checkpoint .test(ckpt_path=PATH)' ) return {} - if not self._device_type == DeviceType.TPU: - self.accelerator.barrier() + + self.training_type_plugin.barrier() ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt['state_dict']) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 2f7425bf3beb0..9fd42008b9d8d 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import operator from abc import ABC from collections.abc import Mapping, Sequence from copy import copy @@ -22,10 +22,13 @@ import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _TORCHTEXT_AVAILABLE +from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_AVAILABLE if _TORCHTEXT_AVAILABLE: - from torchtext.data import Batch + if _compare_version("torchtext", operator.ge, "0.9.0"): + from torchtext.legacy.data import Batch + else: + from torchtext.data import Batch else: Batch = type(None) diff --git a/requirements/adjust_versions.py b/requirements/adjust_versions.py new file mode 100644 index 0000000000000..c1499cd4ea5ee --- /dev/null +++ b/requirements/adjust_versions.py @@ -0,0 +1,49 @@ +import os +import re +import sys +from typing import Any, Dict + +VERSIONS_LUT: Dict[str, Dict[str, Any]] = { + "1.4.0": dict(torchvision="0.5.0", torchtext="0.5"), + "1.5.0": dict(torchvision="0.6.0", torchtext="0.6"), + "1.5.1": dict(torchvision="0.6.1", torchtext="0.6"), + "1.6.0": dict(torchvision="0.7.0", torchtext="0.7"), + "1.7.0": dict(torchvision="0.8.1", torchtext="0.8"), + "1.7.1": dict(torchvision="0.8.2", torchtext="0.8.1"), + "1.8.0": dict(torchvision="0.9.0", torchtext="0.9"), +} + + +def find_latest(ver: str, versions_all: list) -> str: + # drop all except semantic version + ver = re.search(r'([\.\d]+)', ver).groups()[0] + # find candidates, by starting version pattern + options = [v for v in versions_all if v.startswith(ver)] + assert options, f"missing {ver} among {versions_all}" + # take the last one... + return sorted(options)[-1] + + +def main(path_req: str, torch_version: str = None) -> None: + with open(path_req, "r") as fp: + req = fp.read() + + if not torch_version: + import torch + torch_version = torch.__version__ + assert torch_version, f"invalid/missing Torch: {torch_version}" + + torch_version = find_latest(torch_version, list(VERSIONS_LUT.keys())) + dep_versions = VERSIONS_LUT[torch_version] + dep_versions["torch"] = torch_version + for lib in dep_versions: + version = dep_versions[lib] + replace = f"{lib}=={version}\n" + req = re.sub(rf"{lib}[>=]*[\d\.]*{os.linesep}", replace, req) + + with open(path_req, "w") as fp: + fp.write(req) + + +if __name__ == "__main__": + main(*sys.argv[1:]) diff --git a/setup.cfg b/setup.cfg index 121716ec725f6..fc64e5d948ffd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -134,10 +134,6 @@ warn_unused_ignores = True [mypy-pytorch_lightning.accelerators.tpu.*] ignore_errors = True -# todo: add proper typing to this module... -[mypy-pytorch_lightning.accelerators.*] -ignore_errors = True - # todo: add proper typing to this module... [mypy-pytorch_lightning.callbacks.*] ignore_errors = True diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index 62b0d3a8f3bb3..484b09e27bc0d 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -15,7 +15,6 @@ import platform from collections import OrderedDict from logging import INFO -from unittest import mock import pytest import torch @@ -24,7 +23,7 @@ from torch.nn import Sequential from pytorch_lightning import seed_everything, Trainer -from pytorch_lightning.callbacks import ModelPruning +from pytorch_lightning.callbacks import ModelCheckpoint, ModelPruning from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel @@ -42,6 +41,10 @@ def __init__(self): ]) ) + def training_step(self, batch, batch_idx): + self.log("test", -batch_idx) + return super().training_step(batch, batch_idx) + class TestPruningMethod(pytorch_prune.BasePruningMethod): PRUNING_TYPE = "unstructured" @@ -219,7 +222,6 @@ def apply_lottery_ticket_hypothesis(self): @pytest.mark.parametrize("make_pruning_permanent", (False, True)) -@mock.patch.dict(os.environ, {}, clear=True) def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent): seed_everything(0) model = TestModel() @@ -244,8 +246,9 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent): with caplog.at_level(INFO): trainer.fit(model) - actual = [m.strip() for m in caplog.messages[-9:]] - expected = [ + actual = [m.strip() for m in caplog.messages] + actual = [m for m in actual if m.startswith("Applied")] + assert actual == [ "Applied `L1Unstructured`. Pruned: 0/1122 (0.00%) -> 544/1122 (48.48%)", "Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 506 (49.41%)", # noqa: E501 "Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 38 (59.38%)", # noqa: E501 @@ -256,7 +259,6 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent): "Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 633 (61.82%) -> 828 (80.86%)", # noqa: E501 "Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 47 (73.44%) -> 56 (87.50%)", # noqa: E501 ] - assert actual == expected filepath = str(tmpdir / "foo.ckpt") trainer.save_checkpoint(filepath) @@ -264,3 +266,47 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent): model.load_from_checkpoint(filepath, strict=False) has_pruning = hasattr(model.layer.mlp_1, "weight_orig") assert not has_pruning if make_pruning_permanent else has_pruning + + +def test_permanent_when_model_is_saved_multiple_times(tmpdir, caplog): + """ + When a model is saved multiple times and make_permanent=True, we need to + make sure a copy is pruned and not the trained model if we want to continue + with the same pruning buffers. + """ + seed_everything(0) + + class TestPruning(ModelPruning): + + def on_save_checkpoint(self, trainer, pl_module, checkpoint): + super().on_save_checkpoint(trainer, pl_module, checkpoint) + assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"] + assert hasattr(pl_module.layer.mlp_3, "weight_orig") + + model = TestModel() + pruning_callback = TestPruning( + "random_unstructured", + parameters_to_prune=[(model.layer.mlp_3, "weight")], + verbose=1, + make_pruning_permanent=True + ) + ckpt_callback = ModelCheckpoint(monitor="test", save_top_k=2, save_last=True) + trainer = Trainer(callbacks=[pruning_callback, ckpt_callback], max_epochs=3, progress_bar_refresh_rate=0) + with caplog.at_level(INFO): + trainer.fit(model) + + actual = [m.strip() for m in caplog.messages] + actual = [m for m in actual if m.startswith("Applied")] + assert actual == [ + "Applied `RandomUnstructured`. Pruned: 0/66 (0.00%) -> 32/66 (48.48%)", + "Applied `RandomUnstructured`. Pruned: 32/66 (48.48%) -> 48/66 (72.73%)", + "Applied `RandomUnstructured`. Pruned: 48/66 (72.73%) -> 56/66 (84.85%)", + ] + + # removed on_train_end + assert not hasattr(model.layer.mlp_3, "weight_orig") + + model.load_from_checkpoint(trainer.checkpoint_callback.kth_best_model_path) + assert not hasattr(model.layer.mlp_3, "weight_orig") + model.load_from_checkpoint(trainer.checkpoint_callback.last_model_path) + assert not hasattr(model.layer.mlp_3, "weight_orig") diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 7b1a7facbb3fe..b5d22372ff15f 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -52,6 +52,9 @@ "1.1.6", "1.1.7", "1.1.8", + "1.2.0", + "1.2.1", + "1.2.2", ] ) def test_resume_legacy_checkpoints(tmpdir, pl_version): diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index fd37f812c2337..3c6e34df8d5e3 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import patch, DEFAULT +import gc +from typing import Any +from unittest.mock import DEFAULT, patch import torch from torch.optim import Adam, Optimizer, SGD @@ -188,6 +190,7 @@ def test_lightning_optimizer_automatic_optimization_optimizer_zero_grad(tmpdir): """ Test overriding zero_grad works in automatic_optimization """ + class TestModel(BoringModel): def training_step(self, batch, batch_idx, optimizer_idx=None): @@ -281,7 +284,9 @@ def test_lightning_optimizer_automatic_optimization_lbfgs_zero_grad(tmpdir): Test zero_grad is called the same number of times as LBFGS requires for reevaluation of the loss in automatic_optimization. """ + class TestModel(BoringModel): + def configure_optimizers(self): return torch.optim.LBFGS(self.parameters()) @@ -300,3 +305,78 @@ def configure_optimizers(self): lbfgs = model.optimizers() max_iter = lbfgs.param_groups[0]["max_iter"] assert zero_grad.call_count == max_iter + + +class OptimizerWithHooks(Optimizer): + + def __init__(self, model): + self._fwd_handles = [] + self._bwd_handles = [] + self.params = [] + for _, mod in model.named_modules(): + mod_class = mod.__class__.__name__ + if mod_class != 'Linear': + continue + + handle = mod.register_forward_pre_hook(self._save_input) # save the inputs + self._fwd_handles.append(handle) # collect forward-save-input hooks in list + handle = mod.register_backward_hook(self._save_grad_output) # save the gradients + self._bwd_handles.append(handle) # collect backward-save-grad hook in list + + # save the parameters + params = [mod.weight] + if mod.bias is not None: + params.append(mod.bias) + + # save a param_group for each module + d = {'params': params, 'mod': mod, 'layer_type': mod_class} + self.params.append(d) + + super(OptimizerWithHooks, self).__init__(self.params, {"lr": 0.01}) + + def _save_input(self, mod, i): + """Saves input of layer""" + if mod.training: + self.state[mod]['x'] = i[0] + + def _save_grad_output(self, mod, _, grad_output): + """ + Saves grad on output of layer to + grad is scaled with batch_size since gradient is spread over samples in mini batch + """ + batch_size = grad_output[0].shape[0] + if mod.training: + self.state[mod]['grad'] = grad_output[0] * batch_size + + def step(self, closure=None): + closure() + for group in self.param_groups: + _ = self.state[group['mod']]['x'] + _ = self.state[group['mod']]['grad'] + return True + + +def test_lightning_optimizer_keeps_hooks(tmpdir): + + class TestModel(BoringModel): + count_on_train_batch_start = 0 + count_on_train_batch_end = 0 + + def configure_optimizers(self): + return OptimizerWithHooks(self) + + def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + self.count_on_train_batch_start += 1 + optimizer = self.optimizers(use_pl_optimizer=False) + assert len(optimizer._fwd_handles) == 1 + + def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + self.count_on_train_batch_end += 1 + del self.trainer._lightning_optimizers + gc.collect() # not necessary, just in case + + trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=4, limit_val_batches=1, max_epochs=1) + model = TestModel() + trainer.fit(model) + assert model.count_on_train_batch_start == 4 + assert model.count_on_train_batch_end == 4 diff --git a/tests/helpers/imports.py b/tests/helpers/imports.py new file mode 100644 index 0000000000000..4db9c00d45eab --- /dev/null +++ b/tests/helpers/imports.py @@ -0,0 +1,8 @@ +import operator + +from pytorch_lightning.utilities.imports import _compare_version + +if _compare_version("torchtext", operator.ge, "0.9.0"): + from torchtext.legacy.data import Batch, Dataset, Example, Field, Iterator, LabelField # noqa: F401 +else: + from torchtext.data import Batch, Dataset, Example, Field, Iterator, LabelField # noqa: F401 diff --git a/tests/metrics/classification/test_auc.py b/tests/metrics/classification/test_auc.py index 70d61b696711f..e902151ecffce 100644 --- a/tests/metrics/classification/test_auc.py +++ b/tests/metrics/classification/test_auc.py @@ -61,4 +61,4 @@ def test_auc_functional(self, x, y): ]) def test_auc(x, y, expected): # Test Area Under Curve (AUC) computation - assert auc(torch.tensor(x), torch.tensor(y)) == expected + assert auc(torch.tensor(x), torch.tensor(y), reorder=True) == expected diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index f30f12009450e..fe5f507fbacb4 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -16,7 +16,6 @@ import pytest import torch -from torchtext.data import Batch, Dataset, Example, Field, LabelField import tests.helpers.pipelines as tpipes import tests.helpers.utils as tutils @@ -24,6 +23,7 @@ from pytorch_lightning.utilities import device_parser from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel +from tests.helpers.imports import Batch, Dataset, Example, Field, LabelField PRETEND_N_OF_GPUS = 16 diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index bfa8f2432e3a2..db96e6854db90 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -177,8 +177,6 @@ def test_model_16bit_tpu_cores_8(tmpdir): def test_model_tpu_early_stop(tmpdir): """Test if single TPU core training works""" - # todo: Test on 8 cores - hanging. - class CustomBoringModel(BoringModel): def validation_step(self, *args, **kwargs): @@ -195,9 +193,10 @@ def validation_step(self, *args, **kwargs): max_epochs=2, limit_train_batches=2, limit_val_batches=2, - tpu_cores=[1], + tpu_cores=8, ) trainer.fit(model) + trainer.test(test_dataloaders=DataLoader(RandomDataset(32, 2000), batch_size=32)) @pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index f3683ffcba252..d431779dddb4e 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -1,5 +1,6 @@ import os import platform +from unittest import mock import pytest import torch @@ -12,6 +13,24 @@ from tests.helpers.boring_model import BoringModel +@pytest.mark.parametrize("clip_val", [0, 10]) +@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires GPU machine") +@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP") +@mock.patch('fairscale.optim.oss.OSS.clip_grad_norm') +def test_ddp_sharded_precision_16_clip_gradients(mock_oss_clip_grad_norm, clip_val, tmpdir): + """ + Ensure that clip gradients is only called if the value is greater than 0. + """ + model = BoringModel() + trainer = Trainer(accelerator='ddp_sharded', gpus=1, precision=16, fast_dev_run=True, gradient_clip_val=clip_val) + trainer.fit(model) + if clip_val > 0: + mock_oss_clip_grad_norm.assert_called() + else: + mock_oss_clip_grad_norm.assert_not_called() + + @pytest.mark.parametrize(["accelerator"], [("ddp_sharded", ), ("ddp_sharded_spawn", )]) @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_sharded_ddp_choice(tmpdir, accelerator): diff --git a/tests/special_tests.sh b/tests/special_tests.sh index 472f7afda5e9e..a2373d05a42ef 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -35,3 +35,4 @@ python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_trainer_ python ${DEFAULTS} tests/models/test_hooks.py::test_transfer_batch_hook_ddp python ${DEFAULTS} tests/trainer/test_data_loading.py::test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp_with_toggle_model +nvprof --profile-from-start off -o trace_name.prof -- python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_nested_emit_nvtx diff --git a/tests/trainer/logging_/test_eval_loop_logging_1_0.py b/tests/trainer/logging_/test_eval_loop_logging_1_0.py index 46890f6801711..765fab229f6cf 100644 --- a/tests/trainer/logging_/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging_/test_eval_loop_logging_1_0.py @@ -126,6 +126,7 @@ def validation_step_end(self, acc): def validation_epoch_end(self, outputs): self.log('g', torch.tensor(2, device=self.device), on_epoch=True) self.validation_epoch_end_called = True + assert len(self.trainer.evaluation_loop.outputs) == 0 def backward(self, loss, optimizer, optimizer_idx): return LightningModule.backward(self, loss, optimizer, optimizer_idx) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 23861cdd2563b..59f3c2b54c13c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -34,7 +34,7 @@ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.profiler.profilers import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler +from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE @@ -220,8 +220,14 @@ def test_trainer_accumulate_grad_batches_zero_grad(tmpdir, accumulate_grad_batch @pytest.mark.parametrize( ["accumulate_grad_batches", "limit_train_batches"], [ - ({1: 2, 3: 4}, 1.0), - ({1: 2, 3: 4}, 0.5), # not to be divisible by accumulate_grad_batches on purpose + ({ + 1: 2, + 3: 4 + }, 1.0), + ({ + 1: 2, + 3: 4 + }, 0.5), # not to be divisible by accumulate_grad_batches on purpose (3, 1.0), (3, 0.8), # not to be divisible by accumulate_grad_batches on purpose (4, 1.0), @@ -239,9 +245,7 @@ def on_batch_start(self, *_): def on_batch_end(self, outputs, batch, batch_idx, *_): self.on_train_batch_start_end_dict = self.state_dict() for key in self.on_train_batch_start_end_dict.keys(): - equal = torch.equal( - self.on_train_batch_start_state_dict[key], self.on_train_batch_start_end_dict[key] - ) + equal = torch.equal(self.on_train_batch_start_state_dict[key], self.on_train_batch_start_end_dict[key]) if (batch_idx + 1) == self.trainer.num_training_batches: assert equal else: @@ -1582,11 +1586,30 @@ def test_pytorch_profiler_nested(tmpdir): for n in ('a', 'b', 'c'): pa[n] = [e.name for e in pa[n]] - if LooseVersion(torch.__version__) >= LooseVersion("1.7.1"): + if LooseVersion(torch.__version__) >= LooseVersion("1.7.0"): pa[n] = [e.replace("aten::", "") for e in pa[n]] assert pa[n] == expected_[n] +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires GPU machine") +@pytest.mark.skipif( + not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" +) +def test_pytorch_profiler_nested_emit_nvtx(tmpdir): + """ + This test check emit_nvtx is correctly supported + """ + profiler = PyTorchProfiler(use_cuda=True, emit_nvtx=True) + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + profiler=profiler, + gpus=1, + ) + trainer.fit(model) + + @pytest.mark.parametrize( ["limit_train_batches", "global_step", "num_training_batches", "current_epoch", "should_train"], [(0.2, 0, 0, 0, False), (0.5, 10, 2, 4, True)], @@ -1738,6 +1761,7 @@ def test_train_loop_system(tmpdir): ) class TestOptimizer(SGD): + def step(self, *args, **kwargs): called_methods.append("step") return super().step(*args, **kwargs) @@ -1747,6 +1771,7 @@ def zero_grad(self, *args, **kwargs): return super().zero_grad(*args, **kwargs) class TestModel(BoringModel): + def configure_optimizers(self): return TestOptimizer(self.parameters(), lr=0.1) @@ -1800,3 +1825,28 @@ def backward(self, *args, **kwargs): "training_step", "backward", ] + + +def test_init_optimizers_resets_lightning_optimizers(tmpdir): + """ Test that the Trainer resets the `lightning_optimizers` list everytime new optimizers get initialized. """ + + def compare_optimizers(): + assert trainer.lightning_optimizers[0].optimizer is trainer.optimizers[0] + + model = BoringModel() + model.lr = 0.2 + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + auto_lr_find=True, + ) + + trainer.tune(model) + compare_optimizers() + + trainer.fit(model) + compare_optimizers() + + trainer.max_epochs = 2 # simulate multiple fit calls + trainer.fit(model) + compare_optimizers() diff --git a/tests/utilities/test_apply_func_torchtext.py b/tests/utilities/test_apply_func_torchtext.py index c7fec954fdb2f..c5b80ebbb14ee 100644 --- a/tests/utilities/test_apply_func_torchtext.py +++ b/tests/utilities/test_apply_func_torchtext.py @@ -13,14 +13,13 @@ # limitations under the License. import pytest import torch -import torchtext -from torchtext.data.example import Example from pytorch_lightning.utilities.apply_func import move_data_to_device +from tests.helpers.imports import Dataset, Example, Field, Iterator def _get_torchtext_data_iterator(include_lengths=False): - text_field = torchtext.data.Field( + text_field = Field( sequential=True, pad_first=False, # nosec init_token="", @@ -32,13 +31,13 @@ def _get_torchtext_data_iterator(include_lengths=False): example2 = Example.fromdict({"text": "b c a a"}, {"text": ("text", text_field)}) example3 = Example.fromdict({"text": "c b a"}, {"text": ("text", text_field)}) - dataset = torchtext.data.Dataset( + dataset = Dataset( [example1, example2, example3], {"text": text_field}, ) text_field.build_vocab(dataset) - iterator = torchtext.data.Iterator( + iterator = Iterator( dataset, batch_size=3, sort_key=None, diff --git a/tests/utilities/test_parsing.py b/tests/utilities/test_parsing.py index 42edb8e48f336..7087d183906a1 100644 --- a/tests/utilities/test_parsing.py +++ b/tests/utilities/test_parsing.py @@ -1,22 +1,27 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +import inspect + import pytest -from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_hasattr, lightning_setattr +from pytorch_lightning.utilities.parsing import ( + AttributeDict, + clean_namespace, + collect_init_args, + flatten_dict, + get_init_args, + is_picklable, + lightning_getattr, + lightning_hasattr, + lightning_setattr, + parse_class_init_keys, + str_to_bool, + str_to_bool_or_str, +) + +unpicklable_function = lambda: None -def _get_test_cases(): +@pytest.fixture(scope="module") +def model_cases(): class TestHparamsNamespace: learning_rate = 1 @@ -74,9 +79,9 @@ class TestModel7: # test for datamodule w/ hparams w/ attribute (should use dat return model1, model2, model3, model4, model5, model6, model7 -def test_lightning_hasattr(tmpdir): +def test_lightning_hasattr(tmpdir, model_cases): """Test that the lightning_hasattr works in all cases""" - model1, model2, model3, model4, model5, model6, model7 = models = _get_test_cases() + model1, model2, model3, model4, model5, model6, model7 = models = model_cases assert lightning_hasattr(model1, 'learning_rate'), \ 'lightning_hasattr failed to find namespace variable' assert lightning_hasattr(model2, 'learning_rate'), \ @@ -96,9 +101,9 @@ def test_lightning_hasattr(tmpdir): assert not lightning_hasattr(m, "this_attr_not_exist") -def test_lightning_getattr(tmpdir): +def test_lightning_getattr(tmpdir, model_cases): """Test that the lightning_getattr works in all cases""" - models = _get_test_cases() + models = model_cases for i, m in enumerate(models[:3]): value = lightning_getattr(m, 'learning_rate') assert value == i, 'attribute not correctly extracted' @@ -113,15 +118,15 @@ def test_lightning_getattr(tmpdir): for m in models: with pytest.raises( - AttributeError, - match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule." + AttributeError, + match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule." ): lightning_getattr(m, "this_attr_not_exist") -def test_lightning_setattr(tmpdir): +def test_lightning_setattr(tmpdir, model_cases): """Test that the lightning_setattr works in all cases""" - models = _get_test_cases() + models = model_cases for m in models[:3]: lightning_setattr(m, 'learning_rate', 10) assert lightning_getattr(m, 'learning_rate') == 10, \ @@ -140,7 +145,161 @@ def test_lightning_setattr(tmpdir): for m in models: with pytest.raises( - AttributeError, - match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule." + AttributeError, + match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule." ): lightning_setattr(m, "this_attr_not_exist", None) + + +def test_str_to_bool_or_str(tmpdir): + true_cases = ['y', 'yes', 't', 'true', 'on', '1'] + false_cases = ['n', 'no', 'f', 'false', 'off', '0'] + other_cases = ['yyeess', 'noooo', 'lightning'] + + for case in true_cases: + assert str_to_bool_or_str(case) is True + + for case in false_cases: + assert str_to_bool_or_str(case) is False + + for case in other_cases: + assert str_to_bool_or_str(case) == case + + +def test_str_to_bool(tmpdir): + true_cases = ['y', 'yes', 't', 'true', 'on', '1'] + false_cases = ['n', 'no', 'f', 'false', 'off', '0'] + other_cases = ['yyeess', 'noooo', 'lightning'] + + for case in true_cases: + assert str_to_bool(case) is True + + for case in false_cases: + assert str_to_bool(case) is False + + for case in other_cases: + with pytest.raises(ValueError): + str_to_bool(case) + + +def test_is_picklable(tmpdir): + # See the full list of picklable types at + # https://docs.python.org/3/library/pickle.html#pickle-picklable + class UnpicklableClass: + # Only classes defined at the top level of a module are picklable. + pass + + true_cases = [None, True, 123, "str", (123, "str"), max] + false_cases = [unpicklable_function, UnpicklableClass] + + for case in true_cases: + assert is_picklable(case) is True + + for case in false_cases: + assert is_picklable(case) is False + + +def test_clean_namespace(tmpdir): + # See the full list of picklable types at + # https://docs.python.org/3/library/pickle.html#pickle-picklable + class UnpicklableClass: + # Only classes defined at the top level of a module are picklable. + pass + + test_case = { + "1": None, + "2": True, + "3": 123, + "4": unpicklable_function, + "5": UnpicklableClass, + } + + clean_namespace(test_case) + + assert test_case == {"1": None, "2": True, "3": 123} + + +def test_parse_class_init_keys(tmpdir): + + class Class: + + def __init__(self, hparams, *my_args, anykw=42, **my_kwargs): + pass + + assert parse_class_init_keys(Class) == ("self", "my_args", "my_kwargs") + + +def test_get_init_args(tmpdir): + + class AutomaticArgsModel: + + def __init__(self, anyarg, anykw=42, **kwargs): + super().__init__() + + self.get_init_args_wrapper() + + def get_init_args_wrapper(self): + frame = inspect.currentframe().f_back + self.result = get_init_args(frame) + + my_class = AutomaticArgsModel("test", anykw=32, otherkw=123) + assert my_class.result == {"anyarg": "test", "anykw": 32, "otherkw": 123} + + my_class.get_init_args_wrapper() + assert my_class.result == {} + + +def test_collect_init_args(): + + class AutomaticArgsParent: + + def __init__(self, anyarg, anykw=42, **kwargs): + super().__init__() + self.get_init_args_wrapper() + + def get_init_args_wrapper(self): + frame = inspect.currentframe() + self.result = collect_init_args(frame, []) + + class AutomaticArgsChild(AutomaticArgsParent): + + def __init__(self, anyarg, childarg, anykw=42, childkw=42, **kwargs): + super().__init__(anyarg, anykw=anykw, **kwargs) + + my_class = AutomaticArgsChild("test1", "test2", anykw=32, childkw=22, otherkw=123) + assert my_class.result[0] == {"anyarg": "test1", "anykw": 32, "otherkw": 123} + assert my_class.result[1] == {"anyarg": "test1", "childarg": "test2", "anykw": 32, "childkw": 22, "otherkw": 123} + + +def test_attribute_dict(tmpdir): + # Test initialization + inputs = { + 'key1': 1, + 'key2': 'abc', + } + ad = AttributeDict(inputs) + for key, value in inputs.items(): + assert getattr(ad, key) == value + + # Test adding new items + ad = AttributeDict() + ad.update({'key1': 1}) + assert ad.key1 == 1 + + # Test updating existing items + ad = AttributeDict({'key1': 1}) + ad.key1 = 123 + assert ad.key1 == 123 + + +def test_flatten_dict(tmpdir): + d = {'1': 1, '_': {'2': 2, '_': {'3': 3, '4': 4}}} + + expected = { + '1': 1, + '2': 2, + '3': 3, + '4': 4, + } + + assert flatten_dict(d) == expected