From 4d71065066bfa77810874698d93b9c36a5aa1264 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 9 Feb 2021 17:01:48 +0000 Subject: [PATCH 1/2] Fix RPC related tests, clean out old API, update for new accelerator API --- .../plugins/training_type/rpc.py | 54 +---------- .../plugins/training_type/rpc_sequential.py | 96 ++++++++++--------- .../legacy/test_ddp_sequential_plugin.py | 34 +++---- tests/plugins/legacy/test_rpc_plugin.py | 40 +------- tests/special_tests.sh | 5 +- 5 files changed, 71 insertions(+), 158 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/rpc.py b/pytorch_lightning/plugins/training_type/rpc.py index 40ca4fe6b9a4b..be81cd2a03c56 100644 --- a/pytorch_lightning/plugins/training_type/rpc.py +++ b/pytorch_lightning/plugins/training_type/rpc.py @@ -25,6 +25,7 @@ DEFAULT_RPC_TIMEOUT_SEC = 60. if _RPC_AVAILABLE: from torch.distributed import rpc + with suppress(ModuleNotFoundError, ImportError): from torch.distributed.rpc.constants import DEFAULT_RPC_TIMEOUT_SEC @@ -76,60 +77,11 @@ def rpc_save_model(self, save_model_fn, last_filepath, trainer, pl_module) -> No """ raise NotImplementedError - def on_main_rpc_connection(self, trainer) -> None: - """ - Called when main rpc connection has been established. - - Args: - trainer: The trainer object. - """ - raise NotImplementedError - - def on_accelerator_exit_rpc_process(self) -> None: - """ - Called to exit RPC process within the accelerator, that is being managed by main process. - - Args: - trainer: The trainer object. - """ - self.exit_rpc_process() - def exit_rpc_process(self): if self._is_rpc_initialized: torch.distributed.rpc.shutdown() self._is_rpc_initialized = False @property - def return_after_exit_rpc_process(self) -> bool: - """ - Override to decide whether to skip train/test function after shutdown completed. - Usually RPC shutdown is a join/exit function, afterwards we want to exit the process. - - Returns: - Whether to return after RPC exit. - """ - raise NotImplementedError - - def worker_optimizer_step(self, model: LightningModule, opt_idx: int, *args, **kwargs) -> None: - """ - Called when optimizer step is run on the main process. Used to signal any RPC workers to run optimizer step. - - Args: - model: The LightningModule. - opt_idx: The idx of the optimizer to carry out step on. - """ - raise NotImplementedError - - @property - def is_main_rpc_process(self) -> bool: - """ - Override to add logic to determine current process is main RPC process. - """ - raise NotImplementedError - - def barrier(self, name: Optional[str] = None) -> None: - """ - Override to define distributed sync communication. This needs to be handled differently due to - the RPC connection managing certain processes at the same time. - """ - raise NotImplementedError + def rpc_enabled(self) -> bool: + return True diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index b6e2bd9ecc93d..249959cb12e19 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -13,7 +13,7 @@ # limitations under the License import logging import os -from typing import Any, List, Optional, Sequence +from typing import List, Optional import torch import torch.distributed as torch_distrib @@ -22,8 +22,7 @@ from torch.optim import Optimizer from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel -from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.overrides.distributed import LightningDistributedModule from pytorch_lightning.plugins.training_type.rpc import DEFAULT_RPC_TIMEOUT_SEC, RPCPlugin from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only @@ -97,15 +96,18 @@ def __init__( self.checkpoint = checkpoint self.balance_mode = balance_mode self.pipelined_backward = pipelined_backward - self.main_rpc_process = False # Updated by main process, default for all secondary processes + self._main_rpc_process = True def init_ddp_connection( self, global_rank: int, world_size: int, ) -> None: - # what is this used for? - self.prepared_for_backwards = False + if self.lightning_module.trainer.amp_backend is not None: + raise MisconfigurationException( + 'RPCSequentialPlugin is currently not supported in Automatic Mixed Precision' + ) + if self._skip_init_connections(): return super().init_ddp_connection( @@ -119,21 +121,18 @@ def init_ddp_connection( self.set_main_rpc_process() self._check_sequential_model_exists(model) + + # check if user given balance is valid + if self.balance is not None: + self._assert_valid_model_balance() + if self.main_rpc_process: if self.balance is None: self._infer_model_balance() - self._assert_valid_model_balance() - - if not self.is_main_rpc_process: - self.on_accelerator_exit_rpc_process() - self.exit_rpc_process() - if self.return_after_exit_rpc_process: - return + self.init_pipe_module() else: - self.on_main_rpc_connection() - - def on_before_manual_backward(self, model: LightningDistributedDataParallel, output: Any): - pass + self.handle_transferred_pipe_module() + self.exit_rpc_process() def _infer_model_balance(self): log.info(f'Inferring model balance using {self.balance_mode} mode') @@ -231,21 +230,16 @@ def _infer_check_num_gpus(self): # Assume that the user wants to balance his model on all GPUs return self.world_size - def on_accelerator_exit_rpc_process(self) -> None: + def handle_transferred_pipe_module(self) -> None: if not self.lightning_module.running_stage == RunningStage.TESTING: torch_distrib.barrier() # Ensure we await main process initialization - # Add trainer/configure_optimizers to the pipe model for access in all worker processes rpc_pipe.PipeModel.trainer = self.lightning_module.trainer del rpc_pipe.PipeModel.trainer.model.sequential_module rpc_pipe.PipeModel.trainer.model.sequential_module = rpc_pipe.PipeModel rpc_pipe.PipeModel.configure_optimizers = self.lightning_module.configure_optimizers - super().on_accelerator_exit_rpc_process() - def set_main_rpc_process(self): - self.main_rpc_process = torch_distrib.get_rank(group=mpu.get_pipeline_parallel_group()) == 0 - - def on_main_rpc_connection(self) -> None: + def init_pipe_module(self) -> None: # Create pipe_module model = self.lightning_module self._find_and_init_pipe_module(model) @@ -253,21 +247,23 @@ def on_main_rpc_connection(self) -> None: torch_distrib.barrier() # Ensure we join main process initialization model.sequential_module.foreach_worker(register_optimizers, include_self=True) - # TODO: Move this to the connector - def _check_arguments(self, trainer): - if trainer.amp_backend is not None: - raise MisconfigurationException( - 'DDPSequentialPlugin is currently not supported in Automatic Mixed Precision' - ) + # TODO: Move this to the connector def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): """Run before precision plugin executes backward""" - def configure_ddp(self) -> None: - # process_group=mpu.get_data_parallel_group() - super().configure_ddp() - # Plugin handle backwards across processes. Currently not supported for DDP + pipe parallel - self._model.require_backward_grad_sync = False + def configure_ddp(self): + if self.main_rpc_process: + self.pre_configure_ddp() + + self._model = DistributedDataParallel( + LightningDistributedModule(self.model), + device_ids=self.determine_ddp_device_ids(), + process_group=mpu.get_data_parallel_group(), + **self._ddp_kwargs, + ) + # Plugin handle backwards across processes. Currently not supported for DDP + pipe parallel + self._model.require_backward_grad_sync = False @rank_zero_only def rpc_save_model(self, save_model_fn, last_filepath, trainer, pl_module) -> None: @@ -302,16 +298,19 @@ def distributed_sampler_kwargs(self): def data_parallel_group(self): return mpu.get_data_parallel_group() - @property - def is_main_rpc_process(self) -> bool: - return self.main_rpc_process + def set_main_rpc_process(self): + self.main_rpc_process = torch_distrib.get_rank(group=mpu.get_pipeline_parallel_group()) == 0 @property - def return_after_exit_rpc_process(self) -> bool: - return True + def main_rpc_process(self) -> bool: + return self._main_rpc_process + + @main_rpc_process.setter + def main_rpc_process(self, is_main_process): + self._main_rpc_process = is_main_process def barrier(self, name: Optional[str] = None) -> None: - if torch_distrib.is_initialized() and self.is_main_rpc_process: + if torch_distrib.is_initialized() and self.main_rpc_process: torch_distrib.barrier(group=self.data_parallel_group) def _check_pipe_available(self): @@ -322,11 +321,22 @@ def _check_pipe_available(self): def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, **kwargs) -> None: """Hook to do something after each optimizer step.""" - if self.rpc_enabled and self.is_main_rpc_process: - + if self.rpc_enabled and self.main_rpc_process: # Initialize optimizer step on main process self.worker_optimizer_step(model=self.lightning_module, opt_idx=optimizer_idx, **kwargs) + def post_training(self): + if self.main_rpc_process: + super().post_training() + + def start_training(self, trainer: 'Trainer') -> None: + if self.main_rpc_process: + super().start_training(trainer) + + def start_testing(self, trainer: 'Trainer') -> None: + if self.main_rpc_process: + super().start_testing(trainer) + class LightningPipeModule(nn.Module): """ diff --git a/tests/plugins/legacy/test_ddp_sequential_plugin.py b/tests/plugins/legacy/test_ddp_sequential_plugin.py index 2cf347aeb6ea6..c59d4beac7214 100644 --- a/tests/plugins/legacy/test_ddp_sequential_plugin.py +++ b/tests/plugins/legacy/test_ddp_sequential_plugin.py @@ -26,13 +26,6 @@ from tests.helpers.boring_model import RandomDataset -def cleanup(ctx, model): - """ - Cleanup function required to ensure we delete the pipe module at the end of the the test on all workers - """ - del model - - @pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed") @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @@ -54,12 +47,12 @@ def test_ddp_sequential_plugin_ddp_rpc_manual(tmpdir, args=None): trainer.fit(model) - if torch_distrib.get_rank() == 0: + if torch_distrib.is_initialized() and torch_distrib.get_rank() == 0: assert len(trainer.dev_debugger.pbar_added_metrics) > 0 if trainer.accelerator_backend.rpc_enabled: # Called at the end of trainer to ensure all processes are killed - trainer.accelerator_backend.ddp_plugin.exit_rpc_process() + trainer.accelerator_backend.training_type_plugin.exit_rpc_process() @pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed") @@ -81,14 +74,11 @@ def test_ddp_sequential_plugin_ddp_rpc_manual_amp(tmpdir, args=None): distributed_backend="ddp", plugins=[RPCSequentialPlugin(balance=[2, 1])], ) - try: + with pytest.raises( + MisconfigurationException, match='RPCSequentialPlugin is currently not supported in Automatic Mixed Precision' + ): trainer.fit(model) - assert len(trainer.dev_debugger.pbar_added_metrics) > 0 - - except MisconfigurationException as e: - assert str(e) == 'RPCSequentialPlugin is currently not supported in Automatic Mixed Precision' - @pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed") @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @@ -110,13 +100,12 @@ def test_ddp_sequential_plugin_ddp_rpc_automatic(tmpdir, args=None): trainer.fit(model) - if torch_distrib.get_rank() == 0: + if torch_distrib.is_initialized() and torch_distrib.get_rank() == 0: assert len(trainer.dev_debugger.pbar_added_metrics) > 0 if trainer.accelerator_backend.rpc_enabled: - # Called at the end of trainer to ensure all processes are killed - trainer.accelerator_backend.ddp_plugin.exit_rpc_process() + trainer.accelerator_backend.training_type_plugin.exit_rpc_process() @pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed") @@ -137,15 +126,14 @@ def test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance(tmpdir, args=None): plugins=[RPCSequentialPlugin(balance=[2, 2])], ) - try: + with pytest.raises( + MisconfigurationException, match="The provided balance sum: 4 does not match your Sequential length: 3" + ): trainer.fit(model) - except MisconfigurationException as e: - assert str(e) == 'The provided balance sum: 4 does not match your Sequential length: 3' - if trainer.accelerator_backend.rpc_enabled: # Called at the end of trainer to ensure all processes are killed - trainer.accelerator_backend.ddp_plugin.exit_rpc_process() + trainer.accelerator_backend.training_type_plugin.exit_rpc_process() class SequentialModelRPCManual(LightningModule): diff --git a/tests/plugins/legacy/test_rpc_plugin.py b/tests/plugins/legacy/test_rpc_plugin.py index 67e72df5dc93d..2c074e6c3afda 100644 --- a/tests/plugins/legacy/test_rpc_plugin.py +++ b/tests/plugins/legacy/test_rpc_plugin.py @@ -5,7 +5,7 @@ import pytest import torch -from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning import Trainer from pytorch_lightning.callbacks import Callback from pytorch_lightning.plugins.training_type.rpc_sequential import RPCPlugin from pytorch_lightning.utilities import _RPC_AVAILABLE @@ -56,39 +56,15 @@ class CustomRPCPlugin(RPCPlugin): def __init__(self, **kwargs): super().__init__(**kwargs) self.rpc_save_model_count = 0 - self.on_main_rpc_connect_count = 0 self.worker_optimizer_step_count = 0 - self.is_main_rpc_process_count = 0 - self.on_exit_rpc_process_count = 0 - self.return_after_exit_rpc_process_count = 0 - - def on_accelerator_exit_rpc_process(self) -> None: - self.on_exit_rpc_process_count += 1 def rpc_save_model(self, save_model_fn, last_filepath, trainer, pl_module) -> None: self.rpc_save_model_count += 1 - def on_main_rpc_connection(self) -> None: - self.on_main_rpc_connect_count += 1 - - def worker_optimizer_step(self, model: LightningModule, opt_idx: int, *args, **kwargs) -> None: - self.worker_optimizer_step_count += 1 - - @property - def is_main_rpc_process(self) -> bool: - self.is_main_rpc_process_count += 1 - return torch.distributed.get_rank() == 0 - - @property - def return_after_exit_rpc_process(self) -> bool: - self.return_after_exit_rpc_process_count += 1 - return False - def barrier(self, name: Optional[str] = None) -> None: return -@pytest.mark.skipif(True, reason="This test is currently broken") @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(not _RPC_AVAILABLE, reason="RPC is not available") @@ -112,17 +88,5 @@ def test_rpc_function_calls_ddp(tmpdir): trainer.fit(model) if trainer.global_rank == 0: # Main process assert plugin.rpc_save_model_count == max_epochs - assert plugin.on_main_rpc_connect_count == 1 - assert plugin.worker_optimizer_step_count == max_epochs * limit_train_batches - # Call once at init, and at optim step - assert plugin.is_main_rpc_process_count == 1 + plugin.worker_optimizer_step_count - assert plugin.on_exit_rpc_process_count == 0 else: # Worker process - assert plugin.rpc_save_model_count == 0 - assert plugin.on_main_rpc_connect_count == 0 - # Never signaled by worker, only by main process - assert plugin.worker_optimizer_step_count == 0 - # Call once at init, and at optim step - assert plugin.is_main_rpc_process_count == 1 + (max_epochs * limit_train_batches) - # Called at init - assert plugin.on_exit_rpc_process_count == 1 + assert plugin.rpc_save_model_count == max_epochs diff --git a/tests/special_tests.sh b/tests/special_tests.sh index 200ea1c2fd772..dff36b1cbc001 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -17,11 +17,10 @@ export PL_RUNNING_SPECIAL_TESTS=1 DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no" python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp python ${DEFAULTS} tests/models/test_sync_batchnorm.py::test_sync_batchnorm_ddp -# todo: resolve this test -# python ${DEFAULTS} tests/plugins/legacy/test_rpc_plugin.py::test_rpc_function_calls_ddp +python ${DEFAULTS} tests/plugins/legacy/test_rpc_plugin.py::test_rpc_function_calls_ddp python ${DEFAULTS} tests/plugins/legacy/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual python ${DEFAULTS} tests/plugins/legacy/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual_amp -# python ${DEFAULTS} tests/plugins/legacy/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic +python ${DEFAULTS} tests/plugins/legacy/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic python ${DEFAULTS} tests/plugins/legacy/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance python ${DEFAULTS} tests/utilities/test_all_gather_grad.py::test_all_gather_collection python ${DEFAULTS} tests/trainer/test_trainer.py::test_trainer_predict_ddp From e109f7ce2b41335c41d913349cc3c3de83aa0754 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 9 Feb 2021 17:07:20 +0000 Subject: [PATCH 2/2] Move tests out of legacy folder, update paths and names --- tests/plugins/legacy/__init__.py | 1 - .../plugins/{legacy => }/test_ddp_sequential_plugin.py | 8 ++++---- tests/plugins/{legacy => }/test_rpc_plugin.py | 0 tests/special_tests.sh | 10 +++++----- 4 files changed, 9 insertions(+), 10 deletions(-) delete mode 100644 tests/plugins/legacy/__init__.py rename tests/plugins/{legacy => }/test_ddp_sequential_plugin.py (96%) rename tests/plugins/{legacy => }/test_rpc_plugin.py (100%) diff --git a/tests/plugins/legacy/__init__.py b/tests/plugins/legacy/__init__.py deleted file mode 100644 index b1fca65e60042..0000000000000 --- a/tests/plugins/legacy/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# todo: feel free to move any of these "legacy" tests up... diff --git a/tests/plugins/legacy/test_ddp_sequential_plugin.py b/tests/plugins/test_ddp_sequential_plugin.py similarity index 96% rename from tests/plugins/legacy/test_ddp_sequential_plugin.py rename to tests/plugins/test_ddp_sequential_plugin.py index c59d4beac7214..6daf2d1998bbe 100644 --- a/tests/plugins/legacy/test_ddp_sequential_plugin.py +++ b/tests/plugins/test_ddp_sequential_plugin.py @@ -32,7 +32,7 @@ @pytest.mark.skipif( not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" ) -def test_ddp_sequential_plugin_ddp_rpc_manual(tmpdir, args=None): +def test_rpc_sequential_plugin_manual(tmpdir, args=None): model = SequentialModelRPCManual() trainer = Trainer( max_epochs=2, @@ -61,7 +61,7 @@ def test_ddp_sequential_plugin_ddp_rpc_manual(tmpdir, args=None): @pytest.mark.skipif( not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" ) -def test_ddp_sequential_plugin_ddp_rpc_manual_amp(tmpdir, args=None): +def test_rpc_sequential_plugin_manual_amp(tmpdir, args=None): model = SequentialModelRPCManual() trainer = Trainer( max_epochs=2, @@ -86,7 +86,7 @@ def test_ddp_sequential_plugin_ddp_rpc_manual_amp(tmpdir, args=None): @pytest.mark.skipif( not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" ) -def test_ddp_sequential_plugin_ddp_rpc_automatic(tmpdir, args=None): +def test_rpc_sequential_plugin_automatic(tmpdir, args=None): model = SequentialModelRPCAutomatic() trainer = Trainer( max_epochs=2, @@ -114,7 +114,7 @@ def test_ddp_sequential_plugin_ddp_rpc_automatic(tmpdir, args=None): @pytest.mark.skipif( not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" ) -def test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance(tmpdir, args=None): +def test_rpc_sequential_plugin_with_wrong_balance(tmpdir, args=None): model = SequentialModelRPCAutomatic() trainer = Trainer( max_epochs=2, diff --git a/tests/plugins/legacy/test_rpc_plugin.py b/tests/plugins/test_rpc_plugin.py similarity index 100% rename from tests/plugins/legacy/test_rpc_plugin.py rename to tests/plugins/test_rpc_plugin.py diff --git a/tests/special_tests.sh b/tests/special_tests.sh index dff36b1cbc001..3ad6e65512585 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -17,11 +17,11 @@ export PL_RUNNING_SPECIAL_TESTS=1 DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no" python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp python ${DEFAULTS} tests/models/test_sync_batchnorm.py::test_sync_batchnorm_ddp -python ${DEFAULTS} tests/plugins/legacy/test_rpc_plugin.py::test_rpc_function_calls_ddp -python ${DEFAULTS} tests/plugins/legacy/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual -python ${DEFAULTS} tests/plugins/legacy/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual_amp -python ${DEFAULTS} tests/plugins/legacy/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic -python ${DEFAULTS} tests/plugins/legacy/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance +python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp +python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_manual +python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_manual_amp +python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_automatic +python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_with_wrong_balance python ${DEFAULTS} tests/utilities/test_all_gather_grad.py::test_all_gather_collection python ${DEFAULTS} tests/trainer/test_trainer.py::test_trainer_predict_ddp python ${DEFAULTS} tests/trainer/test_trainer.py::test_trainer_predict_dp