diff --git a/CHANGELOG.md b/CHANGELOG.md index c77a3409ddb2b..d351fe2e97f47 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -267,6 +267,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated passing `progress_bar_refresh_rate` to the `Trainer` constructor in favor of adding the `ProgressBar` callback with `refresh_rate` directly to the list of callbacks ([#9616](https://github.com/PyTorchLightning/pytorch-lightning/pull/9616)) +- Deprecate `LightningDistributed` and move the broadcast logic to `DDPPlugin` and `DDPSpawnPlugin` directly ([#9691](https://github.com/PyTorchLightning/pytorch-lightning/pull/9691)) + + ### Removed - Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/)) @@ -395,6 +398,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `trainer.accumulate_grad_batches` to be an int on init. Default value for it is now `None` inside Trainer ([#9652](https://github.com/PyTorchLightning/pytorch-lightning/pull/9652)) +- Fixed `broadcast` in `DDPPlugin` and ``DDPSpawnPlugin` to respect the `src` input ([#9691](https://github.com/PyTorchLightning/pytorch-lightning/pull/9691)) + + ## [1.4.8] - 2021-09-22 - Fixed error reporting in DDP process reconciliation when processes are launched by an external agent ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389)) diff --git a/pytorch_lightning/distributed/dist.py b/pytorch_lightning/distributed/dist.py index d4e41f6e7cc4d..082e0c617a5f7 100644 --- a/pytorch_lightning/distributed/dist.py +++ b/pytorch_lightning/distributed/dist.py @@ -14,11 +14,22 @@ from typing import Any from pytorch_lightning.overrides.torch_distributed import broadcast_object_list +from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.distributed import group as _group class LightningDistributed: + """ + .. deprecated:: v1.5 + This class is deprecated in v1.5 and will be removed in v1.7. + The broadcast logic will be moved to the :class:`DDPPlugin` and :class`DDPSpawnPlugin` classes. + """ + def __init__(self, rank=None, device=None): + rank_zero_deprecation( + "LightningDistributed is deprecated in v1.5 and will be removed in v1.7." + "Broadcast logic is implemented directly in the :class:`TrainingTypePlugin` implementations." + ) self.rank = rank self.device = device diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index df0f658bf712a..a26b63151f5a8 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -31,9 +31,9 @@ import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.distributed import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward +from pytorch_lightning.overrides.torch_distributed import broadcast_object_list from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin @@ -48,13 +48,9 @@ rank_zero_deprecation, rank_zero_warn, ) -from pytorch_lightning.utilities.distributed import ( - distributed_available, - init_ddp_connection, - rank_zero_only, - ReduceOp, - sync_ddp_if_available, -) +from pytorch_lightning.utilities.distributed import distributed_available +from pytorch_lightning.utilities.distributed import group as _group +from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -116,7 +112,6 @@ def __init__( " Notice that it will be overriden by the trainer setting." ) self._sync_batchnorm = sync_batchnorm or False - self.dist = LightningDistributed() self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0 self._ddp_kwargs = kwargs self._task_idx = None @@ -269,10 +264,6 @@ def setup_distributed(self): # where to store ip_table init_ddp_connection(self.cluster_environment, self.torch_distributed_backend) - # set the ranks and devices - self.dist.rank = self.global_rank - self.dist.device = self.root_device - def _check_can_spawn_children(self): if self.local_rank != 0: raise RuntimeError( @@ -403,7 +394,11 @@ def barrier(self, *args, **kwargs) -> None: torch.distributed.barrier() def broadcast(self, obj: object, src: int = 0) -> object: - return self.dist.broadcast(obj) + obj = [obj] + if self.global_rank != src: + obj = [None] + broadcast_object_list(obj, src, group=_group.WORLD) + return obj[0] def pre_backward(self, closure_loss: torch.Tensor) -> None: """Run before precision plugin executes backward.""" diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 5f493001341d6..eb1acaec4100b 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -24,9 +24,9 @@ from torch.nn.parallel.distributed import DistributedDataParallel import pytorch_lightning as pl -from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward +from pytorch_lightning.overrides.torch_distributed import broadcast_object_list from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin @@ -40,13 +40,9 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load -from pytorch_lightning.utilities.distributed import ( - distributed_available, - init_ddp_connection, - rank_zero_only, - ReduceOp, - sync_ddp_if_available, -) +from pytorch_lightning.utilities.distributed import distributed_available +from pytorch_lightning.utilities.distributed import group as _group +from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -93,7 +89,6 @@ def __init__( ) self._sync_batchnorm = sync_batchnorm or False self._ddp_kwargs = kwargs - self.dist = LightningDistributed() self.num_processes = len(parallel_devices) if parallel_devices is not None else 0 self.mp_queue = None self._ddp_comm_state = ddp_comm_state @@ -193,10 +188,6 @@ def new_process(self, process_idx: int, trainer: "pl.Trainer", mp_queue: SimpleQ # ... need to double check that it is the correct place # self.trainer.call_setup_hook(self.model) - # set the ranks and devices - self.dist.rank = self.global_rank - self.dist.device = self.root_device - # move the model to the correct device self.model_to_device() @@ -324,7 +315,11 @@ def barrier(self, *args, **kwargs) -> None: def broadcast(self, obj: object, src: int = 0) -> object: if not distributed_available(): return obj - return self.dist.broadcast(obj) + obj = [obj] + if self.global_rank != src: + obj = [None] + broadcast_object_list(obj, src, group=_group.WORLD) + return obj[0] def model_to_device(self): if self.root_device.type == "cuda": diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index cb3b007b712ff..978152506d0e3 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -342,9 +342,6 @@ def setup_distributed(self): self._init_deepspeed_distributed() - # set the ranks and devices - self.dist.rank = self.global_rank - self.dist.device = self.root_device if not self._config_initialized: self._format_config() self._config_initialized = True diff --git a/setup.cfg b/setup.cfg index 86890f08e2c68..99f3a513b0914 100644 --- a/setup.cfg +++ b/setup.cfg @@ -46,6 +46,7 @@ omit = pytorch_lightning/cluster_environments/*.py pytorch_lightning/utilities/distributed.py pytorch_lightning/tuner/auto_gpu_select.py + pytorch_lightning/distributed/dist.py [flake8] diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 2fa8f96e77148..dbbba95a6de4c 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -243,3 +243,10 @@ def test_v1_7_0_lightning_logger_base_close(tmpdir): ): logger = LoggerCollection([logger]) logger.close() + + +def test_v1_7_0_deprecate_lightning_distributed(tmpdir): + with pytest.deprecated_call(match="LightningDistributed is deprecated in v1.5 and will be removed in v1.7."): + from pytorch_lightning.distributed.dist import LightningDistributed + + _ = LightningDistributed()