diff --git a/CHANGELOG.md b/CHANGELOG.md index 57105e252dfb0..d3baec790195d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed distributed setting and `ddp_cpu` only with `num_processes>1` ([#5297](https://github.com/PyTorchLightning/pytorch-lightning/pull/5297)) + ## [1.1.0] - 2020-12-09 diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py index bd7f335a03720..0f58cb882bcf9 100644 --- a/benchmarks/test_sharded_parity.py +++ b/benchmarks/test_sharded_parity.py @@ -1,7 +1,7 @@ import os import platform import time -from typing import Union +from typing import Type, Union import pytest import torch @@ -14,35 +14,21 @@ from tests.base.boring_model import BoringModel, RandomDataset -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") -def test_ddp_sharded_plugin_correctness_one_device(): - plugin_parity_test( - accelerator='ddp_cpu', - max_percent_speed_diff=0.15, # slower speed due to one CPU doing additional sequential memory saving calls - plugin=DDPShardedPlugin(), - model_cls=SeedTrainLoaderModel - ) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_one_gpu(): plugin_parity_test( gpus=1, accelerator='ddp_spawn', plugin=DDPShardedPlugin(), - model_cls=SeedTrainLoaderModel + model_cls=SeedTrainLoaderModel, ) @pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP") @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_amp_one_gpu(): plugin_parity_test( @@ -50,14 +36,13 @@ def test_ddp_sharded_plugin_correctness_amp_one_gpu(): precision=16, accelerator='ddp_spawn', plugin=DDPShardedPlugin(), - model_cls=SeedTrainLoaderModel + model_cls=SeedTrainLoaderModel, ) @pytest.mark.skip(reason="Not a critical test, skip till drone CI performance improves.") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_multi_gpu(): plugin_parity_test( @@ -65,13 +50,12 @@ def test_ddp_sharded_plugin_correctness_multi_gpu(): accelerator='ddp_spawn', plugin=DDPShardedPlugin(), model_cls=SeedTrainLoaderModel, - max_percent_speed_diff=0.25 + max_percent_speed_diff=0.25, ) @pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): @@ -81,13 +65,12 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): accelerator='ddp_spawn', plugin=DDPShardedPlugin(), model_cls=SeedTrainLoaderModel, - max_percent_speed_diff=0.25 + max_percent_speed_diff=0.25, ) @pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu(): @@ -97,7 +80,7 @@ def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu(): accelerator='ddp_spawn', plugin='ddp_sharded', model_cls=SeedTrainLoaderModel, - max_percent_speed_diff=0.25 + max_percent_speed_diff=0.25, ) @@ -133,8 +116,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None): @pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim(): """ @@ -145,14 +127,13 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim(): gpus=2, accelerator='ddp_spawn', model_cls=SeedTrainLoaderMultipleOptimizersModel, - max_percent_speed_diff=0.25 # Increase speed diff since only 2 GPUs sharding 2 optimizers + max_percent_speed_diff=0.25, # Increase speed diff since only 2 GPUs sharding 2 optimizers ) @pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir): """ @@ -163,7 +144,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir): gpus=2, accelerator='ddp_spawn', model_cls=SeedTrainLoaderManualModel, - max_percent_speed_diff=0.25 # Increase speed diff since only 2 GPUs sharding 2 optimizers + max_percent_speed_diff=0.25, # Increase speed diff since only 2 GPUs sharding 2 optimizers ) @@ -259,13 +240,14 @@ def record_ddp_fit_model_stats(trainer, model, use_cuda): def plugin_parity_test( - model_cls: SeedTrainLoaderModel, + model_cls: Type[SeedTrainLoaderModel], plugin: Union[str, DDPPlugin], seed: int = 42, accelerator: str = 'ddp_spawn', gpus: int = 0, precision: int = 32, - max_percent_speed_diff: float = 0.1): + max_percent_speed_diff: float = 0.1, +): """ Ensures that the trained model is identical to the standard DDP implementation. Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate. diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index ce2a418cf2fa5..7417f889dd808 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -15,7 +15,7 @@ import torch -from pytorch_lightning.utilities import _HOROVOD_AVAILABLE +from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, DeviceType, DistributedType from pytorch_lightning import _logger as log from pytorch_lightning import accelerators from pytorch_lightning.accelerators.accelerator import Accelerator @@ -81,10 +81,7 @@ def on_trainer_init( # sync-bn backend self.trainer.sync_batchnorm = sync_batchnorm - self.trainer.tpu_cores = device_parser.parse_tpu_cores(tpu_cores) - self.trainer.on_tpu = self.trainer.tpu_cores is not None - - self.trainer.tpu_id = self.trainer.tpu_cores[0] if isinstance(self.trainer.tpu_cores, list) else None + self._parse_tpu_device_details(tpu_cores) if num_processes != 1 and distributed_backend != "ddp_cpu": rank_zero_warn("num_processes is only used for `accelerator='ddp_cpu'`. Ignoring it.") @@ -100,23 +97,10 @@ def on_trainer_init( self.trainer.data_parallel_device_ids = device_parser.parse_gpu_ids(self.trainer.gpus) self.trainer.root_gpu = device_parser.determine_root_gpu_device(self.trainer.data_parallel_device_ids) - self.trainer.root_device = torch.device("cpu") - - self.trainer.on_gpu = True if (self.trainer.data_parallel_device_ids and torch.cuda.is_available()) else False - - # tpu state flags - self.trainer.use_tpu = False - self.trainer.tpu_local_core_rank = None - self.trainer.tpu_global_core_rank = None # distributed backend choice self.set_distributed_mode() - # override dist backend when using tpus - if self.trainer.on_tpu: - self.trainer.distributed_backend = "tpu" - self.trainer.use_tpu = True - # init flags for SLURM+DDP to work self.trainer.world_size = 1 self.trainer.interactive_ddp_procs = [] @@ -135,10 +119,29 @@ def on_trainer_init( self.trainer.replace_sampler_ddp = replace_sampler_ddp + def _parse_tpu_device_details(self, tpu_cores): + self.trainer.tpu_cores = device_parser.parse_tpu_cores(tpu_cores) + if self.trainer.tpu_cores is not None: + if _TPU_AVAILABLE: + self.trainer._device_type = DeviceType.TPU + self.trainer.distributed_backend = "tpu" + else: + raise MisconfigurationException( + f"You have requested {self.trainer.tpu_cores} TPU cores but none is available." + ) + + self.trainer.tpu_id = self.trainer.tpu_cores[0] if isinstance(self.trainer.tpu_cores, list) else None + + # tpu state flags + self.trainer.tpu_local_core_rank = None + self.trainer.tpu_global_core_rank = None + def _map_deprecated_dist_backend(self, accelerator, distributed_backend): if distributed_backend is not None: - rank_zero_warn(DeprecationWarning('distributed_backend has been renamed to accelerator. ' - 'Deprecated in 1.0.0, will be removed in 1.2.0')) + rank_zero_warn( + '`distributed_backend` has been renamed to accelerator. Deprecated in 1.0.0, will be removed in 1.2.0', + DeprecationWarning + ) # temporary mapping until we remove all the distributed_backend references if accelerator is not None: @@ -276,71 +279,75 @@ def select_accelerator(self): accelerator_backend = accelerators.CPUAccelerator(self.trainer, cluster_env) else: raise MisconfigurationException( - f'Trainer(accelerator={self.trainer.distributed_backend} is not a supported backend' + f'`Trainer(accelerator={self.trainer.distributed_backend}, num_nodes={self.trainer.num_nodes},' + f' num_processes={self.trainer.num_processes}, ...)` is not a supported backend for' + f' num_gpus={self.trainer.num_gpus}' ) return accelerator_backend def set_distributed_mode(self): - self.trainer.use_dp = False - self.trainer.use_ddp = False - self.trainer.use_ddp2 = False - self.trainer.use_horovod = False - self.trainer.use_single_gpu = False if self.trainer.distributed_backend is None: if self.has_horovodrun(): self._set_horovod_backend() - elif self.trainer.num_gpus == 0: - if self.trainer.num_nodes > 1 or self.trainer.num_processes > 1: - self.trainer.use_ddp = True # ddp_cpu - elif self.trainer.num_gpus == 1: - self.trainer.use_single_gpu = True + elif self.trainer.num_gpus == 0 and (self.trainer.num_nodes > 1 or self.trainer.num_processes > 1): + self.trainer._distrib_type = DistributedType.DDP elif self.trainer.num_gpus > 1: rank_zero_warn( 'You requested multiple GPUs but did not specify a backend, e.g.' - ' `Trainer(accelerator="dp"|"ddp"|"ddp2")`.' - ' Setting `accelerator="ddp_spawn"` for you.' + ' `Trainer(accelerator="dp"|"ddp"|"ddp2")`. Setting `accelerator="ddp_spawn"` for you.' ) self.trainer.distributed_backend = "ddp_spawn" - if self.trainer.distributed_backend == "dp": - # do nothing if num_gpus == 0 - if self.trainer.num_gpus == 1: - self.trainer.use_single_gpu = True - self.trainer.use_dp = True - elif self.trainer.num_gpus > 1: - self.trainer.use_dp = True - - elif self.trainer.distributed_backend in ("ddp", "ddp_spawn"): - if self.trainer.num_gpus == 0: - if self.trainer.num_nodes > 1 or self.trainer.num_processes > 1: - self.trainer.use_ddp = True # ddp_cpu - elif self.trainer.num_gpus == 1: - self.trainer.use_single_gpu = True - self.trainer.use_ddp = True - elif self.trainer.num_gpus > 1: - self.trainer.use_ddp = True - self.trainer.num_processes = self.trainer.num_gpus - - elif self.trainer.distributed_backend == "ddp2": - # do nothing if num_gpus == 0 - if self.trainer.num_gpus >= 1: - self.trainer.use_ddp2 = True - elif self.trainer.distributed_backend == "ddp_cpu": + # special case with DDP on CPUs + if self.trainer.distributed_backend == "ddp_cpu": + self.trainer._distrib_type = DistributedType.DDP + self.trainer.data_parallel_device_ids = None if self.trainer.num_gpus > 0: rank_zero_warn( 'You requested one or more GPUs, but set the backend to `ddp_cpu`. Training will not use GPUs.' ) - self.trainer.use_ddp = True - self.trainer.data_parallel_device_ids = None - self.trainer.on_gpu = False - self.trainer.on_cpu = True - elif self.trainer.distributed_backend == "horovod": + if self.trainer.num_processes is None: + # define the max CPU available + self.trainer.num_processes = os.cpu_count() + # special case with TPUs + elif self.trainer.distributed_backend == 'tpu': + self.trainer._device_type = DeviceType.TPU + # set all other requested distrib. types adn if it was not set in the + elif self.trainer.distributed_backend and self.trainer._distrib_type is None: + self.trainer._distrib_type = DistributedType(self.trainer.distributed_backend) + + # unless you request explicitly for CPU and some GPU are available use them + _on_cpu = self.trainer.distributed_backend and 'cpu' in self.trainer.distributed_backend + if (self.trainer.num_gpus > 0 and not _on_cpu): + self.trainer._device_type = DeviceType.GPU + + _distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2) + # DP and DDP2 cannot run without GPU + if (self.trainer.num_gpus == 0 and self.trainer._distrib_type in _distrib_types): + rank_zero_warn( + 'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.' + ) + # todo: in some cases it yield in comarison None and int + if ((self.trainer.num_nodes and self.trainer.num_nodes > 1) + or (self.trainer.num_processes and self.trainer.num_processes > 1)): + self.trainer._distrib_type = DistributedType.DDP + else: + rank_zero_warn('You are running on single node with no parallelization, so distributed has no effect.') + self.trainer._distrib_type = None + + # for DDP overwrite nb processes by requested GPUs + if (self.trainer._device_type == DeviceType.GPU + and self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN)): + self.trainer.num_processes = self.trainer.num_gpus + + # Horovod si an extra case... + if self.trainer.distributed_backend == "horovod": self._set_horovod_backend() # throw error to force user ddp or ddp2 choice - if self.trainer.num_nodes > 1 and not (self.trainer.use_ddp2 or self.trainer.use_ddp): + if self.trainer.num_nodes > 1 and self.trainer._distrib_type not in (DistributedType.DDP2, DistributedType.DDP): raise MisconfigurationException( 'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. ' 'To silence this warning set `accelerator="ddp"` or `accelerator="ddp2"`' @@ -350,12 +357,12 @@ def set_distributed_mode(self): num_cores = self.trainer.tpu_cores if self.trainer.tpu_cores is not None else 0 rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores') - if torch.cuda.is_available() and not self.trainer.on_gpu: + if torch.cuda.is_available() and self.trainer._device_type != DeviceType.GPU: rank_zero_warn('GPU available but not used. Set the --gpus flag when calling the script.') def _set_horovod_backend(self): - self.check_horovod() - self.trainer.use_horovod = True + self._check_horovod() + self.trainer._distrib_type = DistributedType.HOROVOD # Initialize Horovod to get rank / size info hvd.init() @@ -363,7 +370,7 @@ def _set_horovod_backend(self): # Horovod assigns one local GPU per process self.trainer.root_gpu = hvd.local_rank() - def check_horovod(self): + def _check_horovod(self): """Raises a `MisconfigurationException` if the Trainer is not configured correctly for Horovod.""" if not _HOROVOD_AVAILABLE: raise MisconfigurationException( diff --git a/pytorch_lightning/plugins/plugin_connector.py b/pytorch_lightning/plugins/plugin_connector.py index d66c25173cc77..ccd128d87a26a 100644 --- a/pytorch_lightning/plugins/plugin_connector.py +++ b/pytorch_lightning/plugins/plugin_connector.py @@ -31,8 +31,6 @@ def __init__(self, trainer): self.plugins = [] self.ddp_plugin = DDPPlugin() self.cloud_environment = None - self.amp_plugin = NativeAMPPlugin(trainer) - self.apex_plugin = ApexPlugin(trainer) def on_trainer_init(self, plugins: Optional[Union[str, list]]): self.plugins = plugins diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 7b4de47a1be2c..2c8377d2936c9 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.utilities import DistributedType, DeviceType +from pytorch_lightning.utilities import DistributedType, DeviceType, rank_zero_warn class DeprecatedDistDeviceAttributes: @@ -28,7 +28,7 @@ def on_cpu(self) -> bool: @on_cpu.setter def on_cpu(self, val: bool) -> None: - # rank_zero_warn("Internal: `on_cpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + rank_zero_warn("Internal: `on_cpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) if val: self._device_type = DeviceType.CPU @@ -39,8 +39,7 @@ def on_tpu(self) -> bool: @on_tpu.setter def on_tpu(self, val: bool) -> None: - # rank_zero_warn("Internal: `on_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) - # todo add logic that it cannot be set if TPU is missing + rank_zero_warn("Internal: `on_tpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) if val: self._device_type = DeviceType.TPU @@ -51,7 +50,7 @@ def use_tpu(self) -> bool: @use_tpu.setter def use_tpu(self, val: bool) -> None: - # rank_zero_warn("Internal: `use_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + rank_zero_warn("Internal: `use_tpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) self.on_tpu = val @property @@ -61,8 +60,7 @@ def on_gpu(self) -> bool: @on_gpu.setter def on_gpu(self, val: bool) -> None: - # rank_zero_warn("Internal: `on_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) - # todo add logic that it cannot be set if GPU is missing + rank_zero_warn("Internal: `on_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) if val: self._device_type = DeviceType.GPU @@ -73,7 +71,7 @@ def use_dp(self) -> bool: @use_dp.setter def use_dp(self, val: bool) -> None: - # rank_zero_warn("Internal: `use_dp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + rank_zero_warn("Internal: `use_dp` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) if val: self._distrib_type = DistributedType.DP @@ -84,7 +82,7 @@ def use_ddp(self) -> bool: @use_ddp.setter def use_ddp(self, val: bool) -> None: - # rank_zero_warn("Internal: `use_ddp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + rank_zero_warn("Internal: `use_ddp` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) if val: self._distrib_type = DistributedType.DDP @@ -95,7 +93,7 @@ def use_ddp2(self) -> bool: @use_ddp2.setter def use_ddp2(self, val: bool) -> None: - # rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning) + rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) if val: self._distrib_type = DistributedType.DDP2 @@ -108,9 +106,9 @@ def use_horovod(self) -> bool: @use_horovod.setter def use_horovod(self, val: bool) -> None: - # rank_zero_warn( - # "Internal: `use_horovod` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning - # ) + rank_zero_warn( + "Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning + ) if val: self._distrib_type = DistributedType.HOROVOD @@ -126,8 +124,8 @@ def use_single_gpu(self) -> bool: @use_single_gpu.setter def use_single_gpu(self, val: bool) -> None: - # rank_zero_warn( - # "Internal: `use_single_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning, - # ) + rank_zero_warn( + "Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning, + ) if val: self._device_type = DeviceType.GPU diff --git a/tests/backends/test_accelerator_connector.py b/tests/backends/test_accelerator_connector.py index b9b4263d0cf50..dc8bf338d3eb3 100644 --- a/tests/backends/test_accelerator_connector.py +++ b/tests/backends/test_accelerator_connector.py @@ -50,6 +50,7 @@ def on_fit_start(self, trainer, pl_module): trainer = Trainer( fast_dev_run=True, accelerator='ddp_cpu', + num_processes=2, callbacks=[CB()], ) @@ -242,7 +243,7 @@ def on_fit_start(self, trainer, pl_module): trainer = Trainer( fast_dev_run=True, accelerator='ddp_cpu', - num_processes=1, + num_processes=2, callbacks=[CB()], ) @@ -251,7 +252,7 @@ def on_fit_start(self, trainer, pl_module): @mock.patch.dict(os.environ, { - "SLURM_NTASKS": "1", + "SLURM_NTASKS": "2", "SLURM_JOB_NAME": "SOME_NAME", "SLURM_NODEID": "0", "LOCAL_RANK": "0", @@ -270,7 +271,7 @@ def on_fit_start(self, trainer, pl_module): trainer = Trainer( fast_dev_run=True, accelerator='ddp_cpu', - num_processes=1, + num_processes=2, callbacks=[CB()], ) @@ -279,7 +280,7 @@ def on_fit_start(self, trainer, pl_module): @mock.patch.dict(os.environ, { - "SLURM_NTASKS": "1", + "SLURM_NTASKS": "2", "SLURM_JOB_NAME": "SOME_NAME", "SLURM_NODEID": "0", "LOCAL_RANK": "0", @@ -307,7 +308,7 @@ def on_fit_start(self, trainer, pl_module): plugins=[CustomCluster()], fast_dev_run=True, accelerator='ddp_cpu', - num_processes=1, + num_processes=2, callbacks=[CB()], ) @@ -316,7 +317,7 @@ def on_fit_start(self, trainer, pl_module): @mock.patch.dict(os.environ, { - "SLURM_NTASKS": "1", + "SLURM_NTASKS": "2", "SLURM_JOB_NAME": "SOME_NAME", "SLURM_NODEID": "0", "LOCAL_RANK": "0", @@ -341,7 +342,7 @@ def on_fit_start(self, trainer, pl_module): trainer = Trainer( fast_dev_run=True, accelerator=Accel(), - num_processes=1, + num_processes=2, callbacks=[CB()] ) @@ -350,7 +351,7 @@ def on_fit_start(self, trainer, pl_module): @mock.patch.dict(os.environ, { - "SLURM_NTASKS": "1", + "SLURM_NTASKS": "2", "SLURM_JOB_NAME": "SOME_NAME", "SLURM_NODEID": "0", "LOCAL_RANK": "0", @@ -367,8 +368,8 @@ def on_fit_start(self, trainer, pl_module): trainer = Trainer( fast_dev_run=True, accelerator='ddp_cpu', - num_processes=1, - callbacks=[CB()] + num_processes=2, + callbacks=[CB()], ) with pytest.raises(SystemExit): diff --git a/tests/checkpointing/test_torch_saving.py b/tests/checkpointing/test_torch_saving.py index 493aa0dabe126..a15d425f5a0e7 100644 --- a/tests/checkpointing/test_torch_saving.py +++ b/tests/checkpointing/test_torch_saving.py @@ -43,8 +43,7 @@ def test_model_torch_save(tmpdir, enable_pl_optimizer): assert is_lightning_optimizer if enable_pl_optimizer else not is_lightning_optimizer -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") def test_model_torch_save_ddp_cpu(tmpdir): """Test to ensure torch save does not fail for model and trainer using cpu ddp.""" model = BoringModel() diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index 9a7a970aecaf7..2d2d59be0f797 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -37,35 +37,42 @@ def test_v1_4_0_deprecated_imports(): from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils # noqa: F811 F401 -# todo: later add also checking deprecated warnings def test_v1_4_0_deprecated_trainer_attributes(): """Test that Trainer attributes works fine.""" trainer = Trainer() trainer._distrib_type = None trainer._device_type = None - trainer.on_cpu = True + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + trainer.on_cpu = True assert trainer.on_cpu - trainer.on_gpu = True + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + trainer.on_gpu = True assert trainer.on_gpu - trainer.on_tpu = True + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + trainer.on_tpu = True assert trainer.on_tpu trainer._device_type = None - trainer.use_tpu = True + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + trainer.use_tpu = True assert trainer.use_tpu - trainer.use_dp = True + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + trainer.use_dp = True assert trainer.use_dp - trainer.use_ddp = True + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + trainer.use_ddp = True assert trainer.use_ddp - trainer.use_ddp2 = True + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + trainer.use_ddp2 = True assert trainer.use_ddp2 - trainer.use_horovod = True + with pytest.deprecated_call(match='deprecated in v1.2 and will be removed in v1.4'): + trainer.use_horovod = True assert trainer.use_horovod diff --git a/tests/plugins/test_amp_plugin.py b/tests/plugins/test_amp_plugin.py index 6c5a7b052d0d1..1e98740f99d62 100644 --- a/tests/plugins/test_amp_plugin.py +++ b/tests/plugins/test_amp_plugin.py @@ -21,8 +21,10 @@ "SLURM_LOCALID": "0" }) @mock.patch('torch.cuda.device_count', return_value=2) -@pytest.mark.parametrize(['ddp_backend', 'gpus', 'num_processes'], - [('ddp_cpu', None, None), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)]) +@pytest.mark.parametrize( + ['ddp_backend', 'gpus', 'num_processes'], + [('ddp_cpu', None, 2), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)], +) def test_amp_choice_default_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes): class CB(Callback): @@ -55,8 +57,10 @@ def on_fit_start(self, trainer, pl_module): "SLURM_LOCALID": "0" }) @mock.patch('torch.cuda.device_count', return_value=2) -@pytest.mark.parametrize(['ddp_backend', 'gpus', 'num_processes'], - [('ddp_cpu', None, None), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)]) +@pytest.mark.parametrize( + ['ddp_backend', 'gpus', 'num_processes'], + [('ddp_cpu', None, 2), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)], +) def test_amp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes): class MyNativeAMP(NativeAMPPlugin): pass diff --git a/tests/plugins/test_apex_plugin.py b/tests/plugins/test_apex_plugin.py index bfed1aefec0a1..c4198b97446c3 100644 --- a/tests/plugins/test_apex_plugin.py +++ b/tests/plugins/test_apex_plugin.py @@ -18,8 +18,10 @@ "SLURM_LOCALID": "0" }) @mock.patch('torch.cuda.device_count', return_value=2) -@pytest.mark.parametrize(['ddp_backend', 'gpus', 'num_processes'], - [('ddp_cpu', None, None), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)]) +@pytest.mark.parametrize( + ['ddp_backend', 'gpus', 'num_processes'], + [('ddp_cpu', None, 2), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)], +) def test_amp_choice_default_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes): class CB(Callback): @@ -52,8 +54,10 @@ def on_fit_start(self, trainer, pl_module): "SLURM_LOCALID": "0" }) @mock.patch('torch.cuda.device_count', return_value=2) -@pytest.mark.parametrize(['ddp_backend', 'gpus', 'num_processes'], - [('ddp_cpu', None, None), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)]) +@pytest.mark.parametrize( + ['ddp_backend', 'gpus', 'num_processes'], + [('ddp_cpu', None, 2), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)], +) def test_amp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes): class MyApexPlugin(ApexPlugin): pass diff --git a/tests/plugins/test_ddp_plugin.py b/tests/plugins/test_ddp_plugin.py index 4e51fc7c5ac21..fe8fc555ba06c 100644 --- a/tests/plugins/test_ddp_plugin.py +++ b/tests/plugins/test_ddp_plugin.py @@ -27,7 +27,7 @@ @mock.patch("torch.cuda.device_count", return_value=2) @pytest.mark.parametrize( ["ddp_backend", "gpus", "num_processes"], - [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], + [("ddp_cpu", None, 2), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], ) def test_ddp_choice_default_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes): class CB(Callback): @@ -62,7 +62,7 @@ def on_fit_start(self, trainer, pl_module): @mock.patch("torch.cuda.device_count", return_value=2) @pytest.mark.parametrize( ["ddp_backend", "gpus", "num_processes"], - [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], + [("ddp_cpu", None, 2), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], ) def test_ddp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes): class MyDDP(DDPPlugin): @@ -101,7 +101,7 @@ def on_fit_start(self, trainer, pl_module): @mock.patch("torch.cuda.device_count", return_value=2) @pytest.mark.parametrize( ["ddp_backend", "gpus", "num_processes"], - [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], + [("ddp_cpu", None, 2), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], ) @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed sharded plugin is not supported on Windows") @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") @@ -139,7 +139,7 @@ def on_fit_start(self, trainer, pl_module): @mock.patch("torch.cuda.device_count", return_value=2) @pytest.mark.parametrize( ["ddp_backend", "gpus", "num_processes"], - [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], + [("ddp_cpu", None, 2), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], ) def test_ddp_invalid_choice_string_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes): with pytest.raises(MisconfigurationException, match='not a supported lightning custom plugin'): @@ -166,7 +166,7 @@ def test_ddp_invalid_choice_string_ddp_cpu(tmpdir, ddp_backend, gpus, num_proces @mock.patch("torch.cuda.device_count", return_value=2) @pytest.mark.parametrize( ["ddp_backend", "gpus", "num_processes"], - [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], + [("ddp_cpu", None, 2), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], ) @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed sharded plugin is not supported on Windows") @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") @@ -202,7 +202,7 @@ class MyDDP(DDPPlugin): @mock.patch("torch.cuda.device_count", return_value=2) @pytest.mark.parametrize( ["ddp_backend", "gpus", "num_processes"], - [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], + [("ddp_cpu", None, 2), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], ) def test_ddp_choice_custom_ddp_cpu_custom_args( tmpdir, ddp_backend, gpus, num_processes diff --git a/tests/plugins/test_plugin.py b/tests/plugins/test_plugin.py index be9d95f09f03f..05789596879b4 100644 --- a/tests/plugins/test_plugin.py +++ b/tests/plugins/test_plugin.py @@ -38,7 +38,7 @@ @mock.patch("torch.cuda.device_count", return_value=2) @pytest.mark.parametrize( ["ddp_backend", "gpus", "num_processes"], - [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], + [("ddp_cpu", None, 2), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], ) def test_custom_required_plugins(tmpdir, ddp_backend, gpus, num_processes): """ @@ -92,7 +92,7 @@ def on_fit_start(self, trainer, pl_module): @mock.patch("torch.cuda.device_count", return_value=2) @pytest.mark.parametrize( ["ddp_backend", "gpus", "num_processes"], - [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], + [("ddp_cpu", None, 2), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], ) def test_invalid_custom_required_plugins(tmpdir, ddp_backend, gpus, num_processes): """ diff --git a/tests/plugins/test_rpc_plugin.py b/tests/plugins/test_rpc_plugin.py index 87d64a7b8c686..a28cd4b50e4f4 100644 --- a/tests/plugins/test_rpc_plugin.py +++ b/tests/plugins/test_rpc_plugin.py @@ -26,7 +26,7 @@ @mock.patch("torch.cuda.device_count", return_value=2) @pytest.mark.parametrize( ["ddp_backend", "gpus", "num_processes"], - [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], + [("ddp_cpu", None, 2), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], ) @pytest.mark.skipif(not _RPC_AVAILABLE, reason="RPC is not available") def test_rpc_choice(tmpdir, ddp_backend, gpus, num_processes): diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index c0761b7e03fcb..d8334e24e0e83 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -28,7 +28,7 @@ @mock.patch("torch.cuda.device_count", return_value=2) @pytest.mark.parametrize( ["ddp_backend", "gpus", "num_processes"], - [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], + [("ddp_cpu", None, 2), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], ) @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_choice_sharded(tmpdir, ddp_backend, gpus, num_processes): @@ -89,7 +89,7 @@ def test_invalid_apex_sharded(tmpdir): @mock.patch("torch.cuda.device_count", return_value=2) @pytest.mark.parametrize( ["ddp_backend", "gpus", "num_processes"], - [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], + [("ddp_cpu", None, 2), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], ) @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") @pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP") @@ -129,6 +129,7 @@ def test_ddp_sharded_plugin_checkpoint_cpu(tmpdir): model = BoringModel() trainer = Trainer( accelerator='ddp_cpu', + num_processes=2, plugins=[DDPShardedPlugin()], fast_dev_run=True, ) @@ -208,6 +209,7 @@ def test_ddp_sharded_plugin_resume_from_checkpoint(tmpdir): model = BoringModel() trainer = Trainer( accelerator='ddp_cpu', + num_processes=2, plugins=[DDPShardedPlugin()], fast_dev_run=True, ) @@ -221,6 +223,7 @@ def test_ddp_sharded_plugin_resume_from_checkpoint(tmpdir): trainer = Trainer( accelerator='ddp_cpu', + num_processes=2, plugins=[DDPShardedPlugin()], fast_dev_run=True, resume_from_checkpoint=checkpoint_path @@ -291,6 +294,7 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): trainer = Trainer( plugins=[DDPShardedPlugin()], accelerator='ddp_cpu', + num_processes=2, fast_dev_run=True, resume_from_checkpoint=checkpoint_path ) @@ -308,6 +312,7 @@ def test_ddp_sharded_plugin_test(tmpdir): model = BoringModel() trainer = Trainer( accelerator='ddp_cpu', + num_processes=2, plugins=[DDPShardedPlugin()], fast_dev_run=True, ) diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py index ca1301fb0dec6..16434f390b90a 100644 --- a/tests/trainer/properties/test_get_model.py +++ b/tests/trainer/properties/test_get_model.py @@ -61,7 +61,7 @@ def test_get_model_ddp_cpu(tmpdir): limit_val_batches=2, max_epochs=1, accelerator='ddp_cpu', - num_processes=2 + num_processes=2, ) trainer.fit(model)