diff --git a/.github/workflows/ci_test-base.yml b/.github/workflows/ci_test-base.yml index ed8a2e30949b7..0e84642e2f810 100644 --- a/.github/workflows/ci_test-base.yml +++ b/.github/workflows/ci_test-base.yml @@ -51,9 +51,8 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade --user pip - pip install --requirement ./requirements.txt --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade - pip install --requirement ./requirements/test.txt --quiet --upgrade-strategy only-if-needed - # pip install tox coverage + pip install --requirement ./requirements.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade + pip install "pytest>6.0" "pytest-cov>2.10" --upgrade-strategy only-if-needed python --version pip --version pip list @@ -69,7 +68,7 @@ jobs: - name: Test Package [only] run: | # NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003 - coverage run --source pytorch_lightning -m pytest pytorch_lightning -v --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml + python -m pytest pytorch_lightning -v --cov=pytorch_lightning --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml - name: Upload pytest test results uses: actions/upload-artifact@v2 diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index dd29777d9940c..ff8fba06adee6 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -17,10 +17,6 @@ jobs: os: [ubuntu-18.04, windows-2019, macOS-10.15] python-version: [3.6, 3.7, 3.8] requires: ['minimal', 'latest'] - exclude: - # # todo: segmentation fault for minimal and hanging for latest - - python-version: 3.8 - os: ubuntu-18.04 # Timeout: https://stackoverflow.com/a/59076067/4521646 timeout-minutes: 35 # TODO: the macOS is taking too long, probably caching did not work... diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b2aa324beaf3..42100f29ee6a4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,9 +5,145 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). -## [1.2.3] - 2021-03-09 +## [UnReleased] - 2021-MM-DD + +### Added + +- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) + +- Added support to checkpoint after training steps in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) + +- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) + + +- Added `RunningStage.SANITY_CHECKING` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + +- Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + +- Added `Trainer.validate()` method to perform one evaluation epoch over the validation set ([#4948](https://github.com/PyTorchLightning/pytorch-lightning/pull/4948)) + + +- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915)) + + +- Added `auto_insert_metric_name` parameter to `ModelCheckpoint` ([#6277](https://github.com/PyTorchLightning/pytorch-lightning/pull/6277)) + + +- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274)) + + +- Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139)) + + +### Changed + +- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259)) + + +- Refactor `RunningStage` and `TrainerState` usage ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + +- Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + +- Changed `setup()` and `teardown()` stage argument to take any of `{fit,validate,test,predict}` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) + + +### Deprecated + +- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) + + +- Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) + + +### Removed + +- Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164)) + + +- Removed no return warning from val/test step ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139)) + + +- Removed passing a `ModelCheckpoint` instance to `Trainer(checkpoint_callback)` ([#6166](https://github.com/PyTorchLightning/pytorch-lightning/pull/6166)) + + +- Removed deprecated Trainer argument `enable_pl_optimizer` and `automatic_optimization` ([#6163](https://github.com/PyTorchLightning/pytorch-lightning/pull/6163)) +- Removed deprecated metrics ([#6161](https://github.com/PyTorchLightning/pytorch-lightning/pull/6161)) + * from `pytorch_lightning.metrics.functional.classification` removed `to_onehot`, `to_categorical`, `get_num_classes`, `roc`, `multiclass_roc`, `average_precision`, `precision_recall_curve`, `multiclass_precision_recall_curve` + * from `pytorch_lightning.metrics.functional.reduction` removed `reduce`, `class_reduce` + + +- Removed deprecated `ModelCheckpoint` arguments `prefix`, `mode="auto"` ([#6162](https://github.com/PyTorchLightning/pytorch-lightning/pull/6162)) + + +- Removed `mode='auto'` from `EarlyStopping` ([#6167](https://github.com/PyTorchLightning/pytorch-lightning/pull/6167)) + + +- Removed deprecated `LightningModule` `hparams` setter ([#6207](https://github.com/PyTorchLightning/pytorch-lightning/pull/6207)) + + +- Removed `optimizer_idx` argument from `training_step` in manual optimization ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093)) + + +### Fixed + +- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)) + + +- Move lightning module to correct device type when using LightningDistributedWrapper ([#6070](https://github.com/PyTorchLightning/pytorch-lightning/pull/6070)) + + +- Do not print top-k verbose log with `ModelCheckpoint(monitor=None)` ([#6109](https://github.com/PyTorchLightning/pytorch-lightning/pull/6109)) + + +- Fixed `ModelCheckpoint(monitor=None, save_last=True)` not saving checkpoints ([#6136](https://github.com/PyTorchLightning/pytorch-lightning/pull/6136)) + + +- Fixed `ModelCheckpoint(save_top_k=0, save_last=True)` not saving the `last` checkpoint ([#6136](https://github.com/PyTorchLightning/pytorch-lightning/pull/6136)) + + +- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275)) + + +- Fixed `.teardown(stage='fit')` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) + + +- Fixed `.on_fit_{start,end}()` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386)) + + +- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416)) + + +## [1.2.4] - 2021-03-16 + +### Changed + +- Changed the default of `find_unused_parameters` back to `True` in DDP and DDP Spawn ([#6438](https://github.com/PyTorchLightning/pytorch-lightning/pull/6438)) + +### Fixed + +- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115)) +- Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324)) +- Fixed an issue where the tuner would not tune the learning rate if also tuning the batch size ([#4688](https://github.com/PyTorchLightning/pytorch-lightning/pull/4688)) +- Fixed broadcast to use PyTorch `broadcast_object_list` and add `reduce_decision` ([#6410](https://github.com/PyTorchLightning/pytorch-lightning/pull/6410)) +- Fixed logger creating directory structure too early in DDP ([#6380](https://github.com/PyTorchLightning/pytorch-lightning/pull/6380)) +- Fixed DeepSpeed additional memory use on rank 0 when default device not set early enough ([#6460](https://github.com/PyTorchLightning/pytorch-lightning/pull/6460)) +- Fixed `DummyLogger.log_hyperparams` raising a `TypeError` when running with `fast_dev_run=True` ([#6398](https://github.com/PyTorchLightning/pytorch-lightning/pull/6398)) +- Fixed an issue with `Tuner.scale_batch_size` not finding the batch size attribute in the datamodule ([#5968](https://github.com/PyTorchLightning/pytorch-lightning/pull/5968)) +- Fixed an exception in the layer summary when the model contains torch.jit scripted submodules ([#6511](https://github.com/PyTorchLightning/pytorch-lightning/pull/6511)) +- Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541)) + + +- Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541)) + + +## [1.2.3] - 2021-03-09 + ### Fixed - Fixed `ModelPruning(make_pruning_permanent=True)` pruning buffers getting removed when saved during training ([#6073](https://github.com/PyTorchLightning/pytorch-lightning/pull/6073)) @@ -53,6 +189,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed error message for AMP + CPU incompatibility ([#6107](https://github.com/PyTorchLightning/pytorch-lightning/pull/6107)) +- Disabled batch transfer in DP mode ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093)) + + +- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115) + + ## [1.2.0] - 2021-02-18 ### Added diff --git a/docs/source/benchmarking/performance.rst b/docs/source/benchmarking/performance.rst index 5f89c759e49bc..d1bc2c9ebc009 100644 --- a/docs/source/benchmarking/performance.rst +++ b/docs/source/benchmarking/performance.rst @@ -94,6 +94,21 @@ DP performs three GPU transfers for EVERY batch: Whereas DDP only performs 1 transfer to sync gradients. Because of this, DDP is MUCH faster than DP. +When using DDP set find_unused_parameters=False +----------------------------------------------- + +By default we have enabled find unused parameters to True. This is for compatibility issues that have arisen in the past (see the `discussion `_ for more information). +This by default comes with a performance hit, and can be disabled in most cases. + +.. code-block:: python + + from pytorch_lightning.plugins import DDPPlugin + + trainer = pl.Trainer( + gpus=2, + plugins=DDPPlugin(find_unused_parameters=False), + ) + ---------- 16-bit precision diff --git a/pl_examples/basic_examples/autoencoder.py b/pl_examples/basic_examples/autoencoder.py index b3188a21b7f04..a2010a89f4461 100644 --- a/pl_examples/basic_examples/autoencoder.py +++ b/pl_examples/basic_examples/autoencoder.py @@ -22,9 +22,10 @@ import pytorch_lightning as pl from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: from torchvision import transforms - from torchvision.datasets.mnist import MNIST +if _TORCHVISION_MNIST_AVAILABLE: + from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST diff --git a/pl_examples/basic_examples/backbone_image_classifier.py b/pl_examples/basic_examples/backbone_image_classifier.py index 01a5dca0de3c7..3546bee9ad129 100644 --- a/pl_examples/basic_examples/backbone_image_classifier.py +++ b/pl_examples/basic_examples/backbone_image_classifier.py @@ -21,9 +21,10 @@ import pytorch_lightning as pl from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: from torchvision import transforms - from torchvision.datasets.mnist import MNIST +if _TORCHVISION_MNIST_AVAILABLE: + from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST diff --git a/pl_examples/basic_examples/dali_image_classifier.py b/pl_examples/basic_examples/dali_image_classifier.py index b4bf1407a9b26..da5b1e4fd9e9c 100644 --- a/pl_examples/basic_examples/dali_image_classifier.py +++ b/pl_examples/basic_examples/dali_image_classifier.py @@ -31,9 +31,10 @@ cli_lightning_logo, ) -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: from torchvision import transforms - from torchvision.datasets.mnist import MNIST +if _TORCHVISION_MNIST_AVAILABLE: + from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST diff --git a/pl_examples/basic_examples/mnist_datamodule.py b/pl_examples/basic_examples/mnist_datamodule.py index a50f67cdab301..a6d59c64d9aa0 100644 --- a/pl_examples/basic_examples/mnist_datamodule.py +++ b/pl_examples/basic_examples/mnist_datamodule.py @@ -20,8 +20,9 @@ from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE from pytorch_lightning import LightningDataModule -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib +if _TORCHVISION_MNIST_AVAILABLE: from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index 285fba8b93f1b..e65ede17dac7a 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -32,9 +32,10 @@ from pytorch_lightning.core import LightningDataModule, LightningModule from pytorch_lightning.trainer import Trainer -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: import torchvision - import torchvision.transforms as transforms + from torchvision import transforms +if _TORCHVISION_MNIST_AVAILABLE: from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index abd09e53980c9..fcde9037aee72 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -5,7 +5,7 @@ import time _this_year = time.strftime("%Y") -__version__ = '1.2.3' +__version__ = '1.2.4' __author__ = 'William Falcon et al.' __author_email__ = 'waf2107@columbia.edu' __license__ = 'Apache-2.0' diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 84d53b5addd6b..f9ccc7a42fa06 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -21,7 +21,6 @@ from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin from pytorch_lightning.plugins.training_type import TrainingTypePlugin from pytorch_lightning.utilities.apply_func import move_data_to_device -from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available from pytorch_lightning.utilities.enums import AMPType, LightningEnum @@ -396,7 +395,7 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s Return: A tensor of shape (world_size, batch, ...) """ - return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + return self.training_type_plugin.all_gather(tensor, group=group, sync_grads=sync_grads) def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: """Wraps the dataloader if necessary diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 53f9388d83597..af9ce25f902b3 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -1,40 +1,59 @@ import logging import os +from typing import TYPE_CHECKING, Any import torch from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.plugins import DataParallelPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException +if TYPE_CHECKING: + from pytorch_lightning.core.lightning import LightningModule + from pytorch_lightning.trainer.trainer import Trainer + _log = logging.getLogger(__name__) class GPUAccelerator(Accelerator): - def setup(self, trainer, model): + def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None: + """ + Raises: + MisconfigurationException: + If the selected device is not GPU. + """ if "cuda" not in str(self.root_device): raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead") self.set_nvidia_flags() torch.cuda.set_device(self.root_device) return super().setup(trainer, model) - def on_train_start(self): + def on_train_start(self) -> None: # clear cache before training # use context because of: # https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898 with torch.cuda.device(self.root_device): torch.cuda.empty_cache() - def on_train_end(self): + def on_train_end(self) -> None: # clean up memory self.model.cpu() with torch.cuda.device(self.root_device): torch.cuda.empty_cache() @staticmethod - def set_nvidia_flags(): + def set_nvidia_flags() -> None: # set the correct cuda visible devices (using pci order) os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())]) devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids) _log.info(f"LOCAL_RANK: {os.getenv('LOCAL_RANK', 0)} - CUDA_VISIBLE_DEVICES: [{devices}]") + + def to_device(self, batch: Any) -> Any: + # no need to transfer batch to device in DP mode + # TODO: Add support to allow batch transfer to device in Lightning for DP mode. + if not isinstance(self.training_type_plugin, DataParallelPlugin): + batch = super().to_device(batch) + + return batch diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index d114641f1af72..2dfd0afb02634 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -190,4 +190,4 @@ def _run_early_stopping_check(self, trainer, pl_module): trainer.should_stop = True # stop every ddp process if any world process decides to stop - trainer.should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(trainer.should_stop) + trainer.should_stop = trainer.training_type_plugin.reduce_boolean_decision(trainer.should_stop) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 5b5a851a922b7..383e1caa6a7dc 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -336,7 +336,7 @@ def _save_model(self, filepath: str, trainer, pl_module): else: raise ValueError(".save_function() not set") - def check_monitor_top_k(self, current) -> bool: + def check_monitor_top_k(self, trainer, current: Optional[torch.Tensor] = None) -> bool: if current is None: return False @@ -356,7 +356,12 @@ def check_monitor_top_k(self, current) -> bool: current = torch.tensor(current) monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode] - return monitor_op(current, self.best_k_models[self.kth_best_model_path]).item() + should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path]) + + # If using multiple devices, make sure all processes are unanimous on the decision. + should_update_best_and_save = trainer.training_type_plugin.reduce_boolean_decision(should_update_best_and_save) + + return should_update_best_and_save @classmethod def _format_checkpoint_name( @@ -554,15 +559,7 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics): epoch = metrics.get("epoch") step = metrics.get("step") - # when `val_loss` is being logged and no ModelCheckpoint is being provided - # `val_loss` will be selected for monitor and need to be reduced to - # prevent processes divergence - # TODO: Move this logic to logger_connector. This also needs to be fixed for any - # other monitor logged value which aren't produced from a Metric. - if self.monitor == "val_loss": - current = trainer.training_type_plugin.reduce(current, reduce_op="mean") - - if self.check_monitor_top_k(current): + if self.check_monitor_top_k(trainer, current): self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics) elif self.verbose: rank_zero_info(f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}") @@ -627,5 +624,4 @@ def file_exists(self, filepath: Union[str, Path], trainer) -> bool: the internal state to diverge between ranks. """ exists = self._fs.exists(filepath) - exists = trainer.training_type_plugin.broadcast(exists) - return exists + return trainer.training_type_plugin.broadcast(exists) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 604803365298c..2e1ea31871e03 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -615,10 +615,7 @@ def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = Note: This hook only runs on single GPU training and DDP (no data-parallel). - If you need multi-GPU support for your custom batch objects, you need to define your custom - :class:`~torch.nn.parallel.DistributedDataParallel` or - :class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and - override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`. + Data-Parallel support will come in near future. Args: batch: A batch of data that needs to be transferred to a new device. @@ -638,6 +635,10 @@ def transfer_batch_to_device(self, batch, device): batch = super().transfer_batch_to_device(data, device) return batch + Raises: + MisconfigurationException: + If using data-parallel, ``Trainer(accelerator='dp')``. + See Also: - :meth:`move_data_to_device` - :meth:`apply_to_collection` @@ -649,10 +650,11 @@ def on_before_batch_transfer(self, batch, dataloader_idx): """ Override to alter or apply batch augmentations to your batch before it is transferred to the device. - .. warning:: dataloader_idx always returns 0, and will be updated to support the true idx in the future. + .. warning:: ``dataloader_idx`` always returns 0, and will be updated to support the true index in the future. Note: This hook only runs on single GPU training and DDP (no data-parallel). + Data-Parallel support will come in near future. Args: batch: A batch of data that needs to be altered or augmented. @@ -667,6 +669,10 @@ def on_before_batch_transfer(self, batch, dataloader_idx): batch['x'] = transforms(batch['x']) return batch + Raises: + MisconfigurationException: + If using data-parallel, ``Trainer(accelerator='dp')``. + See Also: - :meth:`on_after_batch_transfer` - :meth:`transfer_batch_to_device` @@ -681,6 +687,7 @@ def on_after_batch_transfer(self, batch, dataloader_idx): Note: This hook only runs on single GPU training and DDP (no data-parallel). + Data-Parallel support will come in near future. Args: batch: A batch of data that needs to be altered or augmented. @@ -695,6 +702,10 @@ def on_after_batch_transfer(self, batch, dataloader_idx): batch['x'] = gpu_transforms(batch['x']) return batch + Raises: + MisconfigurationException: + If using data-parallel, ``Trainer(accelerator='dp')``. + See Also: - :meth:`on_before_batch_transfer` - :meth:`transfer_batch_to_device` diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 00611d87d7f35..6a83b7b1f8637 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -105,6 +105,7 @@ def __init__(self, *args, **kwargs): self._current_dataloader_idx = None self.running_stage = None self._automatic_optimization: bool = True + self._param_requires_grad_state = dict() def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: if use_pl_optimizer: @@ -1295,7 +1296,7 @@ def untoggle_optimizer(self, optimizer_idx: int): if param in self._param_requires_grad_state: param.requires_grad = self._param_requires_grad_state[param] # save memory - del self._param_requires_grad_state + self._param_requires_grad_state = dict() def optimizer_step( self, diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index afb64535d1470..a3eab728f8ea8 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -16,7 +16,7 @@ import shutil import subprocess from collections import OrderedDict -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -71,14 +71,15 @@ def __init__(self, module: nn.Module): def __del__(self): self.detach_hook() - def _register_hook(self) -> RemovableHandle: + def _register_hook(self) -> Optional[RemovableHandle]: """ Registers a hook on the module that computes the input- and output size(s) on the first forward pass. If the hook is called, it will remove itself from the from the module, meaning that recursive models will only record their input- and output shapes once. + Registering hooks on :class:`~torch.jit.ScriptModule` is not supported. Return: - A handle for the installed hook. + A handle for the installed hook, or ``None`` if registering the hook is not possible. """ def hook(module, inp, out): @@ -88,7 +89,10 @@ def hook(module, inp, out): self._out_size = parse_batch_shape(out) self._hook_handle.remove() - return self._module.register_forward_hook(hook) + handle = None + if not isinstance(self._module, torch.jit.ScriptModule): + handle = self._module.register_forward_hook(hook) + return handle def detach_hook(self): """ diff --git a/pytorch_lightning/distributed/dist.py b/pytorch_lightning/distributed/dist.py index 5da7dfa86084d..37ac5d8b13462 100644 --- a/pytorch_lightning/distributed/dist.py +++ b/pytorch_lightning/distributed/dist.py @@ -11,18 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import io from typing import Any -import torch -from torch import distributed as torch_distrib - -from pytorch_lightning.utilities import _GROUP_AVAILABLE - -WORLD = None -if _GROUP_AVAILABLE: - from torch.distributed import group - WORLD = group.WORLD +from pytorch_lightning.overrides.torch_distributed import broadcast_object_list +from pytorch_lightning.utilities.distributed import group as _group class LightningDistributed: @@ -31,32 +23,13 @@ def __init__(self, rank=None, device=None): self.rank = rank self.device = device - def broadcast(self, obj: Any, group=WORLD): - if self.rank == 0: - self._emit(obj, group) - else: - obj = self._receive(group) - return obj - - def _broadcast(self, tensor, src=0, group=WORLD): - if group is None: - return torch_distrib.broadcast(tensor, src=src) - return torch_distrib.broadcast(tensor, src=0, group=group) - - def _emit(self, obj: Any, group=WORLD): - buffer = io.BytesIO() - torch.save(obj, buffer) - data = bytearray(buffer.getbuffer()) - length_tensor = torch.tensor([len(data)]).long().to(self.device) - self._broadcast(length_tensor, src=0, group=group) - data_tensor = torch.ByteTensor(data).to(self.device) - self._broadcast(data_tensor, src=0, group=group) - - def _receive(self, group=WORLD): - length_tensor = torch.tensor([0]).long().to(self.device) - self._broadcast(length_tensor, src=0, group=group) - data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8).to(self.device) - self._broadcast(data_tensor, src=0, group=group) - buffer = io.BytesIO(data_tensor.cpu().numpy()) - obj = torch.load(buffer) - return obj + def broadcast(self, obj: Any, group=_group.WORLD): + # always wrap into a list so list can be brodcasted. + obj = [obj] + + if self.rank != 0: + obj = [None] * len(obj) + + broadcast_object_list(obj, 0, group=group or _group.WORLD) + + return obj[0] diff --git a/pytorch_lightning/overrides/torch_distributed.py b/pytorch_lightning/overrides/torch_distributed.py new file mode 100644 index 0000000000000..67b64c046dc18 --- /dev/null +++ b/pytorch_lightning/overrides/torch_distributed.py @@ -0,0 +1,94 @@ +import logging +import pickle + +import torch + +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7 + +log = logging.getLogger(__name__) + +if torch.distributed.is_available(): + from torch.distributed import Backend, broadcast, get_backend, get_rank, GroupMember + +# The code underneath is taken from PyTorch ``torch/distributed/distributed_c10d.py`` +# and enable broadcasting for PyTorch 1.6 and lower. + + +# https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L160 +def _rank_not_in_group(group): + """ + Helper that checks if the current process's rank is not in a given group. + """ + if group is None: + return False + return group == GroupMember.NON_GROUP_MEMBER + + +# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1164 +def _object_to_tensor(obj): + buffer = pickle.dumps(obj) + byte_storage = torch.ByteStorage.from_buffer(buffer) # type: ignore[attr-defined] + byte_tensor = torch.ByteTensor(byte_storage) + local_size = torch.LongTensor([byte_tensor.numel()]) + return byte_tensor, local_size + + +# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py +def _tensor_to_object(tensor, tensor_size): + buf = tensor.numpy().tobytes()[:tensor_size] + out = pickle.loads(buf) + return out + + +# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1327 +def _broadcast_object_list(object_list, src=0, group=None): + if _rank_not_in_group(group): + return + + my_rank = get_rank() + # Serialize object_list elements to tensors on src rank. + if my_rank == src: + tensor_list, size_list = zip(*[_object_to_tensor(obj) for obj in object_list]) + object_sizes_tensor = torch.cat(size_list) + else: + object_sizes_tensor = torch.LongTensor(len(object_list)) + + group_backend = get_backend(group) + is_nccl_backend = group_backend == Backend.NCCL + current_device = torch.device("cpu") + if is_nccl_backend: + # See note about using torch.cuda.current_device() here in docstring. + # We cannot simply use my_rank since rank == device is not necessarily + # true. + current_device = torch.device('cuda', torch.cuda.current_device()) + object_sizes_tensor = object_sizes_tensor.to(current_device) + object_sizes_tensor = object_sizes_tensor.to(current_device) + + # Broadcast object sizes + broadcast(object_sizes_tensor, src=src, group=group) + + # Concatenate and broadcast serialized object tensors + if my_rank == src: + object_tensor = torch.cat(tensor_list) + else: + object_tensor = torch.ByteTensor(torch.sum(object_sizes_tensor).item()) + + if is_nccl_backend: + object_tensor = object_tensor.to(current_device) + + broadcast(object_tensor, src=src, group=group) + + # Deserialize objects using their stored sizes. + offset = 0 + if my_rank != src: + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset:offset + obj_size] + obj_view = obj_view.type(torch.ByteTensor) # type: ignore[call-overload] + offset += obj_size + object_list[i] = _tensor_to_object(obj_view, obj_size) + + +if _TORCH_GREATER_EQUAL_1_7 and torch.distributed.is_available(): + from torch.distributed.distributed_c10d import broadcast_object_list +else: + broadcast_object_list = _broadcast_object_list diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index a1f33b9931cf5..b600eca5e6bc2 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Tuple +from typing import Any, Callable, Generator, List, Sequence, Tuple, Type, TYPE_CHECKING import torch -from torch.optim import Optimizer from pytorch_lightning.core import LightningModule from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin @@ -23,24 +22,28 @@ if _APEX_AVAILABLE: from apex import amp +if TYPE_CHECKING: + from torch.optim import Optimizer + class ApexMixedPrecisionPlugin(MixedPrecisionPlugin): """Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)""" - def __init__(self, amp_level: str): + def __init__(self, amp_level: str = "O2") -> None: self.backend = AMPType.APEX self.amp_level = amp_level - def master_params(self, optimizer: torch.optim.Optimizer): + def master_params(self, optimizer: 'Optimizer') -> Generator[torch.Tensor, None, None]: return amp.master_params(optimizer) - def connect(self, model: torch.nn.Module, optimizers, lr_schedulers): + def connect(self, model: torch.nn.Module, optimizers: Sequence['Optimizer'], + lr_schedulers: Sequence[Any]) -> Tuple[torch.nn.Module, Sequence['Optimizer'], Sequence[Any]]: """Connects the precision plugin to the training process, configures apex and reinits the schedulers """ if model.device.type != "cuda": return model, optimizers, lr_schedulers - model, optimizers = self.configure_apex(amp, model, optimizers, self.amp_level) + model, optimizers = self.configure_apex(amp, model, list(optimizers), self.amp_level) self.reinit_scheduler_properties(optimizers, lr_schedulers) return model, optimizers, lr_schedulers @@ -48,12 +51,12 @@ def backward( self, model: LightningModule, closure_loss: torch.Tensor, - optimizer: torch.optim.Optimizer, + optimizer: 'Optimizer', opt_idx: int, should_accumulate: bool, - *args, - **kwargs, - ): + *args: Any, + **kwargs: Any, + ) -> torch.Tensor: """performs the actual backpropagation Args: @@ -94,11 +97,11 @@ def backward( def configure_apex( self, - amp: object, + amp: Type, model: LightningModule, - optimizers: List[Optimizer], + optimizers: List['Optimizer'], amp_level: str, - ) -> Tuple[LightningModule, List[Optimizer]]: + ) -> Tuple[LightningModule, List['Optimizer']]: r""" Override to init AMP your own way. Must return a model and list of optimizers. @@ -127,7 +130,7 @@ def configure_apex(self, amp, model, optimizers, amp_level): return model, optimizers @staticmethod - def reinit_scheduler_properties(optimizers: list, schedulers: list): + def reinit_scheduler_properties(optimizers: Sequence['Optimizer'], schedulers: Sequence[Any]) -> None: """Reinitializes schedulers with correct properties""" # Reinitialize optimizer.step properties added by schedulers for scheduler in schedulers: @@ -149,7 +152,12 @@ def reinit_scheduler_properties(optimizers: list, schedulers: list): break def pre_optimizer_step( - self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs + self, + pl_module: LightningModule, + optimizer: 'Optimizer', + optimizer_idx: int, + lambda_closure: Callable, + **kwargs: Any, ) -> bool: """ always called before the optimizer step. @@ -160,6 +168,5 @@ def pre_optimizer_step( if not pl_module.automatic_optimization: pl_module.trainer.call_hook("on_after_backward") - optimizer.step() - + optimizer.step(**kwargs) return False diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 6a4e948e899cf..e0a52fc7609d6 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -182,6 +182,13 @@ def set_world_ranks(self): self.world_size = self.num_nodes * self.num_processes def pre_configure_ddp(self): + # if unset, default `find_unused_parameters` `True` + # Many models require setting this parameter to True, as there are corner cases + # when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True. + # This flag does come with a performance hit, so it is suggested to disable in cases where it is possible. + self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get( + "find_unused_parameters", True + ) # todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()`` breaking manual_optimization if _TORCH_GREATER_EQUAL_1_7 and not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get( "find_unused_parameters", False diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 66d3cb7bf4619..cde2f3dea711c 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -171,6 +171,13 @@ def post_dispatch(self): self.__recover_child_process_weights(best_path, last_path) def pre_configure_ddp(self): + # if unset, default `find_unused_parameters` `True` + # Many models require setting this parameter to True, as there are corner cases + # when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True. + # This flag does come with a performance hit, so it is suggested to disable in cases where it is possible. + self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get( + "find_unused_parameters", True + ) # todo: PyTorch 1.7.0 DDP introduces ``self.reducer._rebuild_buckets()`` breaking manual_optimization if _TORCH_GREATER_EQUAL_1_7 and not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get( "find_unused_parameters", False diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 06cea848ce1dc..58fd1304209bb 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -79,6 +79,11 @@ def __init__( num_nodes: int = 1, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, + loss_scale: float = 0, + initial_scale_power: int = 32, + loss_scale_window: int = 1000, + hysteresis: int = 2, + min_loss_scale: int = 1 ) -> None: """ @@ -127,6 +132,18 @@ def __init__( logging_level: Set logging level for deepspeed. (Default: ``logging.WARN``) + loss_scale: Loss scaling value for FP16 training. + 0.0 results in dynamic loss scaling, otherwise static (Default: 0) + + initial_scale_power: Power of the initial dynamic loss scale value. Loss scale is computed + by ``2^initial_scale_power`` (Default: 32) + + loss_scale_window: Window in which to raise/lower the dynamic FP16 loss scaling value (Default: 1000) + + hysteresis: FP16 Delay shift in Dynamic Loss scaling (Default: 2) + + min_loss_scale: The minimum FP16 dynamic loss scaling value (Default: 1000) + """ if not _DEEPSPEED_AVAILABLE: raise MisconfigurationException( @@ -154,6 +171,13 @@ def __init__( self._config_initialized = False deepspeed.utils.logging.logger.setLevel(logging_level) + # default FP16 parameters. + self.loss_scale = loss_scale + self.initial_scale_power = initial_scale_power + self.loss_scale_window = loss_scale_window + self.hysteresis = hysteresis + self.min_loss_scale = min_loss_scale + def _load_config(self, config): if config is None and self.DEEPSPEED_ENV_VAR in os.environ: rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable") @@ -207,6 +231,8 @@ def _init_scheduler_optimizer(self): return optimizer, scheduler, optimizer_frequencies def _initialize_deepspeed_train(self, model): + if self.on_gpu: + torch.cuda.set_device(self.root_device) optimizer, lightning_scheduler, optimizer_frequencies = None, None, None if "optimizer" not in self.config: rank_zero_info( @@ -297,9 +323,19 @@ def _format_precision_config(self): amp_level = self.lightning_module.trainer.accelerator_connector.amp_level precision = self.lightning_module.trainer.accelerator_connector.precision if precision == 16: - if "amp" not in self.config and amp_type == AMPType.NATIVE: - self.config["fp16"] = {"enabled": True} - elif "apex" not in self.config and amp_type == AMPType.APEX: + if "fp16" not in self.config and amp_type == AMPType.NATIVE: + # FP16 is a DeepSpeed standalone AMP implementation + rank_zero_info("Enabling DeepSpeed FP16.") + self.config["fp16"] = { + "enabled": True, + "loss_scale": self.loss_scale, + "initial_scale_power": self.initial_scale_power, + "loss_scale_window": self.loss_scale_window, + "hysteresis": self.hysteresis, + "min_loss_scale": self.min_loss_scale + } + elif "amp" not in self.config and amp_type == AMPType.APEX: + rank_zero_only("Enabling DeepSpeed APEX Implementation.") self.config["amp"] = { "enabled": True, "opt_level": amp_level, diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index e1002faf8a3b4..7a1f7ac1201c0 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -19,6 +19,7 @@ from pytorch_lightning.core.step_result import Result from pytorch_lightning.overrides.data_parallel import LightningParallelModule from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin +from pytorch_lightning.utilities.apply_func import apply_to_collection class DataParallelPlugin(ParallelPlugin): @@ -31,14 +32,30 @@ def setup(self, model): model.to(self.root_device) self._model = DataParallel(LightningParallelModule(model), self.parallel_devices) - def reduce(self, output, *args, **kwargs): - if isinstance(output, Result): - output.dp_reduce() + def reduce(self, tensor, *args, **kwargs): + """ + Reduces a tensor from all parallel processes to one aggregated tensor. - elif isinstance(output, torch.Tensor): - output = output.mean() + Args: + tensor: the tensor to sync and reduce + *args: ignored for DP + **kwargs: ignored for DP - return output + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ + if isinstance(tensor, Result): + tensor.dp_reduce() + + else: + + def _reduce(tensor: torch.Tensor): + dtype_tensor = tensor.dtype + return tensor.float().mean().type(dtype_tensor) + + tensor = apply_to_collection(tensor, torch.Tensor, _reduce) + + return tensor @property def root_device(self): @@ -54,8 +71,8 @@ def barrier(self, *args, **kwargs): def broadcast(self, obj: object, src: int = 0) -> object: return obj - def reduce_early_stopping_decision(self, should_stop: bool) -> bool: - return should_stop + def reduce_boolean_decision(self, decision: bool) -> bool: + return decision def training_step(self, *args, **kwargs): return self.model(*args, **kwargs) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 13585f8f368f4..27ae26d67b493 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -21,7 +21,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import _HOROVOD_AVAILABLE -from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp +from pytorch_lightning.utilities.distributed import group, rank_zero_only, ReduceOp if _HOROVOD_AVAILABLE: import horovod.torch as hvd @@ -147,8 +147,13 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ hvd.join() return hvd.allreduce(output, op=reduce_op) - def gather_all_tensors(self, result: Union[torch.Tensor], group: Optional[Any] = None): - if group is not None: + def all_gather( + self, + result: Union[torch.Tensor], + group: Optional[Any] = group.WORLD, + sync_grads: bool = False + ) -> torch.Tensor: + if group is not None and group != group.WORLD: raise ValueError( "Horovod does not support allgather using a subcommunicator at this time. " "Unset `group`." diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index f3c825fe9cd7a..9809443aff3fb 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -11,11 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import io import os from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import List, Optional +from typing import Any, List, Optional import torch from torch.nn.parallel import DistributedDataParallel @@ -36,9 +35,10 @@ def __init__( ): super().__init__() self.parallel_devices = parallel_devices + self.cluster_environment = cluster_environment + self.global_rank = 0 self.world_size = 1 self.local_rank = 0 - self.cluster_environment = cluster_environment @property def cluster_local_rank(self): @@ -77,11 +77,15 @@ def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=len(self.parallel_devices), rank=self.global_rank) return distributed_sampler_kwargs - def reduce_early_stopping_decision(self, should_stop: bool) -> bool: - should_stop = torch.tensor(int(should_stop), device=self.lightning_module.device) - should_stop = self.reduce(should_stop, reduce_op=ReduceOp.SUM) - should_stop = bool(should_stop == self.world_size) - return should_stop + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """Perform a all_gather on all processes """ + return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) + + def reduce_boolean_decision(self, decision: bool) -> bool: + decision = torch.tensor(int(decision), device=self.lightning_module.device) + decision = self.reduce(decision, reduce_op=ReduceOp.SUM) + decision = bool(decision == self.world_size) + return decision @property def torch_distributed_backend(self): @@ -119,13 +123,3 @@ def block_backward_sync(self): yield None else: yield None - - def broadcast(self, obj: object, src: int) -> object: - buffer = io.BytesIO() - torch.save(obj, buffer) - data = bytearray(buffer.getbuffer()) - data_tensor = torch.tensor(data).to(self.root_device, dtype=torch.float) - data = all_gather_ddp_if_available(data_tensor) - buffer = io.BytesIO(data.cpu().byte().numpy()) - obj = torch.load(buffer) - return obj diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 4b1d24301b8a0..39fe06e1d46f2 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -1,4 +1,17 @@ -from typing import Any, Union +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Union import torch @@ -10,6 +23,9 @@ class SingleDevicePlugin(TrainingTypePlugin): def __init__(self, device: torch.device): super().__init__() self.device: torch.device = device + self.global_rank = 0 + self.local_rank = 0 + self.world_size = 1 @property def on_tpu(self) -> bool: @@ -20,8 +36,24 @@ def on_gpu(self) -> bool: return self.device.type == "cuda" and torch.cuda.is_available() def reduce(self, output: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]: + """ + Reduces output from several distributed processes to one aggregated tensor. + As this plugin only operates with a single device, the reduction is simply the identity. + + Args: + output: the tensor to sync and reduce + *args: ignored + **kwargs: ignored + + Return: + the unmodified input as reduction is not needed for single process operation + """ return output + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """Perform a all_gather on all processes """ + return tensor + @property def root_device(self) -> torch.device: return self.device diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 371649057909b..1e951329b22cc 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -188,12 +188,11 @@ def save_spawn_weights(self, model: LightningModule) -> Optional[str]: model.trainer.save_checkpoint(path) return path - def reduce_early_stopping_decision(self, should_stop: bool) -> bool: - should_stop = torch.tensor(int(should_stop), device=self.lightning_module.device) - stop = xm.mesh_reduce('stop_signal', should_stop, sum) - rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check") - should_stop = int(stop.item()) == self.world_size - return should_stop + def reduce_decision(self, decision: bool) -> bool: + decision = torch.tensor(int(decision), device=self.device) + decision = self.reduce(decision, "sum") + decision = bool(decision == self.world_size) + return decision def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): if not isinstance(output, torch.Tensor): diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index d7c3b4d4d77e1..b3a6c36bfbbf6 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -33,7 +33,6 @@ class TrainingTypePlugin(Plugin, ABC): def __init__(self) -> None: self._model = None self._results = None - self.global_rank = 0 @property @abstractmethod @@ -66,9 +65,13 @@ def barrier(self, name: Optional[str] = None) -> None: def broadcast(self, obj: object, src: int = 0) -> object: """Broadcasts an object to all processes""" - def reduce_early_stopping_decision(self, should_stop: bool) -> bool: - """Reduce the early stopping decision across all possibly spawned processes""" - return should_stop + @abstractmethod + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """Perform a all_gather on all processes """ + + def reduce_boolean_decision(self, decision: bool) -> bool: + """Reduce the early stopping decision across all processes""" + return decision def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): """Run before precision plugin executes backward""" diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 9cb22f39b7228..a5c3a8d04a1dd 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -30,7 +30,9 @@ def verify_loop_configurations(self, model: LightningModule): model: The model to check the configuration. """ - if not self.trainer.testing: + if self.trainer.predicting: + self.__verify_predict_loop_configuration(model) + elif not self.trainer.testing: self.__verify_train_loop_configuration(model) self.__verify_eval_loop_configuration(model, 'validation') else: @@ -98,3 +100,9 @@ def __verify_eval_loop_configuration(self, model, eval_loop_name): rank_zero_warn(f'you passed in a {loader_name} but have no {step_name}. Skipping {eval_loop_name} loop') if has_step and not has_loader: rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {eval_loop_name} loop') + + def __verify_predict_loop_configuration(self, model): + + has_predict_dataloader = is_overridden('predict_dataloader', model) + if not has_predict_dataloader: + raise MisconfigurationException('Dataloader not found for `Trainer.predict`') diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 59d406b0479c6..e5c17614474ee 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -266,6 +266,10 @@ def use_deepspeed(self) -> bool: @property def is_distributed(self) -> bool: + # Used for custom plugins. + # Custom plugins should implement is_distributed property. + if hasattr(self.training_type_plugin, 'is_distributed') and not self.on_tpu: + return self.training_type_plugin.is_distributed is_distributed = self.use_ddp or self.use_ddp2 or self.use_horovod if self.on_tpu: is_distributed |= self.training_type_plugin.is_distributed diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 6ff35aadc36a3..6bd6b7bec5068 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -82,6 +82,7 @@ def attach_data(self, model, train_dataloader, val_dataloaders, datamodule): # set up the passed in dataloaders (if needed) self.attach_dataloaders(model, train_dataloader, val_dataloaders) self.attach_datamodule(model, datamodule, 'fit') + self._validate_data_hooks(model) def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule): # If you supply a datamodule you can't supply train_dataloader or val_dataloaders @@ -90,6 +91,14 @@ def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloa 'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule' ) + def _validate_data_hooks(self, model): + # Raise Misconfiguration exception since these hooks are not supported in DP mode + # TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode. + batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer') + for hook in batch_transfer_hooks: + if self.trainer.accelerator_connector.use_dp and is_overridden(hook, model): + raise MisconfigurationException(f'Overriding `{hook}` is not supported in DP mode.') + def attach_dataloaders( self, model, @@ -122,22 +131,16 @@ def attach_datamodule(self, model, datamodule: Optional[LightningDataModule], st if datamodule: # Override loader hooks - if is_overridden('train_dataloader', datamodule): - model.train_dataloader = datamodule.train_dataloader - if is_overridden('val_dataloader', datamodule): - model.val_dataloader = datamodule.val_dataloader - if is_overridden('test_dataloader', datamodule): - model.test_dataloader = datamodule.test_dataloader - if is_overridden('predict_dataloader', datamodule): - model.predict_dataloader = datamodule.predict_dataloader + dl_methods = ('train_dataloader', 'val_dataloader', 'test_dataloader', 'predict_dataloader') + for method in dl_methods: + if is_overridden(method, datamodule): + setattr(model, method, getattr(datamodule, method)) # Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule - if is_overridden('on_before_batch_transfer', datamodule): - model.on_before_batch_transfer = datamodule.on_before_batch_transfer - if is_overridden('transfer_batch_to_device', datamodule): - model.transfer_batch_to_device = datamodule.transfer_batch_to_device - if is_overridden('on_after_batch_transfer', datamodule): - model.on_after_batch_transfer = datamodule.on_after_batch_transfer + batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer') + for hook in batch_transfer_hooks: + if is_overridden(hook, datamodule): + setattr(model, hook, getattr(datamodule, hook)) self.trainer.datamodule = datamodule datamodule.trainer = self.trainer diff --git a/pytorch_lightning/trainer/connectors/env_vars_connector.py b/pytorch_lightning/trainer/connectors/env_vars_connector.py index 2e788c256af0d..f4209f40d002e 100644 --- a/pytorch_lightning/trainer/connectors/env_vars_connector.py +++ b/pytorch_lightning/trainer/connectors/env_vars_connector.py @@ -18,27 +18,24 @@ from pytorch_lightning.utilities.argparse import get_init_arguments_and_types, parse_env_variables -def overwrite_by_env_vars(fn: Callable) -> Callable: +def _defaults_from_env_vars(fn: Callable) -> Callable: """ Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which input arguments should be moved automatically to the correct device. - """ - @wraps(fn) - def overwrite_by_env_vars(self, *args, **kwargs): - # get the class - cls = self.__class__ + def insert_env_defaults(self, *args, **kwargs): + cls = self.__class__ # get the class if args: # inace any args passed move them to kwargs # parse only the argument names cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)] # convert args to kwargs kwargs.update({k: v for k, v in zip(cls_arg_names, args)}) + env_variables = vars(parse_env_variables(cls)) # update the kwargs by env variables - # todo: maybe add a warning that some init args were overwritten by Env arguments - kwargs.update(vars(parse_env_variables(cls))) + kwargs = dict(list(env_variables.items()) + list(kwargs.items())) # all args were already moved to kwargs return fn(self, **kwargs) - return overwrite_by_env_vars + return insert_env_defaults diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index a547144c8a6f3..b40d87c2d9664 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -16,6 +16,7 @@ import torch +import pytorch_lightning as pl from pytorch_lightning.core.step_result import Result from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import DistributedType, LightningEnum @@ -50,7 +51,7 @@ class HookResultStore: Those data structures enables us to reduce properly Result object when batch loop is finished. """ - def __init__(self, fx_name): + def __init__(self, fx_name: str) -> None: self._fx_name = fx_name self._internals = {} self._internals_reduced = {} @@ -104,6 +105,7 @@ def get_batch_log_metrics(self, *args, **kwargs): def run_epoch_func(self, results, opt_metric, func_name, *args, **kwargs) -> None: if not isinstance(opt_metric, Result): raise Exception("The provided opt_metric should be a Result Object. Something is wrong") + func = getattr(opt_metric, func_name) metrics_to_log = func(*args, add_dataloader_idx=self.has_several_dataloaders, **kwargs) results.append(metrics_to_log) @@ -222,7 +224,7 @@ class EpochResultStore: ``` """ - def __init__(self, trainer, stage): + def __init__(self, trainer: 'pl.Trainer', stage): self.trainer = trainer self._stage = stage self.reset() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e123c1af5a5d0..cffb1914c69f9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -37,7 +37,7 @@ from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector -from pytorch_lightning.trainer.connectors.env_vars_connector import overwrite_by_env_vars +from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector @@ -83,7 +83,7 @@ class Trainer( DeprecatedTrainerAttributes, ): - @overwrite_by_env_vars + @_defaults_from_env_vars def __init__( self, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, @@ -407,21 +407,6 @@ def __init__( # Callback system self.on_init_end() - def setup_trainer(self, model: LightningModule): - """ - Sanity check a few things before starting actual training or testing. - - Args: - model: The model to run sanity test on. - """ - - # log hyper-parameters - if self.logger is not None: - # save exp to get started (this is where the first experiment logs are written) - self.logger.log_hyperparams(model.hparams_initial) - self.logger.log_graph(model) - self.logger.save() - def fit( self, model: LightningModule, @@ -471,8 +456,7 @@ def fit( # ---------------------------- self.call_setup_hook(model) self.call_hook("on_before_accelerator_backend_setup", model) - self.accelerator.setup(self, model) - self.setup_trainer(model) + self.accelerator.setup(self, model) # note: this sets up self.lightning_module # ---------------------------- # INSPECT THE CORE LOOPS @@ -539,6 +523,13 @@ def fit( def pre_dispatch(self): self.accelerator.pre_dispatch() + # log hyper-parameters + if self.logger is not None: + # save exp to get started (this is where the first experiment logs are written) + self.logger.log_hyperparams(self.lightning_module.hparams_initial) + self.logger.log_graph(self.lightning_module) + self.logger.save() + def post_dispatch(self): self.accelerator.post_dispatch() self.accelerator.teardown() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index de3eabe606a59..c3afe14285d9f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -108,7 +108,7 @@ def on_train_start(self): # provide rank to profiler self.trainer.profile_connector.on_train_start(self.trainer) - def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): + def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None): # clean hparams if hasattr(model, "hparams"): parsing.clean_namespace(model.hparams) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index cf29799a05a5b..0975fdcbb6a79 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -412,11 +412,11 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data self.progress_bar.update() current_loss = trainer.train_loop.running_loss.last().item() - current_step = trainer.global_step + 1 # remove the +1 in 1.0 + current_step = trainer.global_step # Avg loss (loss with momentum) + smoothing self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * current_loss - smoothed_loss = self.avg_loss / (1 - self.beta**current_step) + smoothed_loss = self.avg_loss / (1 - self.beta**(current_step + 1)) # Check if we diverging if self.early_stop_threshold is not None: diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 06475547b03f2..c5256c6ddc65f 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -32,13 +32,20 @@ def on_trainer_init(self, auto_lr_find, auto_scale_batch_size): self.trainer.auto_lr_find = auto_lr_find self.trainer.auto_scale_batch_size = auto_scale_batch_size - def tune(self, model, train_dataloader, val_dataloaders, datamodule): + def setup_trainer( + self, + model: LightningModule, + train_dataloader: Optional[DataLoader] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + datamodule: LightningDataModule = None, + ): + self.trainer.model_connector.copy_trainer_model_properties(model) # setup data, etc... self.trainer.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule) - # hook self.trainer.data_connector.prepare_data(model) + def tune(self, model, train_dataloader, val_dataloaders, datamodule): # Run auto batch size scaling if self.trainer.auto_scale_batch_size: if isinstance(self.trainer.auto_scale_batch_size, bool): @@ -101,6 +108,7 @@ def scale_batch_size( or datamodule. """ + self.setup_trainer(model, **fit_kwargs) return scale_batch_size( self.trainer, model, @@ -125,6 +133,7 @@ def lr_find( datamodule: Optional[LightningDataModule] = None, update_attr: bool = False, ): + self.setup_trainer(model, train_dataloader, val_dataloaders, datamodule) return lr_find( self.trainer, model, diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index 62626d1b5bcc8..2533dbc425948 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -108,7 +108,7 @@ def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s") def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: - r"""Scans the Trainer signature and returns argument names, types and default values. + r"""Scans the class signature and returns argument names, types and default values. Returns: List with tuples of 3 values: @@ -120,11 +120,11 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: >>> args = get_init_arguments_and_types(Trainer) """ - trainer_default_params = inspect.signature(cls).parameters + cls_default_params = inspect.signature(cls).parameters name_type_default = [] - for arg in trainer_default_params: - arg_type = trainer_default_params[arg].annotation - arg_default = trainer_default_params[arg].default + for arg in cls_default_params: + arg_type = cls_default_params[arg].annotation + arg_default = cls_default_params[arg].default try: arg_types = tuple(arg_type.__args__) except AttributeError: diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index f283497e5e5a1..61f581a5b5571 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -23,6 +23,7 @@ if torch.distributed.is_available(): from torch.distributed import group, ReduceOp + else: class ReduceOp: diff --git a/requirements/extra.txt b/requirements/extra.txt index 0e7dffbcb39b0..a05c4971ac450 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -3,8 +3,8 @@ matplotlib>3.1 horovod>=0.21.2 # no need to install with [pytorch] as pytorch is already installed omegaconf>=2.0.1 -torchtext>=0.5, <0.7 # TODO: temporary fix fix for compatibility -onnx>=1.7.0 +torchtext>=0.5 +# onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip diff --git a/requirements/test.txt b/requirements/test.txt index 2d47143ca58d4..84ddb2f981b54 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,11 +1,11 @@ -coverage>=5.0 +coverage>=5.2 codecov>=2.1 -pytest>=5.0 -# pytest-cov +pytest>=6.0 +pytest-cov>2.10 +# pytest-xdist flake8>=3.6 check-manifest twine==3.2 -# scipy>=0.13.3 scikit-learn>=0.22.2 scikit-image>=0.17.2 isort>=5.6.4 diff --git a/setup.cfg b/setup.cfg index fc64e5d948ffd..4c478dccb709e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,15 +39,10 @@ exclude_lines = pass rank_zero_warn raise NotImplementedError - # TODO: figure out how to get codecov to pick up the test results on these backends # The actual coverage for each is 90%+ # *metrics (94%+) are temporarily removed from testing while tests speed up omit = - pytorch_lightning/accelerators/ddp_*.py - pytorch_lightning/accelerators/ddp2_*.py - pytorch_lightning/accelerators/dp_*.py - pytorch_lightning/accelerators/tpu_*.py pytorch_lightning/cluster_environments/*.py pytorch_lightning/utilities/xla_device_utils.py pytorch_lightning/utilities/distributed.py diff --git a/tests/accelerators/test_dp.py b/tests/accelerators/test_dp.py index 6e826719b5b98..15faf98d94d57 100644 --- a/tests/accelerators/test_dp.py +++ b/tests/accelerators/test_dp.py @@ -13,15 +13,45 @@ # limitations under the License. import pytest import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader import pytorch_lightning as pl import tests.helpers.pipelines as tpipes import tests.helpers.utils as tutils +from pytorch_lightning import Trainer from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.core import memory +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers import BoringModel, RandomDataset +from tests.helpers.datamodules import ClassifDataModule +from tests.helpers.simple_models import ClassificationModel from tests.base import EvalModelTemplate -PRETEND_N_OF_GPUS = 16 + +class CustomClassificationModelDP(ClassificationModel): + + def _step(self, batch, batch_idx): + x, y = batch + logits = self(x) + return {'logits': logits, 'y': y} + + def training_step(self, batch, batch_idx): + out = self._step(batch, batch_idx) + loss = F.cross_entropy(out['logits'], out['y']) + return loss + + def validation_step(self, batch, batch_idx): + return self._step(batch, batch_idx) + + def test_step(self, batch, batch_idx): + return self._step(batch, batch_idx) + + def validation_step_end(self, outputs): + self.log('val_acc', self.valid_acc(outputs['logits'], outputs['y'])) + + def test_step_end(self, outputs): + self.log('test_acc', self.test_acc(outputs['logits'], outputs['y'])) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @@ -29,9 +59,12 @@ def test_multi_gpu_early_stop_dp(tmpdir): """Make sure DDP works. with early stopping""" tutils.set_random_master_port() + dm = ClassifDataModule() + model = CustomClassificationModelDP() + trainer_options = dict( default_root_dir=tmpdir, - callbacks=[EarlyStopping()], + callbacks=[EarlyStopping(monitor='val_acc')], max_epochs=50, limit_train_batches=10, limit_val_batches=10, @@ -39,8 +72,7 @@ def test_multi_gpu_early_stop_dp(tmpdir): accelerator='dp', ) - model = EvalModelTemplate() - tpipes.run_model_test(trainer_options, model) + tpipes.run_model_test(trainer_options, model, dm) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @@ -57,7 +89,7 @@ def test_multi_gpu_model_dp(tmpdir): progress_bar_refresh_rate=0, ) - model = EvalModelTemplate() + model = BoringModel() tpipes.run_model_test(trainer_options, model) @@ -65,6 +97,114 @@ def test_multi_gpu_model_dp(tmpdir): memory.get_memory_profile('min_max') +class ReductionTestModel(BoringModel): + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=2) + + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=2) + + def test_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=2) + + def add_outputs(self, output, device): + output.update({ + "reduce_int": torch.tensor(device.index, dtype=torch.int, device=device), + "reduce_float": torch.tensor(device.index, dtype=torch.float, device=device), + }) + + def training_step(self, batch, batch_idx): + output = super().training_step(batch, batch_idx) + self.add_outputs(output, batch.device) + return output + + def validation_step(self, batch, batch_idx): + output = super().validation_step(batch, batch_idx) + self.add_outputs(output, batch.device) + return output + + def test_step(self, batch, batch_idx): + output = super().test_step(batch, batch_idx) + self.add_outputs(output, batch.device) + return output + + def training_epoch_end(self, outputs): + assert outputs[0]["loss"].shape == torch.Size([]) + assert outputs[0]["reduce_int"].item() == 0 # mean([0, 1]) = 0 + assert outputs[0]["reduce_float"].item() == 0.5 # mean([0., 1.]) = 0.5 + + +def test_dp_raise_exception_with_batch_transfer_hooks(tmpdir, monkeypatch): + """ + Test that an exception is raised when overriding batch_transfer_hooks in DP model. + """ + monkeypatch.setattr("torch.cuda.device_count", lambda: 2) + + class CustomModel(BoringModel): + + def transfer_batch_to_device(self, batch, device): + batch = batch.to(device) + return batch + + trainer_options = dict( + default_root_dir=tmpdir, + max_steps=7, + gpus=[0, 1], + accelerator='dp', + ) + + trainer = Trainer(**trainer_options) + model = CustomModel() + + with pytest.raises(MisconfigurationException, match=r'Overriding `transfer_batch_to_device` is not .* in DP'): + trainer.fit(model) + + class CustomModel(BoringModel): + + def on_before_batch_transfer(self, batch, dataloader_idx): + batch += 1 + return batch + + trainer = Trainer(**trainer_options) + model = CustomModel() + + with pytest.raises(MisconfigurationException, match=r'Overriding `on_before_batch_transfer` is not .* in DP'): + trainer.fit(model) + + class CustomModel(BoringModel): + + def on_after_batch_transfer(self, batch, dataloader_idx): + batch += 1 + return batch + + trainer = Trainer(**trainer_options) + model = CustomModel() + + with pytest.raises(MisconfigurationException, match=r'Overriding `on_after_batch_transfer` is not .* in DP'): + trainer.fit(model) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_dp_training_step_dict(tmpdir): + """ This test verifies that dp properly reduces dictionaries """ + model = ReductionTestModel() + model.training_step_end = None + model.validation_step_end = None + model.test_step_end = None + + trainer = pl.Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + gpus=2, + accelerator='dp', + ) + trainer.fit(model) + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_dp_test(tmpdir): tutils.set_random_master_port() diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index 6ce1938d3990f..9bfd378aedc16 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -98,3 +98,45 @@ def training_step(self, batch, batch_idx): # make sure types are correct assert save_mock.call_count == expected + + +@mock.patch('torch.save') +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif( + not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" +) +@pytest.mark.parametrize(['k', 'epochs', 'val_check_interval', 'expected'], [(1, 1, 1.0, 1), (2, 2, 0.3, 5)]) +def test_top_k_ddp(save_mock, tmpdir, k, epochs, val_check_interval, expected): + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + local_rank = int(os.getenv("LOCAL_RANK")) + self.log('my_loss', batch_idx * (1 + local_rank), on_epoch=True) + return super().training_step(batch, batch_idx) + + def training_epoch_end(self, outputs) -> None: + data = str(self.global_rank) + obj = [[data], (data, ), set(data)] + out = self.trainer.training_type_plugin.broadcast(obj) + assert obj == [[str(self.global_rank)], (str(self.global_rank), ), set(str(self.global_rank))] + assert out == [['0'], ('0', ), set('0')] + + model = TestModel() + trainer = Trainer( + callbacks=[callbacks.ModelCheckpoint(dirpath=tmpdir, monitor='my_loss_step', save_top_k=k, mode="max")], + default_root_dir=tmpdir, + max_epochs=epochs, + weights_summary=None, + val_check_interval=val_check_interval, + accelerator="ddp", + gpus=2, + limit_train_batches=64, + limit_val_batches=32, + ) + if os.getenv("LOCAL_RANK") == "0": + with pytest.raises(UserWarning, match="The value associated to the key my_loss_epoch: [15.5, 31.0]"): + trainer.fit(model) + assert save_mock.call_count == expected + else: + trainer.fit(model) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 06a114ca15eb9..29eaebc031e3c 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -347,7 +347,7 @@ def on_train_start(self, trainer, pl_module): torch.save = Mock(wraps=torch.save) def on_save_checkpoint(self, trainer, pl_module, checkpoint): - # expect all ranks to run but only rank 0 will actually write the checkpoint file + # only rank 0 will call ``torch.save`` super().on_save_checkpoint(trainer, pl_module, checkpoint) self.on_save_checkpoint_count += 1 @@ -357,8 +357,7 @@ def on_train_end(self, trainer, pl_module): assert self.best_model_score assert self.on_save_checkpoint_count == self.expected_count if trainer.is_global_zero: - # twice the calls expected because ddp broadcast also uses torch.save - assert torch.save.call_count == self.expected_count * 2 + assert torch.save.call_count == self.expected_count else: assert torch.save.call_count == 0 diff --git a/tests/core/test_memory.py b/tests/core/test_memory.py index 1db6981064c6c..b8cf33ab6afcf 100644 --- a/tests/core/test_memory.py +++ b/tests/core/test_memory.py @@ -88,6 +88,19 @@ def forward(self, x): return self.reduce(self.embed(x)) +class PartialScriptModel(LightningModule): + """ A model which contains scripted layers. """ + + def __init__(self): + super().__init__() + self.layer1 = torch.jit.script(nn.Linear(5, 3)) + self.layer2 = nn.Linear(3, 2) + self.example_input_array = torch.rand(2, 5) + + def forward(self, x): + return self.layer2(self.layer1(x)) + + def test_invalid_weights_summmary(): """ Test that invalid value for weights_summary raises an error. """ with pytest.raises(MisconfigurationException, match='`mode` can be None, .* got temp'): @@ -97,11 +110,8 @@ def test_invalid_weights_summmary(): Trainer(weights_summary='temp') -@pytest.mark.parametrize(['mode'], [ - pytest.param(ModelSummary.MODE_FULL), - pytest.param(ModelSummary.MODE_TOP), -]) -def test_empty_model_summary_shapes(mode): +@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) +def test_empty_model_summary_shapes(mode: ModelSummary): """ Test that the summary works for models that have no submodules. """ model = EmptyModule() summary = model.summarize(mode=mode) @@ -110,10 +120,7 @@ def test_empty_model_summary_shapes(mode): assert summary.param_nums == [] -@pytest.mark.parametrize(['mode'], [ - pytest.param(ModelSummary.MODE_FULL), - pytest.param(ModelSummary.MODE_TOP), -]) +@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) @pytest.mark.parametrize(['device'], [ pytest.param(torch.device('cpu')), pytest.param(torch.device('cuda', 0)), @@ -157,10 +164,7 @@ def test_mixed_dtype_model_summary(): ] -@pytest.mark.parametrize(['mode'], [ - pytest.param(ModelSummary.MODE_FULL), - pytest.param(ModelSummary.MODE_TOP), -]) +@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) def test_hooks_removed_after_summarize(mode): """ Test that all hooks were properly removed after summary, even ones that were not run. """ model = UnorderedModel() @@ -171,10 +175,7 @@ def test_hooks_removed_after_summarize(mode): assert handle.id not in handle.hooks_dict_ref() -@pytest.mark.parametrize(['mode'], [ - pytest.param(ModelSummary.MODE_FULL), - pytest.param(ModelSummary.MODE_TOP), -]) +@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) def test_rnn_summary_shapes(mode): """ Test that the model summary works for RNNs. """ model = ParityModuleRNN() @@ -198,10 +199,7 @@ def test_rnn_summary_shapes(mode): ] -@pytest.mark.parametrize(['mode'], [ - pytest.param(ModelSummary.MODE_FULL), - pytest.param(ModelSummary.MODE_TOP), -]) +@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) def test_summary_parameter_count(mode): """ Test that the summary counts the number of parameters in every submodule. """ model = UnorderedModel() @@ -215,10 +213,7 @@ def test_summary_parameter_count(mode): ] -@pytest.mark.parametrize(['mode'], [ - pytest.param(ModelSummary.MODE_FULL), - pytest.param(ModelSummary.MODE_TOP), -]) +@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) def test_summary_layer_types(mode): """ Test that the summary displays the layer names correctly. """ model = UnorderedModel() @@ -232,10 +227,16 @@ def test_summary_layer_types(mode): ] -@pytest.mark.parametrize(['mode'], [ - pytest.param(ModelSummary.MODE_FULL), - pytest.param(ModelSummary.MODE_TOP), -]) +@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) +def test_summary_with_scripted_modules(mode): + model = PartialScriptModel() + summary = model.summarize(mode=mode) + assert summary.layer_types == ["RecursiveScriptModule", "Linear"] + assert summary.in_sizes == [UNKNOWN_SIZE, [2, 3]] + assert summary.out_sizes == [UNKNOWN_SIZE, [2, 2]] + + +@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) @pytest.mark.parametrize(['example_input', 'expected_size'], [ pytest.param([], UNKNOWN_SIZE), pytest.param((1, 2, 3), [UNKNOWN_SIZE] * 3), @@ -269,10 +270,7 @@ def forward(self, *args, **kwargs): assert summary.in_sizes == [expected_size] -@pytest.mark.parametrize(['mode'], [ - pytest.param(ModelSummary.MODE_FULL), - pytest.param(ModelSummary.MODE_TOP), -]) +@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) def test_model_size(mode): """ Test model size is calculated correctly. """ model = PreCalculatedModel() @@ -280,10 +278,7 @@ def test_model_size(mode): assert model.pre_calculated_model_size == summary.model_size -@pytest.mark.parametrize(['mode'], [ - pytest.param(ModelSummary.MODE_FULL), - pytest.param(ModelSummary.MODE_TOP), -]) +@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) def test_empty_model_size(mode): """ Test empty model size is zero. """ model = EmptyModule() @@ -293,15 +288,9 @@ def test_empty_model_size(mode): @pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.") @pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="test requires native AMP.") -@pytest.mark.parametrize( - 'precision', [ - pytest.param(16, marks=pytest.mark.skip(reason="no longer valid, because 16 can mean mixed precision")), - pytest.param(32), - ] -) -def test_model_size_precision(monkeypatch, tmpdir, precision): +def test_model_size_precision(tmpdir): """ Test model size for half and full precision. """ - model = PreCalculatedModel(precision) + model = PreCalculatedModel() # fit model trainer = Trainer( @@ -309,7 +298,7 @@ def test_model_size_precision(monkeypatch, tmpdir, precision): gpus=1, max_steps=1, max_epochs=1, - precision=precision, + precision=32, ) trainer.fit(model) summary = model.summarize() diff --git a/tests/helpers/datasets.py b/tests/helpers/datasets.py index 5af3fbfbc4a11..e7bdad0f1538c 100644 --- a/tests/helpers/datasets.py +++ b/tests/helpers/datasets.py @@ -69,6 +69,7 @@ def __init__( train: bool = True, normalize: tuple = (0.1307, 0.3081), download: bool = True, + **kwargs, ): super().__init__() self.root = root diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index 93a637dda1071..1ef55e729912b 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -21,6 +21,8 @@ import os import sys +import torch + # this is needed because Conda does not use `PYTHONPATH` env var while pip and virtualenv do PYTHONPATH = os.getenv('PYTHONPATH', '') if ':' in PYTHONPATH: @@ -53,8 +55,13 @@ def run_test_from_config(trainer_options): ckpt_path = trainer_options['weights_save_path'] trainer_options.update(callbacks=[ModelCheckpoint(dirpath=ckpt_path)]) - model = BoringModel() + class TestModel(BoringModel): + + def training_epoch_end(self, outputs) -> None: + res = self.trainer.training_type_plugin.reduce(torch.tensor(1., device=self.device), reduce_op="sum") + assert res.sum() == self.trainer.training_type_plugin.world_size + model = TestModel() trainer = Trainer(**trainer_options) trainer.fit(model) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 0b89c3b06c041..8a1260251eb14 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -17,11 +17,13 @@ import shlex import subprocess import sys +from unittest.mock import patch import numpy as np import pytest import torch from sklearn.metrics import accuracy_score +from torch import optim import tests.helpers.pipelines as tpipes import tests.helpers.utils as tutils @@ -55,6 +57,9 @@ def _run_horovod(trainer_options, on_gpu=False): # for Horovod, we interpret `gpus` to be set per worker trainer_options.update(gpus=1 if on_gpu else None) tutils.reset_seed() + # todo: Find why coverage breaks CI. + # append = '-a' if '.coverage' in os.listdir(_PROJECT_ROOT) else '' # noqa E265 + # str(num_processes), sys.executable, '-m', 'coverage', 'run', '--source', 'pytorch_lightning', append, # noqa E265 cmdline = [ 'horovodrun', '-np', str(num_processes), sys.executable, TEST_SCRIPT, '--trainer-options', @@ -119,6 +124,8 @@ def test_horovod_multi_gpu(tmpdir): _run_horovod(trainer_options, on_gpu=True) +# https://discuss.pytorch.org/t/torch-cuda-amp-vs-nvidia-apex/74994 +# Check with (tgaddair) on Horovod issues if this feature is needed @pytest.mark.skip(reason="Horovod has a problem with broadcast when using apex?") @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") @pytest.mark.skipif(not _HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support") @@ -167,6 +174,27 @@ def test_horovod_amp(tmpdir): _run_horovod(trainer_options, on_gpu=True) +@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") +@pytest.mark.skipif(not _HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="test requires torch.cuda.amp") +def test_horovod_gather(tmpdir): + """Test Horovod with multi-GPU support using native amp.""" + trainer_options = dict( + default_root_dir=str(tmpdir), + weights_save_path=str(tmpdir), + gradient_clip_val=1.0, + progress_bar_refresh_rate=0, + max_epochs=1, + limit_train_batches=0.4, + limit_val_batches=0.2, + gpus=2, + deterministic=True, + accelerator='horovod', + ) + _run_horovod(trainer_options, on_gpu=True) + + @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") @pytest.mark.skipif(not _HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support") @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") @@ -198,6 +226,7 @@ def validation_step(self, batch, *args, **kwargs): @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") +@pytest.mark.skipif(not _HOROVOD_AVAILABLE, reason="Horovod is unavailable") def test_horovod_multi_optimizer(tmpdir): model = BasicGAN() @@ -230,7 +259,7 @@ def get_optimizer_params(optimizer): # TODO: unclear Horovod failure... -@pytest.mark.skip(reason="unclear Horovod failure...") +@pytest.mark.skipif(reason="CI agent.jobstatus=Succeeded: Permission denied") @pytest.mark.skipif(not _HOROVOD_AVAILABLE, reason="Horovod is unavailable") @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") def test_result_reduce_horovod(tmpdir): @@ -273,6 +302,7 @@ def training_epoch_end(self, outputs) -> None: max_epochs=1, log_every_n_steps=1, weights_summary=None, + logger=False ) trainer.fit(model) @@ -281,7 +311,7 @@ def training_epoch_end(self, outputs) -> None: # TODO: unclear Horovod failure... -@pytest.mark.skip(reason="unclear Horovod failure...") +@pytest.mark.skipif(reason="CI agent.jobstatus=Succeeded: Permission denied") @pytest.mark.skipif(not _HOROVOD_AVAILABLE, reason="Horovod is unavailable") @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") def test_accuracy_metric_horovod(): @@ -298,10 +328,7 @@ def sk_metric(preds, target): target = torch.randint(high=2, size=(num_batches, batch_size)) def _compute_batch(): - trainer = Trainer( - fast_dev_run=True, - accelerator='horovod', - ) + trainer = Trainer(fast_dev_run=True, accelerator='horovod', logger=False) assert isinstance(trainer.accelerator, CPUAccelerator) # TODO: test that we selected the correct training_type_plugin based on horovod flags @@ -309,7 +336,7 @@ def _compute_batch(): metric = Accuracy( compute_on_step=True, dist_sync_on_step=True, - dist_sync_fn=trainer.training_type_plugin.gather_all_tensors, + dist_sync_fn=trainer.training_type_plugin.all_gather, threshold=threshold ) @@ -334,33 +361,46 @@ def _compute_batch(): horovod.run(_compute_batch, np=2) -# @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") -# def test_horovod_multi_optimizer_with_scheduling_stepping(tmpdir): -# model = BoringModel() -# model.configure_optimizers = model.configure_optimizers__multiple_schedulers -# -# num_workers = 8 -# init_lr = hparams.get('learning_rate') * num_workers -# -# with patch('pytorch_lightning.accelerators.legacy.horovod_backend.hvd.size') as mock_hvd_size: -# mock_hvd_size.return_value = 8 -# -# # fit model -# trainer = Trainer( -# default_root_dir=tmpdir, -# max_epochs=1, -# limit_val_batches=0.5, -# limit_train_batches=0.2, -# distributed_backend='horovod' -# ) -# results = trainer.fit(model) -# assert results == 1 -# -# adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups][0] -# adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups][0] -# -# # Called ones after end of epoch with gamma=0.1 -# assert pytest.approx(init_lr * 0.1) == adjusted_lr1 -# -# # Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times with gamma=0.1 -# assert pytest.approx(init_lr * 0.1) == adjusted_lr2 +@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") +@pytest.mark.skipif(not _HOROVOD_AVAILABLE, reason="Horovod is unavailable") +def test_horovod_multi_optimizer_with_scheduling_stepping(tmpdir): + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx, optimizer_idx): + return super().training_step(batch, batch_idx) + + def configure_optimizers(self): + optimizer1 = optim.Adam(self.parameters(), lr=0.1) + optimizer2 = optim.Adam(self.parameters(), lr=0.1) + lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1) + lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1) + return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2] + + model = TestModel() + model.training_epoch_end = None + + num_workers = 8 + init_lr = 0.1 * num_workers + + with patch('horovod.torch.size', return_value=8): + + # fit model + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=0.5, + limit_train_batches=0.2, + accelerator='horovod' + ) + results = trainer.fit(model) + assert results == 1 + + adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups][0] + adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups][0] + + # Called ones after end of epoch with gamma=0.1 + assert pytest.approx(init_lr * 0.1) == adjusted_lr1 + + # Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times with gamma=0.1 + assert pytest.approx(init_lr * 0.1) == adjusted_lr2 diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 9c9c5c097b4c5..8c4c7873681ad 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -34,6 +34,11 @@ def deepspeed_config(): } +@pytest.fixture +def deepspeed_zero_config(deepspeed_config): + return {**deepspeed_config, 'zero_allow_untested_optimizer': True, 'zero_optimization': {'stage': 2}} + + @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") def test_deepspeed_plugin_string(tmpdir): """ @@ -165,9 +170,6 @@ def test_invalid_deepspeed_defaults_no_precision(tmpdir): @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") -@pytest.mark.skipif( - not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" -) def test_warn_deepspeed_override_backward(tmpdir): """ Test to ensure that if the backward hook in the LightningModule is overridden, we throw a warning. @@ -191,9 +193,6 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") -@pytest.mark.skipif( - not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" -) def test_deepspeed_run_configure_optimizers(tmpdir): """ Test end to end that deepspeed works with defaults (without ZeRO as that requires compilation), @@ -223,9 +222,6 @@ def on_train_start(self) -> None: @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") -@pytest.mark.skipif( - not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" -) def test_deepspeed_config(tmpdir, deepspeed_config): """ Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers @@ -255,6 +251,58 @@ def on_train_start(self) -> None: _assert_save_model_is_equal(model, tmpdir, trainer) +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +def test_deepspeed_custom_precision_params(tmpdir): + """ + Ensure if we modify the FP16 parameters via the DeepSpeedPlugin, the deepspeed config contains these changes. + """ + + class TestModel(BoringModel): + + def on_train_start(self) -> None: + assert self.trainer.training_type_plugin.config['fp16']['loss_scale'] == 10 + assert self.trainer.training_type_plugin.config['fp16']['initial_scale_power'] == 10 + assert self.trainer.training_type_plugin.config['fp16']['loss_scale_window'] == 10 + assert self.trainer.training_type_plugin.config['fp16']['hysteresis'] == 10 + assert self.trainer.training_type_plugin.config['fp16']['min_loss_scale'] == 10 + raise SystemExit() + + model = TestModel() + trainer = Trainer( + plugins=[ + DeepSpeedPlugin( + loss_scale=10, initial_scale_power=10, loss_scale_window=10, hysteresis=10, min_loss_scale=10 + ) + ], + precision=16, + gpus=1 + ) + with pytest.raises(SystemExit): + trainer.fit(model) + + +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +def test_deepspeed_assert_config_zero_offload_disabled(tmpdir, deepspeed_zero_config): + """ + Ensure if we use a config and turn off cpu_offload, that this is set to False within the config. + """ + + deepspeed_zero_config['zero_optimization']['cpu_offload'] = False + + class TestModel(BoringModel): + + def on_train_start(self) -> None: + assert self.trainer.training_type_plugin.config['zero_optimization']['cpu_offload'] is False + raise SystemExit() + + model = TestModel() + trainer = Trainer(plugins=[DeepSpeedPlugin(config=deepspeed_zero_config)], precision=16, gpus=1) + with pytest.raises(SystemExit): + trainer.fit(model) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") diff --git a/tests/special_tests.sh b/tests/special_tests.sh index a2373d05a42ef..43658721e9226 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -17,9 +17,6 @@ export PL_RUNNING_SPECIAL_TESTS=1 DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no" python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp python ${DEFAULTS} tests/models/test_sync_batchnorm.py::test_sync_batchnorm_ddp -python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_warn_deepspeed_override_backward -python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_run_configure_optimizers -python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_config python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_multigpu python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_manual @@ -35,4 +32,5 @@ python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_trainer_ python ${DEFAULTS} tests/models/test_hooks.py::test_transfer_batch_hook_ddp python ${DEFAULTS} tests/trainer/test_data_loading.py::test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp_with_toggle_model +python ${DEFAULTS} tests/checkpointing/test_checkpoint_callback_frequency.py::test_top_k_ddp nvprof --profile-from-start off -o trace_name.prof -- python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_nested_emit_nvtx diff --git a/tests/trainer/flags/test_env_vars.py b/tests/trainer/flags/test_env_vars.py index ba76820d15ee8..65b251a6633b5 100644 --- a/tests/trainer/flags/test_env_vars.py +++ b/tests/trainer/flags/test_env_vars.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from unittest import mock from pytorch_lightning import Trainer -def test_passing_env_variables(tmpdir): +def test_passing_no_env_variables(): """Testing overwriting trainer arguments """ trainer = Trainer() assert trainer.logger is not None @@ -25,17 +26,29 @@ def test_passing_env_variables(tmpdir): assert trainer.logger is None assert trainer.max_steps == 42 - os.environ['PL_TRAINER_LOGGER'] = 'False' - os.environ['PL_TRAINER_MAX_STEPS'] = '7' + +@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "False", "PL_TRAINER_MAX_STEPS": "7"}) +def test_passing_env_variables_only(): + """Testing overwriting trainer arguments """ trainer = Trainer() assert trainer.logger is None assert trainer.max_steps == 7 - os.environ['PL_TRAINER_LOGGER'] = 'True' + +@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "True", "PL_TRAINER_MAX_STEPS": "7"}) +def test_passing_env_variables_defaults(): + """Testing overwriting trainer arguments """ trainer = Trainer(False, max_steps=42) - assert trainer.logger is not None - assert trainer.max_steps == 7 + assert trainer.logger is None + assert trainer.max_steps == 42 + - # this has to be cleaned - del os.environ['PL_TRAINER_LOGGER'] - del os.environ['PL_TRAINER_MAX_STEPS'] +@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_GPUS": "2"}) +@mock.patch('torch.cuda.device_count', return_value=2) +@mock.patch('torch.cuda.is_available', return_value=True) +def test_passing_env_variables_gpus(cuda_available_mock, device_count_mock): + """Testing overwriting trainer arguments """ + trainer = Trainer() + assert trainer.gpus == 2 + trainer = Trainer(gpus=1) + assert trainer.gpus == 1 diff --git a/tests/trainer/logging_/test_distributed_logging.py b/tests/trainer/logging_/test_distributed_logging.py index dffb511614cf6..de51de74cc355 100644 --- a/tests/trainer/logging_/test_distributed_logging.py +++ b/tests/trainer/logging_/test_distributed_logging.py @@ -11,14 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os import platform from unittest import mock +from unittest.mock import Mock import pytest import torch -from pytorch_lightning import Trainer +from pytorch_lightning import Callback, Trainer from tests.helpers import BoringModel @@ -69,3 +70,39 @@ def test_global_zero_only_logging_ddp_spawn(tmpdir): weights_summary=None, ) trainer.fit(model) + + +def test_first_logger_call_in_subprocess(tmpdir): + """ + Test that the Trainer does not call the logger too early. Only when the worker processes are initialized + do we have access to the rank and know which one is the main process. + """ + + class LoggerCallsObserver(Callback): + + def on_fit_start(self, trainer, pl_module): + # this hook is executed directly before Trainer.pre_dispatch + # logger should not write any logs until this point + assert not trainer.logger.method_calls + assert not os.listdir(trainer.logger.save_dir) + + def on_train_start(self, trainer, pl_module): + assert trainer.logger.method_call + trainer.logger.log_hyperparams.assert_called_once() + trainer.logger.log_graph.assert_called_once() + + logger = Mock() + logger.version = "0" + logger.name = "name" + logger.save_dir = tmpdir + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=1, + logger=logger, + callbacks=[LoggerCallsObserver()] + ) + trainer.fit(model) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 59f3c2b54c13c..6966edc3cbf70 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1850,3 +1850,35 @@ def compare_optimizers(): trainer.max_epochs = 2 # simulate multiple fit calls trainer.fit(model) compare_optimizers() + + +@pytest.mark.parametrize("use_datamodule", [False, True]) +def test_trainer_predict_verify_config(tmpdir, use_datamodule): + + class TestModel(LightningModule): + + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + return self.layer(x) + + dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir) + + if use_datamodule: + datamodule = TestLightningDataModule(dataloaders) + results = trainer.predict(model, datamodule=datamodule) + else: + results = trainer.predict(model, dataloaders=dataloaders) + + assert len(results) == 2 + assert results[0][0].shape == torch.Size([1, 2]) + + model.predict_dataloader = None + + with pytest.raises(MisconfigurationException, match="Dataloader not found for `Trainer.predict`"): + trainer.predict(model) diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py new file mode 100644 index 0000000000000..ad7fc57092f32 --- /dev/null +++ b/tests/tuner/test_scale_batch_size.py @@ -0,0 +1,65 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from torch.utils.data import DataLoader + +from pytorch_lightning import Trainer +from pytorch_lightning.tuner.tuning import Tuner +from tests.helpers import BoringDataModule, BoringModel + + +class BatchSizeDataModule(BoringDataModule): + + def __init__(self, batch_size=None): + super().__init__() + if batch_size is not None: + self.batch_size = batch_size + + def train_dataloader(self): + return DataLoader(self.random_train, batch_size=getattr(self, "batch_size", 1)) + + +class BatchSizeModel(BoringModel): + + def __init__(self, batch_size=None): + super().__init__() + if batch_size is not None: + self.batch_size = batch_size + + +@pytest.mark.parametrize( + "model,datamodule", [ + (BatchSizeModel(2), None), + (BatchSizeModel(2), BatchSizeDataModule(2)), + (BatchSizeModel(2), BatchSizeDataModule(None)), + (BatchSizeModel(None), BatchSizeDataModule(2)), + ] +) +def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model, datamodule): + """ Test the tuner method `Tuner.scale_batch_size` with a datamodule. """ + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=0, + max_epochs=1, + ) + tuner = Tuner(trainer) + new_batch_size = tuner.scale_batch_size( + model=model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule + ) + assert new_batch_size == 16 + if hasattr(model, "batch_size"): + assert model.batch_size == 16 + if datamodule is not None and hasattr(datamodule, "batch_size"): + assert datamodule.batch_size == 16