From ad77ff2f95875f8cdf0050f2805d7601b71a10ed Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 1 Mar 2021 23:00:29 +0530 Subject: [PATCH 01/43] feat: add smddp plugin & environment --- .../environments/smdist_environment.py | 54 ++++ .../plugins/training_type/smddp.py | 283 ++++++++++++++++++ pytorch_lightning/utilities/__init__.py | 1 + pytorch_lightning/utilities/enums.py | 1 + pytorch_lightning/utilities/imports.py | 1 + 5 files changed, 340 insertions(+) create mode 100644 pytorch_lightning/plugins/environments/smdist_environment.py create mode 100644 pytorch_lightning/plugins/training_type/smddp.py diff --git a/pytorch_lightning/plugins/environments/smdist_environment.py b/pytorch_lightning/plugins/environments/smdist_environment.py new file mode 100644 index 0000000000000..0698a7a16cfc9 --- /dev/null +++ b/pytorch_lightning/plugins/environments/smdist_environment.py @@ -0,0 +1,54 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from pytorch_lightning import _logger as log +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.utilities import _SMDIST_AVAILABLE, rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +if _SMDIST_AVAILABLE: + import smdistributed.dataparallel.torch.distributed as dist + + +class SMDistributedEnvironment(ClusterEnvironment): + + def __init__(self): + if not _SMDIST_AVAILABLE: + raise MisconfigurationException("`smdistributed` module is not available.") + super().__init__() + + def master_address(self): + master_address = os.environ['SM_CURRENT_HOST'] + log.debug(f"MASTER_ADDR: {master_address}") + return master_address + + def master_port(self): + if "MASTER_PORT" not in os.environ: + rank_zero_warn("MASTER_PORT environment variable is not defined. Set as 12910") + os.environ["MASTER_PORT"] = "12910" + log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") + + port = os.environ.get('MASTER_PORT') + return port + + def world_size(self): + return len(os.environ['SM_HOSTS']) + + def local_rank(self): + return int(dist.get_local_rank()) + + def node_rank(self) -> int: + return dist.get_rank() diff --git a/pytorch_lightning/plugins/training_type/smddp.py b/pytorch_lightning/plugins/training_type/smddp.py new file mode 100644 index 0000000000000..67cfcc7a90293 --- /dev/null +++ b/pytorch_lightning/plugins/training_type/smddp.py @@ -0,0 +1,283 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +import os +from typing import Any, Dict, List, Optional, Union + +import torch +from torch.optim import Optimizer + +from pytorch_lightning import _logger as log +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.overrides import LightningDistributedModule +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.overrides.distributed import prepare_for_backward +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin +from pytorch_lightning.utilities import _GROUP_AVAILABLE, _SMDIST_AVAILABLE +from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.seed import seed_everything + +WORLD = None +if _GROUP_AVAILABLE: + from torch.distributed import group + WORLD = group.WORLD + +if _SMDIST_AVAILABLE: + import smdistributed.dataparallel.torch.distributed as dist + from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel + + +class SMDDPPlugin(TrainingTypePlugin): + + distributed_backend = "smddp" + + def __init__( + self, + parallel_devices: Optional[List[torch.device]] = None, + cluster_environment: Optional[ClusterEnvironment] = None, + sync_batchnorm: bool = False, + **kwargs: Union[Any, Dict[str, Any]], + ): + if not _SMDIST_AVAILABLE: + raise MisconfigurationException("`smdistributed` module is not available.") + super().__init__() + parallel_device_ids = list(range(torch.cuda.device_count())) + self.parallel_devices = [torch.device("cuda", i) for i in parallel_device_ids] + self.sync_batchnorm = sync_batchnorm + self.dist = SMLightningDistributed() + self.num_nodes = len(os.environ['SM_HOSTS']) + self._ddp_kwargs = kwargs + self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else self.parallel_devices + + @property + def root_device(self): + return self.parallel_devices[self.local_rank] + + @property + def distributed_sampler_kwargs(self): + distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) + return distributed_sampler_kwargs + + def barrier(self, *args, **kwargs) -> None: + if dist.is_initialized(): + dist.barrier() + + def broadcast(self, obj: object, src: int = 0) -> object: + return self.dist.broadcast(obj) + + def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): + """Run before precision plugin executes backward""" + if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync: + prepare_for_backward(self.model, closure_loss) + + def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): + """ + Reduces a tensor from several distributed processes to one aggregated tensor. + As this plugin only operates with a single device, the reduction is simply the identity. + + Args: + tensor: the tensor to sync and reduce + *args: ignored + **kwargs: ignored + + Return: + the unmodified input as reduction is not needed for single process operation + """ + if isinstance(tensor, torch.Tensor): + tensor = self.sync_ddp_if_available(tensor, group, reduce_op=(reduce_op or "mean")) + return tensor + + @property + def lightning_module(self): + return self.unwrap_lightning_module() + + def setup(self, model): + self._model = model + + self.node_rank = self.cluster_environment.node_rank() + self.local_rank = self.cluster_environment.local_rank() + self.global_rank = self.node_rank * self.num_processes + self.local_rank + self.world_size = self.cluster_environment.world_size() + + rank_zero_only.rank = self.global_rank + self.model_to_device() + + def pre_dispatch(self): + # TODO: check if needed + seed = os.environ.get("PL_GLOBAL_SEED") + if seed is not None: + seed_everything(int(seed)) + + # set warning rank + rank_zero_only.rank = self.global_rank + + # set up server using proc 0's ip address + # try to init for 20 times at max in case ports are taken + # where to store ip_table + self.init_ddp_connection(self.global_rank, self.world_size) + + # TODO: we moved it to the trainer.fit after calling pre_dispatch + # ... need to double check that it is the correct place + # self.trainer.call_setup_hook(self.model) + + # on world_size=0 let everyone know training is starting + if self.is_global_zero and not dist.is_initialized(): + print("===" * 10, "Inside the loop") + log.info("-" * 100) + log.info(f"distributed_backend={self.distributed_backend}") + log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes") + log.info("-" * 100) + + # # set the ranks and devices + self.dist.rank = self.global_rank + self.dist.device = self.root_device + + if self.sync_batchnorm: + self.model = self.configure_sync_batchnorm(self.model) + + # move the model to the correct device + self.model_to_device() + + self.configure_ddp() + + self.barrier() + + def model_to_device(self): + if self.on_gpu: + torch.cuda.set_device(self.root_device) + self.model.to(self.root_device) + + def init_ddp_connection(self, global_rank: int, world_size: int) -> None: + + os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address()) + os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) + os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size()) + + if not dist.is_initialized(): + log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") + dist.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) + + def configure_ddp(self): + # self.pre_configure_ddp() + # print("=Device IDs=" * 5, self.determine_ddp_device_ids()) + print("=Local Device IDs=" * 5, dist.get_local_rank()) + self._model = DistributedDataParallel( + LightningDistributedModule(self.model), + device_ids=[dist.get_local_rank()], + # **self._ddp_kwargs, + ) + + def sync_ddp_if_available( + self, + result: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None + ) -> torch.Tensor: + """ + Function to reduce a tensor across worker processes during distributed training + Args: + result: the value to sync and reduce (typically tensor or number) + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to sum. + Can also be a string of 'avg', 'mean' to calculate the mean during reduction. + + Return: + reduced value + """ + if dist.is_available() and dist.is_initialized(): + return self.sync_ddp(result, group=group, reduce_op=reduce_op) + return result + + def sync_ddp( + self, + result: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None + ) -> torch.Tensor: + """ + Function to reduce the tensors from several ddp processes to one master process + + Args: + result: the value to sync and reduce (typically tensor or number) + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to sum. + Can also be a string of 'avg', 'mean' to calculate the mean during reduction. + + Return: + reduced value + """ + return result + + def training_step(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def validation_step(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def test_step(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def predict(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def post_training_step(self): + if not self.lightning_module.automatic_optimization: + self.model.require_backward_grad_sync = True + + def unwrap_lightning_module(self) -> LightningModule: + model = self._model + if isinstance(model, (DistributedDataParallel)): + model = model.module + if isinstance(model, _LightningModuleWrapperBase): + model = model.module + return model + + +class SMLightningDistributed: + + def __init__(self, rank=None, device=None): + self.rank = rank + self.device = device + + def broadcast(self, obj: Any, group=WORLD): + if self.rank == 0: + self._emit(obj, group) + else: + obj = self._receive(group) + return obj + + def _broadcast(self, tensor, src=0, group=WORLD): + if group is None: + return dist.broadcast(tensor, src=src) + return dist.broadcast(tensor, src=0, group=group) + + def _emit(self, obj: Any, group=WORLD): + buffer = io.BytesIO() + torch.save(obj, buffer) + data = bytearray(buffer.getbuffer()) + length_tensor = torch.tensor([len(data)]).long().to(self.device) + self._broadcast(length_tensor, src=0, group=group) + data_tensor = torch.ByteTensor(data).to(self.device) + self._broadcast(data_tensor, src=0, group=group) + + def _receive(self, group=WORLD): + length_tensor = torch.tensor([0]).long().to(self.device) + self._broadcast(length_tensor, src=0, group=group) + data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8).to(self.device) + self._broadcast(data_tensor, src=0, group=group) + buffer = io.BytesIO(data_tensor.cpu().numpy()) + obj = torch.load(buffer) + return obj diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 3e2ee3e51efe1..9fc1fe5036d94 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -37,6 +37,7 @@ _NATIVE_AMP_AVAILABLE, _OMEGACONF_AVAILABLE, _RPC_AVAILABLE, + _SMDIST_AVAILABLE, _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_7, _TORCH_LOWER_EQUAL_1_4, diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 169481fa63e67..4f9571544b1c0 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -76,6 +76,7 @@ def is_interactive_compatible(self) -> bool: HOROVOD = 'horovod' DDP_SHARDED = 'ddp_sharded' DDP_SHARDED_SPAWN = 'ddp_sharded_spawn' + SMDDP = 'smddp' RPC_SEQUENTIAL_PLUGIN = 'rpc_sequential' diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 41a13d6c678a0..66c6214764601 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -71,3 +71,4 @@ def _compare_version(package: str, op, version) -> bool: _TORCHTEXT_AVAILABLE = _module_available("torchtext") _TORCHVISION_AVAILABLE = _module_available('torchvision') _XLA_AVAILABLE = _module_available("torch_xla") +_SMDIST_AVAILABLE = _module_available("smdistributed") From 270e4dfdbfd78514e012186d3ecb465eb545ea36 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 1 Mar 2021 23:26:18 +0530 Subject: [PATCH 02/43] update accelerator connector --- pytorch_lightning/plugins/__init__.py | 2 ++ .../plugins/environments/__init__.py | 1 + .../plugins/training_type/__init__.py | 1 + .../trainer/connectors/accelerator_connector.py | 16 +++++++++++++++- 4 files changed, 19 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index dec672d025294..53bd76ee3f62c 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -18,6 +18,7 @@ from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin # noqa: F401 from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.smddp import SMDDPPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401 @@ -44,4 +45,5 @@ 'Plugin', 'DDPShardedPlugin', 'DDPSpawnShardedPlugin', + 'SMDDPPlugin', ] diff --git a/pytorch_lightning/plugins/environments/__init__.py b/pytorch_lightning/plugins/environments/__init__.py index 10d9bf50a4b84..715390ed24c3c 100644 --- a/pytorch_lightning/plugins/environments/__init__.py +++ b/pytorch_lightning/plugins/environments/__init__.py @@ -13,4 +13,5 @@ # limitations under the License. from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment # noqa: F401 from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment # noqa: F401 +from pytorch_lightning.plugins.environments.smdist_environment import SMDistributedEnvironment # noqa: F401 from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment # noqa: F401 diff --git a/pytorch_lightning/plugins/training_type/__init__.py b/pytorch_lightning/plugins/training_type/__init__.py index 30723d67da3f4..d5c809d186082 100644 --- a/pytorch_lightning/plugins/training_type/__init__.py +++ b/pytorch_lightning/plugins/training_type/__init__.py @@ -11,5 +11,6 @@ from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin # noqa: F401 from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.smddp import SMDDPPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401 diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 4f942f9b35e5d..9e491c0c407f9 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -38,11 +38,17 @@ ShardedNativeMixedPrecisionPlugin, SingleDevicePlugin, SingleTPUPlugin, + SMDDPPlugin, TPUHalfPrecisionPlugin, TPUSpawnPlugin, TrainingTypePlugin, ) -from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment +from pytorch_lightning.plugins.environments import ( + ClusterEnvironment, + SLURMEnvironment, + SMDistributedEnvironment, + TorchElasticEnvironment, +) from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus from pytorch_lightning.utilities import ( _APEX_AVAILABLE, @@ -299,6 +305,10 @@ def is_using_torchelastic(self) -> bool: te_flags_passed = "WORLD_SIZE" in os.environ and ("GROUP_RANK" in os.environ or "NODE_RANK" in os.environ) return te_flags_passed + @property + def use_smdistributed(self) -> bool: + return self.distributed_backend == DistributedType.SMDDP + def select_precision_plugin(self) -> PrecisionPlugin: # set precision type self.amp_type = AMPType.from_str(self.amp_type) @@ -396,6 +406,8 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: plugin = DataParallelPlugin(parallel_devices=self.parallel_devices) elif self.use_horovod: plugin = HorovodPlugin(parallel_devices=self.parallel_devices) + elif self.use_smdistributed: + plugin = SMDDPPlugin(cluster_environment=self.cluster_environment, sync_batchnorm=self.sync_batchnorm) elif self.on_tpu: if isinstance(self.tpu_cores, list): plugin = SingleTPUPlugin(self.tpu_id) @@ -457,6 +469,8 @@ def select_cluster_environment(self) -> ClusterEnvironment: # TODO: decouple DDP from TE # refactor and let generic cluster env hold the information about who spawns the processes os.environ["PL_IN_DDP_SUBPROCESS"] = "1" + elif self.use_smdistributed: + env = SMDistributedEnvironment() else: # TODO: maybe introduce a DefaultEnvironment? env = TorchElasticEnvironment() From 33c989167d28dbfddc035989369f376b10f72ed4 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 2 Mar 2021 00:02:11 +0530 Subject: [PATCH 03/43] update smddp plugin --- .../plugins/training_type/smddp.py | 60 +++++++++++-------- 1 file changed, 34 insertions(+), 26 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/smddp.py b/pytorch_lightning/plugins/training_type/smddp.py index 67cfcc7a90293..d33e54a0f856a 100644 --- a/pytorch_lightning/plugins/training_type/smddp.py +++ b/pytorch_lightning/plugins/training_type/smddp.py @@ -46,20 +46,19 @@ class SMDDPPlugin(TrainingTypePlugin): def __init__( self, - parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, sync_batchnorm: bool = False, **kwargs: Union[Any, Dict[str, Any]], ): if not _SMDIST_AVAILABLE: raise MisconfigurationException("`smdistributed` module is not available.") + super().__init__() parallel_device_ids = list(range(torch.cuda.device_count())) self.parallel_devices = [torch.device("cuda", i) for i in parallel_device_ids] self.sync_batchnorm = sync_batchnorm self.dist = SMLightningDistributed() self.num_nodes = len(os.environ['SM_HOSTS']) - self._ddp_kwargs = kwargs self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else self.parallel_devices @property @@ -71,6 +70,22 @@ def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) return distributed_sampler_kwargs + def training_step(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def validation_step(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def test_step(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def predict(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def post_training_step(self): + if not self.lightning_module.automatic_optimization: + self.model.require_backward_grad_sync = True + def barrier(self, *args, **kwargs) -> None: if dist.is_initialized(): dist.barrier() @@ -127,7 +142,7 @@ def pre_dispatch(self): # set up server using proc 0's ip address # try to init for 20 times at max in case ports are taken # where to store ip_table - self.init_ddp_connection(self.global_rank, self.world_size) + self.init_smddp_connection(self.global_rank, self.world_size) # TODO: we moved it to the trainer.fit after calling pre_dispatch # ... need to double check that it is the correct place @@ -151,7 +166,7 @@ def pre_dispatch(self): # move the model to the correct device self.model_to_device() - self.configure_ddp() + self.configure_smddp() self.barrier() @@ -160,24 +175,16 @@ def model_to_device(self): torch.cuda.set_device(self.root_device) self.model.to(self.root_device) - def init_ddp_connection(self, global_rank: int, world_size: int) -> None: - - os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address()) - os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size()) + def init_smdddp_connection(self, global_rank: int, world_size: int) -> None: if not dist.is_initialized(): log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") dist.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) - def configure_ddp(self): - # self.pre_configure_ddp() - # print("=Device IDs=" * 5, self.determine_ddp_device_ids()) - print("=Local Device IDs=" * 5, dist.get_local_rank()) + def configure_smddp(self): self._model = DistributedDataParallel( LightningDistributedModule(self.model), device_ids=[dist.get_local_rank()], - # **self._ddp_kwargs, ) def sync_ddp_if_available( @@ -219,23 +226,24 @@ def sync_ddp( Return: reduced value """ - return result + divide_by_world_size = False - def training_step(self, *args, **kwargs): - return self.model(*args, **kwargs) + if group is None: + group = dist.group.WORLD - def validation_step(self, *args, **kwargs): - return self.model(*args, **kwargs) + op = reduce_op if isinstance(reduce_op, ReduceOp) else ReduceOp.SUM - def test_step(self, *args, **kwargs): - return self.model(*args, **kwargs) + if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"): + divide_by_world_size = True - def predict(self, *args, **kwargs): - return self.model(*args, **kwargs) + # sync all processes before reduction + dist.barrier(group=group) + dist.all_reduce(result, op=op, group=group, async_op=False) - def post_training_step(self): - if not self.lightning_module.automatic_optimization: - self.model.require_backward_grad_sync = True + if divide_by_world_size: + result = result / dist.get_world_size(group) + + return result def unwrap_lightning_module(self) -> LightningModule: model = self._model From 0380efb419860dc32ea6856ffd9b720dd4569f46 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 2 Mar 2021 00:20:00 +0530 Subject: [PATCH 04/43] update smddp plugin --- .../plugins/environments/smdist_environment.py | 8 ++++---- pytorch_lightning/plugins/training_type/smddp.py | 8 +++++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/plugins/environments/smdist_environment.py b/pytorch_lightning/plugins/environments/smdist_environment.py index 0698a7a16cfc9..84c252f9a319d 100644 --- a/pytorch_lightning/plugins/environments/smdist_environment.py +++ b/pytorch_lightning/plugins/environments/smdist_environment.py @@ -44,11 +44,11 @@ def master_port(self): port = os.environ.get('MASTER_PORT') return port - def world_size(self): - return len(os.environ['SM_HOSTS']) + def world_size(self) -> int: + return dist.get_world_size() - def local_rank(self): - return int(dist.get_local_rank()) + def local_rank(self) -> int: + return dist.get_local_rank() def node_rank(self) -> int: return dist.get_rank() diff --git a/pytorch_lightning/plugins/training_type/smddp.py b/pytorch_lightning/plugins/training_type/smddp.py index d33e54a0f856a..2f576f4701820 100644 --- a/pytorch_lightning/plugins/training_type/smddp.py +++ b/pytorch_lightning/plugins/training_type/smddp.py @@ -24,7 +24,7 @@ from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment -from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin +from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import _GROUP_AVAILABLE, _SMDIST_AVAILABLE from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -40,7 +40,7 @@ from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel -class SMDDPPlugin(TrainingTypePlugin): +class SMDDPPlugin(ParallelPlugin): distributed_backend = "smddp" @@ -53,9 +53,11 @@ def __init__( if not _SMDIST_AVAILABLE: raise MisconfigurationException("`smdistributed` module is not available.") - super().__init__() parallel_device_ids = list(range(torch.cuda.device_count())) self.parallel_devices = [torch.device("cuda", i) for i in parallel_device_ids] + + super().__init__(parallel_devices=self.parallel_devices, cluster_environment=cluster_environment) + self.sync_batchnorm = sync_batchnorm self.dist = SMLightningDistributed() self.num_nodes = len(os.environ['SM_HOSTS']) From c115ecea2ccc731fb10b70081fe8471d5f6b0f89 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 2 Mar 2021 00:25:37 +0530 Subject: [PATCH 05/43] update lightning distributed --- .../plugins/training_type/smddp.py | 32 ++----------------- 1 file changed, 2 insertions(+), 30 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/smddp.py b/pytorch_lightning/plugins/training_type/smddp.py index 2f576f4701820..23d38e4c9ec74 100644 --- a/pytorch_lightning/plugins/training_type/smddp.py +++ b/pytorch_lightning/plugins/training_type/smddp.py @@ -20,6 +20,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.distributed import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.overrides.distributed import prepare_for_backward @@ -256,38 +257,9 @@ def unwrap_lightning_module(self) -> LightningModule: return model -class SMLightningDistributed: - - def __init__(self, rank=None, device=None): - self.rank = rank - self.device = device - - def broadcast(self, obj: Any, group=WORLD): - if self.rank == 0: - self._emit(obj, group) - else: - obj = self._receive(group) - return obj +class SMLightningDistributed(LightningDistributed): def _broadcast(self, tensor, src=0, group=WORLD): if group is None: return dist.broadcast(tensor, src=src) return dist.broadcast(tensor, src=0, group=group) - - def _emit(self, obj: Any, group=WORLD): - buffer = io.BytesIO() - torch.save(obj, buffer) - data = bytearray(buffer.getbuffer()) - length_tensor = torch.tensor([len(data)]).long().to(self.device) - self._broadcast(length_tensor, src=0, group=group) - data_tensor = torch.ByteTensor(data).to(self.device) - self._broadcast(data_tensor, src=0, group=group) - - def _receive(self, group=WORLD): - length_tensor = torch.tensor([0]).long().to(self.device) - self._broadcast(length_tensor, src=0, group=group) - data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8).to(self.device) - self._broadcast(data_tensor, src=0, group=group) - buffer = io.BytesIO(data_tensor.cpu().numpy()) - obj = torch.load(buffer) - return obj From 50c045d6a5a1c7cf5e65945725b46a9308ab7417 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 2 Mar 2021 00:30:12 +0530 Subject: [PATCH 06/43] fix typo --- pytorch_lightning/plugins/training_type/parallel.py | 2 +- pytorch_lightning/plugins/training_type/smddp.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index f3c825fe9cd7a..c01b19cf00217 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -120,7 +120,7 @@ def block_backward_sync(self): else: yield None - def broadcast(self, obj: object, src: int) -> object: + def broadcast(self, obj: object, src: int = 0) -> object: buffer = io.BytesIO() torch.save(obj, buffer) data = bytearray(buffer.getbuffer()) diff --git a/pytorch_lightning/plugins/training_type/smddp.py b/pytorch_lightning/plugins/training_type/smddp.py index 23d38e4c9ec74..bbde045607357 100644 --- a/pytorch_lightning/plugins/training_type/smddp.py +++ b/pytorch_lightning/plugins/training_type/smddp.py @@ -178,7 +178,7 @@ def model_to_device(self): torch.cuda.set_device(self.root_device) self.model.to(self.root_device) - def init_smdddp_connection(self, global_rank: int, world_size: int) -> None: + def init_smddp_connection(self, global_rank: int, world_size: int) -> None: if not dist.is_initialized(): log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") From eab4f582eced8b91e595de008ff38b436f9d27ca Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 2 Mar 2021 00:40:25 +0530 Subject: [PATCH 07/43] update smddp plugin --- pytorch_lightning/plugins/training_type/smddp.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/smddp.py b/pytorch_lightning/plugins/training_type/smddp.py index bbde045607357..251c2339e950b 100644 --- a/pytorch_lightning/plugins/training_type/smddp.py +++ b/pytorch_lightning/plugins/training_type/smddp.py @@ -26,16 +26,11 @@ from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin -from pytorch_lightning.utilities import _GROUP_AVAILABLE, _SMDIST_AVAILABLE +from pytorch_lightning.utilities import _SMDIST_AVAILABLE from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything -WORLD = None -if _GROUP_AVAILABLE: - from torch.distributed import group - WORLD = group.WORLD - if _SMDIST_AVAILABLE: import smdistributed.dataparallel.torch.distributed as dist from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel @@ -153,7 +148,6 @@ def pre_dispatch(self): # on world_size=0 let everyone know training is starting if self.is_global_zero and not dist.is_initialized(): - print("===" * 10, "Inside the loop") log.info("-" * 100) log.info(f"distributed_backend={self.distributed_backend}") log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes") @@ -259,7 +253,7 @@ def unwrap_lightning_module(self) -> LightningModule: class SMLightningDistributed(LightningDistributed): - def _broadcast(self, tensor, src=0, group=WORLD): + def _broadcast(self, tensor, src, group): if group is None: return dist.broadcast(tensor, src=src) return dist.broadcast(tensor, src=0, group=group) From f233b1ab9211b06c29f76ca782ebb3d47a7bfb8d Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 2 Mar 2021 00:55:55 +0530 Subject: [PATCH 08/43] add comment for parallel devices --- pytorch_lightning/plugins/training_type/smddp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/plugins/training_type/smddp.py b/pytorch_lightning/plugins/training_type/smddp.py index 251c2339e950b..f2f447c47b465 100644 --- a/pytorch_lightning/plugins/training_type/smddp.py +++ b/pytorch_lightning/plugins/training_type/smddp.py @@ -49,6 +49,7 @@ def __init__( if not _SMDIST_AVAILABLE: raise MisconfigurationException("`smdistributed` module is not available.") + # While running smdistributed, all the gpus in the instance are considered parallel_device_ids = list(range(torch.cuda.device_count())) self.parallel_devices = [torch.device("cuda", i) for i in parallel_device_ids] From 2e9280c3ef60451075de98afd8c8382095cd0fea Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 2 Mar 2021 01:18:26 +0530 Subject: [PATCH 09/43] debug --- pytorch_lightning/plugins/training_type/smddp.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/smddp.py b/pytorch_lightning/plugins/training_type/smddp.py index f2f447c47b465..ff7f42943db1a 100644 --- a/pytorch_lightning/plugins/training_type/smddp.py +++ b/pytorch_lightning/plugins/training_type/smddp.py @@ -58,6 +58,10 @@ def __init__( self.sync_batchnorm = sync_batchnorm self.dist = SMLightningDistributed() self.num_nodes = len(os.environ['SM_HOSTS']) + # Debugging + print("====" * 10) + print("Nodes", os.environ['SM_HOSTS']) + print("====" * 10) self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else self.parallel_devices @property From c91e9f2c4081ad5776154a7016ddc609d5491e4f Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 2 Mar 2021 01:24:35 +0530 Subject: [PATCH 10/43] debug --- pytorch_lightning/plugins/training_type/smddp.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/smddp.py b/pytorch_lightning/plugins/training_type/smddp.py index ff7f42943db1a..f2f447c47b465 100644 --- a/pytorch_lightning/plugins/training_type/smddp.py +++ b/pytorch_lightning/plugins/training_type/smddp.py @@ -58,10 +58,6 @@ def __init__( self.sync_batchnorm = sync_batchnorm self.dist = SMLightningDistributed() self.num_nodes = len(os.environ['SM_HOSTS']) - # Debugging - print("====" * 10) - print("Nodes", os.environ['SM_HOSTS']) - print("====" * 10) self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else self.parallel_devices @property From a13a6751b60b2305827585efc9090fee39592ed4 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 8 Mar 2021 13:28:25 +0530 Subject: [PATCH 11/43] ddp name consistency --- pytorch_lightning/plugins/training_type/smddp.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/smddp.py b/pytorch_lightning/plugins/training_type/smddp.py index f2f447c47b465..73d9a860961a6 100644 --- a/pytorch_lightning/plugins/training_type/smddp.py +++ b/pytorch_lightning/plugins/training_type/smddp.py @@ -141,7 +141,7 @@ def pre_dispatch(self): # set up server using proc 0's ip address # try to init for 20 times at max in case ports are taken # where to store ip_table - self.init_smddp_connection(self.global_rank, self.world_size) + self.init_ddp_connection(self.global_rank, self.world_size) # TODO: we moved it to the trainer.fit after calling pre_dispatch # ... need to double check that it is the correct place @@ -164,7 +164,7 @@ def pre_dispatch(self): # move the model to the correct device self.model_to_device() - self.configure_smddp() + self.configure_ddp() self.barrier() @@ -173,13 +173,13 @@ def model_to_device(self): torch.cuda.set_device(self.root_device) self.model.to(self.root_device) - def init_smddp_connection(self, global_rank: int, world_size: int) -> None: + def init_ddp_connection(self, global_rank: int, world_size: int) -> None: if not dist.is_initialized(): log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") dist.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) - def configure_smddp(self): + def configure_ddp(self): self._model = DistributedDataParallel( LightningDistributedModule(self.model), device_ids=[dist.get_local_rank()], From 88b2b4b844679906bc646329a0e1df1c242cc40b Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 8 Mar 2021 14:44:36 +0530 Subject: [PATCH 12/43] fix global rank --- pytorch_lightning/plugins/training_type/smddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/smddp.py b/pytorch_lightning/plugins/training_type/smddp.py index 73d9a860961a6..6b98741a47309 100644 --- a/pytorch_lightning/plugins/training_type/smddp.py +++ b/pytorch_lightning/plugins/training_type/smddp.py @@ -123,7 +123,7 @@ def setup(self, model): self.node_rank = self.cluster_environment.node_rank() self.local_rank = self.cluster_environment.local_rank() - self.global_rank = self.node_rank * self.num_processes + self.local_rank + self.global_rank = dist.get_rank() self.world_size = self.cluster_environment.world_size() rank_zero_only.rank = self.global_rank From ffc85ea7e2de14fc0bbaa83f967404165fa30ee6 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 8 Mar 2021 15:05:16 +0530 Subject: [PATCH 13/43] update smdist environment --- pytorch_lightning/plugins/environments/smdist_environment.py | 5 +++++ pytorch_lightning/plugins/training_type/smddp.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/environments/smdist_environment.py b/pytorch_lightning/plugins/environments/smdist_environment.py index 84c252f9a319d..b4f5469dcbb5a 100644 --- a/pytorch_lightning/plugins/environments/smdist_environment.py +++ b/pytorch_lightning/plugins/environments/smdist_environment.py @@ -51,4 +51,9 @@ def local_rank(self) -> int: return dist.get_local_rank() def node_rank(self) -> int: + hosts = os.environ['SM_HOSTS'] + current_host = os.environ['SM_CURRENT_HOST'] + return hosts.index(current_host) if current_host in hosts else 0 + + def global_rank(self) -> int: return dist.get_rank() diff --git a/pytorch_lightning/plugins/training_type/smddp.py b/pytorch_lightning/plugins/training_type/smddp.py index 6b98741a47309..9754684d975b0 100644 --- a/pytorch_lightning/plugins/training_type/smddp.py +++ b/pytorch_lightning/plugins/training_type/smddp.py @@ -123,7 +123,7 @@ def setup(self, model): self.node_rank = self.cluster_environment.node_rank() self.local_rank = self.cluster_environment.local_rank() - self.global_rank = dist.get_rank() + self.global_rank = self.cluster_environment.global_rank() self.world_size = self.cluster_environment.world_size() rank_zero_only.rank = self.global_rank From 833bb576ae0abd5bf892b661e4763e5626a7c036 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 8 Mar 2021 15:10:26 +0530 Subject: [PATCH 14/43] Update Type Annotation --- pytorch_lightning/plugins/training_type/smddp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/smddp.py b/pytorch_lightning/plugins/training_type/smddp.py index 9754684d975b0..0c21028181c55 100644 --- a/pytorch_lightning/plugins/training_type/smddp.py +++ b/pytorch_lightning/plugins/training_type/smddp.py @@ -24,7 +24,7 @@ from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.overrides.distributed import prepare_for_backward -from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.environments.smdist_environment import SMDistributedEnvironment from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import _SMDIST_AVAILABLE from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp @@ -42,7 +42,7 @@ class SMDDPPlugin(ParallelPlugin): def __init__( self, - cluster_environment: Optional[ClusterEnvironment] = None, + cluster_environment: Optional[SMDistributedEnvironment] = None, sync_batchnorm: bool = False, **kwargs: Union[Any, Dict[str, Any]], ): From 467e76be60fa0da16499262a3437348dd558a376 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 8 Mar 2021 16:54:11 +0530 Subject: [PATCH 15/43] DDP plugin as base class --- .../plugins/training_type/smddp.py | 133 +++++++----------- 1 file changed, 52 insertions(+), 81 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/smddp.py b/pytorch_lightning/plugins/training_type/smddp.py index 0c21028181c55..f9f45c92faffd 100644 --- a/pytorch_lightning/plugins/training_type/smddp.py +++ b/pytorch_lightning/plugins/training_type/smddp.py @@ -25,7 +25,7 @@ from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.smdist_environment import SMDistributedEnvironment -from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin +from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.utilities import _SMDIST_AVAILABLE from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -36,7 +36,7 @@ from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel -class SMDDPPlugin(ParallelPlugin): +class SMDDPPlugin(DDPPlugin): distributed_backend = "smddp" @@ -52,71 +52,17 @@ def __init__( # While running smdistributed, all the gpus in the instance are considered parallel_device_ids = list(range(torch.cuda.device_count())) self.parallel_devices = [torch.device("cuda", i) for i in parallel_device_ids] + num_nodes = len(os.environ['SM_HOSTS']) - super().__init__(parallel_devices=self.parallel_devices, cluster_environment=cluster_environment) + super().__init__( + parallel_devices=self.parallel_devices, + num_nodes=num_nodes, + cluster_environment=cluster_environment, + sync_batchnorm=sync_batchnorm + ) - self.sync_batchnorm = sync_batchnorm + self._ddp_kwargs = kwargs self.dist = SMLightningDistributed() - self.num_nodes = len(os.environ['SM_HOSTS']) - self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else self.parallel_devices - - @property - def root_device(self): - return self.parallel_devices[self.local_rank] - - @property - def distributed_sampler_kwargs(self): - distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) - return distributed_sampler_kwargs - - def training_step(self, *args, **kwargs): - return self.model(*args, **kwargs) - - def validation_step(self, *args, **kwargs): - return self.model(*args, **kwargs) - - def test_step(self, *args, **kwargs): - return self.model(*args, **kwargs) - - def predict(self, *args, **kwargs): - return self.model(*args, **kwargs) - - def post_training_step(self): - if not self.lightning_module.automatic_optimization: - self.model.require_backward_grad_sync = True - - def barrier(self, *args, **kwargs) -> None: - if dist.is_initialized(): - dist.barrier() - - def broadcast(self, obj: object, src: int = 0) -> object: - return self.dist.broadcast(obj) - - def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): - """Run before precision plugin executes backward""" - if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync: - prepare_for_backward(self.model, closure_loss) - - def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): - """ - Reduces a tensor from several distributed processes to one aggregated tensor. - As this plugin only operates with a single device, the reduction is simply the identity. - - Args: - tensor: the tensor to sync and reduce - *args: ignored - **kwargs: ignored - - Return: - the unmodified input as reduction is not needed for single process operation - """ - if isinstance(tensor, torch.Tensor): - tensor = self.sync_ddp_if_available(tensor, group, reduce_op=(reduce_op or "mean")) - return tensor - - @property - def lightning_module(self): - return self.unwrap_lightning_module() def setup(self, model): self._model = model @@ -129,6 +75,20 @@ def setup(self, model): rank_zero_only.rank = self.global_rank self.model_to_device() + def configure_ddp(self): + self.pre_configure_ddp() + self._model = DistributedDataParallel( + LightningDistributedModule(self.model), + device_ids=[dist.get_local_rank()], + **self._ddp_kwargs, + ) + + def init_ddp_connection(self, global_rank: int, world_size: int) -> None: + + if not dist.is_initialized(): + log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") + dist.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) + def pre_dispatch(self): # TODO: check if needed seed = os.environ.get("PL_GLOBAL_SEED") @@ -168,23 +128,6 @@ def pre_dispatch(self): self.barrier() - def model_to_device(self): - if self.on_gpu: - torch.cuda.set_device(self.root_device) - self.model.to(self.root_device) - - def init_ddp_connection(self, global_rank: int, world_size: int) -> None: - - if not dist.is_initialized(): - log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") - dist.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) - - def configure_ddp(self): - self._model = DistributedDataParallel( - LightningDistributedModule(self.model), - device_ids=[dist.get_local_rank()], - ) - def sync_ddp_if_available( self, result: Union[torch.Tensor], @@ -243,6 +186,10 @@ def sync_ddp( return result + @property + def lightning_module(self): + return self.unwrap_lightning_module() + def unwrap_lightning_module(self) -> LightningModule: model = self._model if isinstance(model, (DistributedDataParallel)): @@ -251,6 +198,30 @@ def unwrap_lightning_module(self) -> LightningModule: model = model.module return model + def barrier(self, *args, **kwargs) -> None: + if dist.is_initialized(): + dist.barrier() + + def broadcast(self, obj: object, src: int = 0) -> object: + return self.dist.broadcast(obj) + + def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): + """ + Reduces a tensor from several distributed processes to one aggregated tensor. + As this plugin only operates with a single device, the reduction is simply the identity. + + Args: + tensor: the tensor to sync and reduce + *args: ignored + **kwargs: ignored + + Return: + the unmodified input as reduction is not needed for single process operation + """ + if isinstance(tensor, torch.Tensor): + tensor = self.sync_ddp_if_available(tensor, group, reduce_op=(reduce_op or "mean")) + return tensor + class SMLightningDistributed(LightningDistributed): From c2a508ca3f403fb59bd56d7a6ef898e6fc64fdeb Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 8 Mar 2021 19:23:10 +0530 Subject: [PATCH 16/43] add test --- .../plugins/training_type/smddp.py | 2 +- tests/plugins/training_type/test_smddp.py | 23 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) create mode 100644 tests/plugins/training_type/test_smddp.py diff --git a/pytorch_lightning/plugins/training_type/smddp.py b/pytorch_lightning/plugins/training_type/smddp.py index f9f45c92faffd..9709989f03a6e 100644 --- a/pytorch_lightning/plugins/training_type/smddp.py +++ b/pytorch_lightning/plugins/training_type/smddp.py @@ -45,7 +45,7 @@ def __init__( cluster_environment: Optional[SMDistributedEnvironment] = None, sync_batchnorm: bool = False, **kwargs: Union[Any, Dict[str, Any]], - ): + ) -> None: if not _SMDIST_AVAILABLE: raise MisconfigurationException("`smdistributed` module is not available.") diff --git a/tests/plugins/training_type/test_smddp.py b/tests/plugins/training_type/test_smddp.py new file mode 100644 index 0000000000000..4eba9bcbada10 --- /dev/null +++ b/tests/plugins/training_type/test_smddp.py @@ -0,0 +1,23 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from pytorch_lightning.plugins import SMDDPPlugin +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +def test_smdistributed_not_available(): # type: ignore + + with pytest.raises(MisconfigurationException, match="`smdistributed` module is not available."): + SMDDPPlugin() From f5675e8ebbc4651476447da2979553306e03927d Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 10 Mar 2021 14:04:33 +0530 Subject: [PATCH 17/43] add creates_children mthod --- pytorch_lightning/plugins/environments/smdist_environment.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_lightning/plugins/environments/smdist_environment.py b/pytorch_lightning/plugins/environments/smdist_environment.py index b4f5469dcbb5a..ff7f5f6a1e747 100644 --- a/pytorch_lightning/plugins/environments/smdist_environment.py +++ b/pytorch_lightning/plugins/environments/smdist_environment.py @@ -30,6 +30,9 @@ def __init__(self): raise MisconfigurationException("`smdistributed` module is not available.") super().__init__() + def creates_children(self) -> bool: + return False + def master_address(self): master_address = os.environ['SM_CURRENT_HOST'] log.debug(f"MASTER_ADDR: {master_address}") From 0dff5c7cc74a9ed3028be3814d50109ac5c9d637 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 10 Mar 2021 14:50:02 +0530 Subject: [PATCH 18/43] change backend to mpi --- pytorch_lightning/plugins/training_type/smddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/smddp.py b/pytorch_lightning/plugins/training_type/smddp.py index 9709989f03a6e..b9456736b821d 100644 --- a/pytorch_lightning/plugins/training_type/smddp.py +++ b/pytorch_lightning/plugins/training_type/smddp.py @@ -87,7 +87,7 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None: if not dist.is_initialized(): log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") - dist.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) + dist.init_process_group("mpi", rank=global_rank, world_size=world_size) def pre_dispatch(self): # TODO: check if needed From eacf9a80afc4638378a12bc6660c78a3dee62687 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 10 Mar 2021 15:17:31 +0530 Subject: [PATCH 19/43] set broadcast buffers set to False --- pytorch_lightning/plugins/training_type/smddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/smddp.py b/pytorch_lightning/plugins/training_type/smddp.py index b9456736b821d..6ec184389c4b5 100644 --- a/pytorch_lightning/plugins/training_type/smddp.py +++ b/pytorch_lightning/plugins/training_type/smddp.py @@ -80,7 +80,7 @@ def configure_ddp(self): self._model = DistributedDataParallel( LightningDistributedModule(self.model), device_ids=[dist.get_local_rank()], - **self._ddp_kwargs, + broadcast_buffers=False, ) def init_ddp_connection(self, global_rank: int, world_size: int) -> None: From d1bf909a6789a2ba0d5ce62856dc872460ef4ece Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 10 Mar 2021 16:50:58 +0530 Subject: [PATCH 20/43] mini refactor --- .../trainer/connectors/accelerator_connector.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 920cb7e6e3a60..de232c9ede5d3 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -273,6 +273,10 @@ def use_horovod(self) -> bool: def use_deepspeed(self) -> bool: return self._distrib_type == DistributedType.DEEPSPEED + @property + def use_smdistributed(self) -> bool: + return self.distributed_backend == DistributedType.SMDDP + @property def is_distributed(self) -> bool: is_distributed = self.use_ddp or self.use_ddp2 or self.use_horovod @@ -308,10 +312,6 @@ def is_using_torchelastic(self) -> bool: te_flags_passed = "WORLD_SIZE" in os.environ and ("GROUP_RANK" in os.environ or "NODE_RANK" in os.environ) return te_flags_passed - @property - def use_smdistributed(self) -> bool: - return self.distributed_backend == DistributedType.SMDDP - def select_precision_plugin(self) -> PrecisionPlugin: # set precision type self.amp_type = AMPType.from_str(self.amp_type) From da66a198d436423da859641c105637f11d0bfb64 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 10 Mar 2021 16:56:20 +0530 Subject: [PATCH 21/43] address reviews --- pytorch_lightning/plugins/training_type/smddp.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/smddp.py b/pytorch_lightning/plugins/training_type/smddp.py index 6ec184389c4b5..3b7d7b4b6cb78 100644 --- a/pytorch_lightning/plugins/training_type/smddp.py +++ b/pytorch_lightning/plugins/training_type/smddp.py @@ -47,7 +47,11 @@ def __init__( **kwargs: Union[Any, Dict[str, Any]], ) -> None: if not _SMDIST_AVAILABLE: - raise MisconfigurationException("`smdistributed` module is not available.") + raise MisconfigurationException( + "`smdistributed` module is not available." + " You would need to enable distributed=smdistributed" + " in the Sagemaker Estimator Object." + ) # While running smdistributed, all the gpus in the instance are considered parallel_device_ids = list(range(torch.cuda.device_count())) @@ -126,7 +130,7 @@ def pre_dispatch(self): self.configure_ddp() - self.barrier() + self.barrier("configure ddp") def sync_ddp_if_available( self, From fdaeb5bd5b3f2eeb40e49182c65d85b0167ed2fb Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 10 Mar 2021 17:32:30 +0530 Subject: [PATCH 22/43] address reviews --- pytorch_lightning/distributed/dist.py | 13 ++++++++++++- .../plugins/environments/smdist_environment.py | 8 ++++---- pytorch_lightning/plugins/training_type/smddp.py | 15 ++------------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/distributed/dist.py b/pytorch_lightning/distributed/dist.py index 5da7dfa86084d..739b86e382066 100644 --- a/pytorch_lightning/distributed/dist.py +++ b/pytorch_lightning/distributed/dist.py @@ -17,13 +17,16 @@ import torch from torch import distributed as torch_distrib -from pytorch_lightning.utilities import _GROUP_AVAILABLE +from pytorch_lightning.utilities import _GROUP_AVAILABLE, _SMDIST_AVAILABLE WORLD = None if _GROUP_AVAILABLE: from torch.distributed import group WORLD = group.WORLD +if _SMDIST_AVAILABLE: + import smdistributed.dataparallel.torch.distributed as sm_dist + class LightningDistributed: @@ -60,3 +63,11 @@ def _receive(self, group=WORLD): buffer = io.BytesIO(data_tensor.cpu().numpy()) obj = torch.load(buffer) return obj + + +class SMLightningDistributed(LightningDistributed): + + def _broadcast(self, tensor, src, group): + if group is None: + return sm_dist.broadcast(tensor, src=src) + return sm_dist.broadcast(tensor, src=0, group=group) diff --git a/pytorch_lightning/plugins/environments/smdist_environment.py b/pytorch_lightning/plugins/environments/smdist_environment.py index ff7f5f6a1e747..eab8714dc5d89 100644 --- a/pytorch_lightning/plugins/environments/smdist_environment.py +++ b/pytorch_lightning/plugins/environments/smdist_environment.py @@ -34,7 +34,7 @@ def creates_children(self) -> bool: return False def master_address(self): - master_address = os.environ['SM_CURRENT_HOST'] + master_address = os.environ["SM_CURRENT_HOST"] log.debug(f"MASTER_ADDR: {master_address}") return master_address @@ -44,7 +44,7 @@ def master_port(self): os.environ["MASTER_PORT"] = "12910" log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") - port = os.environ.get('MASTER_PORT') + port = os.environ.get("MASTER_PORT") return port def world_size(self) -> int: @@ -54,8 +54,8 @@ def local_rank(self) -> int: return dist.get_local_rank() def node_rank(self) -> int: - hosts = os.environ['SM_HOSTS'] - current_host = os.environ['SM_CURRENT_HOST'] + hosts = os.environ["SM_HOSTS"] + current_host = os.environ["SM_CURRENT_HOST"] return hosts.index(current_host) if current_host in hosts else 0 def global_rank(self) -> int: diff --git a/pytorch_lightning/plugins/training_type/smddp.py b/pytorch_lightning/plugins/training_type/smddp.py index 3b7d7b4b6cb78..de20dbcb2690f 100644 --- a/pytorch_lightning/plugins/training_type/smddp.py +++ b/pytorch_lightning/plugins/training_type/smddp.py @@ -11,19 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import io import os -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional, Union import torch -from torch.optim import Optimizer from pytorch_lightning import _logger as log from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.distributed import LightningDistributed +from pytorch_lightning.distributed import SMLightningDistributed from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.base import _LightningModuleWrapperBase -from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.smdist_environment import SMDistributedEnvironment from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.utilities import _SMDIST_AVAILABLE @@ -225,11 +222,3 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ if isinstance(tensor, torch.Tensor): tensor = self.sync_ddp_if_available(tensor, group, reduce_op=(reduce_op or "mean")) return tensor - - -class SMLightningDistributed(LightningDistributed): - - def _broadcast(self, tensor, src, group): - if group is None: - return dist.broadcast(tensor, src=src) - return dist.broadcast(tensor, src=0, group=group) From 01b2d3792479717f6d72d0bd03d21b4b992775f5 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 10 Mar 2021 17:59:05 +0530 Subject: [PATCH 23/43] add missing init --- pytorch_lightning/distributed/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/distributed/__init__.py b/pytorch_lightning/distributed/__init__.py index ea060e551ad9d..141ab0720b835 100644 --- a/pytorch_lightning/distributed/__init__.py +++ b/pytorch_lightning/distributed/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.distributed.dist import LightningDistributed # noqa: F401 +from pytorch_lightning.distributed.dist import LightningDistributed, SMLightningDistributed # noqa: F401 From af070e3f355db7dd203e53e4c2210e2856728eb9 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 10 Mar 2021 20:36:30 +0530 Subject: [PATCH 24/43] change backend --- pytorch_lightning/plugins/training_type/smddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/smddp.py b/pytorch_lightning/plugins/training_type/smddp.py index de20dbcb2690f..727236078a423 100644 --- a/pytorch_lightning/plugins/training_type/smddp.py +++ b/pytorch_lightning/plugins/training_type/smddp.py @@ -88,7 +88,7 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None: if not dist.is_initialized(): log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") - dist.init_process_group("mpi", rank=global_rank, world_size=world_size) + dist.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) def pre_dispatch(self): # TODO: check if needed From b7e554868037e60f1345c45b53752f7a86bc9e96 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 11 Mar 2021 12:04:25 +0530 Subject: [PATCH 25/43] change all reduce --- pytorch_lightning/plugins/training_type/smddp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/smddp.py b/pytorch_lightning/plugins/training_type/smddp.py index 727236078a423..ddb581593dade 100644 --- a/pytorch_lightning/plugins/training_type/smddp.py +++ b/pytorch_lightning/plugins/training_type/smddp.py @@ -15,6 +15,7 @@ from typing import Any, Dict, Optional, Union import torch +import torch.distributed as torch_distrib from pytorch_lightning import _logger as log from pytorch_lightning.core.lightning import LightningModule @@ -180,7 +181,7 @@ def sync_ddp( # sync all processes before reduction dist.barrier(group=group) - dist.all_reduce(result, op=op, group=group, async_op=False) + torch_distrib.all_reduce(result, op=op, group=group, async_op=False) if divide_by_world_size: result = result / dist.get_world_size(group) From 1373f8fca5e646b05df9aea6f9b5929a1ec1c4d9 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 11 Mar 2021 12:40:07 +0530 Subject: [PATCH 26/43] add type hints --- pytorch_lightning/distributed/dist.py | 2 +- .../plugins/environments/smdist_environment.py | 6 +++--- pytorch_lightning/plugins/training_type/smddp.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/distributed/dist.py b/pytorch_lightning/distributed/dist.py index 739b86e382066..b2f3d99c41079 100644 --- a/pytorch_lightning/distributed/dist.py +++ b/pytorch_lightning/distributed/dist.py @@ -67,7 +67,7 @@ def _receive(self, group=WORLD): class SMLightningDistributed(LightningDistributed): - def _broadcast(self, tensor, src, group): + def _broadcast(self, tensor: torch.Tensor, src: int, group: Optional[Any] = None): if group is None: return sm_dist.broadcast(tensor, src=src) return sm_dist.broadcast(tensor, src=0, group=group) diff --git a/pytorch_lightning/plugins/environments/smdist_environment.py b/pytorch_lightning/plugins/environments/smdist_environment.py index eab8714dc5d89..78083cd64a128 100644 --- a/pytorch_lightning/plugins/environments/smdist_environment.py +++ b/pytorch_lightning/plugins/environments/smdist_environment.py @@ -33,18 +33,18 @@ def __init__(self): def creates_children(self) -> bool: return False - def master_address(self): + def master_address(self) -> str: master_address = os.environ["SM_CURRENT_HOST"] log.debug(f"MASTER_ADDR: {master_address}") return master_address - def master_port(self): + def master_port(self) -> str: if "MASTER_PORT" not in os.environ: rank_zero_warn("MASTER_PORT environment variable is not defined. Set as 12910") os.environ["MASTER_PORT"] = "12910" log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") - port = os.environ.get("MASTER_PORT") + port = os.environ["MASTER_PORT"] return port def world_size(self) -> int: diff --git a/pytorch_lightning/plugins/training_type/smddp.py b/pytorch_lightning/plugins/training_type/smddp.py index ddb581593dade..fbbb0f17139cb 100644 --- a/pytorch_lightning/plugins/training_type/smddp.py +++ b/pytorch_lightning/plugins/training_type/smddp.py @@ -205,7 +205,7 @@ def barrier(self, *args, **kwargs) -> None: dist.barrier() def broadcast(self, obj: object, src: int = 0) -> object: - return self.dist.broadcast(obj) + return self.dist.broadcast(obj, src) def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): """ From 98176087a7ffd82f6f6a8f0d579b6ec63fd5f204 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 11 Mar 2021 12:57:18 +0530 Subject: [PATCH 27/43] Add missing Import --- pytorch_lightning/distributed/dist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/distributed/dist.py b/pytorch_lightning/distributed/dist.py index b2f3d99c41079..13b95cba32507 100644 --- a/pytorch_lightning/distributed/dist.py +++ b/pytorch_lightning/distributed/dist.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import io -from typing import Any +from typing import Any, Optional import torch from torch import distributed as torch_distrib From 937b50cc5105306c7aae8dd07671db0802f4d0dc Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 11 Mar 2021 13:13:56 +0530 Subject: [PATCH 28/43] broadcast fix --- pytorch_lightning/plugins/training_type/smddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/smddp.py b/pytorch_lightning/plugins/training_type/smddp.py index fbbb0f17139cb..ddb581593dade 100644 --- a/pytorch_lightning/plugins/training_type/smddp.py +++ b/pytorch_lightning/plugins/training_type/smddp.py @@ -205,7 +205,7 @@ def barrier(self, *args, **kwargs) -> None: dist.barrier() def broadcast(self, obj: object, src: int = 0) -> object: - return self.dist.broadcast(obj, src) + return self.dist.broadcast(obj) def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): """ From c7c16bac427b56e0d04dc8c14f474806d6d121cc Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 11 Mar 2021 17:49:53 +0530 Subject: [PATCH 29/43] Change SMDDP to DDPSM --- pytorch_lightning/plugins/__init__.py | 4 ++-- pytorch_lightning/plugins/training_type/__init__.py | 2 +- .../plugins/training_type/{smddp.py => ddp_sm.py} | 7 +++---- .../trainer/connectors/accelerator_connector.py | 6 +++--- pytorch_lightning/utilities/enums.py | 2 +- tests/plugins/training_type/test_smddp.py | 4 ++-- 6 files changed, 12 insertions(+), 13 deletions(-) rename pytorch_lightning/plugins/training_type/{smddp.py => ddp_sm.py} (97%) diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index 53bd76ee3f62c..9aad230e3a4fa 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -7,6 +7,7 @@ from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.ddp import DDPPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin # noqa: F401 +from pytorch_lightning.plugins.training_type.ddp_sm import DDPSMPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401 @@ -18,7 +19,6 @@ from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin # noqa: F401 from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin # noqa: F401 -from pytorch_lightning.plugins.training_type.smddp import SMDDPPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401 @@ -45,5 +45,5 @@ 'Plugin', 'DDPShardedPlugin', 'DDPSpawnShardedPlugin', - 'SMDDPPlugin', + 'DDPSMPlugin', ] diff --git a/pytorch_lightning/plugins/training_type/__init__.py b/pytorch_lightning/plugins/training_type/__init__.py index d5c809d186082..1c586df83fdaf 100644 --- a/pytorch_lightning/plugins/training_type/__init__.py +++ b/pytorch_lightning/plugins/training_type/__init__.py @@ -1,5 +1,6 @@ from pytorch_lightning.plugins.training_type.ddp import DDPPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin # noqa: F401 +from pytorch_lightning.plugins.training_type.ddp_sm import DDPSMPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401 @@ -11,6 +12,5 @@ from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin # noqa: F401 from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin # noqa: F401 -from pytorch_lightning.plugins.training_type.smddp import SMDDPPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/training_type/smddp.py b/pytorch_lightning/plugins/training_type/ddp_sm.py similarity index 97% rename from pytorch_lightning/plugins/training_type/smddp.py rename to pytorch_lightning/plugins/training_type/ddp_sm.py index ddb581593dade..e3cca23811989 100644 --- a/pytorch_lightning/plugins/training_type/smddp.py +++ b/pytorch_lightning/plugins/training_type/ddp_sm.py @@ -15,7 +15,6 @@ from typing import Any, Dict, Optional, Union import torch -import torch.distributed as torch_distrib from pytorch_lightning import _logger as log from pytorch_lightning.core.lightning import LightningModule @@ -34,9 +33,9 @@ from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel -class SMDDPPlugin(DDPPlugin): +class DDPSMPlugin(DDPPlugin): - distributed_backend = "smddp" + distributed_backend = "ddp_sm" def __init__( self, @@ -181,7 +180,7 @@ def sync_ddp( # sync all processes before reduction dist.barrier(group=group) - torch_distrib.all_reduce(result, op=op, group=group, async_op=False) + dist.all_reduce(result, op=op, group=group, async_op=False) if divide_by_world_size: result = result / dist.get_world_size(group) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index de232c9ede5d3..6844921826759 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -28,6 +28,7 @@ DDP2Plugin, DDPPlugin, DDPShardedPlugin, + DDPSMPlugin, DDPSpawnPlugin, DDPSpawnShardedPlugin, DeepSpeedPlugin, @@ -38,7 +39,6 @@ ShardedNativeMixedPrecisionPlugin, SingleDevicePlugin, SingleTPUPlugin, - SMDDPPlugin, TPUHalfPrecisionPlugin, TPUSpawnPlugin, TrainingTypePlugin, @@ -275,7 +275,7 @@ def use_deepspeed(self) -> bool: @property def use_smdistributed(self) -> bool: - return self.distributed_backend == DistributedType.SMDDP + return self.distributed_backend == DistributedType.DDP_SM @property def is_distributed(self) -> bool: @@ -410,7 +410,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: elif self.use_horovod: plugin = HorovodPlugin(parallel_devices=self.parallel_devices) elif self.use_smdistributed: - plugin = SMDDPPlugin(cluster_environment=self.cluster_environment, sync_batchnorm=self.sync_batchnorm) + plugin = DDPSMPlugin(cluster_environment=self.cluster_environment, sync_batchnorm=self.sync_batchnorm) elif self.on_tpu: if isinstance(self.tpu_cores, list): plugin = SingleTPUPlugin(self.tpu_id) diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 4f9571544b1c0..d425d566e4a3a 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -76,7 +76,7 @@ def is_interactive_compatible(self) -> bool: HOROVOD = 'horovod' DDP_SHARDED = 'ddp_sharded' DDP_SHARDED_SPAWN = 'ddp_sharded_spawn' - SMDDP = 'smddp' + DDP_SM = 'ddp_sm' RPC_SEQUENTIAL_PLUGIN = 'rpc_sequential' diff --git a/tests/plugins/training_type/test_smddp.py b/tests/plugins/training_type/test_smddp.py index 4eba9bcbada10..a5e427cbbba03 100644 --- a/tests/plugins/training_type/test_smddp.py +++ b/tests/plugins/training_type/test_smddp.py @@ -13,11 +13,11 @@ # limitations under the License. import pytest -from pytorch_lightning.plugins import SMDDPPlugin +from pytorch_lightning.plugins import DDPSMPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException def test_smdistributed_not_available(): # type: ignore with pytest.raises(MisconfigurationException, match="`smdistributed` module is not available."): - SMDDPPlugin() + DDPSMPlugin() From 281231ea6af34d6ab491182dccd45b654beb5f6e Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 16 Mar 2021 20:14:56 +0530 Subject: [PATCH 30/43] Update num gpus --- pytorch_lightning/plugins/training_type/ddp_sm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_sm.py b/pytorch_lightning/plugins/training_type/ddp_sm.py index e3cca23811989..c6c41a9ede5f1 100644 --- a/pytorch_lightning/plugins/training_type/ddp_sm.py +++ b/pytorch_lightning/plugins/training_type/ddp_sm.py @@ -51,7 +51,7 @@ def __init__( ) # While running smdistributed, all the gpus in the instance are considered - parallel_device_ids = list(range(torch.cuda.device_count())) + parallel_device_ids = list(range(int(os.environ["SM_NUM_GPUS"]))) self.parallel_devices = [torch.device("cuda", i) for i in parallel_device_ids] num_nodes = len(os.environ['SM_HOSTS']) From 6c2f229afa7ab98963c29afe18c1066b25123224 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Tue, 16 Mar 2021 23:24:47 +0530 Subject: [PATCH 31/43] fix --- pytorch_lightning/plugins/training_type/ddp_sm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_sm.py b/pytorch_lightning/plugins/training_type/ddp_sm.py index c6c41a9ede5f1..d965e6c89fa3d 100644 --- a/pytorch_lightning/plugins/training_type/ddp_sm.py +++ b/pytorch_lightning/plugins/training_type/ddp_sm.py @@ -51,9 +51,9 @@ def __init__( ) # While running smdistributed, all the gpus in the instance are considered - parallel_device_ids = list(range(int(os.environ["SM_NUM_GPUS"]))) + parallel_device_ids = list(range(torch.cuda.device_count())) self.parallel_devices = [torch.device("cuda", i) for i in parallel_device_ids] - num_nodes = len(os.environ['SM_HOSTS']) + num_nodes = len(os.environ['SM_HOSTS'].split(",")) super().__init__( parallel_devices=self.parallel_devices, From d514a6f8e8bced66457cce8c2adbcffb34a868c9 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Wed, 17 Mar 2021 01:47:35 +0530 Subject: [PATCH 32/43] Update accelerator_connector.py --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 6844921826759..e58f358ba6b71 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -279,7 +279,7 @@ def use_smdistributed(self) -> bool: @property def is_distributed(self) -> bool: - is_distributed = self.use_ddp or self.use_ddp2 or self.use_horovod + is_distributed = self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_smdistributed if self.on_tpu: is_distributed |= self.training_type_plugin.is_distributed return is_distributed From 1c4a315f5fb27efe8f320f7d6d17ac5cf183067d Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 31 Mar 2021 16:41:32 +0530 Subject: [PATCH 33/43] Update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 14a3410a96bf3..8e710afd973e6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -79,6 +79,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `artifact_location` argument to `MLFlowLogger` which will be passed to the `MlflowClient.create_experiment` call ([#6677](https://github.com/PyTorchLightning/pytorch-lightning/pull/6677)) +- Added Sagemaker DDP Plugin ([#6271](https://github.com/PyTorchLightning/pytorch-lightning/pull/6271)) + + ### Changed - Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259)) From 13dac0b9df57b03fd05f58e6ef8dde66983809c6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 3 Jul 2021 19:14:18 +0000 Subject: [PATCH 34/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 2 +- pytorch_lightning/utilities/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index ca3df349cf08b..dd2daad7bbcb5 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -306,7 +306,7 @@ def use_deepspeed(self) -> bool: @property def use_smdistributed(self) -> bool: return self.distributed_backend == DistributedType.DDP_SM - + @property def _is_sharded_training_type(self) -> bool: return isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 18dbc65d00e30..1be8a091634da 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -41,8 +41,8 @@ _module_available, _NATIVE_AMP_AVAILABLE, _OMEGACONF_AVAILABLE, - _SMDIST_AVAILABLE, _POPTORCH_AVAILABLE, + _SMDIST_AVAILABLE, _TORCH_GREATER_EQUAL_1_5, _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_7, From e38fb37635babff8ed8378f3fbed244315a5b9b9 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Thu, 8 Jul 2021 14:14:58 +0530 Subject: [PATCH 35/43] Fix circular import --- pytorch_lightning/plugins/training_type/ddp_sm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_sm.py b/pytorch_lightning/plugins/training_type/ddp_sm.py index d965e6c89fa3d..36131b29701bb 100644 --- a/pytorch_lightning/plugins/training_type/ddp_sm.py +++ b/pytorch_lightning/plugins/training_type/ddp_sm.py @@ -17,7 +17,6 @@ import torch from pytorch_lightning import _logger as log -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.distributed import SMLightningDistributed from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.base import _LightningModuleWrapperBase @@ -191,7 +190,7 @@ def sync_ddp( def lightning_module(self): return self.unwrap_lightning_module() - def unwrap_lightning_module(self) -> LightningModule: + def unwrap_lightning_module(self): model = self._model if isinstance(model, (DistributedDataParallel)): model = model.module From 77977704940faeffdb09ffbb56358d8558f2ce02 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Thu, 8 Jul 2021 18:40:45 +0530 Subject: [PATCH 36/43] Update environment --- .../plugins/environments/smdist_environment.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pytorch_lightning/plugins/environments/smdist_environment.py b/pytorch_lightning/plugins/environments/smdist_environment.py index 78083cd64a128..ab13051609f36 100644 --- a/pytorch_lightning/plugins/environments/smdist_environment.py +++ b/pytorch_lightning/plugins/environments/smdist_environment.py @@ -50,6 +50,9 @@ def master_port(self) -> str: def world_size(self) -> int: return dist.get_world_size() + def set_world_size(self, size: int) -> None: + log.debug("SMDistributedEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") + def local_rank(self) -> int: return dist.get_local_rank() @@ -60,3 +63,8 @@ def node_rank(self) -> int: def global_rank(self) -> int: return dist.get_rank() + + def set_global_rank(self, rank: int) -> None: + log.debug( + "SMDistributedEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored." + ) From 4a0f78de19caf865b3e01b1a02ec0660a724fb83 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Thu, 8 Jul 2021 23:40:45 +0530 Subject: [PATCH 37/43] Add set_world_ranks --- pytorch_lightning/plugins/training_type/ddp_sm.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/ddp_sm.py b/pytorch_lightning/plugins/training_type/ddp_sm.py index 36131b29701bb..3a59518d54aaf 100644 --- a/pytorch_lightning/plugins/training_type/ddp_sm.py +++ b/pytorch_lightning/plugins/training_type/ddp_sm.py @@ -75,6 +75,11 @@ def setup(self, model): rank_zero_only.rank = self.global_rank self.model_to_device() + def set_world_ranks(self) -> None: + if self.cluster_environment is None: + return + rank_zero_only.rank = self.cluster_environment.global_rank() + def configure_ddp(self): self.pre_configure_ddp() self._model = DistributedDataParallel( From 574887fe870dc4260b7b1defd7fe892364ae01be Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Tue, 13 Jul 2021 11:29:40 +0530 Subject: [PATCH 38/43] Add updates --- .../plugins/training_type/ddp_sm.py | 65 ++++++++----------- 1 file changed, 28 insertions(+), 37 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_sm.py b/pytorch_lightning/plugins/training_type/ddp_sm.py index 3a59518d54aaf..4198a2a39b758 100644 --- a/pytorch_lightning/plugins/training_type/ddp_sm.py +++ b/pytorch_lightning/plugins/training_type/ddp_sm.py @@ -23,9 +23,9 @@ from pytorch_lightning.plugins.environments.smdist_environment import SMDistributedEnvironment from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.utilities import _SMDIST_AVAILABLE -from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp +from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.utilities.seed import reset_seed if _SMDIST_AVAILABLE: import smdistributed.dataparallel.torch.distributed as dist @@ -64,21 +64,32 @@ def __init__( self._ddp_kwargs = kwargs self.dist = SMLightningDistributed() - def setup(self, model): - self._model = model + def setup_environment(self) -> None: + self.setup_distributed() - self.node_rank = self.cluster_environment.node_rank() - self.local_rank = self.cluster_environment.local_rank() - self.global_rank = self.cluster_environment.global_rank() - self.world_size = self.cluster_environment.world_size() + def setup_distributed(self) -> None: + reset_seed() + # determine which process we are and world size + self.set_world_ranks() + + # set warning rank rank_zero_only.rank = self.global_rank - self.model_to_device() + + self.init_ddp_connection(self.global_rank, self.world_size) + + # # set the ranks and devices + self.dist.rank = self.global_rank + self.dist.device = self.root_device def set_world_ranks(self) -> None: if self.cluster_environment is None: return - rank_zero_only.rank = self.cluster_environment.global_rank() + + self.node_rank = self.cluster_environment.node_rank() + self.local_rank = self.cluster_environment.local_rank() + self.global_rank = self.cluster_environment.global_rank() + self.world_size = self.cluster_environment.world_size() def configure_ddp(self): self.pre_configure_ddp() @@ -94,34 +105,14 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None: log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") dist.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) - def pre_dispatch(self): - # TODO: check if needed - seed = os.environ.get("PL_GLOBAL_SEED") - if seed is not None: - seed_everything(int(seed)) - - # set warning rank - rank_zero_only.rank = self.global_rank - - # set up server using proc 0's ip address - # try to init for 20 times at max in case ports are taken - # where to store ip_table - self.init_ddp_connection(self.global_rank, self.world_size) - - # TODO: we moved it to the trainer.fit after calling pre_dispatch - # ... need to double check that it is the correct place - # self.trainer.call_setup_hook(self.model) - - # on world_size=0 let everyone know training is starting - if self.is_global_zero and not dist.is_initialized(): - log.info("-" * 100) - log.info(f"distributed_backend={self.distributed_backend}") - log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes") - log.info("-" * 100) + rank_zero_info( + f"{'-' * 100}\n" + f"distributed_backend={self.torch_distributed_backend}\n" + f"All DDP processes registered. Starting smddp with {world_size} processes\n" + f"{'-' * 100}\n" + ) - # # set the ranks and devices - self.dist.rank = self.global_rank - self.dist.device = self.root_device + def pre_dispatch(self): if self.sync_batchnorm: self.model = self.configure_sync_batchnorm(self.model) From 19fa54216cb624222694e017a2e1c3ea8abe2997 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Tue, 13 Jul 2021 12:42:44 +0530 Subject: [PATCH 39/43] Add updates --- pytorch_lightning/plugins/training_type/ddp_sm.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_sm.py b/pytorch_lightning/plugins/training_type/ddp_sm.py index 4198a2a39b758..f23540d85875c 100644 --- a/pytorch_lightning/plugins/training_type/ddp_sm.py +++ b/pytorch_lightning/plugins/training_type/ddp_sm.py @@ -73,9 +73,6 @@ def setup_distributed(self) -> None: # determine which process we are and world size self.set_world_ranks() - # set warning rank - rank_zero_only.rank = self.global_rank - self.init_ddp_connection(self.global_rank, self.world_size) # # set the ranks and devices @@ -85,11 +82,8 @@ def setup_distributed(self) -> None: def set_world_ranks(self) -> None: if self.cluster_environment is None: return - - self.node_rank = self.cluster_environment.node_rank() - self.local_rank = self.cluster_environment.local_rank() - self.global_rank = self.cluster_environment.global_rank() - self.world_size = self.cluster_environment.world_size() + # set warning rank + rank_zero_only.rank = self.global_rank def configure_ddp(self): self.pre_configure_ddp() From 5699a32de96911a9d6ffcd5848c329f060688649 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Tue, 13 Jul 2021 14:47:45 +0530 Subject: [PATCH 40/43] Fix broadcasting --- pytorch_lightning/distributed/dist.py | 123 +++++++++++++++++++------- 1 file changed, 91 insertions(+), 32 deletions(-) diff --git a/pytorch_lightning/distributed/dist.py b/pytorch_lightning/distributed/dist.py index 0f6980673f7ba..9531517c24a3f 100644 --- a/pytorch_lightning/distributed/dist.py +++ b/pytorch_lightning/distributed/dist.py @@ -11,12 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import io -from typing import Any, Optional +from typing import Any import torch - -from pytorch_lightning.overrides.torch_distributed import broadcast_object_list +from torch.distributed import Backend, get_backend + +from pytorch_lightning.overrides.torch_distributed import ( + _object_to_tensor, + _rank_not_in_group, + _tensor_to_object, + broadcast_object_list, +) from pytorch_lightning.utilities import _SMDIST_AVAILABLE from pytorch_lightning.utilities.distributed import group as _group @@ -45,31 +50,85 @@ def broadcast(self, obj: Any, group=_group.WORLD): class SMLightningDistributed(LightningDistributed): def broadcast(self, obj: Any, group=_group.WORLD): - if self.rank == 0: - self._emit(obj, group) - else: - obj = self._receive(group) - return obj - - def _broadcast(self, tensor: torch.Tensor, src: int, group: Optional[Any] = None): - if group is None: - return sm_dist.broadcast(tensor, src=src) - return sm_dist.broadcast(tensor, src=0, group=group) - - def _emit(self, obj: Any, group=_group.WORLD): - buffer = io.BytesIO() - torch.save(obj, buffer) - data = bytearray(buffer.getbuffer()) - length_tensor = torch.tensor([len(data)]).long().to(self.device) - self._broadcast(length_tensor, src=0, group=group) - data_tensor = torch.ByteTensor(data).to(self.device) - self._broadcast(data_tensor, src=0, group=group) - - def _receive(self, group=_group.WORLD): - length_tensor = torch.tensor([0]).long().to(self.device) - self._broadcast(length_tensor, src=0, group=group) - data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8).to(self.device) - self._broadcast(data_tensor, src=0, group=group) - buffer = io.BytesIO(data_tensor.cpu().numpy()) - obj = torch.load(buffer) - return obj + # always wrap into a list so list can be brodcasted. + obj = [obj] + + obj = [obj] + + if self.rank != 0: + obj = [None] * len(obj) + + _broadcast_object_list(obj, self.rank, 0, group=group or _group.WORLD) + + return obj[0] + + # def _broadcast(self, tensor: torch.Tensor, src: int, group: Optional[Any] = None): + # if group is None: + # return sm_dist.broadcast(tensor, src=src) + # return sm_dist.broadcast(tensor, src=0, group=group) + + # def _emit(self, obj: Any, group=_group.WORLD): + # buffer = io.BytesIO() + # torch.save(obj, buffer) + # data = bytearray(buffer.getbuffer()) + # length_tensor = torch.tensor([len(data)]).long().to(self.device) + # self._broadcast(length_tensor, src=0, group=group) + # data_tensor = torch.ByteTensor(data).to(self.device) + # self._broadcast(data_tensor, src=0, group=group) + + # def _receive(self, group=_group.WORLD): + # length_tensor = torch.tensor([0]).long().to(self.device) + # self._broadcast(length_tensor, src=0, group=group) + # data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8).to(self.device) + # self._broadcast(data_tensor, src=0, group=group) + # buffer = io.BytesIO(data_tensor.cpu().numpy()) + # obj = torch.load(buffer) + # return obj + + +# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1327 +def _broadcast_object_list(object_list, rank, src=0, group=None): + if _rank_not_in_group(group): + return + + my_rank = rank + # Serialize object_list elements to tensors on src rank. + if my_rank == src: + tensor_list, size_list = zip(*[_object_to_tensor(obj) for obj in object_list]) + object_sizes_tensor = torch.cat(size_list) + else: + object_sizes_tensor = torch.LongTensor(len(object_list)) + + group_backend = get_backend(group) + is_nccl_backend = group_backend == Backend.NCCL + current_device = torch.device("cpu") + if is_nccl_backend: + # See note about using torch.cuda.current_device() here in docstring. + # We cannot simply use my_rank since rank == device is not necessarily + # true. + current_device = torch.device('cuda', torch.cuda.current_device()) + object_sizes_tensor = object_sizes_tensor.to(current_device) + object_sizes_tensor = object_sizes_tensor.to(current_device) + + # Broadcast object sizes + sm_dist.broadcast(object_sizes_tensor, src=src, group=group) + + # Concatenate and broadcast serialized object tensors + if my_rank == src: + object_tensor = torch.cat(tensor_list) + else: + object_tensor = torch.ByteTensor(torch.sum(object_sizes_tensor).item()) + + if is_nccl_backend: + object_tensor = object_tensor.to(current_device) + + sm_dist.broadcast(object_tensor, src=src, group=group) + + # Deserialize objects using their stored sizes. + offset = 0 + if my_rank != src: + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset:offset + obj_size] + obj_view = obj_view.type(torch.ByteTensor) # type: ignore[call-overload] + offset += obj_size + object_list[i] = _tensor_to_object(obj_view, obj_size) From 7a715a0a2f8ad7c5c8d5c768db370796f94a09db Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Tue, 13 Jul 2021 15:31:23 +0530 Subject: [PATCH 41/43] Fix broadcasting --- pytorch_lightning/distributed/dist.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/distributed/dist.py b/pytorch_lightning/distributed/dist.py index 9531517c24a3f..ff1b13c66ebeb 100644 --- a/pytorch_lightning/distributed/dist.py +++ b/pytorch_lightning/distributed/dist.py @@ -14,7 +14,6 @@ from typing import Any import torch -from torch.distributed import Backend, get_backend from pytorch_lightning.overrides.torch_distributed import ( _object_to_tensor, @@ -99,16 +98,16 @@ def _broadcast_object_list(object_list, rank, src=0, group=None): else: object_sizes_tensor = torch.LongTensor(len(object_list)) - group_backend = get_backend(group) - is_nccl_backend = group_backend == Backend.NCCL - current_device = torch.device("cpu") - if is_nccl_backend: - # See note about using torch.cuda.current_device() here in docstring. - # We cannot simply use my_rank since rank == device is not necessarily - # true. - current_device = torch.device('cuda', torch.cuda.current_device()) - object_sizes_tensor = object_sizes_tensor.to(current_device) - object_sizes_tensor = object_sizes_tensor.to(current_device) + # group_backend = get_backend(group) + # is_nccl_backend = group_backend == Backend.NCCL + # current_device = torch.device("cpu") + # if is_nccl_backend: + # # See note about using torch.cuda.current_device() here in docstring. + # # We cannot simply use my_rank since rank == device is not necessarily + # # true. + # current_device = torch.device('cuda', torch.cuda.current_device()) + # object_sizes_tensor = object_sizes_tensor.to(current_device) + # object_sizes_tensor = object_sizes_tensor.to(current_device) # Broadcast object sizes sm_dist.broadcast(object_sizes_tensor, src=src, group=group) @@ -119,8 +118,8 @@ def _broadcast_object_list(object_list, rank, src=0, group=None): else: object_tensor = torch.ByteTensor(torch.sum(object_sizes_tensor).item()) - if is_nccl_backend: - object_tensor = object_tensor.to(current_device) + # if is_nccl_backend: + # object_tensor = object_tensor.to(current_device) sm_dist.broadcast(object_tensor, src=src, group=group) From e43fc1b9401df31a53f1355adaac504f5e126510 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Tue, 13 Jul 2021 21:38:29 +0530 Subject: [PATCH 42/43] Update group --- pytorch_lightning/distributed/dist.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/distributed/dist.py b/pytorch_lightning/distributed/dist.py index ff1b13c66ebeb..1b456951a6247 100644 --- a/pytorch_lightning/distributed/dist.py +++ b/pytorch_lightning/distributed/dist.py @@ -48,7 +48,7 @@ def broadcast(self, obj: Any, group=_group.WORLD): class SMLightningDistributed(LightningDistributed): - def broadcast(self, obj: Any, group=_group.WORLD): + def broadcast(self, obj: Any, group=sm_dist.group.WORLD): # always wrap into a list so list can be brodcasted. obj = [obj] @@ -57,7 +57,7 @@ def broadcast(self, obj: Any, group=_group.WORLD): if self.rank != 0: obj = [None] * len(obj) - _broadcast_object_list(obj, self.rank, 0, group=group or _group.WORLD) + _broadcast_object_list(obj, self.rank, 0, group=group) return obj[0] From 1ac88a8dc4f9ec377bc4655899c48356aaa04d48 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Wed, 14 Jul 2021 18:24:39 +0530 Subject: [PATCH 43/43] Update logger --- pytorch_lightning/loggers/tensorboard.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index ea0937016550d..e9a7b821fe489 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -266,16 +266,14 @@ def version(self) -> int: return self._version def _get_next_version(self): - root_dir = self.root_dir + root_dir = os.path.join(self.save_dir, self.name) - try: - listdir_info = self._fs.listdir(root_dir) - except OSError: + if not self._fs.isdir(root_dir): log.warning('Missing logger folder: %s', root_dir) return 0 existing_versions = [] - for listing in listdir_info: + for listing in self._fs.listdir(root_dir): d = listing["name"] bn = os.path.basename(d) if self._fs.isdir(d) and bn.startswith("version_"):