From a7dea0da801f851fec325b5b4341c0854de1954e Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 8 Dec 2020 12:16:26 +0000 Subject: [PATCH 01/20] Added changes for RPC plugin --- pytorch_lightning/accelerators/accelerator.py | 19 +++-- .../accelerators/ddp2_accelerator.py | 33 +++++++-- .../accelerators/ddp_accelerator.py | 42 ++++++++--- .../accelerators/ddp_cpu_spawn_accelerator.py | 14 +++- .../accelerators/ddp_hpc_accelerator.py | 30 ++++++-- .../accelerators/ddp_spawn_accelerator.py | 33 +++++++-- .../accelerators/horovod_accelerator.py | 4 ++ .../accelerators/tpu_accelerator.py | 4 ++ .../callbacks/model_checkpoint.py | 9 ++- pytorch_lightning/core/optimizer.py | 12 ++++ pytorch_lightning/distributed/dist.py | 20 +++--- pytorch_lightning/plugins/ddp_plugin.py | 21 +++++- pytorch_lightning/plugins/rpc_plugin.py | 72 +++++++++++++++++++ pytorch_lightning/trainer/data_loading.py | 15 +--- tests/backends/test_accelerator_connector.py | 7 +- 15 files changed, 280 insertions(+), 55 deletions(-) create mode 100644 pytorch_lightning/plugins/rpc_plugin.py diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index c5b744c3384ec..efcc8ef9b9b45 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -11,20 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from contextlib import contextmanager from enum import Enum from typing import Any, Optional, Union import torch import torch.distributed as torch_distrib +from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from torch.optim import Optimizer from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.apply_func import move_data_to_device -from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict if torch.distributed.is_available(): @@ -146,10 +143,10 @@ def setup_optimizers(self, model): self.trainer.lr_schedulers = lr_schedulers self.trainer.optimizer_frequencies = optimizer_frequencies - def init_ddp_connection( + def init_distributed_connection( self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True ) -> None: - self.ddp_plugin.init_ddp_connection( + self.ddp_plugin.init_distributed_connection( self.trainer, self.cluster_environment, global_rank, @@ -222,6 +219,16 @@ def __setstate__(self, d): def on_save(self, checkpoint): return checkpoint + @property + def rpc_enabled(self): + if self.ddp_plugin is not None and isinstance(self.ddp_plugin, RPCPlugin): + return True + return False + + @property + def distributed_sampler_kwargs(self): + raise NotImplementedError + @contextmanager def block_ddp_plugin_sync_behaviour(self): """ diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index f43866881cabb..90e8d7a3e7efa 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -23,6 +23,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.step_result import Result from pytorch_lightning.distributed.dist import LightningDistributed +from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available @@ -101,9 +102,11 @@ def set_world_ranks(self, process_idx): def broadcast(self, obj, src=0): return self.dist.broadcast(obj) - def model_to_device(self, model, process_idx): + def init_device(self, process_idx): self.trainer.root_gpu = process_idx torch.cuda.set_device(self.trainer.root_gpu) + + def model_to_device(self, model): model.cuda(self.trainer.root_gpu) def get_device_ids(self): @@ -133,16 +136,26 @@ def ddp_train(self, process_idx, mp_queue, model): # set warning rank rank_zero_only.rank = self.trainer.global_rank + # Initialize cuda device + self.init_device(process_idx) + # set up server using proc 0's ip address # try to init for 20 times at max in case ports are taken # where to store ip_table model.trainer = self.trainer - self.init_ddp_connection( + self.init_distributed_connection( self.trainer.global_rank, self.trainer.world_size, self.trainer.is_slurm_managing_tasks ) + if isinstance(self.ddp_plugin, RPCPlugin): + if not self.ddp_plugin.is_main_rpc_process: + self.ddp_plugin.on_exit_rpc_process(self.trainer) + self.ddp_plugin.exit_rpc_process() + return + self.ddp_plugin.on_main_rpc_connection(self.trainer) + # call setup after the ddp process has connected self.trainer.call_setup_hook(model) @@ -158,12 +171,14 @@ def ddp_train(self, process_idx, mp_queue, model): model = self.configure_sync_batchnorm(model) # move the model to the correct device - self.model_to_device(model, process_idx) + self.model_to_device(model) # CHOOSE OPTIMIZER # allow for lr schedulers as well self.setup_optimizers(model) + self.ddp_plugin.on_after_setup_optimizers(self.trainer) + # set model properties before going into wrapper self.trainer.model_connector.copy_trainer_model_properties(model) @@ -189,7 +204,7 @@ def ddp_train(self, process_idx, mp_queue, model): return results def configure_ddp( - self, model: LightningModule, device_ids: List[int] + self, model: LightningModule, device_ids: List[int] ) -> DistributedDataParallel: model = self.ddp_plugin.configure_ddp(model, device_ids) return model @@ -219,3 +234,13 @@ def sync_tensor(self, def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) + + @property + def distributed_sampler_kwargs(self): + distributed_sampler_kwargs = dict( + num_replicas=self.trainer.num_nodes, + rank=self.trainer.global_rank + ) + if self.ddp_plugin is not None: + distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) + return distributed_sampler_kwargs diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index 687b5c21874fb..52356d4f73dad 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -27,6 +27,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.distributed.dist import LightningDistributed +from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType from pytorch_lightning.utilities.distributed import find_free_network_port, rank_zero_only, sync_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -162,8 +163,11 @@ def _step(self, args): return output def barrier(self, name: Optional[str] = None): - if torch_distrib.is_initialized(): - torch_distrib.barrier() + if self.rpc_enabled: + # Allow RPC to handle barrier on main RPC processes + self.ddp_plugin.barrier() + elif torch_distrib.is_initialized(): + torch_distrib.barrier(group=self.ddp_plugin.data_parallel_group) def _check_can_spawn_children(self): if self._has_spawned_children: @@ -177,9 +181,11 @@ def set_world_ranks(self, process_idx): self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes - def model_to_device(self, model, process_idx): + def init_device(self, process_idx): self.trainer.root_gpu = self.trainer.data_parallel_device_ids[self.trainer.local_rank] torch.cuda.set_device(self.trainer.root_gpu) + + def model_to_device(self, model): model.cuda(self.trainer.root_gpu) def get_device_ids(self): @@ -192,12 +198,12 @@ def on_train_end(self): def early_stopping_should_stop(self, pl_module): stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM) - torch_distrib.barrier() + self.barrier('early_stopping') should_stop = stop == self.trainer.world_size return should_stop def broadcast(self, obj, src=0): - return self.dist.broadcast(obj) + return self.dist.broadcast(obj, group=self.ddp_plugin.data_parallel_group) def ddp_train(self, process_idx, model): """ @@ -226,16 +232,26 @@ def ddp_train(self, process_idx, model): # set warning rank rank_zero_only.rank = self.trainer.global_rank + # Initialize cuda device + self.init_device(process_idx) + # set up server using proc 0's ip address # try to init for 20 times at max in case ports are taken # where to store ip_table model.trainer = self.trainer - self.init_ddp_connection( + self.init_distributed_connection( self.trainer.global_rank, self.trainer.world_size, self.trainer.is_slurm_managing_tasks ) + if isinstance(self.ddp_plugin, RPCPlugin): + if not self.ddp_plugin.is_main_rpc_process: + self.ddp_plugin.on_exit_rpc_process(self.trainer) + self.ddp_plugin.exit_rpc_process() + return + self.ddp_plugin.on_main_rpc_connection(self.trainer) + # call setup after the ddp process has connected self.trainer.call_setup_hook(model) @@ -251,7 +267,7 @@ def ddp_train(self, process_idx, model): model = self.configure_sync_batchnorm(model) # move the model to the correct device - self.model_to_device(model, process_idx) + self.model_to_device(model) # CHOOSE OPTIMIZER # allow for lr schedulers as well @@ -284,7 +300,7 @@ def ddp_train(self, process_idx, model): return results def configure_ddp( - self, model: LightningModule, device_ids: List[int] + self, model: LightningModule, device_ids: List[int] ) -> DistributedDataParallel: model = self.ddp_plugin.configure_ddp(model, device_ids) return model @@ -317,3 +333,13 @@ def sync_tensor(self, def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) + + @property + def distributed_sampler_kwargs(self): + distributed_sampler_kwargs = dict( + num_replicas=self.trainer.num_nodes * self.trainer.num_processes, + rank=self.trainer.global_rank + ) + if self.ddp_plugin is not None: + distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) + return distributed_sampler_kwargs diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 982da2f53216b..4e7a322f1d69b 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -24,6 +24,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.distributed.dist import LightningDistributed +from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType from pytorch_lightning.utilities.distributed import ( find_free_network_port, @@ -101,12 +102,19 @@ def ddp_train(self, process_idx, mp_queue, model): # try to init for 20 times at max in case ports are taken # where to store ip_table model.trainer = self.trainer - self.init_ddp_connection( + self.init_distributed_connection( self.trainer.global_rank, self.trainer.world_size, self.trainer.is_slurm_managing_tasks ) + if isinstance(self.ddp_plugin, RPCPlugin): + if not self.ddp_plugin.is_main_rpc_process: + self.ddp_plugin.on_exit_rpc_process(self.trainer) + self.ddp_plugin.exit_rpc_process() + return + self.ddp_plugin.on_main_rpc_connection(self.trainer) + # call setup after the ddp process has connected self.trainer.call_setup_hook(model) @@ -128,6 +136,8 @@ def ddp_train(self, process_idx, mp_queue, model): # allow for lr schedulers as well self.setup_optimizers(model) + self.ddp_plugin.on_after_setup_optimizers(self.trainer) + # set model properties before going into wrapper self.trainer.model_connector.copy_trainer_model_properties(model) @@ -221,7 +231,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): mp_queue.put(results) def configure_ddp( - self, model: LightningModule, device_ids: List[int] + self, model: LightningModule, device_ids: List[int] ) -> DistributedDataParallel: model = self.ddp_plugin.configure_ddp(model, device_ids) return model diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index 28817c6845f5b..5842f3518b39a 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -23,6 +23,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.distributed.dist import LightningDistributed +from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available @@ -62,9 +63,11 @@ def set_world_ranks(self, process_idx): self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes - def model_to_device(self, model, process_idx): + def init_device(self, process_idx): self.trainer.root_gpu = process_idx torch.cuda.set_device(self.trainer.root_gpu) + + def model_to_device(self, model): model.cuda(self.trainer.root_gpu) def get_device_ids(self): @@ -130,12 +133,19 @@ def ddp_train(self, process_idx, model): # try to init for 20 times at max in case ports are taken # where to store ip_table model.trainer = self.trainer - self.init_ddp_connection( + self.init_distributed_connection( self.trainer.global_rank, self.trainer.world_size, self.trainer.is_slurm_managing_tasks ) + if isinstance(self.ddp_plugin, RPCPlugin): + if not self.ddp_plugin.is_main_rpc_process: + self.ddp_plugin.on_exit_rpc_process(self.trainer) + self.ddp_plugin.exit_rpc_process() + return + self.ddp_plugin.on_main_rpc_connection(self.trainer) + # call setup after the ddp process has connected self.trainer.call_setup_hook(model) @@ -151,12 +161,14 @@ def ddp_train(self, process_idx, model): model = self.configure_sync_batchnorm(model) # move the model to the correct device - self.model_to_device(model, process_idx) + self.model_to_device(model) # CHOOSE OPTIMIZER # allow for lr schedulers as well self.setup_optimizers(model) + self.ddp_plugin.on_after_setup_optimizers(self.trainer) + # set model properties before going into wrapper self.trainer.model_connector.copy_trainer_model_properties(model) @@ -183,7 +195,7 @@ def ddp_train(self, process_idx, model): return results def configure_ddp( - self, model: LightningModule, device_ids: List[int] + self, model: LightningModule, device_ids: List[int] ) -> DistributedDataParallel: model = self.ddp_plugin.configure_ddp(model, device_ids) return model @@ -213,3 +225,13 @@ def sync_tensor(self, def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) + + @property + def distributed_sampler_kwargs(self): + distributed_sampler_kwargs = dict( + num_replicas=self.trainer.num_nodes * self.trainer.num_processes, + rank=self.trainer.global_rank + ) + if self.ddp_plugin is not None: + distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) + return distributed_sampler_kwargs diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index a06d0b82d6d15..dda1c8aec6294 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -25,6 +25,7 @@ from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.distributed import LightningDistributed +from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import load as pl_load @@ -109,16 +110,26 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 # set warning rank rank_zero_only.rank = self.trainer.global_rank + # Initialize cuda device + self.init_device(process_idx, is_master) + # set up server using proc 0's ip address # try to init for 20 times at max in case ports are taken # where to store ip_table model.trainer = self.trainer - self.init_ddp_connection( + self.init_distributed_connection( self.trainer.global_rank, self.trainer.world_size, self.trainer.is_slurm_managing_tasks ) + if isinstance(self.ddp_plugin, RPCPlugin): + if not self.ddp_plugin.is_main_rpc_process: + self.ddp_plugin.on_exit_rpc_process(self.trainer) + self.ddp_plugin.exit_rpc_process() + return + self.ddp_plugin.on_main_rpc_connection(self.trainer) + # call setup after the ddp process has connected self.trainer.call_setup_hook(model) @@ -134,12 +145,14 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 model = self.configure_sync_batchnorm(model) # move the model to the correct device - self.model_to_device(model, process_idx, is_master) + self.model_to_device(model) # CHOOSE OPTIMIZER # allow for lr schedulers as well self.setup_optimizers(model) + self.ddp_plugin.on_after_setup_optimizers(self.trainer) + # set model properties before going into wrapper self.trainer.model_connector.copy_trainer_model_properties(model) @@ -174,10 +187,12 @@ def set_world_ranks(self, process_idx): self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes - def model_to_device(self, model, process_idx, is_master): + def init_device(self, process_idx, is_master): gpu_idx = self.trainer.data_parallel_device_ids[self.trainer.local_rank] self.trainer.root_gpu = gpu_idx torch.cuda.set_device(self.trainer.root_gpu) + + def model_to_device(self, model): model.cuda(self.trainer.root_gpu) def get_device_ids(self): @@ -248,7 +263,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): mp_queue.put(last_path) def configure_ddp( - self, model: LightningModule, device_ids: List[int] + self, model: LightningModule, device_ids: List[int] ) -> DistributedDataParallel: model = self.ddp_plugin.configure_ddp(model, device_ids) return model @@ -278,3 +293,13 @@ def sync_tensor(self, def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) + + @property + def distributed_sampler_kwargs(self): + distributed_sampler_kwargs = dict( + num_replicas=self.trainer.num_nodes * self.trainer.num_processes, + rank=self.trainer.global_rank + ) + if self.ddp_plugin is not None: + distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) + return distributed_sampler_kwargs diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index 460f5a83d2582..364415de47e63 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -206,3 +206,7 @@ def sync_tensor(self, # sync all processes before reduction hvd.join() return hvd.allreduce(tensor, op=reduce_op) + + @property + def distributed_sampler_kwargs(self): + return dict(num_replicas=hvd.size(), rank=hvd.rank()) diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index cd6b99fa64eef..01b876ade94ad 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -364,3 +364,7 @@ def on_save(self, checkpoint): https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors """ return move_data_to_device(checkpoint, torch.device("cpu")) + + @property + def distributed_sampler_kwargs(self): + return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index eb669736ada3a..1354f7f5056b3 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -33,6 +33,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn +from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -548,7 +549,13 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath) ) last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}") - self._save_model(last_filepath, trainer, pl_module) + accelerator_backend = trainer.accelerator_backend + + if accelerator_backend is not None and accelerator_backend.rpc_enabled: + # RPCPlugin manages saving all model states + accelerator_backend.ddp_plugin.rpc_save_model(self._save_model, last_filepath, trainer, pl_module) + else: + self._save_model(last_filepath, trainer, pl_module) if ( self.last_model_path and self.last_model_path != last_filepath diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index f07f467810c09..37bb1c286dee8 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -102,6 +102,18 @@ def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_n optimizer = self._optimizer model = trainer.get_model() + accelerator_backend = trainer.accelerator_backend + if accelerator_backend is not None and accelerator_backend.rpc_enabled: + if accelerator_backend.ddp_plugin.is_main_rpc_process: + # Initialize optimizer step on main process + accelerator_backend.ddp_plugin.optimizer_step( + model=model, + lightning_optimizer=self, + closure=closure, + *args, + **kwargs + ) + if trainer.on_tpu: with trainer.profiler.profile(profiler_name): xm.optimizer_step(optimizer, optimizer_args={'closure': closure, **kwargs}) diff --git a/pytorch_lightning/distributed/dist.py b/pytorch_lightning/distributed/dist.py index 37706523c8fdd..b66e9bd390cde 100644 --- a/pytorch_lightning/distributed/dist.py +++ b/pytorch_lightning/distributed/dist.py @@ -15,7 +15,7 @@ import torch from typing import Any from torch import distributed as torch_distrib - +from torch.distributed import group class LightningDistributed: @@ -23,27 +23,27 @@ def __init__(self, rank=None, device=None): self.rank = rank self.device = device - def broadcast(self, obj: Any): + def broadcast(self, obj: Any, group: group.WORLD = group.WORLD): if self.rank == 0: - self._emit(obj) + self._emit(obj, group) else: - obj = self._receive() + obj = self._receive(group) return obj - def _emit(self, obj): + def _emit(self, obj: Any, group: group.WORLD = group.WORLD): buffer = io.BytesIO() torch.save(obj, buffer) data = bytearray(buffer.getbuffer()) length_tensor = torch.tensor([len(data)]).long().to(self.device) - length_tensor = torch_distrib.broadcast(length_tensor, src=0) + length_tensor = torch_distrib.broadcast(length_tensor, src=0, group=group) data_tensor = torch.ByteTensor(data).to(self.device) - data_tensor = torch_distrib.broadcast(data_tensor, src=0) + data_tensor = torch_distrib.broadcast(data_tensor, src=0, group=group) - def _receive(self): + def _receive(self, group: group.WORLD = group.WORLD): length_tensor = torch.tensor([0]).long().to(self.device) - torch_distrib.broadcast(length_tensor, src=0) + torch_distrib.broadcast(length_tensor, src=0, group=group) data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8).to(self.device) - torch_distrib.broadcast(data_tensor, src=0) + torch_distrib.broadcast(data_tensor, src=0, group=group) buffer = io.BytesIO(data_tensor.cpu().numpy()) obj = torch.load(buffer) return obj diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 7e481dfade421..664e9eec01609 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -70,7 +70,7 @@ def configure_ddp(self, model, device_ids): ) return model - def init_ddp_connection( + def init_distributed_connection( self, trainer, cluster_environment, @@ -112,6 +112,13 @@ def on_before_forward(self, model, *args): def optimizer_state(self, optimizer: Optimizer) -> dict: return optimizer.state_dict() + def on_after_setup_optimizers(self, trainer): + """ + Called after optimizers have been set-up. This is useful for doing any configuration options in RPC, or + state sharding. + """ + pass + def get_model_from_plugin( self, model: Union[LightningDistributedDataParallel, LightningModule] @@ -148,3 +155,15 @@ def on_before_manual_backward(self, model: LightningDistributedDataParallel, out def on_after_manual_backward(self, model: LightningDistributedDataParallel): model.reducer_reset_hooks() + + def distributed_sampler_kwargs(self, distributed_sampler_kwargs): + return distributed_sampler_kwargs + + @property + def data_parallel_group(self) -> torch_distrib.group: + """ + Return the group that this process exists in. By default, this is the world size. + Useful for when additional parallel groups have been created, to select certain processes. + Returns: The ProcessGroup this process exists in. + """ + return torch_distrib.group.WORLD diff --git a/pytorch_lightning/plugins/rpc_plugin.py b/pytorch_lightning/plugins/rpc_plugin.py new file mode 100644 index 0000000000000..111442ec71314 --- /dev/null +++ b/pytorch_lightning/plugins/rpc_plugin.py @@ -0,0 +1,72 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import torch +from torch.distributed import rpc + +from pytorch_lightning.plugins.ddp_plugin import DDPPlugin + + +class RPCPlugin(DDPPlugin): + """ + Backbone for RPC Plugins built on top of DDP. + RPC introduces different communication behaviour than DDP. Unlike DDP, processes potentially are not + required to run the same code as the main process. + This leads to edge cases where logic needs to be re-defined. This class contains special cases + that need to be addressed when using RPC communication when building custom RPC Plugins. + """ + + def __init__(self, **kwargs): + self.rpc_initialized = False + super().__init__(**kwargs) + + def init_rpc_connection(self, + global_rank: int, + world_size: int): + os.environ['MASTER_PORT'] = os.getenv('RPC_MASTER_PORT', '15000') + rpc.init_rpc(f"worker{global_rank}", rank=global_rank, world_size=world_size) + self.rpc_initialized = True + + def rpc_save_model(self, + save_model_fn, + last_filepath, + trainer, + pl_module): + raise NotImplementedError + + def on_main_rpc_connection(self, trainer): + raise NotImplementedError + + def on_exit_rpc_process(self, trainer): + self.exit_rpc_process() + + def exit_rpc_process(self): + if self.rpc_initialized: + torch.distributed.rpc.shutdown() + self.rpc_initialized = False + + def optimizer_step(self, + model, + lightning_optimizer, + closure, + *args, + **kwargs): + raise NotImplementedError + + def is_main_rpc_process(self): + raise NotImplementedError + + def barrier(self): + raise NotImplementedError diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 4a7b14d0b1fe9..3813e21ef7cdf 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -131,20 +131,7 @@ def replace_sampler(self, dataloader, sampler): return dataloader def _get_distributed_sampler(self, dataloader, shuffle): - if self.use_tpu: - kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) - elif self.use_horovod: - kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank()) - else: - world_size = { - "ddp": self.num_nodes * self.num_processes, - "ddp_spawn": self.num_nodes * self.num_processes, - "ddp2": self.num_nodes, - "ddp_cpu": self.num_processes * self.num_nodes - } - assert self.distributed_backend is not None - kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank) - + kwargs = self.accelerator_backend.distributed_sampler_kwargs kwargs['shuffle'] = shuffle and not self.overfit_batches sampler = DistributedSampler(dataloader.dataset, **kwargs) return sampler diff --git a/tests/backends/test_accelerator_connector.py b/tests/backends/test_accelerator_connector.py index 7eeada3d5ddd1..d5521241fdd4d 100644 --- a/tests/backends/test_accelerator_connector.py +++ b/tests/backends/test_accelerator_connector.py @@ -288,6 +288,7 @@ def test_accelerator_choice_ddp_cpu_custom_cluster(tmpdir): """ Test that we choose the custom cluster even when SLURM or TE flags are around """ + class CustomCluster(ClusterEnvironment): def master_address(self): return 'asdf' @@ -322,7 +323,11 @@ def on_fit_start(self, trainer, pl_module): @mock.patch('torch.cuda.device_count', return_value=0) def test_custom_accelerator(tmpdir): class Accel(Accelerator): - def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True) -> None: + def init_distributed_connection( + self, + global_rank: int, + world_size: int, + is_slurm_managing_tasks: bool = True) -> None: pass class CB(Callback): From 836073bf91ce2e1aff8c170155d7245fe735bb54 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 8 Dec 2020 12:36:17 +0000 Subject: [PATCH 02/20] Add missing kwargs --- .../accelerators/ddp_cpu_spawn_accelerator.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 4e7a322f1d69b..fe3d037f2b2a1 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -261,3 +261,13 @@ def sync_tensor(self, def get_reference_model(self, model) -> LightningModule: return self.ddp_plugin.get_model_from_plugin(model) + + @property + def distributed_sampler_kwargs(self): + distributed_sampler_kwargs = dict( + num_replicas=self.trainer.num_nodes * self.trainer.num_processes, + rank=self.trainer.global_rank + ) + if self.ddp_plugin is not None: + distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) + return distributed_sampler_kwargs From 075bc9b0ad3f112186b143a77d4cc786301fc137 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 8 Dec 2020 12:41:20 +0000 Subject: [PATCH 03/20] Fix code format --- pytorch_lightning/distributed/dist.py | 1 + pytorch_lightning/plugins/ddp_plugin.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/distributed/dist.py b/pytorch_lightning/distributed/dist.py index b66e9bd390cde..50374ea923336 100644 --- a/pytorch_lightning/distributed/dist.py +++ b/pytorch_lightning/distributed/dist.py @@ -17,6 +17,7 @@ from torch import distributed as torch_distrib from torch.distributed import group + class LightningDistributed: def __init__(self, rank=None, device=None): diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 664e9eec01609..14f2a0271a1f1 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -117,7 +117,6 @@ def on_after_setup_optimizers(self, trainer): Called after optimizers have been set-up. This is useful for doing any configuration options in RPC, or state sharding. """ - pass def get_model_from_plugin( self, From 4a713a594d95ea2e384aee58f049153e68be8492 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 8 Dec 2020 13:20:07 +0000 Subject: [PATCH 04/20] Loading refactors by introducing is_distributed var, fix optimizer step flow --- pytorch_lightning/accelerators/accelerator.py | 5 +++- .../accelerators/cpu_accelerator.py | 4 ++++ .../accelerators/ddp2_accelerator.py | 4 ++++ .../accelerators/ddp_accelerator.py | 4 ++++ .../accelerators/ddp_cpu_spawn_accelerator.py | 4 ++++ .../accelerators/ddp_hpc_accelerator.py | 4 ++++ .../accelerators/ddp_spawn_accelerator.py | 4 ++++ .../accelerators/dp_accelerator.py | 4 ++++ .../accelerators/gpu_accelerator.py | 4 ++++ .../accelerators/horovod_accelerator.py | 4 ++++ .../accelerators/tpu_accelerator.py | 4 ++++ pytorch_lightning/core/optimizer.py | 23 +++++++++---------- pytorch_lightning/plugins/rpc_plugin.py | 11 ++++----- pytorch_lightning/trainer/data_loading.py | 2 +- 14 files changed, 61 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index efcc8ef9b9b45..2b7ae6567be76 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -229,6 +229,10 @@ def rpc_enabled(self): def distributed_sampler_kwargs(self): raise NotImplementedError + @property + def is_distributed(self): + raise NotImplementedError + @contextmanager def block_ddp_plugin_sync_behaviour(self): """ @@ -239,7 +243,6 @@ def block_ddp_plugin_sync_behaviour(self): cm = self.ddp_plugin.block_backward_sync(self.trainer.model) if self.ddp_plugin else None yield cm - # TODO: allow user to compare with string even internaly we shall use these Enum to prevent typos... class BackendType(Enum): DP = 'dp' diff --git a/pytorch_lightning/accelerators/cpu_accelerator.py b/pytorch_lightning/accelerators/cpu_accelerator.py index fe0ab59fb554f..4bd638cc9623b 100644 --- a/pytorch_lightning/accelerators/cpu_accelerator.py +++ b/pytorch_lightning/accelerators/cpu_accelerator.py @@ -90,3 +90,7 @@ def sync_tensor(self, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return tensor + + @property + def is_distributed(self): + return False diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index 90e8d7a3e7efa..26b560f559985 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -244,3 +244,7 @@ def distributed_sampler_kwargs(self): if self.ddp_plugin is not None: distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) return distributed_sampler_kwargs + + @property + def is_distributed(self): + return True diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index 52356d4f73dad..2673ca77d5c0d 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -343,3 +343,7 @@ def distributed_sampler_kwargs(self): if self.ddp_plugin is not None: distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) return distributed_sampler_kwargs + + @property + def is_distributed(self): + return True diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index fe3d037f2b2a1..1bfd5046bf7b3 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -271,3 +271,7 @@ def distributed_sampler_kwargs(self): if self.ddp_plugin is not None: distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) return distributed_sampler_kwargs + + @property + def is_distributed(self): + return True diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index 5842f3518b39a..f6df7b62bd586 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -235,3 +235,7 @@ def distributed_sampler_kwargs(self): if self.ddp_plugin is not None: distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) return distributed_sampler_kwargs + + @property + def is_distributed(self): + return True diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index dda1c8aec6294..ad4449cabf738 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -303,3 +303,7 @@ def distributed_sampler_kwargs(self): if self.ddp_plugin is not None: distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) return distributed_sampler_kwargs + + @property + def is_distributed(self): + return True diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index 4b4e1eac8a66c..cb2e812227e2d 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -181,3 +181,7 @@ def get_reference_model(self, model) -> LightningModule: if isinstance(model, LightningDataParallel): return model.module return model + + @property + def is_distributed(self): + return True diff --git a/pytorch_lightning/accelerators/gpu_accelerator.py b/pytorch_lightning/accelerators/gpu_accelerator.py index b12d275c8ac26..2ba72b12882be 100644 --- a/pytorch_lightning/accelerators/gpu_accelerator.py +++ b/pytorch_lightning/accelerators/gpu_accelerator.py @@ -129,3 +129,7 @@ def sync_tensor(self, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return tensor + + @property + def is_distributed(self): + return False diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index 364415de47e63..540ff2a3e408a 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -210,3 +210,7 @@ def sync_tensor(self, @property def distributed_sampler_kwargs(self): return dict(num_replicas=hvd.size(), rank=hvd.rank()) + + @property + def is_distributed(self): + return True diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index 01b876ade94ad..fc0f9932c5f12 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -368,3 +368,7 @@ def on_save(self, checkpoint): @property def distributed_sampler_kwargs(self): return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) + + @property + def is_distributed(self): + return True diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 37bb1c286dee8..dc63231ba6ccb 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -102,18 +102,6 @@ def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_n optimizer = self._optimizer model = trainer.get_model() - accelerator_backend = trainer.accelerator_backend - if accelerator_backend is not None and accelerator_backend.rpc_enabled: - if accelerator_backend.ddp_plugin.is_main_rpc_process: - # Initialize optimizer step on main process - accelerator_backend.ddp_plugin.optimizer_step( - model=model, - lightning_optimizer=self, - closure=closure, - *args, - **kwargs - ) - if trainer.on_tpu: with trainer.profiler.profile(profiler_name): xm.optimizer_step(optimizer, optimizer_args={'closure': closure, **kwargs}) @@ -125,6 +113,17 @@ def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_n with trainer.profiler.profile(profiler_name): optimizer.step(closure=closure, *args, **kwargs) + accelerator_backend = trainer.accelerator_backend + if accelerator_backend is not None and accelerator_backend.rpc_enabled: + if accelerator_backend.ddp_plugin.is_main_rpc_process: + # Initialize optimizer step on main process + accelerator_backend.ddp_plugin.worker_optimizer_step( + model=model, + opt_idx=self._optimizer_idx, + *args, + **kwargs + ) + trainer.train_loop.on_before_zero_grad(self) model.optimizer_zero_grad( diff --git a/pytorch_lightning/plugins/rpc_plugin.py b/pytorch_lightning/plugins/rpc_plugin.py index 111442ec71314..ca9cd794955ca 100644 --- a/pytorch_lightning/plugins/rpc_plugin.py +++ b/pytorch_lightning/plugins/rpc_plugin.py @@ -57,12 +57,11 @@ def exit_rpc_process(self): torch.distributed.rpc.shutdown() self.rpc_initialized = False - def optimizer_step(self, - model, - lightning_optimizer, - closure, - *args, - **kwargs): + def worker_optimizer_step(self, + model, + opt_idx, + *args, + **kwargs): raise NotImplementedError def is_main_rpc_process(self): diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 3813e21ef7cdf..f87b544d42679 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -100,7 +100,7 @@ def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader: if not is_dataloader or is_iterable_ds: return dataloader - is_in_dist = self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu + is_in_dist = self.accelerator_backend.is_distributed if self.accelerator_backend is not None else False need_dist_sampler = is_in_dist and not isinstance(dataloader.sampler, DistributedSampler) if self.replace_sampler_ddp and need_dist_sampler: if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): From 2475064ea6e20b092176aa3b654435d43874cb46 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 8 Dec 2020 13:21:52 +0000 Subject: [PATCH 05/20] Add rpc guard --- pytorch_lightning/plugins/rpc_plugin.py | 5 ++++- pytorch_lightning/utilities/__init__.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/rpc_plugin.py b/pytorch_lightning/plugins/rpc_plugin.py index ca9cd794955ca..f44579c7128eb 100644 --- a/pytorch_lightning/plugins/rpc_plugin.py +++ b/pytorch_lightning/plugins/rpc_plugin.py @@ -14,9 +14,12 @@ import os import torch -from torch.distributed import rpc from pytorch_lightning.plugins.ddp_plugin import DDPPlugin +from pytorch_lightning.utilities import RPC_AVAILABLE + +if RPC_AVAILABLE: + from torch.distributed import rpc class RPCPlugin(DDPPlugin): diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 1e2eeea9f456c..610c33c6c2e8e 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -51,6 +51,7 @@ def _module_available(module_path: str) -> bool: TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel') +RPC_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.rpc') FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps From 1abf7723da5e0a6ca543736f31e216f31da680be Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 8 Dec 2020 13:29:02 +0000 Subject: [PATCH 06/20] Added docstrings and typing --- pytorch_lightning/plugins/rpc_plugin.py | 52 ++++++++++++++++++++----- 1 file changed, 43 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/plugins/rpc_plugin.py b/pytorch_lightning/plugins/rpc_plugin.py index f44579c7128eb..7765e04a7a1de 100644 --- a/pytorch_lightning/plugins/rpc_plugin.py +++ b/pytorch_lightning/plugins/rpc_plugin.py @@ -14,6 +14,7 @@ import os import torch +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.ddp_plugin import DDPPlugin from pytorch_lightning.utilities import RPC_AVAILABLE @@ -37,7 +38,7 @@ def __init__(self, **kwargs): def init_rpc_connection(self, global_rank: int, - world_size: int): + world_size: int) -> None: os.environ['MASTER_PORT'] = os.getenv('RPC_MASTER_PORT', '15000') rpc.init_rpc(f"worker{global_rank}", rank=global_rank, world_size=world_size) self.rpc_initialized = True @@ -46,13 +47,32 @@ def rpc_save_model(self, save_model_fn, last_filepath, trainer, - pl_module): + pl_module) -> None: + """ + Override to save model to disk. + This is required as the main process will be required to handle aggregating model states from RPC processes. + Args: + save_model_fn: The saving function to save final model. + last_filepath: The filepath to save the model to. + trainer: The trainer object. + pl_module: The LightningModule. + """ raise NotImplementedError - def on_main_rpc_connection(self, trainer): + def on_main_rpc_connection(self, trainer) -> None: + """ + Called when main rpc connection has been established. + Args: + trainer: The trainer object. + """ raise NotImplementedError - def on_exit_rpc_process(self, trainer): + def on_exit_rpc_process(self, trainer) -> None: + """ + Called to exit RPC process that is being managed by main process. + Args: + trainer: The trainer object. + """ self.exit_rpc_process() def exit_rpc_process(self): @@ -61,14 +81,28 @@ def exit_rpc_process(self): self.rpc_initialized = False def worker_optimizer_step(self, - model, - opt_idx, + model: LightningModule, + opt_idx: int, *args, - **kwargs): + **kwargs) -> None: + """ + Called when optimizer step is run on the main process. Used to signal any RPC workers to run optimizer step. + Args: + model: The LightningModule. + opt_idx: The idx of the optimizer to carry out step on. + """ raise NotImplementedError - def is_main_rpc_process(self): + @property + def is_main_rpc_process(self) -> bool: + """ + Override to add logic to determine current process is main RPC process. + """ raise NotImplementedError - def barrier(self): + def barrier(self) -> None: + """ + Override to define distributed sync communication. This needs to be handled differently due to + the RPC connection managing certain processes at the same time. + """ raise NotImplementedError From a12d396e9b89e00ffa34af71b1d18f4c426f22e7 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 8 Dec 2020 15:55:20 +0000 Subject: [PATCH 07/20] resolve comments --- pytorch_lightning/accelerators/accelerator.py | 7 +++---- pytorch_lightning/accelerators/dp_accelerator.py | 10 ++++++++++ pytorch_lightning/plugins/rpc_plugin.py | 5 +++-- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 2b7ae6567be76..6df6a9f1661e8 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -17,10 +17,10 @@ import torch import torch.distributed as torch_distrib -from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from torch.optim import Optimizer from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.parsing import AttributeDict @@ -221,9 +221,7 @@ def on_save(self, checkpoint): @property def rpc_enabled(self): - if self.ddp_plugin is not None and isinstance(self.ddp_plugin, RPCPlugin): - return True - return False + return self.ddp_plugin is not None and isinstance(self.ddp_plugin, RPCPlugin) @property def distributed_sampler_kwargs(self): @@ -243,6 +241,7 @@ def block_ddp_plugin_sync_behaviour(self): cm = self.ddp_plugin.block_backward_sync(self.trainer.model) if self.ddp_plugin else None yield cm + # TODO: allow user to compare with string even internaly we shall use these Enum to prevent typos... class BackendType(Enum): DP = 'dp' diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index cb2e812227e2d..2bda5e27a0b67 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -185,3 +185,13 @@ def get_reference_model(self, model) -> LightningModule: @property def is_distributed(self): return True + + @property + def distributed_sampler_kwargs(self): + distributed_sampler_kwargs = dict( + num_replicas=self.trainer.num_nodes, + rank=self.trainer.global_rank + ) + if self.ddp_plugin is not None: + distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) + return distributed_sampler_kwargs diff --git a/pytorch_lightning/plugins/rpc_plugin.py b/pytorch_lightning/plugins/rpc_plugin.py index 7765e04a7a1de..a14a66ce2d964 100644 --- a/pytorch_lightning/plugins/rpc_plugin.py +++ b/pytorch_lightning/plugins/rpc_plugin.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from typing import Optional import torch -from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.ddp_plugin import DDPPlugin from pytorch_lightning.utilities import RPC_AVAILABLE @@ -100,7 +101,7 @@ def is_main_rpc_process(self) -> bool: """ raise NotImplementedError - def barrier(self) -> None: + def barrier(self, name: Optional[str] = None) -> None: """ Override to define distributed sync communication. This needs to be handled differently due to the RPC connection managing certain processes at the same time. From 691dacc87dac8c0c4b8b850bb6eb28a03844c7c4 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 8 Dec 2020 16:08:35 +0000 Subject: [PATCH 08/20] Add additional rpc hook, refactor name of exit process hook for clarity --- pytorch_lightning/accelerators/ddp2_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_accelerator.py | 2 +- .../accelerators/ddp_cpu_spawn_accelerator.py | 5 +++-- .../accelerators/ddp_hpc_accelerator.py | 2 +- .../accelerators/ddp_spawn_accelerator.py | 2 +- pytorch_lightning/plugins/rpc_plugin.py | 13 +++++++++++-- 6 files changed, 18 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index 26b560f559985..16c187b5f9e0d 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -151,7 +151,7 @@ def ddp_train(self, process_idx, mp_queue, model): if isinstance(self.ddp_plugin, RPCPlugin): if not self.ddp_plugin.is_main_rpc_process: - self.ddp_plugin.on_exit_rpc_process(self.trainer) + self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) self.ddp_plugin.exit_rpc_process() return self.ddp_plugin.on_main_rpc_connection(self.trainer) diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index 2673ca77d5c0d..8c3b14439a669 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -247,7 +247,7 @@ def ddp_train(self, process_idx, model): if isinstance(self.ddp_plugin, RPCPlugin): if not self.ddp_plugin.is_main_rpc_process: - self.ddp_plugin.on_exit_rpc_process(self.trainer) + self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) self.ddp_plugin.exit_rpc_process() return self.ddp_plugin.on_main_rpc_connection(self.trainer) diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 1bfd5046bf7b3..33206bdf4dcdf 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -110,9 +110,10 @@ def ddp_train(self, process_idx, mp_queue, model): if isinstance(self.ddp_plugin, RPCPlugin): if not self.ddp_plugin.is_main_rpc_process: - self.ddp_plugin.on_exit_rpc_process(self.trainer) + self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) self.ddp_plugin.exit_rpc_process() - return + if self.ddp_plugin.return_after_exit_rpc_process: + return self.ddp_plugin.on_main_rpc_connection(self.trainer) # call setup after the ddp process has connected diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index f6df7b62bd586..f5b5b67ad5806 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -141,7 +141,7 @@ def ddp_train(self, process_idx, model): if isinstance(self.ddp_plugin, RPCPlugin): if not self.ddp_plugin.is_main_rpc_process: - self.ddp_plugin.on_exit_rpc_process(self.trainer) + self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) self.ddp_plugin.exit_rpc_process() return self.ddp_plugin.on_main_rpc_connection(self.trainer) diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index ad4449cabf738..fd5889a3b5c7d 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -125,7 +125,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 if isinstance(self.ddp_plugin, RPCPlugin): if not self.ddp_plugin.is_main_rpc_process: - self.ddp_plugin.on_exit_rpc_process(self.trainer) + self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) self.ddp_plugin.exit_rpc_process() return self.ddp_plugin.on_main_rpc_connection(self.trainer) diff --git a/pytorch_lightning/plugins/rpc_plugin.py b/pytorch_lightning/plugins/rpc_plugin.py index a14a66ce2d964..776ac17c3d4eb 100644 --- a/pytorch_lightning/plugins/rpc_plugin.py +++ b/pytorch_lightning/plugins/rpc_plugin.py @@ -68,9 +68,9 @@ def on_main_rpc_connection(self, trainer) -> None: """ raise NotImplementedError - def on_exit_rpc_process(self, trainer) -> None: + def on_accelerator_exit_rpc_process(self, trainer) -> None: """ - Called to exit RPC process that is being managed by main process. + Called to exit RPC process within the accelerator, that is being managed by main process. Args: trainer: The trainer object. """ @@ -81,6 +81,15 @@ def exit_rpc_process(self): torch.distributed.rpc.shutdown() self.rpc_initialized = False + @property + def return_after_exit_rpc_process(self) -> bool: + """ + Override to decide whether to skip train/test function after shutdown completed. + Usually RPC shutdown is a join/exit function, afterwards we want to exit the process. + Returns: Whether to return after rpc exit. + """ + raise NotImplementedError + def worker_optimizer_step(self, model: LightningModule, opt_idx: int, From 3569e8fa81802d88daec7fa1e9be68755d361e06 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 8 Dec 2020 16:16:21 +0000 Subject: [PATCH 09/20] remove annotation --- pytorch_lightning/plugins/ddp_plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 14f2a0271a1f1..f13e519fea31e 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -1,6 +1,6 @@ import os from contextlib import contextmanager -from typing import Any, Dict, List, Union, Optional +from typing import Any, Dict, List, Optional, Union import torch.distributed as torch_distrib from torch.optim import Optimizer @@ -159,7 +159,7 @@ def distributed_sampler_kwargs(self, distributed_sampler_kwargs): return distributed_sampler_kwargs @property - def data_parallel_group(self) -> torch_distrib.group: + def data_parallel_group(self): """ Return the group that this process exists in. By default, this is the world size. Useful for when additional parallel groups have been created, to select certain processes. From 168f9b616e1e8c7084a7d738c5346208d55b3dd7 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 8 Dec 2020 16:50:24 +0000 Subject: [PATCH 10/20] Modify behaviour to allow optional return, add test for rpc plugin --- .../accelerators/ddp2_accelerator.py | 6 +- .../accelerators/ddp_accelerator.py | 6 +- .../accelerators/ddp_cpu_spawn_accelerator.py | 3 +- .../accelerators/ddp_hpc_accelerator.py | 6 +- .../accelerators/ddp_spawn_accelerator.py | 6 +- tests/plugins/test_rpc_plugin.py | 124 ++++++++++++++++++ tests/special_tests.sh | 1 + 7 files changed, 143 insertions(+), 9 deletions(-) create mode 100644 tests/plugins/test_rpc_plugin.py diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index 16c187b5f9e0d..ed5de4af90c5f 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -153,8 +153,10 @@ def ddp_train(self, process_idx, mp_queue, model): if not self.ddp_plugin.is_main_rpc_process: self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) self.ddp_plugin.exit_rpc_process() - return - self.ddp_plugin.on_main_rpc_connection(self.trainer) + if self.ddp_plugin.return_after_exit_rpc_process: + return + else: + self.ddp_plugin.on_main_rpc_connection(self.trainer) # call setup after the ddp process has connected self.trainer.call_setup_hook(model) diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index 8c3b14439a669..65d04911ac4fc 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -249,8 +249,10 @@ def ddp_train(self, process_idx, model): if not self.ddp_plugin.is_main_rpc_process: self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) self.ddp_plugin.exit_rpc_process() - return - self.ddp_plugin.on_main_rpc_connection(self.trainer) + if self.ddp_plugin.return_after_exit_rpc_process: + return + else: + self.ddp_plugin.on_main_rpc_connection(self.trainer) # call setup after the ddp process has connected self.trainer.call_setup_hook(model) diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 33206bdf4dcdf..2f699f231a3e6 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -114,7 +114,8 @@ def ddp_train(self, process_idx, mp_queue, model): self.ddp_plugin.exit_rpc_process() if self.ddp_plugin.return_after_exit_rpc_process: return - self.ddp_plugin.on_main_rpc_connection(self.trainer) + else: + self.ddp_plugin.on_main_rpc_connection(self.trainer) # call setup after the ddp process has connected self.trainer.call_setup_hook(model) diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index f5b5b67ad5806..d3346d697e4da 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -143,8 +143,10 @@ def ddp_train(self, process_idx, model): if not self.ddp_plugin.is_main_rpc_process: self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) self.ddp_plugin.exit_rpc_process() - return - self.ddp_plugin.on_main_rpc_connection(self.trainer) + if self.ddp_plugin.return_after_exit_rpc_process: + return + else: + self.ddp_plugin.on_main_rpc_connection(self.trainer) # call setup after the ddp process has connected self.trainer.call_setup_hook(model) diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index fd5889a3b5c7d..e5cb205f87365 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -127,8 +127,10 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 if not self.ddp_plugin.is_main_rpc_process: self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer) self.ddp_plugin.exit_rpc_process() - return - self.ddp_plugin.on_main_rpc_connection(self.trainer) + if self.ddp_plugin.return_after_exit_rpc_process: + return + else: + self.ddp_plugin.on_main_rpc_connection(self.trainer) # call setup after the ddp process has connected self.trainer.call_setup_hook(model) diff --git a/tests/plugins/test_rpc_plugin.py b/tests/plugins/test_rpc_plugin.py new file mode 100644 index 0000000000000..8c5584ce66ad4 --- /dev/null +++ b/tests/plugins/test_rpc_plugin.py @@ -0,0 +1,124 @@ +import os +from typing import Optional +from unittest import mock + +import pytest +import torch + +from pytorch_lightning import Trainer, LightningModule +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.plugins.rpc_plugin import RPCPlugin +from pytorch_lightning.utilities import RPC_AVAILABLE +from tests.base.boring_model import BoringModel + + +@mock.patch.dict( + os.environ, + { + "CUDA_VISIBLE_DEVICES": "0,1", + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_LOCALID": "0", + }, +) +@mock.patch("torch.cuda.device_count", return_value=2) +@pytest.mark.parametrize( + ["ddp_backend", "gpus", "num_processes"], + [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], +) +@pytest.mark.skipif(not RPC_AVAILABLE, reason="RPC is not available") +def test_rpc_choice(tmpdir, ddp_backend, gpus, num_processes): + class CB(Callback): + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.accelerator_backend.ddp_plugin, RPCPlugin) + raise RuntimeError('finished plugin check') + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + gpus=gpus, + num_processes=num_processes, + distributed_backend=ddp_backend, + callbacks=[CB()], + plugins=[RPCPlugin()] + ) + + with pytest.raises(RuntimeError, match='finished plugin check'): + trainer.fit(model) + + +class CustomRPCPlugin(RPCPlugin): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.rpc_save_model_count = 0 + self.on_main_rpc_connect_count = 0 + self.worker_optimizer_step_count = 0 + self.is_main_rpc_process_count = 0 + self.on_exit_rpc_process_count = 0 + self.return_after_exit_rpc_process_count = 0 + + def on_accelerator_exit_rpc_process(self, trainer) -> None: + self.on_exit_rpc_process_count += 1 + + def rpc_save_model(self, save_model_fn, last_filepath, trainer, pl_module) -> None: + self.rpc_save_model_count += 1 + + def on_main_rpc_connection(self, trainer) -> None: + self.on_main_rpc_connect_count += 1 + + def worker_optimizer_step(self, model: LightningModule, opt_idx: int, *args, **kwargs) -> None: + self.worker_optimizer_step_count += 1 + + @property + def is_main_rpc_process(self) -> bool: + self.is_main_rpc_process_count += 1 + return torch.distributed.get_rank() == 0 + + @property + def return_after_exit_rpc_process(self) -> bool: + self.return_after_exit_rpc_process_count += 1 + return False + + def barrier(self, name: Optional[str] = None) -> None: + return + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(not RPC_AVAILABLE, reason="RPC is not available") +@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', + reason="test should be run outside of pytest") +def test_rpc_function_calls_ddp(tmpdir): + model = BoringModel() + plugin = CustomRPCPlugin() + max_epochs = 2 + limit_train_batches = 2 + trainer = Trainer( + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=max_epochs, + gpus=2, + distributed_backend='ddp', + plugins=[plugin] + ) + + trainer.fit(model) + if trainer.global_rank == 0: # Main process + assert plugin.rpc_save_model_count == max_epochs + assert plugin.on_main_rpc_connect_count == 1 + assert plugin.worker_optimizer_step_count == max_epochs * limit_train_batches + # Call once at init, and at optim step + assert plugin.is_main_rpc_process_count == 1 + plugin.worker_optimizer_step_count + assert plugin.on_exit_rpc_process_count == 0 + else: # Worker process + assert plugin.rpc_save_model_count == 0 + assert plugin.on_main_rpc_connect_count == 0 + # Never signaled by worker, only by main process + assert plugin.worker_optimizer_step_count == 0 + # Call once at init, and at optim step + assert plugin.is_main_rpc_process_count == 1 + plugin.worker_optimizer_step_count + # Called at init + assert plugin.on_exit_rpc_process_count == 1 diff --git a/tests/special_tests.sh b/tests/special_tests.sh index a87e380dbe275..7ea0f77ca2971 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -15,3 +15,4 @@ export PL_RUNNING_SPECIAL_TESTS=1 DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no" python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp +python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp \ No newline at end of file From 3e10f8f684e307a7d9d3f52db25f2d81bbf95fc6 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 8 Dec 2020 17:13:03 +0000 Subject: [PATCH 11/20] resolve tests --- pytorch_lightning/trainer/data_loading.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index f87b544d42679..dadde6cddf533 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -16,14 +16,14 @@ import platform from abc import ABC from copy import deepcopy -from typing import Union, List, Tuple, Callable, Optional, Iterable +from typing import Callable, Iterable, List, Optional, Tuple, Union from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.core import LightningModule -from pytorch_lightning.utilities import rank_zero_warn, TPU_AVAILABLE, HOROVOD_AVAILABLE +from pytorch_lightning.utilities import HOROVOD_AVAILABLE, TPU_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -131,7 +131,23 @@ def replace_sampler(self, dataloader, sampler): return dataloader def _get_distributed_sampler(self, dataloader, shuffle): - kwargs = self.accelerator_backend.distributed_sampler_kwargs + if self.accelerator_backend is not None: + kwargs = self.accelerator_backend.distributed_sampler_kwargs + else: + # todo: Find a way to remove this part + if self.use_tpu: + kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) + elif self.use_horovod: + kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank()) + else: + world_size = { + "ddp": self.num_nodes * self.num_processes, + "ddp_spawn": self.num_nodes * self.num_processes, + "ddp2": self.num_nodes, + "ddp_cpu": self.num_processes * self.num_nodes + } + assert self.distributed_backend is not None + kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank) kwargs['shuffle'] = shuffle and not self.overfit_batches sampler = DistributedSampler(dataloader.dataset, **kwargs) return sampler From b4236dadf012b48c9029458c76ec596f2c1ed5f4 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 8 Dec 2020 17:24:29 +0000 Subject: [PATCH 12/20] rename is_ddp_based --- pytorch_lightning/accelerators/accelerator.py | 3 ++- pytorch_lightning/accelerators/cpu_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp2_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_accelerator.py | 2 +- .../accelerators/ddp_cpu_spawn_accelerator.py | 2 +- .../accelerators/ddp_hpc_accelerator.py | 2 +- .../accelerators/ddp_spawn_accelerator.py | 2 +- pytorch_lightning/accelerators/dp_accelerator.py | 14 ++------------ pytorch_lightning/accelerators/gpu_accelerator.py | 2 +- .../accelerators/horovod_accelerator.py | 4 ++-- pytorch_lightning/accelerators/tpu_accelerator.py | 2 +- pytorch_lightning/trainer/data_loading.py | 4 ++-- 12 files changed, 16 insertions(+), 25 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 6df6a9f1661e8..83716bab30d21 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -31,6 +31,7 @@ class ReduceOp: SUM = None + class Accelerator(object): def __init__(self, trainer=None, cluster_environment=None, ddp_plugin=None): @@ -228,7 +229,7 @@ def distributed_sampler_kwargs(self): raise NotImplementedError @property - def is_distributed(self): + def is_ddp_based(self): raise NotImplementedError @contextmanager diff --git a/pytorch_lightning/accelerators/cpu_accelerator.py b/pytorch_lightning/accelerators/cpu_accelerator.py index 4bd638cc9623b..77904e265ecaa 100644 --- a/pytorch_lightning/accelerators/cpu_accelerator.py +++ b/pytorch_lightning/accelerators/cpu_accelerator.py @@ -92,5 +92,5 @@ def sync_tensor(self, return tensor @property - def is_distributed(self): + def is_ddp_based(self): return False diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index ed5de4af90c5f..26eba03e8b28c 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -248,5 +248,5 @@ def distributed_sampler_kwargs(self): return distributed_sampler_kwargs @property - def is_distributed(self): + def is_ddp_based(self): return True diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index 65d04911ac4fc..c139bac37f87d 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -347,5 +347,5 @@ def distributed_sampler_kwargs(self): return distributed_sampler_kwargs @property - def is_distributed(self): + def is_ddp_based(self): return True diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 2f699f231a3e6..85448e16f8725 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -275,5 +275,5 @@ def distributed_sampler_kwargs(self): return distributed_sampler_kwargs @property - def is_distributed(self): + def is_ddp_based(self): return True diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index d3346d697e4da..b7e2e336eb3e6 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -239,5 +239,5 @@ def distributed_sampler_kwargs(self): return distributed_sampler_kwargs @property - def is_distributed(self): + def is_ddp_based(self): return True diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index e5cb205f87365..22251fe0a7499 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -307,5 +307,5 @@ def distributed_sampler_kwargs(self): return distributed_sampler_kwargs @property - def is_distributed(self): + def is_ddp_based(self): return True diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index 2bda5e27a0b67..4309baaf7c917 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -183,15 +183,5 @@ def get_reference_model(self, model) -> LightningModule: return model @property - def is_distributed(self): - return True - - @property - def distributed_sampler_kwargs(self): - distributed_sampler_kwargs = dict( - num_replicas=self.trainer.num_nodes, - rank=self.trainer.global_rank - ) - if self.ddp_plugin is not None: - distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs) - return distributed_sampler_kwargs + def is_ddp_based(self): + return False diff --git a/pytorch_lightning/accelerators/gpu_accelerator.py b/pytorch_lightning/accelerators/gpu_accelerator.py index 2ba72b12882be..7b9635511700e 100644 --- a/pytorch_lightning/accelerators/gpu_accelerator.py +++ b/pytorch_lightning/accelerators/gpu_accelerator.py @@ -131,5 +131,5 @@ def sync_tensor(self, return tensor @property - def is_distributed(self): + def is_ddp_based(self): return False diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index 540ff2a3e408a..c892c9c72bea4 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -18,7 +18,7 @@ from torch.optim.lr_scheduler import _LRScheduler from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp -from pytorch_lightning.utilities import AMPType, HOROVOD_AVAILABLE +from pytorch_lightning.utilities import HOROVOD_AVAILABLE, AMPType from pytorch_lightning.utilities.distributed import rank_zero_only if HOROVOD_AVAILABLE: @@ -212,5 +212,5 @@ def distributed_sampler_kwargs(self): return dict(num_replicas=hvd.size(), rank=hvd.rank()) @property - def is_distributed(self): + def is_ddp_based(self): return True diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index fc0f9932c5f12..6d9e8ad1b2a3f 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -370,5 +370,5 @@ def distributed_sampler_kwargs(self): return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) @property - def is_distributed(self): + def is_ddp_based(self): return True diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index dadde6cddf533..cad9510e62c72 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -100,8 +100,8 @@ def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader: if not is_dataloader or is_iterable_ds: return dataloader - is_in_dist = self.accelerator_backend.is_distributed if self.accelerator_backend is not None else False - need_dist_sampler = is_in_dist and not isinstance(dataloader.sampler, DistributedSampler) + is_ddp_based = self.accelerator_backend.is_ddp_based if self.accelerator_backend is not None else False + need_dist_sampler = is_ddp_based and not isinstance(dataloader.sampler, DistributedSampler) if self.replace_sampler_ddp and need_dist_sampler: if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): raise MisconfigurationException( From ebbdd3eefb1d6c4d2952172202b55a4b16c0cd12 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 8 Dec 2020 17:33:21 +0000 Subject: [PATCH 13/20] update --- pytorch_lightning/accelerators/accelerator.py | 1 - pytorch_lightning/trainer/data_loading.py | 29 ++------------- pytorch_lightning/trainer/properties.py | 37 ++++++++++++++++++- 3 files changed, 39 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 83716bab30d21..9886b210deed3 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -31,7 +31,6 @@ class ReduceOp: SUM = None - class Accelerator(object): def __init__(self, trainer=None, cluster_environment=None, ddp_plugin=None): diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index cad9510e62c72..d829eb7a78f7a 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -23,18 +23,12 @@ from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.core import LightningModule -from pytorch_lightning.utilities import HOROVOD_AVAILABLE, TPU_AVAILABLE, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_utils import is_overridden -if TPU_AVAILABLE: - import torch_xla.core.xla_model as xm - -if HOROVOD_AVAILABLE: - import horovod.torch as hvd - class TrainerDataLoadingMixin(ABC): @@ -100,8 +94,7 @@ def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader: if not is_dataloader or is_iterable_ds: return dataloader - is_ddp_based = self.accelerator_backend.is_ddp_based if self.accelerator_backend is not None else False - need_dist_sampler = is_ddp_based and not isinstance(dataloader.sampler, DistributedSampler) + need_dist_sampler = self.is_ddp_based and not isinstance(dataloader.sampler, DistributedSampler) if self.replace_sampler_ddp and need_dist_sampler: if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): raise MisconfigurationException( @@ -131,23 +124,7 @@ def replace_sampler(self, dataloader, sampler): return dataloader def _get_distributed_sampler(self, dataloader, shuffle): - if self.accelerator_backend is not None: - kwargs = self.accelerator_backend.distributed_sampler_kwargs - else: - # todo: Find a way to remove this part - if self.use_tpu: - kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) - elif self.use_horovod: - kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank()) - else: - world_size = { - "ddp": self.num_nodes * self.num_processes, - "ddp_spawn": self.num_nodes * self.num_processes, - "ddp2": self.num_nodes, - "ddp_cpu": self.num_processes * self.num_nodes - } - assert self.distributed_backend is not None - kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank) + kwargs = self.distributed_sampler_kwargs kwargs['shuffle'] = shuffle and not self.overfit_batches sampler = DistributedSampler(dataloader.dataset, **kwargs) return sampler diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index ba94ec2d95abb..5954c800bd632 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -27,10 +27,16 @@ from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import argparse_utils +from pytorch_lightning.utilities import HOROVOD_AVAILABLE, TPU_AVAILABLE, argparse_utils, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.model_utils import is_overridden +if TPU_AVAILABLE: + import torch_xla.core.xla_model as xm + +if HOROVOD_AVAILABLE: + import horovod.torch as hvd + class TrainerProperties(ABC): @@ -242,6 +248,35 @@ def __setstate__(self, d): # wrap optimizers in enable_pl_optimzer is True self.convert_to_lightning_optimizers() + @property + def is_ddp_based(self): + if self.accelerator_backend is not None: + return self.accelerator_backend.is_ddp_based + return self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu + + @property + def distributed_sampler_kwargs(self): + if self.accelerator_backend is not None: + return self.accelerator_backend.distributed_sampler_kwargs + + if self.use_tpu: + kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) + + elif self.use_horovod: + kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank()) + + else: + world_size = { + "ddp": self.num_nodes * self.num_processes, + "ddp_spawn": self.num_nodes * self.num_processes, + "ddp2": self.num_nodes, + "ddp_cpu": self.num_processes * self.num_nodes + } + assert self.distributed_backend is not None + kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank) + + return kwargs + # Used to represent the concrete type TrainerProperties class methods are called on. _T = TypeVar('_T', bound=TrainerProperties) From aa7cdb2cd132ef0ec82f9f7596b9cc3105653dfb Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 8 Dec 2020 17:49:40 +0000 Subject: [PATCH 14/20] update for windows --- pytorch_lightning/distributed/dist.py | 17 ++++++++++++----- pytorch_lightning/utilities/__init__.py | 1 + 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/distributed/dist.py b/pytorch_lightning/distributed/dist.py index 50374ea923336..d909d05c341d6 100644 --- a/pytorch_lightning/distributed/dist.py +++ b/pytorch_lightning/distributed/dist.py @@ -12,10 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import io -import torch from typing import Any + +import torch from torch import distributed as torch_distrib -from torch.distributed import group + +from pytorch_lightning.utilities import GROUP_AVAILABLE + +WORLD = None +if GROUP_AVAILABLE: + from torch.distributed import group + WORLD = group.WORLD class LightningDistributed: @@ -24,14 +31,14 @@ def __init__(self, rank=None, device=None): self.rank = rank self.device = device - def broadcast(self, obj: Any, group: group.WORLD = group.WORLD): + def broadcast(self, obj: Any, group=WORLD): if self.rank == 0: self._emit(obj, group) else: obj = self._receive(group) return obj - def _emit(self, obj: Any, group: group.WORLD = group.WORLD): + def _emit(self, obj: Any, group=WORLD): buffer = io.BytesIO() torch.save(obj, buffer) data = bytearray(buffer.getbuffer()) @@ -40,7 +47,7 @@ def _emit(self, obj: Any, group: group.WORLD = group.WORLD): data_tensor = torch.ByteTensor(data).to(self.device) data_tensor = torch_distrib.broadcast(data_tensor, src=0, group=group) - def _receive(self, group: group.WORLD = group.WORLD): + def _receive(self, group=WORLD): length_tensor = torch.tensor([0]).long().to(self.device) torch_distrib.broadcast(length_tensor, src=0, group=group) data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8).to(self.device) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 610c33c6c2e8e..0178eea2d60d1 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -52,6 +52,7 @@ def _module_available(module_path: str) -> bool: TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel') RPC_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.rpc') +GROUP_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.grpup') FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps From 85dcd25b78e8e30a7b79757faa563f7eb3ee3b97 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 8 Dec 2020 17:59:55 +0000 Subject: [PATCH 15/20] update --- pytorch_lightning/utilities/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 0178eea2d60d1..3a04a325905a9 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -52,7 +52,7 @@ def _module_available(module_path: str) -> bool: TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel') RPC_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.rpc') -GROUP_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.grpup') +GROUP_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.group') FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps From 18a37e7a5d75a429f8adeb1e48e5d4ed52f4ba20 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 8 Dec 2020 18:15:48 +0000 Subject: [PATCH 16/20] resolve test --- pytorch_lightning/distributed/dist.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/distributed/dist.py b/pytorch_lightning/distributed/dist.py index d909d05c341d6..d3539c93bad6c 100644 --- a/pytorch_lightning/distributed/dist.py +++ b/pytorch_lightning/distributed/dist.py @@ -38,20 +38,26 @@ def broadcast(self, obj: Any, group=WORLD): obj = self._receive(group) return obj + def _broadcast(self, tensor, src=0, group=WORLD): + if group is None: + return torch_distrib.broadcast(tensor, src=src) + else: + return torch_distrib.broadcast(tensor, src=0, group=group) + def _emit(self, obj: Any, group=WORLD): buffer = io.BytesIO() torch.save(obj, buffer) data = bytearray(buffer.getbuffer()) length_tensor = torch.tensor([len(data)]).long().to(self.device) - length_tensor = torch_distrib.broadcast(length_tensor, src=0, group=group) + length_tensor = self._broadcast(length_tensor, src=0, group=group) data_tensor = torch.ByteTensor(data).to(self.device) - data_tensor = torch_distrib.broadcast(data_tensor, src=0, group=group) + data_tensor = self._broadcast(data_tensor, src=0, group=group) def _receive(self, group=WORLD): length_tensor = torch.tensor([0]).long().to(self.device) - torch_distrib.broadcast(length_tensor, src=0, group=group) + self._broadcast(length_tensor, src=0, group=group) data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8).to(self.device) - torch_distrib.broadcast(data_tensor, src=0, group=group) + self._broadcast(data_tensor, src=0, group=group) buffer = io.BytesIO(data_tensor.cpu().numpy()) obj = torch.load(buffer) return obj From 6dae0f4f0a0740e29ee3e371b61edb17fa0557d7 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 8 Dec 2020 18:21:05 +0000 Subject: [PATCH 17/20] code smell --- pytorch_lightning/distributed/dist.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/distributed/dist.py b/pytorch_lightning/distributed/dist.py index d3539c93bad6c..429121f71feeb 100644 --- a/pytorch_lightning/distributed/dist.py +++ b/pytorch_lightning/distributed/dist.py @@ -41,8 +41,7 @@ def broadcast(self, obj: Any, group=WORLD): def _broadcast(self, tensor, src=0, group=WORLD): if group is None: return torch_distrib.broadcast(tensor, src=src) - else: - return torch_distrib.broadcast(tensor, src=0, group=group) + return torch_distrib.broadcast(tensor, src=0, group=group) def _emit(self, obj: Any, group=WORLD): buffer = io.BytesIO() From 975228113bbbf1c4ea8541224a7c9127984de366 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 8 Dec 2020 20:11:21 +0000 Subject: [PATCH 18/20] Revert back to init_ddp_connection for backwards compat --- pytorch_lightning/accelerators/accelerator.py | 4 ++-- pytorch_lightning/accelerators/ddp2_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_hpc_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_spawn_accelerator.py | 2 +- pytorch_lightning/plugins/ddp_plugin.py | 2 +- tests/backends/test_accelerator_connector.py | 2 +- 8 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 9886b210deed3..9e63dce5768ec 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -143,10 +143,10 @@ def setup_optimizers(self, model): self.trainer.lr_schedulers = lr_schedulers self.trainer.optimizer_frequencies = optimizer_frequencies - def init_distributed_connection( + def init_ddp_connection( self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True ) -> None: - self.ddp_plugin.init_distributed_connection( + self.ddp_plugin.init_ddp_connection( self.trainer, self.cluster_environment, global_rank, diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index 26eba03e8b28c..748d46584aca6 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -143,7 +143,7 @@ def ddp_train(self, process_idx, mp_queue, model): # try to init for 20 times at max in case ports are taken # where to store ip_table model.trainer = self.trainer - self.init_distributed_connection( + self.init_ddp_connection( self.trainer.global_rank, self.trainer.world_size, self.trainer.is_slurm_managing_tasks diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index 60d4a51ddb870..c3dd6dda36b34 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -239,7 +239,7 @@ def ddp_train(self, process_idx, model): # try to init for 20 times at max in case ports are taken # where to store ip_table model.trainer = self.trainer - self.init_distributed_connection( + self.init_ddp_connection( self.trainer.global_rank, self.trainer.world_size, self.trainer.is_slurm_managing_tasks diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 85448e16f8725..e10c1edd786af 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -102,7 +102,7 @@ def ddp_train(self, process_idx, mp_queue, model): # try to init for 20 times at max in case ports are taken # where to store ip_table model.trainer = self.trainer - self.init_distributed_connection( + self.init_ddp_connection( self.trainer.global_rank, self.trainer.world_size, self.trainer.is_slurm_managing_tasks diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index b7e2e336eb3e6..e6ebf83bac298 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -133,7 +133,7 @@ def ddp_train(self, process_idx, model): # try to init for 20 times at max in case ports are taken # where to store ip_table model.trainer = self.trainer - self.init_distributed_connection( + self.init_ddp_connection( self.trainer.global_rank, self.trainer.world_size, self.trainer.is_slurm_managing_tasks diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index 22251fe0a7499..7208040fcdede 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -117,7 +117,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 # try to init for 20 times at max in case ports are taken # where to store ip_table model.trainer = self.trainer - self.init_distributed_connection( + self.init_ddp_connection( self.trainer.global_rank, self.trainer.world_size, self.trainer.is_slurm_managing_tasks diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index f13e519fea31e..281074cb37813 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -70,7 +70,7 @@ def configure_ddp(self, model, device_ids): ) return model - def init_distributed_connection( + def init_ddp_connection( self, trainer, cluster_environment, diff --git a/tests/backends/test_accelerator_connector.py b/tests/backends/test_accelerator_connector.py index d5521241fdd4d..551de95c7e480 100644 --- a/tests/backends/test_accelerator_connector.py +++ b/tests/backends/test_accelerator_connector.py @@ -323,7 +323,7 @@ def on_fit_start(self, trainer, pl_module): @mock.patch('torch.cuda.device_count', return_value=0) def test_custom_accelerator(tmpdir): class Accel(Accelerator): - def init_distributed_connection( + def init_ddp_connection( self, global_rank: int, world_size: int, From 35be26bef22c4c2ff2a786038b4ab376dc845ed7 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 8 Dec 2020 20:16:37 +0000 Subject: [PATCH 19/20] Swap to explicit name for property --- pytorch_lightning/accelerators/accelerator.py | 2 +- pytorch_lightning/accelerators/cpu_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp2_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_hpc_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_spawn_accelerator.py | 2 +- pytorch_lightning/accelerators/dp_accelerator.py | 2 +- pytorch_lightning/accelerators/gpu_accelerator.py | 2 +- pytorch_lightning/accelerators/horovod_accelerator.py | 2 +- pytorch_lightning/accelerators/tpu_accelerator.py | 2 +- pytorch_lightning/trainer/data_loading.py | 2 +- pytorch_lightning/trainer/properties.py | 4 ++-- 13 files changed, 14 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 9e63dce5768ec..11af9e4d8f91e 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -228,7 +228,7 @@ def distributed_sampler_kwargs(self): raise NotImplementedError @property - def is_ddp_based(self): + def require_distributed_sampler(self): raise NotImplementedError @contextmanager diff --git a/pytorch_lightning/accelerators/cpu_accelerator.py b/pytorch_lightning/accelerators/cpu_accelerator.py index 77904e265ecaa..9113331ef0a7d 100644 --- a/pytorch_lightning/accelerators/cpu_accelerator.py +++ b/pytorch_lightning/accelerators/cpu_accelerator.py @@ -92,5 +92,5 @@ def sync_tensor(self, return tensor @property - def is_ddp_based(self): + def require_distributed_sampler(self): return False diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index 748d46584aca6..f47b389faf436 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -248,5 +248,5 @@ def distributed_sampler_kwargs(self): return distributed_sampler_kwargs @property - def is_ddp_based(self): + def require_distributed_sampler(self): return True diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index c3dd6dda36b34..d3d4c1fa1b766 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -347,5 +347,5 @@ def distributed_sampler_kwargs(self): return distributed_sampler_kwargs @property - def is_ddp_based(self): + def require_distributed_sampler(self): return True diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index e10c1edd786af..50bd1b7ab9051 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -275,5 +275,5 @@ def distributed_sampler_kwargs(self): return distributed_sampler_kwargs @property - def is_ddp_based(self): + def require_distributed_sampler(self): return True diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index e6ebf83bac298..50267afa525dc 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -239,5 +239,5 @@ def distributed_sampler_kwargs(self): return distributed_sampler_kwargs @property - def is_ddp_based(self): + def require_distributed_sampler(self): return True diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index 7208040fcdede..7db2d5a309d9c 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -307,5 +307,5 @@ def distributed_sampler_kwargs(self): return distributed_sampler_kwargs @property - def is_ddp_based(self): + def require_distributed_sampler(self): return True diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index 4309baaf7c917..a7f3c260e682c 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -183,5 +183,5 @@ def get_reference_model(self, model) -> LightningModule: return model @property - def is_ddp_based(self): + def require_distributed_sampler(self): return False diff --git a/pytorch_lightning/accelerators/gpu_accelerator.py b/pytorch_lightning/accelerators/gpu_accelerator.py index 7b9635511700e..abc065cd39ed4 100644 --- a/pytorch_lightning/accelerators/gpu_accelerator.py +++ b/pytorch_lightning/accelerators/gpu_accelerator.py @@ -131,5 +131,5 @@ def sync_tensor(self, return tensor @property - def is_ddp_based(self): + def require_distributed_sampler(self): return False diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index c892c9c72bea4..93983369f17a9 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -212,5 +212,5 @@ def distributed_sampler_kwargs(self): return dict(num_replicas=hvd.size(), rank=hvd.rank()) @property - def is_ddp_based(self): + def require_distributed_sampler(self): return True diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index 6d9e8ad1b2a3f..a7752e42a96cf 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -370,5 +370,5 @@ def distributed_sampler_kwargs(self): return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) @property - def is_ddp_based(self): + def require_distributed_sampler(self): return True diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index d829eb7a78f7a..3bb444622cebc 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -94,7 +94,7 @@ def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader: if not is_dataloader or is_iterable_ds: return dataloader - need_dist_sampler = self.is_ddp_based and not isinstance(dataloader.sampler, DistributedSampler) + need_dist_sampler = self.require_distributed_sampler and not isinstance(dataloader.sampler, DistributedSampler) if self.replace_sampler_ddp and need_dist_sampler: if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): raise MisconfigurationException( diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 001453d8007c9..355bbad3a037e 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -249,9 +249,9 @@ def __setstate__(self, d): self.convert_to_lightning_optimizers() @property - def is_ddp_based(self): + def require_distributed_sampler(self): if self.accelerator_backend is not None: - return self.accelerator_backend.is_ddp_based + return self.accelerator_backend.require_distributed_sampler return self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu @property From ca5cce5228dbc6351b41ac400610503f7efe4f34 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 8 Dec 2020 21:06:41 +0000 Subject: [PATCH 20/20] Add missing speed parity increase for CI variability, fix call counts for child process --- benchmarks/test_sharded_parity.py | 1 + tests/plugins/test_rpc_plugin.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py index 15a5871ca643b..1240710674c59 100644 --- a/benchmarks/test_sharded_parity.py +++ b/benchmarks/test_sharded_parity.py @@ -161,6 +161,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir): gpus=2, accelerator='ddp_spawn', model_cls=SeedTrainLoaderManualModel, + max_percent_speed_diff=0.25 # Increase speed diff since only 2 GPUs sharding 2 optimizers ) diff --git a/tests/plugins/test_rpc_plugin.py b/tests/plugins/test_rpc_plugin.py index 8c5584ce66ad4..7411fe9774334 100644 --- a/tests/plugins/test_rpc_plugin.py +++ b/tests/plugins/test_rpc_plugin.py @@ -114,11 +114,11 @@ def test_rpc_function_calls_ddp(tmpdir): assert plugin.is_main_rpc_process_count == 1 + plugin.worker_optimizer_step_count assert plugin.on_exit_rpc_process_count == 0 else: # Worker process - assert plugin.rpc_save_model_count == 0 + assert plugin.rpc_save_model_count == max_epochs assert plugin.on_main_rpc_connect_count == 0 # Never signaled by worker, only by main process assert plugin.worker_optimizer_step_count == 0 # Call once at init, and at optim step - assert plugin.is_main_rpc_process_count == 1 + plugin.worker_optimizer_step_count + assert plugin.is_main_rpc_process_count == 1 + (max_epochs * limit_train_batches) # Called at init assert plugin.on_exit_rpc_process_count == 1