diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index a8e63776f93d8..968469f104cba 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -279,6 +279,7 @@ def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Cal if make_optimizer_step: self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs) self.precision_plugin.post_optimizer_step(optimizer, opt_idx) + self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs) def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs): optimizer.step(closure=lambda_closure, **kwargs) diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index 8393a21104704..3b3cfa50cc045 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -161,6 +161,7 @@ def handle_given_plugins(self, plugins: Optional[Sequence]): if isinstance(plug, TrainingTypePlugin): if training_type is None: training_type = plug + else: raise MisconfigurationException( 'You can only specify one precision and one training type plugin. ' @@ -191,20 +192,20 @@ def handle_given_plugins(self, plugins: Optional[Sequence]): self._training_type_plugin = training_type self._precision_plugin = precision - self._cluster_environment = cluster_environment + self._cluster_environment = cluster_environment or self.select_cluster_environment() @property def precision_plugin(self) -> PrecisionPlugin: if self._precision_plugin is None: self._precision_plugin = self.select_precision_plugin() - return self._precision_plugin @property def training_type_plugin(self) -> TrainingTypePlugin: if self._training_type_plugin is None: self._training_type_plugin = self.select_training_type_plugin() - + else: + self._training_type_plugin = self.resolve_training_type_plugin(self._training_type_plugin) return self._training_type_plugin @property @@ -283,9 +284,6 @@ def select_precision_plugin(self): if self.on_tpu: return TPUHalfPrecisionPlugin() - if isinstance(self.training_type_plugin, RPCPlugin): - raise MisconfigurationException - if self.amp_type == "native": if not _NATIVE_AMP_AVAILABLE: rank_zero_warn( @@ -324,9 +322,8 @@ def select_precision_plugin(self): raise NotImplementedError("We only support precisions 32 and 16!") def select_training_type_plugin(self): - cluster_environment = self.select_cluster_environment() if self.use_ddp2: - plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=cluster_environment) + plugin = DDP2Plugin(parallel_devices=self.parallel_devices, cluster_environment=self._cluster_environment) elif self.use_ddp: use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic @@ -358,7 +355,7 @@ def select_training_type_plugin(self): plugin = ddp_plugin_cls( parallel_devices=self.parallel_devices, num_nodes=self.num_nodes, - cluster_environment=cluster_environment, + cluster_environment=self.select_cluster_environment(), sync_batchnorm=self.sync_batchnorm, ) elif self.use_dp: @@ -371,6 +368,22 @@ def select_training_type_plugin(self): plugin = SingleDevicePlugin(device=torch.device(f"cuda:{self.root_gpu}" if self.on_gpu else "cpu")) return plugin + + def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> TrainingTypePlugin: + # necessary for RPC, when user has to provide balance + if hasattr(training_type, 'parallel_devices') and not getattr(training_type, 'parallel_devices'): + training_type.parallel_devices = self.parallel_devices + if hasattr(training_type, 'num_processes'): + training_type.num_processes = len(self.parallel_devices) + + if hasattr(training_type, 'cluster_environment') and getattr(training_type, 'cluster_environment') is None: + training_type.cluster_environment = self.select_cluster_environment() + + if hasattr(training_type, 'num_nodes') and getattr(training_type, 'num_nodes') is None: + training_type.num_nodes = self.num_nodes + + return training_type + def select_accelerator(self): if isinstance(self.distributed_backend, Accelerator): # custom accelerator from user diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 274078d8a80d4..e2571b8a96eac 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -73,7 +73,7 @@ def __init__( self._has_spawned_children = False self.task_idx = None self.node_rank = 0 - self.num_processes = len(parallel_devices) + self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices @property def root_device(self): diff --git a/pytorch_lightning/plugins/training_type/rpc.py b/pytorch_lightning/plugins/training_type/rpc.py index 4aff83189b6bc..dc1c731da4ffa 100644 --- a/pytorch_lightning/plugins/training_type/rpc.py +++ b/pytorch_lightning/plugins/training_type/rpc.py @@ -13,7 +13,7 @@ # limitations under the License. import os from contextlib import suppress -from typing import Optional +from typing import Optional, Sequence import torch @@ -40,11 +40,11 @@ class RPCPlugin(DDPPlugin): def __init__( self, - parallel_devices, - num_nodes=1, - cluster_environment: ClusterEnvironment = None, - sync_batchnorm=False, rpc_timeout_sec: float = DEFAULT_RPC_TIMEOUT_SEC, + parallel_devices : Sequence[int] = (), + num_nodes: Optional[int] = None, + cluster_environment: Optional[ClusterEnvironment] = None, + sync_batchnorm: Optional[bool] = None, **kwargs ): self.rpc_timeout_sec = rpc_timeout_sec diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index 79cecac3fbb4d..cf02776eb5881 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -13,7 +13,7 @@ # limitations under the License import logging import os -from typing import Any, List, Optional +from typing import Any, List, Optional, Sequence import torch import torch.distributed as torch_distrib @@ -21,6 +21,7 @@ from torch.nn.parallel import DistributedDataParallel from pytorch_lightning.core.lightning import LightningModule +from torch.optim import Optimizer from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.rpc import DEFAULT_RPC_TIMEOUT_SEC, RPCPlugin @@ -42,11 +43,7 @@ class RPCSequentialPlugin(RPCPlugin): def __init__( self, - parallel_devices, - num_nodes: int = 1, - cluster_environment: ClusterEnvironment = None, - sync_batchnorm=False, - balance: Optional[List[int]] = None, + balance : List[int], microbatches: int = 8, checkpoint: str = 'except_last', balance_mode: str = "balance_by_size", @@ -93,10 +90,6 @@ def __init__( """ self._check_pipe_available() super().__init__( - parallel_devices=parallel_devices, - num_nodes=num_nodes, - cluster_environment=cluster_environment, - sync_batchnorm=sync_batchnorm, rpc_timeout_sec=rpc_timeout_sec, **kwargs ) @@ -324,6 +317,12 @@ def _check_pipe_available(self): 'PipeRPCPlugin requires FairScale and currently is only supported on PyTorch 1.6.' ) + def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, **kwargs) -> None: + """Hook to do something after each optimizer step.""" + if self.rpc_enabled and self.is_main_rpc_process: + + # Initialize optimizer step on main process + self.worker_optimizer_step(model=self.lightning_module, opt_idx=optimizer_idx, **kwargs) class LightningPipeModule(nn.Module): """ diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index c26f5fbc1b743..acc8288a31ada 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -75,6 +75,9 @@ def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, opti def post_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): """Run after precision plugin executes backward""" + def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, **kwargs) -> None: + """Hook to do something after each optimizer step.""" + @property def model(self) -> Module: """Returns the potentially wrapped LightningModule""" diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 6c539dec7fd3a..c7796b433f1ed 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -65,6 +65,7 @@ class DistributedType(LightningEnum): HOROVOD = 'horovod' DDP_SHARDED = 'ddp_sharded' DDP_SHARDED_SPAWN = 'ddp_sharded_spawn' + RPC_SEQUENTIAL_PLUGIN = 'rpc_sequential' class DeviceType(LightningEnum): diff --git a/tests/plugins/legacy/test_ddp_sequential_plugin.py b/tests/plugins/legacy/test_ddp_sequential_plugin.py index 8c6061d12cf11..00e163827269e 100644 --- a/tests/plugins/legacy/test_ddp_sequential_plugin.py +++ b/tests/plugins/legacy/test_ddp_sequential_plugin.py @@ -20,7 +20,7 @@ from torch import nn from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.plugins.legacy.ddp_sequential_plugin import DDPSequentialPlugin +from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base.boring_model import RandomDataset @@ -47,7 +47,7 @@ def test_ddp_sequential_plugin_ddp_rpc_manual(tmpdir, args=None): limit_test_batches=2, gpus=2, distributed_backend="ddp", - plugins=[DDPSequentialPlugin(balance=[2, 1], rpc_timeout_sec=5 * 60)], + plugins=[RPCSequentialPlugin(balance=[2, 1], rpc_timeout_sec=5 * 60)], enable_pl_optimizer=True, ) @@ -77,7 +77,7 @@ def test_ddp_sequential_plugin_ddp_rpc_manual_amp(tmpdir, args=None): precision=16, amp_backend="native", distributed_backend="ddp", - plugins=[DDPSequentialPlugin(balance=[2, 1])], + plugins=[RPCSequentialPlugin(balance=[2, 1])], ) try: trainer.fit(model) @@ -85,7 +85,7 @@ def test_ddp_sequential_plugin_ddp_rpc_manual_amp(tmpdir, args=None): assert len(trainer.dev_debugger.pbar_added_metrics) > 0 except MisconfigurationException as e: - assert str(e) == 'DDPSequentialPlugin is currently not supported in Automatic Mixed Precision' + assert str(e) == 'RPCSequentialPlugin is currently not supported in Automatic Mixed Precision' @pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed") @@ -102,7 +102,7 @@ def test_ddp_sequential_plugin_ddp_rpc_automatic(tmpdir, args=None): limit_test_batches=2, gpus=2, distributed_backend="ddp", - plugins=[DDPSequentialPlugin(balance=[2, 1])], + plugins=[RPCSequentialPlugin(balance=[2, 1])], ) trainer.fit(model) @@ -130,7 +130,7 @@ def test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance(tmpdir, args=None): limit_test_batches=2, gpus=2, distributed_backend="ddp", - plugins=[DDPSequentialPlugin(balance=[2, 2])], + plugins=[RPCSequentialPlugin(balance=[2, 2])], ) try: diff --git a/tests/plugins/legacy/test_rpc_plugin.py b/tests/plugins/legacy/test_rpc_plugin.py index 77937c16058dc..a1e28d22ace58 100644 --- a/tests/plugins/legacy/test_rpc_plugin.py +++ b/tests/plugins/legacy/test_rpc_plugin.py @@ -7,7 +7,7 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import Callback -from pytorch_lightning.plugins.legacy.rpc_plugin import RPCPlugin +from pytorch_lightning.plugins.training_type.rpc_sequential import RPCPlugin from pytorch_lightning.utilities import _RPC_AVAILABLE from tests.base.boring_model import BoringModel diff --git a/tests/special_tests.sh b/tests/special_tests.sh index 577e49cec49d2..3da35696e44b7 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -21,7 +21,7 @@ python ${DEFAULTS} tests/plugins/legacy/test_ddp_sequential_plugin.py::test_ddp_ python ${DEFAULTS} tests/plugins/legacy/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual_amp python ${DEFAULTS} tests/plugins/legacy/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic python ${DEFAULTS} tests/utilities/test_all_gather_grad.py::test_all_gather_collection -# python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance +python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance python ${DEFAULTS} tests/trainer/test_trainer.py::test_trainer_predict_ddp python ${DEFAULTS} tests/trainer/test_trainer.py::test_trainer_predict_dp python ${DEFAULTS} tests/trainer/logging_/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp