From e697ca6c945eb6a5a51c47f7e1a19f44b097e2ff Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 16 Feb 2022 22:15:00 -0800 Subject: [PATCH 01/25] Create optimizers in DDP strategy after moving model to device --- CHANGELOG.md | 3 +++ pytorch_lightning/strategies/ddp.py | 15 ++++++++------- pytorch_lightning/strategies/ddp_spawn.py | 6 ++++-- pytorch_lightning/strategies/sharded.py | 17 +++++++++++++++++ tests/strategies/test_ddp_strategy.py | 21 +++++++++++++++++++++ 5 files changed, 53 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c189e6ce006c2..fba153490a5ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -853,6 +853,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed initializing optimizers unnecessarily in `DDPFullyShardedStrategy` ([#12267](https://github.com/PyTorchLightning/pytorch-lightning/pull/12267)) +- Fixed `DDPStrategy` and `DDPSpawnStrategy` to initialize optimizers only after moving the module to the device ([#]()) + + ## [1.5.10] - 2022-02-08 ### Fixed diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index af09417ef7249..e0a33b7a8ecfb 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -146,7 +146,6 @@ def setup_environment(self) -> None: super().setup_environment() def setup(self, trainer: "pl.Trainer") -> None: - super().setup(trainer) # share ddp pids to all processes self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts) if self._should_run_deadlock_detection(): @@ -165,6 +164,14 @@ def setup(self, trainer: "pl.Trainer") -> None: self.configure_ddp() + # set up optimizers after the wrapped module has been moved to the device + super().setup(trainer) + if _TORCH_GREATER_EQUAL_1_10 and trainer.state.fn == TrainerFn.FITTING: + import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD + + if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState): + self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter) + def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self.determine_ddp_device_ids() @@ -228,12 +235,6 @@ def _register_ddp_hooks(self) -> None: ddp_comm_wrapper=self._ddp_comm_wrapper, ) - if _TORCH_GREATER_EQUAL_1_10 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: - import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD - - if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState): - self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter) - def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int): log.detail(f"{self.__class__.__name__}: reinitializing optimizers with post localSGD") optimizers = self.optimizers diff --git a/pytorch_lightning/strategies/ddp_spawn.py b/pytorch_lightning/strategies/ddp_spawn.py index 129719ba14459..84e6dd59e3ff0 100644 --- a/pytorch_lightning/strategies/ddp_spawn.py +++ b/pytorch_lightning/strategies/ddp_spawn.py @@ -120,7 +120,6 @@ def _configure_launcher(self): def setup(self, trainer: "pl.Trainer") -> None: os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) - super().setup(trainer) # move the model to the correct device self.model_to_device() @@ -133,7 +132,10 @@ def setup(self, trainer: "pl.Trainer") -> None: self.model = self._layer_sync.apply(self.model) # skip wrapping the model if we are not fitting as no gradients need to be exchanged - self.configure_ddp() + trainer_fn = trainer.state.fn + if trainer_fn == TrainerFn.FITTING: + self.configure_ddp() + super().setup(trainer) def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" diff --git a/pytorch_lightning/strategies/sharded.py b/pytorch_lightning/strategies/sharded.py index 6811721ecaab7..5634115f904d7 100644 --- a/pytorch_lightning/strategies/sharded.py +++ b/pytorch_lightning/strategies/sharded.py @@ -40,6 +40,23 @@ class DDPShardedStrategy(DDPStrategy): strategy_name = "ddp_sharded" _REDUCE_BUFFER_SIZE_DEFAULT: int = 2 ** 23 # 8M + def setup(self, trainer: "pl.Trainer") -> None: + # share ddp pids to all processes + self._rank_0_has_called_call_children_scripts = self.broadcast(self._rank_0_has_called_call_children_scripts) + if self._should_run_deadlock_detection(): + self._share_information_to_prevent_deadlock() + + # move the model to the correct device + self.model_to_device() + + if self.sync_batchnorm: + self.model = self.configure_sync_batchnorm(self.model) + + # skip wrapping the model if we are not fitting as no gradients need to be exchanged + trainer_fn = trainer.state.fn + if trainer_fn == TrainerFn.FITTING: + self.configure_ddp() + def configure_ddp(self) -> None: trainer = self.lightning_module.trainer if "reduce_buffer_size" not in self._ddp_kwargs: diff --git a/tests/strategies/test_ddp_strategy.py b/tests/strategies/test_ddp_strategy.py index e1ed780275f0f..6cb55cad945d3 100644 --- a/tests/strategies/test_ddp_strategy.py +++ b/tests/strategies/test_ddp_strategy.py @@ -147,3 +147,24 @@ def test_ddp_dont_configure_sync_batchnorm(trainer_fn): trainer.strategy.setup(trainer) # because TrainerFn is not FITTING, model is not configured with sync batchnorm assert not isinstance(trainer.strategy.model.layer, torch.nn.modules.batchnorm.SyncBatchNorm) + + +class CheckOptimizerDeviceModel(BoringModel): + def configure_optimizers(self): + assert all(param.device.type == "cuda" for param in self.parameters()) + super().configure_optimizers() + + +@RunIf(min_gpus=1) +@pytest.mark.parametrize("strategy", ("ddp", "ddp_spawn")) +def test_model_parameters_on_device_for_optimizer(strategy): + """Test that the strategy has moved the parameters to the device by the time the optimizer gets created.""" + model = CheckOptimizerDeviceModel() + trainer = Trainer( + default_root_dir=os.getcwd(), + fast_dev_run=1, + accelerator="gpu", + devices=1, + strategy=strategy, + ) + trainer.fit(model) From a6cf5db5b588d54cf89ec14104c190c359ae4476 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 16 Feb 2022 22:16:35 -0800 Subject: [PATCH 02/25] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fba153490a5ae..dad12b004b028 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -853,7 +853,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed initializing optimizers unnecessarily in `DDPFullyShardedStrategy` ([#12267](https://github.com/PyTorchLightning/pytorch-lightning/pull/12267)) -- Fixed `DDPStrategy` and `DDPSpawnStrategy` to initialize optimizers only after moving the module to the device ([#]()) +- Fixed `DDPStrategy` and `DDPSpawnStrategy` to initialize optimizers only after moving the module to the device ([#11952](https://github.com/PyTorchLightning/pytorch-lightning/pull/11952)) ## [1.5.10] - 2022-02-08 From 5d362fdea7b0a428e7a1f8c7d2e6ab81f3101fe9 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 4 Mar 2022 22:34:58 -0800 Subject: [PATCH 03/25] Update sharded.py --- pytorch_lightning/strategies/sharded.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pytorch_lightning/strategies/sharded.py b/pytorch_lightning/strategies/sharded.py index 5634115f904d7..b8fcad22ec878 100644 --- a/pytorch_lightning/strategies/sharded.py +++ b/pytorch_lightning/strategies/sharded.py @@ -25,6 +25,7 @@ from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE +from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_only if _FAIRSCALE_AVAILABLE: @@ -57,6 +58,11 @@ def setup(self, trainer: "pl.Trainer") -> None: if trainer_fn == TrainerFn.FITTING: self.configure_ddp() + self.accelerator.setup(trainer) + self.setup_optimizers(trainer) + self.setup_precision_plugin() + optimizers_to_device(self.optimizers, self.root_device) + def configure_ddp(self) -> None: trainer = self.lightning_module.trainer if "reduce_buffer_size" not in self._ddp_kwargs: From f2b6f81adb7ef5d25e58ae5d73f46e2ff3a6aa9d Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 4 Mar 2022 23:08:01 -0800 Subject: [PATCH 04/25] update --- pytorch_lightning/strategies/ddp.py | 9 ++++++++- pytorch_lightning/strategies/sharded.py | 3 ++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index e0a33b7a8ecfb..27d444bf28e40 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -52,6 +52,7 @@ _TORCH_GREATER_EQUAL_1_11, ) from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -151,6 +152,8 @@ def setup(self, trainer: "pl.Trainer") -> None: if self._should_run_deadlock_detection(): self._share_information_to_prevent_deadlock() + self.accelerator.setup(trainer) + # move the model to the correct device self.model_to_device() @@ -165,7 +168,10 @@ def setup(self, trainer: "pl.Trainer") -> None: self.configure_ddp() # set up optimizers after the wrapped module has been moved to the device - super().setup(trainer) + self.setup_optimizers(trainer) + self.setup_precision_plugin() + optimizers_to_device(self.optimizers, self.root_device) + if _TORCH_GREATER_EQUAL_1_10 and trainer.state.fn == TrainerFn.FITTING: import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD @@ -276,6 +282,7 @@ def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int): optimizers[x] = post_localSGD_optimizer del optimizer self.optimizers = optimizers + optimizers_to_device(self.optimizers, self.root_device) def configure_ddp(self) -> None: log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel") diff --git a/pytorch_lightning/strategies/sharded.py b/pytorch_lightning/strategies/sharded.py index b8fcad22ec878..504b15df10736 100644 --- a/pytorch_lightning/strategies/sharded.py +++ b/pytorch_lightning/strategies/sharded.py @@ -47,6 +47,8 @@ def setup(self, trainer: "pl.Trainer") -> None: if self._should_run_deadlock_detection(): self._share_information_to_prevent_deadlock() + self.accelerator.setup(trainer) + # move the model to the correct device self.model_to_device() @@ -58,7 +60,6 @@ def setup(self, trainer: "pl.Trainer") -> None: if trainer_fn == TrainerFn.FITTING: self.configure_ddp() - self.accelerator.setup(trainer) self.setup_optimizers(trainer) self.setup_precision_plugin() optimizers_to_device(self.optimizers, self.root_device) From f3f39aff7ef9298505a0a060b04a638f7eae4a0b Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 4 Mar 2022 23:22:52 -0800 Subject: [PATCH 05/25] ordering --- pytorch_lightning/strategies/ddp.py | 4 ++-- pytorch_lightning/strategies/ddp_spawn.py | 9 ++++++++- pytorch_lightning/strategies/sharded.py | 2 +- pytorch_lightning/strategies/tpu_spawn.py | 6 ++---- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index 27d444bf28e40..60c9e1886848f 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -169,15 +169,15 @@ def setup(self, trainer: "pl.Trainer") -> None: # set up optimizers after the wrapped module has been moved to the device self.setup_optimizers(trainer) - self.setup_precision_plugin() optimizers_to_device(self.optimizers, self.root_device) - if _TORCH_GREATER_EQUAL_1_10 and trainer.state.fn == TrainerFn.FITTING: import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState): self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter) + self.setup_precision_plugin() + def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self.determine_ddp_device_ids() diff --git a/pytorch_lightning/strategies/ddp_spawn.py b/pytorch_lightning/strategies/ddp_spawn.py index 84e6dd59e3ff0..79ded35b31c3c 100644 --- a/pytorch_lightning/strategies/ddp_spawn.py +++ b/pytorch_lightning/strategies/ddp_spawn.py @@ -37,6 +37,7 @@ from pytorch_lightning.utilities.distributed import group as _group from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_11 +from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -121,6 +122,8 @@ def _configure_launcher(self): def setup(self, trainer: "pl.Trainer") -> None: os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) + self.accelerator.setup(trainer) + # move the model to the correct device self.model_to_device() @@ -135,7 +138,11 @@ def setup(self, trainer: "pl.Trainer") -> None: trainer_fn = trainer.state.fn if trainer_fn == TrainerFn.FITTING: self.configure_ddp() - super().setup(trainer) + + # set up optimizers after the wrapped module has been moved to the device + self.setup_optimizers(trainer) + optimizers_to_device(self.optimizers, self.root_device) + self.setup_precision_plugin() def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" diff --git a/pytorch_lightning/strategies/sharded.py b/pytorch_lightning/strategies/sharded.py index 504b15df10736..77997721db553 100644 --- a/pytorch_lightning/strategies/sharded.py +++ b/pytorch_lightning/strategies/sharded.py @@ -61,8 +61,8 @@ def setup(self, trainer: "pl.Trainer") -> None: self.configure_ddp() self.setup_optimizers(trainer) - self.setup_precision_plugin() optimizers_to_device(self.optimizers, self.root_device) + self.setup_precision_plugin() def configure_ddp(self) -> None: trainer = self.lightning_module.trainer diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index 3eb776e3d8ab5..4ecb587d529dc 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -126,9 +126,6 @@ def _configure_launcher(self): def setup(self, trainer: "pl.Trainer") -> None: self.start_method = "fork" self.accelerator.setup(trainer) - self.setup_optimizers(trainer) - self.setup_precision_plugin() - optimizers_to_device(self.optimizers, self.root_device) if self.debug: os.environ["PT_XLA_DEBUG"] = str(1) @@ -141,7 +138,8 @@ def setup(self, trainer: "pl.Trainer") -> None: set_shared_parameters(self.model.module, shared_params) self.setup_optimizers(trainer) - self.precision_plugin.connect(self.model, None, None) + optimizers_to_device(self.optimizers, self.root_device) + self.setup_precision_plugin() def _setup_model(self, model: Module) -> Module: return model From f23eb99d97d6272af3e61cf7b698fd4c469257c4 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 7 Mar 2022 22:43:14 -0800 Subject: [PATCH 06/25] update bagua --- pytorch_lightning/strategies/bagua.py | 35 +++++++++++++++++++++++-- tests/strategies/test_bagua_strategy.py | 4 +-- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/strategies/bagua.py b/pytorch_lightning/strategies/bagua.py index 17318331b840d..c09d0acd489cb 100644 --- a/pytorch_lightning/strategies/bagua.py +++ b/pytorch_lightning/strategies/bagua.py @@ -12,9 +12,11 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.strategies.strategy import TBroadcast +from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _BAGUA_AVAILABLE +from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.seed import reset_seed if _BAGUA_AVAILABLE: @@ -148,6 +150,35 @@ def _set_node_environment_variables(self) -> None: os.environ["WORLD_SIZE"] = str(self.world_size) os.environ["LOCAL_RANK"] = str(self.local_rank) + def setup(self, trainer: "pl.Trainer") -> None: + self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts) + if self._should_run_deadlock_detection(): + self._share_information_to_prevent_deadlock() + + self.accelerator.setup(trainer) + + # move the model to the correct device + self.model_to_device() + + if self._layer_sync: + self.model = self._layer_sync.apply(self.model) + + # skip wrapping the model if we are not fitting as no gradients need to be exchanged + trainer_fn = trainer.state.fn + + # set up optimizers after the module has been moved to the device + # but before the module has been wrapped + self.setup_optimizers(trainer) + optimizers_to_device(self.optimizers, self.root_device) + + if trainer_fn == TrainerFn.FITTING: + self._configure_bagua_model(trainer) + + self.setup_precision_plugin() + self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts) + if self._should_run_deadlock_detection(): + self._share_information_to_prevent_deadlock() + def _check_qadam_optimizer(self) -> None: has_qadam_optimizer = any([isinstance(opt, QAdamOptimizer) for opt in self.optimizers]) @@ -156,12 +187,12 @@ def _check_qadam_optimizer(self) -> None: self._bagua_kwargs["q_adam_optimizer"] = self.optimizers[0] - def configure_ddp(self) -> None: + def _configure_bagua_model(self, trainer: "pl.Trainer") -> None: model = LightningBaguaModule(self.model) # type: ignore[arg-type] self._model = self._setup_model(model) # start the background communication for async algorithm - if self.lightning_module.trainer.training and self._bagua_algorithm == "async": + if trainer.training and self._bagua_algorithm == "async": self.model.bagua_algorithm.resume(self.model) # type: ignore def _setup_model(self, model: Module) -> BaguaDistributedDataParallel: diff --git a/tests/strategies/test_bagua_strategy.py b/tests/strategies/test_bagua_strategy.py index 2aee59f9e6565..da966ae1c037e 100644 --- a/tests/strategies/test_bagua_strategy.py +++ b/tests/strategies/test_bagua_strategy.py @@ -85,9 +85,9 @@ def test_configuration(algorithm, tmpdir): ), mock.patch("bagua.torch_api.communication.is_initialized", return_value=True): if algorithm == "qadam": with pytest.raises(MisconfigurationException, match="Bagua QAdam can only accept one QAdamOptimizer"): - trainer.strategy.configure_ddp() + trainer.strategy._configure_bagua_model() else: - trainer.strategy.configure_ddp() + trainer.strategy._configure_bagua_model() @RunIf(bagua=True, min_gpus=1) From d15378eafa912d2b29cf7a1e92d5c1c84526b77a Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 7 Mar 2022 22:47:21 -0800 Subject: [PATCH 07/25] Update fully_sharded.py --- pytorch_lightning/strategies/fully_sharded.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/strategies/fully_sharded.py b/pytorch_lightning/strategies/fully_sharded.py index b61429264d80a..a982aa435d2fb 100644 --- a/pytorch_lightning/strategies/fully_sharded.py +++ b/pytorch_lightning/strategies/fully_sharded.py @@ -138,6 +138,9 @@ def setup_distributed(self) -> None: def setup(self, trainer: "pl.Trainer") -> None: self.accelerator.setup(trainer) + self.setup_optimizers(trainer) + self.setup_precision_plugin() + optimizers_to_device(self.optimizers, self.root_device) if trainer.state.fn == TrainerFn.FITTING and self._layer_sync: self.model = self._layer_sync.apply(self.model) @@ -145,8 +148,6 @@ def setup(self, trainer: "pl.Trainer") -> None: self.configure_ddp() self.barrier() self.setup_optimizers(trainer) - optimizers_to_device(self.optimizers, self.root_device) - self.setup_precision_plugin() @contextlib.contextmanager def model_sharded_context(self) -> Generator: @@ -183,6 +184,9 @@ def configure_ddp(self) -> None: # (TODO: need to figure out solution) self.model_to_device() + # setup optimizers after fully sharded has wrapped the lightning module + self.setup_optimizers(self.lightning_module.trainer) + def model_to_device(self) -> None: log.detail(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...") # ensure we update the device type in the lightning module From bfbbf752f5807954e5c0df082a9c29b0a4c7edcb Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 7 Mar 2022 23:39:23 -0800 Subject: [PATCH 08/25] Update bagua.py --- pytorch_lightning/strategies/bagua.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/strategies/bagua.py b/pytorch_lightning/strategies/bagua.py index c09d0acd489cb..c215eb93aad8f 100644 --- a/pytorch_lightning/strategies/bagua.py +++ b/pytorch_lightning/strategies/bagua.py @@ -160,7 +160,7 @@ def setup(self, trainer: "pl.Trainer") -> None: # move the model to the correct device self.model_to_device() - if self._layer_sync: + if self._layer_sync and self.model: self.model = self._layer_sync.apply(self.model) # skip wrapping the model if we are not fitting as no gradients need to be exchanged From 05c12df3b696d51ac99bf86d191a84e5541ed8c4 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 7 Mar 2022 23:51:55 -0800 Subject: [PATCH 09/25] update --- pytorch_lightning/strategies/sharded.py | 2 +- tests/strategies/test_bagua_strategy.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/strategies/sharded.py b/pytorch_lightning/strategies/sharded.py index 77997721db553..b50181bb4a565 100644 --- a/pytorch_lightning/strategies/sharded.py +++ b/pytorch_lightning/strategies/sharded.py @@ -43,7 +43,7 @@ class DDPShardedStrategy(DDPStrategy): def setup(self, trainer: "pl.Trainer") -> None: # share ddp pids to all processes - self._rank_0_has_called_call_children_scripts = self.broadcast(self._rank_0_has_called_call_children_scripts) + self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts) if self._should_run_deadlock_detection(): self._share_information_to_prevent_deadlock() diff --git a/tests/strategies/test_bagua_strategy.py b/tests/strategies/test_bagua_strategy.py index da966ae1c037e..416c84543d0dd 100644 --- a/tests/strategies/test_bagua_strategy.py +++ b/tests/strategies/test_bagua_strategy.py @@ -85,9 +85,9 @@ def test_configuration(algorithm, tmpdir): ), mock.patch("bagua.torch_api.communication.is_initialized", return_value=True): if algorithm == "qadam": with pytest.raises(MisconfigurationException, match="Bagua QAdam can only accept one QAdamOptimizer"): - trainer.strategy._configure_bagua_model() + trainer.strategy._configure_bagua_model(trainer) else: - trainer.strategy._configure_bagua_model() + trainer.strategy._configure_bagua_model(trainer) @RunIf(bagua=True, min_gpus=1) @@ -109,7 +109,7 @@ def test_qadam_configuration(tmpdir): with mock.patch( "bagua.torch_api.data_parallel.bagua_distributed.BaguaDistributedDataParallel.__init__", return_value=None ), mock.patch("bagua.torch_api.communication.is_initialized", return_value=True): - trainer.strategy.configure_ddp() + trainer.strategy._configure_bagua_model(trainer) def test_bagua_not_available(monkeypatch): From b4520fad5d65157f01cc7cc9b2120b499c82d3bb Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 8 Mar 2022 00:29:04 -0800 Subject: [PATCH 10/25] Update sharded.py --- pytorch_lightning/strategies/sharded.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/strategies/sharded.py b/pytorch_lightning/strategies/sharded.py index b50181bb4a565..44f27758b670c 100644 --- a/pytorch_lightning/strategies/sharded.py +++ b/pytorch_lightning/strategies/sharded.py @@ -52,8 +52,8 @@ def setup(self, trainer: "pl.Trainer") -> None: # move the model to the correct device self.model_to_device() - if self.sync_batchnorm: - self.model = self.configure_sync_batchnorm(self.model) + if self._layer_sync: + self.model = self._layer_sync.apply(self.model) # skip wrapping the model if we are not fitting as no gradients need to be exchanged trainer_fn = trainer.state.fn From 2823d63c77ffb7e08edbb9ad1d09342b27c7f111 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 9 Mar 2022 21:31:11 -0800 Subject: [PATCH 11/25] sharded opt --- pytorch_lightning/strategies/ddp.py | 2 +- pytorch_lightning/strategies/ddp_spawn.py | 2 +- pytorch_lightning/strategies/sharded.py | 23 +++++++++++------------ tests/strategies/test_sharded_strategy.py | 3 ++- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index 60c9e1886848f..c641225405d25 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -341,7 +341,7 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT: def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: with self.precision_plugin.val_step_context(): - if isinstance(self.model, DistributedDataParallel): + if isinstance(self.model, LightningDistributedModule): # used when calling `trainer.fit` return self.model(*args, **kwargs) else: diff --git a/pytorch_lightning/strategies/ddp_spawn.py b/pytorch_lightning/strategies/ddp_spawn.py index 79ded35b31c3c..516c5b7647b7b 100644 --- a/pytorch_lightning/strategies/ddp_spawn.py +++ b/pytorch_lightning/strategies/ddp_spawn.py @@ -257,7 +257,7 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT: def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: with self.precision_plugin.val_step_context(): - if isinstance(self.model, DistributedDataParallel): + if isinstance(self.model, LightningDistributedModule): # used when calling `trainer.fit` return self.model(*args, **kwargs) else: diff --git a/pytorch_lightning/strategies/sharded.py b/pytorch_lightning/strategies/sharded.py index 44f27758b670c..1fa140d5bfc85 100644 --- a/pytorch_lightning/strategies/sharded.py +++ b/pytorch_lightning/strategies/sharded.py @@ -58,22 +58,21 @@ def setup(self, trainer: "pl.Trainer") -> None: # skip wrapping the model if we are not fitting as no gradients need to be exchanged trainer_fn = trainer.state.fn if trainer_fn == TrainerFn.FITTING: - self.configure_ddp() + self._configure_sdp(trainer) - self.setup_optimizers(trainer) - optimizers_to_device(self.optimizers, self.root_device) self.setup_precision_plugin() - def configure_ddp(self) -> None: - trainer = self.lightning_module.trainer + def _configure_sdp(self, trainer: "pl.Trainer") -> None: if "reduce_buffer_size" not in self._ddp_kwargs: # For multi-node training, enabling bucketing will improve performance. self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0 + self.setup_optimizers(trainer) self.model, self.optimizers = self._setup_model_and_optimizers( model=LightningShardedDataParallel(self.model), - optimizers=trainer.optimizers, + optimizers=self.optimizers, ) + optimizers_to_device(self.optimizers, self.root_device) def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: """Wraps the model and optimizers with fairscale components. @@ -86,6 +85,12 @@ def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer] model = ShardedDataParallel(model, sharded_optimizer=optimizers, **self._ddp_kwargs) return model, optimizers + def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]: + if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING: + return optimizers + + return self._reinit_optimizers_with_oss(optimizers) + def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]: for x, optimizer in enumerate(optimizers): if isinstance(optimizer, LightningOptimizer): @@ -103,12 +108,6 @@ def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, Lightnin del optimizer return optimizers - def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]: - if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING: - return optimizers - - return self._reinit_optimizers_with_oss(optimizers) - def optimizer_state(self, optimizer: "OSS") -> Optional[dict]: if isinstance(optimizer, LightningOptimizer): optimizer = optimizer._optimizer diff --git a/tests/strategies/test_sharded_strategy.py b/tests/strategies/test_sharded_strategy.py index 5a454af46c0bb..66c30b6d099c8 100644 --- a/tests/strategies/test_sharded_strategy.py +++ b/tests/strategies/test_sharded_strategy.py @@ -257,9 +257,10 @@ def test_custom_kwargs_sharded_reduce_buffer_size(tmpdir, params, expected_buffe strategy.num_nodes = num_nodes strategy.model = Mock(spec=LightningModule) strategy.model.trainer = Mock() + strategy.model.trainer.state.fn = TrainerFn.FITTING with mock.patch("pytorch_lightning.strategies.sharded.ShardedDataParallel", autospec=True) as mock_sharded: - strategy.configure_ddp() + strategy._configure_sdp(strategy.model.trainer) args, kwargs = mock_sharded.call_args assert "reduce_buffer_size" in kwargs From f9843c95ec54e05692773c959fc75c51000d3f4f Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 9 Mar 2022 23:17:04 -0800 Subject: [PATCH 12/25] Update ddp.py --- pytorch_lightning/strategies/ddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index c641225405d25..60c9e1886848f 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -341,7 +341,7 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT: def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: with self.precision_plugin.val_step_context(): - if isinstance(self.model, LightningDistributedModule): + if isinstance(self.model, DistributedDataParallel): # used when calling `trainer.fit` return self.model(*args, **kwargs) else: From 917e14a8ef645f666828a7d5d6d2389b1e37d576 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 10 Mar 2022 13:26:51 -0800 Subject: [PATCH 13/25] update test --- pytorch_lightning/strategies/sharded.py | 10 ++++++---- tests/strategies/test_sharded_strategy.py | 3 +-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/strategies/sharded.py b/pytorch_lightning/strategies/sharded.py index 1fa140d5bfc85..1e9ff750c8348 100644 --- a/pytorch_lightning/strategies/sharded.py +++ b/pytorch_lightning/strategies/sharded.py @@ -63,10 +63,7 @@ def setup(self, trainer: "pl.Trainer") -> None: self.setup_precision_plugin() def _configure_sdp(self, trainer: "pl.Trainer") -> None: - if "reduce_buffer_size" not in self._ddp_kwargs: - # For multi-node training, enabling bucketing will improve performance. - self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0 - + self._set_ddp_kwargs() self.setup_optimizers(trainer) self.model, self.optimizers = self._setup_model_and_optimizers( model=LightningShardedDataParallel(self.model), @@ -74,6 +71,11 @@ def _configure_sdp(self, trainer: "pl.Trainer") -> None: ) optimizers_to_device(self.optimizers, self.root_device) + def _set_ddp_kwargs(self) -> None: + if "reduce_buffer_size" not in self._ddp_kwargs: + # For multi-node training, enabling bucketing will improve performance. + self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0 + def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: """Wraps the model and optimizers with fairscale components. diff --git a/tests/strategies/test_sharded_strategy.py b/tests/strategies/test_sharded_strategy.py index 66c30b6d099c8..9dc8f27aead7d 100644 --- a/tests/strategies/test_sharded_strategy.py +++ b/tests/strategies/test_sharded_strategy.py @@ -257,10 +257,9 @@ def test_custom_kwargs_sharded_reduce_buffer_size(tmpdir, params, expected_buffe strategy.num_nodes = num_nodes strategy.model = Mock(spec=LightningModule) strategy.model.trainer = Mock() - strategy.model.trainer.state.fn = TrainerFn.FITTING with mock.patch("pytorch_lightning.strategies.sharded.ShardedDataParallel", autospec=True) as mock_sharded: - strategy._configure_sdp(strategy.model.trainer) + strategy._set_ddp_kwargs() args, kwargs = mock_sharded.call_args assert "reduce_buffer_size" in kwargs From ec53da67e05cad69a9e3cde26d51973799588f2f Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 21 Mar 2022 21:30:36 -0700 Subject: [PATCH 14/25] rebase --- pytorch_lightning/strategies/ddp.py | 12 +++++------- pytorch_lightning/strategies/ddp_spawn.py | 9 ++------- pytorch_lightning/strategies/sharded.py | 8 +++++--- tests/strategies/test_sharded_strategy.py | 2 +- 4 files changed, 13 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index 60c9e1886848f..858627111b8f0 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -159,18 +159,16 @@ def setup(self, trainer: "pl.Trainer") -> None: # skip wrapping the model if we are not fitting as no gradients need to be exchanged trainer_fn = trainer.state.fn - if trainer_fn != TrainerFn.FITTING: - return - - if self._layer_sync: - self.model = self._layer_sync.apply(self.model) - self.configure_ddp() + if trainer_fn == TrainerFn.FITTING: + if self._layer_sync: + self.model = self._layer_sync.apply(self.model) + self.configure_ddp() # set up optimizers after the wrapped module has been moved to the device self.setup_optimizers(trainer) optimizers_to_device(self.optimizers, self.root_device) - if _TORCH_GREATER_EQUAL_1_10 and trainer.state.fn == TrainerFn.FITTING: + if _TORCH_GREATER_EQUAL_1_10 and trainer_fn == TrainerFn.FITTING: import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState): diff --git a/pytorch_lightning/strategies/ddp_spawn.py b/pytorch_lightning/strategies/ddp_spawn.py index 516c5b7647b7b..b4f478c8bca05 100644 --- a/pytorch_lightning/strategies/ddp_spawn.py +++ b/pytorch_lightning/strategies/ddp_spawn.py @@ -127,16 +127,11 @@ def setup(self, trainer: "pl.Trainer") -> None: # move the model to the correct device self.model_to_device() - trainer_fn = self.lightning_module.trainer.state.fn - if trainer_fn != TrainerFn.FITTING: - return - - if self._layer_sync: - self.model = self._layer_sync.apply(self.model) - # skip wrapping the model if we are not fitting as no gradients need to be exchanged trainer_fn = trainer.state.fn if trainer_fn == TrainerFn.FITTING: + if self._layer_sync: + self.model = self._layer_sync.apply(self.model) self.configure_ddp() # set up optimizers after the wrapped module has been moved to the device diff --git a/pytorch_lightning/strategies/sharded.py b/pytorch_lightning/strategies/sharded.py index 1e9ff750c8348..91a4b7ddf12d1 100644 --- a/pytorch_lightning/strategies/sharded.py +++ b/pytorch_lightning/strategies/sharded.py @@ -58,13 +58,15 @@ def setup(self, trainer: "pl.Trainer") -> None: # skip wrapping the model if we are not fitting as no gradients need to be exchanged trainer_fn = trainer.state.fn if trainer_fn == TrainerFn.FITTING: - self._configure_sdp(trainer) + if self._layer_sync: + self.model = self._layer_sync.apply(self.model) + self.configure_ddp() self.setup_precision_plugin() - def _configure_sdp(self, trainer: "pl.Trainer") -> None: + def configure_ddp(self) -> None: self._set_ddp_kwargs() - self.setup_optimizers(trainer) + self.setup_optimizers(self.model.trainer) self.model, self.optimizers = self._setup_model_and_optimizers( model=LightningShardedDataParallel(self.model), optimizers=self.optimizers, diff --git a/tests/strategies/test_sharded_strategy.py b/tests/strategies/test_sharded_strategy.py index 9dc8f27aead7d..e5fbdff68c25f 100644 --- a/tests/strategies/test_sharded_strategy.py +++ b/tests/strategies/test_sharded_strategy.py @@ -241,7 +241,7 @@ def test_custom_kwargs_sharded(tmpdir, cls): class_name = "sharded" if isinstance(strategy, DDPShardedStrategy) else "sharded_spawn" with mock.patch(f"pytorch_lightning.strategies.{class_name}.ShardedDataParallel", autospec=True) as mock_sharded: - strategy.configure_ddp() + strategy._configure_sdp(strategy.model.trainer) args, kwargs = mock_sharded.call_args assert "reduce_fp16" in kwargs assert kwargs["reduce_fp16"] From 6fe6eaf77d040c1db7ea6b04cb51e6220c3c966e Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 21 Mar 2022 21:31:41 -0700 Subject: [PATCH 15/25] Update test_sharded_strategy.py --- tests/strategies/test_sharded_strategy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/strategies/test_sharded_strategy.py b/tests/strategies/test_sharded_strategy.py index e5fbdff68c25f..9dc8f27aead7d 100644 --- a/tests/strategies/test_sharded_strategy.py +++ b/tests/strategies/test_sharded_strategy.py @@ -241,7 +241,7 @@ def test_custom_kwargs_sharded(tmpdir, cls): class_name = "sharded" if isinstance(strategy, DDPShardedStrategy) else "sharded_spawn" with mock.patch(f"pytorch_lightning.strategies.{class_name}.ShardedDataParallel", autospec=True) as mock_sharded: - strategy._configure_sdp(strategy.model.trainer) + strategy.configure_ddp() args, kwargs = mock_sharded.call_args assert "reduce_fp16" in kwargs assert kwargs["reduce_fp16"] From 2650067f53295aae6b5dc71cc4305c501c735d54 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 21 Mar 2022 21:39:11 -0700 Subject: [PATCH 16/25] Update test_sharded_strategy.py --- tests/strategies/test_sharded_strategy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/strategies/test_sharded_strategy.py b/tests/strategies/test_sharded_strategy.py index 9dc8f27aead7d..3e6265eff8388 100644 --- a/tests/strategies/test_sharded_strategy.py +++ b/tests/strategies/test_sharded_strategy.py @@ -241,7 +241,7 @@ def test_custom_kwargs_sharded(tmpdir, cls): class_name = "sharded" if isinstance(strategy, DDPShardedStrategy) else "sharded_spawn" with mock.patch(f"pytorch_lightning.strategies.{class_name}.ShardedDataParallel", autospec=True) as mock_sharded: - strategy.configure_ddp() + strategy._set_ddp_kwargs() args, kwargs = mock_sharded.call_args assert "reduce_fp16" in kwargs assert kwargs["reduce_fp16"] From a98b070559b70e563331af42dad65b0677d97dea Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 23 Mar 2022 22:30:43 +0000 Subject: [PATCH 17/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/strategies/ddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index 858627111b8f0..93ed445460c2b 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -51,8 +51,8 @@ _TORCH_GREATER_EQUAL_1_10, _TORCH_GREATER_EQUAL_1_11, ) -from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.optimizer import optimizers_to_device +from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT From e600a7ba36fee3808d3b936d9fc5edd2fa1ffb7d Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 4 May 2022 18:12:41 +0530 Subject: [PATCH 18/25] bagua fix --- pytorch_lightning/strategies/bagua.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytorch_lightning/strategies/bagua.py b/pytorch_lightning/strategies/bagua.py index 850a5b7c08f97..a53783ce91db2 100644 --- a/pytorch_lightning/strategies/bagua.py +++ b/pytorch_lightning/strategies/bagua.py @@ -174,9 +174,6 @@ def setup(self, trainer: "pl.Trainer") -> None: self._configure_bagua_model(trainer) self.setup_precision_plugin() - self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts) - if self._should_run_deadlock_detection(): - self._share_information_to_prevent_deadlock() def _check_qadam_optimizer(self) -> None: has_qadam_optimizer = any([isinstance(opt, QAdamOptimizer) for opt in self.optimizers]) From bd09a3de0715027c01c4619341f5f8ee8098cf6c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 May 2022 12:43:33 +0000 Subject: [PATCH 19/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/strategies/ddp_spawn.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/strategies/ddp_spawn.py b/pytorch_lightning/strategies/ddp_spawn.py index 6168ccbf87a99..85e06149518be 100644 --- a/pytorch_lightning/strategies/ddp_spawn.py +++ b/pytorch_lightning/strategies/ddp_spawn.py @@ -35,18 +35,15 @@ get_default_process_group_backend_for_device, ) from pytorch_lightning.utilities.distributed import group as _group -from pytorch_lightning.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_11 -from pytorch_lightning.utilities.optimizer import optimizers_to_device -from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.distributed import ( init_dist_connection, ReduceOp, register_ddp_comm_hook, sync_ddp_if_available, ) -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11 -from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_11 +from pytorch_lightning.utilities.optimizer import optimizers_to_device +from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT From 2d8cd03ee8e12325c713db7270247952c3519376 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 4 May 2022 18:22:57 +0530 Subject: [PATCH 20/25] fix --- pytorch_lightning/strategies/ddp_spawn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/strategies/ddp_spawn.py b/pytorch_lightning/strategies/ddp_spawn.py index 85e06149518be..471e3ebe80ea9 100644 --- a/pytorch_lightning/strategies/ddp_spawn.py +++ b/pytorch_lightning/strategies/ddp_spawn.py @@ -41,9 +41,9 @@ register_ddp_comm_hook, sync_ddp_if_available, ) -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_11 +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11 from pytorch_lightning.utilities.optimizer import optimizers_to_device -from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT From 9dcf463dd084a51e85d6a5d8661abe1b4e866079 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 5 May 2022 16:14:38 +0530 Subject: [PATCH 21/25] fix tests --- pytorch_lightning/strategies/sharded.py | 3 --- tests/strategies/test_sharded_strategy.py | 6 ++++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/strategies/sharded.py b/pytorch_lightning/strategies/sharded.py index f5eec250f3711..9186a31f5d997 100644 --- a/pytorch_lightning/strategies/sharded.py +++ b/pytorch_lightning/strategies/sharded.py @@ -52,9 +52,6 @@ def setup(self, trainer: "pl.Trainer") -> None: # move the model to the correct device self.model_to_device() - if self._layer_sync: - self.model = self._layer_sync.apply(self.model) - # skip wrapping the model if we are not fitting as no gradients need to be exchanged trainer_fn = trainer.state.fn if trainer_fn == TrainerFn.FITTING: diff --git a/tests/strategies/test_sharded_strategy.py b/tests/strategies/test_sharded_strategy.py index d747b607df196..f75ff41654b2b 100644 --- a/tests/strategies/test_sharded_strategy.py +++ b/tests/strategies/test_sharded_strategy.py @@ -270,10 +270,11 @@ def test_custom_kwargs_sharded(tmpdir, cls): strategy = cls(reduce_fp16=True) strategy.model = Mock(spec=LightningModule) strategy.model.trainer = Mock() + strategy.parallel_devices = [Mock()] class_name = "sharded" if isinstance(strategy, DDPShardedStrategy) else "sharded_spawn" with mock.patch(f"pytorch_lightning.strategies.{class_name}.ShardedDataParallel", autospec=True) as mock_sharded: - strategy._set_ddp_kwargs() + strategy.configure_ddp() args, kwargs = mock_sharded.call_args assert "reduce_fp16" in kwargs assert kwargs["reduce_fp16"] @@ -289,9 +290,10 @@ def test_custom_kwargs_sharded_reduce_buffer_size(tmpdir, params, expected_buffe strategy.num_nodes = num_nodes strategy.model = Mock(spec=LightningModule) strategy.model.trainer = Mock() + strategy.parallel_devices = [Mock()] with mock.patch("pytorch_lightning.strategies.sharded.ShardedDataParallel", autospec=True) as mock_sharded: - strategy._set_ddp_kwargs() + strategy.configure_ddp() args, kwargs = mock_sharded.call_args assert "reduce_buffer_size" in kwargs From bbf5547094d7aa44353a131155c382ec7b1827e3 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 5 May 2022 08:41:59 -0400 Subject: [PATCH 22/25] keep DDP as top wrapper --- pytorch_lightning/strategies/bagua.py | 13 +++++++------ pytorch_lightning/strategies/ddp.py | 6 ++++-- pytorch_lightning/strategies/ddp_spawn.py | 7 +++++-- pytorch_lightning/strategies/fully_sharded.py | 3 +-- pytorch_lightning/strategies/sharded.py | 4 +++- pytorch_lightning/strategies/tpu_spawn.py | 2 +- 6 files changed, 21 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/strategies/bagua.py b/pytorch_lightning/strategies/bagua.py index a53783ce91db2..8cb86cb3ffc3d 100644 --- a/pytorch_lightning/strategies/bagua.py +++ b/pytorch_lightning/strategies/bagua.py @@ -159,22 +159,23 @@ def setup(self, trainer: "pl.Trainer") -> None: # move the model to the correct device self.model_to_device() - if self._layer_sync and self.model: - self.model = self._layer_sync.apply(self.model) - - # skip wrapping the model if we are not fitting as no gradients need to be exchanged trainer_fn = trainer.state.fn + if trainer_fn == TrainerFn.FITTING: + if self._layer_sync and self.model: + self.model = self._layer_sync.apply(self.model) + + self.setup_precision_plugin() + # set up optimizers after the module has been moved to the device # but before the module has been wrapped self.setup_optimizers(trainer) optimizers_to_device(self.optimizers, self.root_device) + # skip wrapping the model if we are not fitting as no gradients need to be exchanged if trainer_fn == TrainerFn.FITTING: self._configure_bagua_model(trainer) - self.setup_precision_plugin() - def _check_qadam_optimizer(self) -> None: has_qadam_optimizer = any([isinstance(opt, QAdamOptimizer) for opt in self.optimizers]) diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index cdafa14e1ffa7..ea53a44bb9bb4 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -169,6 +169,10 @@ def setup(self, trainer: "pl.Trainer") -> None: if trainer_fn == TrainerFn.FITTING: if self._layer_sync: self.model = self._layer_sync.apply(self.model) + + self.setup_precision_plugin() + + if trainer_fn == TrainerFn.FITTING: self.configure_ddp() # set up optimizers after the wrapped module has been moved to the device @@ -180,8 +184,6 @@ def setup(self, trainer: "pl.Trainer") -> None: if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState): self._enable_model_averaging(self._ddp_comm_state.start_localSGD_iter) - self.setup_precision_plugin() - def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" device_ids = self.determine_ddp_device_ids() diff --git a/pytorch_lightning/strategies/ddp_spawn.py b/pytorch_lightning/strategies/ddp_spawn.py index 471e3ebe80ea9..68ae579858a57 100644 --- a/pytorch_lightning/strategies/ddp_spawn.py +++ b/pytorch_lightning/strategies/ddp_spawn.py @@ -134,12 +134,15 @@ def setup(self, trainer: "pl.Trainer") -> None: if trainer_fn == TrainerFn.FITTING: if self._layer_sync: self.model = self._layer_sync.apply(self.model) + + self.setup_precision_plugin() + + if trainer_fn == TrainerFn.FITTING: self.configure_ddp() # set up optimizers after the wrapped module has been moved to the device self.setup_optimizers(trainer) optimizers_to_device(self.optimizers, self.root_device) - self.setup_precision_plugin() def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" @@ -245,7 +248,7 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT: def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: with self.precision_plugin.val_step_context(): - if isinstance(self.model, LightningDistributedModule): + if isinstance(self.model, DistributedDataParallel): # used when calling `trainer.fit` return self.model(*args, **kwargs) else: diff --git a/pytorch_lightning/strategies/fully_sharded.py b/pytorch_lightning/strategies/fully_sharded.py index 4ff59ba3b1398..fc42fb5237346 100644 --- a/pytorch_lightning/strategies/fully_sharded.py +++ b/pytorch_lightning/strategies/fully_sharded.py @@ -139,15 +139,14 @@ def setup_distributed(self) -> None: def setup(self, trainer: "pl.Trainer") -> None: self.accelerator.setup(trainer) self.setup_optimizers(trainer) - self.setup_precision_plugin() optimizers_to_device(self.optimizers, self.root_device) if trainer.state.fn == TrainerFn.FITTING and self._layer_sync: self.model = self._layer_sync.apply(self.model) + self.setup_precision_plugin() self.configure_ddp() self.barrier() - self.setup_optimizers(trainer) @contextlib.contextmanager def model_sharded_context(self) -> Generator: diff --git a/pytorch_lightning/strategies/sharded.py b/pytorch_lightning/strategies/sharded.py index 9186a31f5d997..8a76520755345 100644 --- a/pytorch_lightning/strategies/sharded.py +++ b/pytorch_lightning/strategies/sharded.py @@ -57,10 +57,12 @@ def setup(self, trainer: "pl.Trainer") -> None: if trainer_fn == TrainerFn.FITTING: if self._layer_sync: self.model = self._layer_sync.apply(self.model) - self.configure_ddp() self.setup_precision_plugin() + if trainer_fn == TrainerFn.FITTING: + self.configure_ddp() + def configure_ddp(self) -> None: self._set_ddp_kwargs() self.setup_optimizers(self.model.trainer) diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index 4ecb587d529dc..262e37360ad10 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -137,9 +137,9 @@ def setup(self, trainer: "pl.Trainer") -> None: else: set_shared_parameters(self.model.module, shared_params) + self.setup_precision_plugin() self.setup_optimizers(trainer) optimizers_to_device(self.optimizers, self.root_device) - self.setup_precision_plugin() def _setup_model(self, model: Module) -> Module: return model From 7aa67771d288459cf7460f7309f31054806ea109 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Mon, 9 May 2022 05:43:52 -0400 Subject: [PATCH 23/25] fixes --- pytorch_lightning/strategies/bagua.py | 12 ++++++------ pytorch_lightning/strategies/ddp.py | 7 ++++--- pytorch_lightning/strategies/ddp_spawn.py | 8 ++++---- pytorch_lightning/strategies/fully_sharded.py | 10 ++++++---- pytorch_lightning/strategies/sharded_spawn.py | 4 ++++ pytorch_lightning/strategies/tpu_spawn.py | 7 +++++-- 6 files changed, 29 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/strategies/bagua.py b/pytorch_lightning/strategies/bagua.py index 8cb86cb3ffc3d..94bc9aff8f0af 100644 --- a/pytorch_lightning/strategies/bagua.py +++ b/pytorch_lightning/strategies/bagua.py @@ -167,13 +167,13 @@ def setup(self, trainer: "pl.Trainer") -> None: self.setup_precision_plugin() - # set up optimizers after the module has been moved to the device - # but before the module has been wrapped - self.setup_optimizers(trainer) - optimizers_to_device(self.optimizers, self.root_device) - - # skip wrapping the model if we are not fitting as no gradients need to be exchanged if trainer_fn == TrainerFn.FITTING: + # set up optimizers after the module has been moved to the device + # but before the module has been wrapped + self.setup_optimizers(trainer) + optimizers_to_device(self.optimizers, self.root_device) + + # skip wrapping the model if we are not fitting as no gradients need to be exchanged self._configure_bagua_model(trainer) def _check_qadam_optimizer(self) -> None: diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index ea53a44bb9bb4..8a84c285aaac0 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -175,9 +175,10 @@ def setup(self, trainer: "pl.Trainer") -> None: if trainer_fn == TrainerFn.FITTING: self.configure_ddp() - # set up optimizers after the wrapped module has been moved to the device - self.setup_optimizers(trainer) - optimizers_to_device(self.optimizers, self.root_device) + # set up optimizers after the wrapped module has been moved to the device + self.setup_optimizers(trainer) + optimizers_to_device(self.optimizers, self.root_device) + if _TORCH_GREATER_EQUAL_1_10 and trainer_fn == TrainerFn.FITTING: import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD diff --git a/pytorch_lightning/strategies/ddp_spawn.py b/pytorch_lightning/strategies/ddp_spawn.py index 68ae579858a57..f44a2218ab8e1 100644 --- a/pytorch_lightning/strategies/ddp_spawn.py +++ b/pytorch_lightning/strategies/ddp_spawn.py @@ -140,10 +140,6 @@ def setup(self, trainer: "pl.Trainer") -> None: if trainer_fn == TrainerFn.FITTING: self.configure_ddp() - # set up optimizers after the wrapped module has been moved to the device - self.setup_optimizers(trainer) - optimizers_to_device(self.optimizers, self.root_device) - def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs) @@ -193,6 +189,10 @@ def configure_ddp(self) -> None: self.model = self._setup_model(LightningDistributedModule(self.model)) self._register_ddp_hooks() + # set up optimizers after the wrapped module has been moved to the device + self.setup_optimizers(self.lightning_module.trainer) + optimizers_to_device(self.optimizers, self.root_device) + def determine_ddp_device_ids(self): if self.root_device.type == "cpu": return None diff --git a/pytorch_lightning/strategies/fully_sharded.py b/pytorch_lightning/strategies/fully_sharded.py index fc42fb5237346..450f44de623d7 100644 --- a/pytorch_lightning/strategies/fully_sharded.py +++ b/pytorch_lightning/strategies/fully_sharded.py @@ -138,11 +138,13 @@ def setup_distributed(self) -> None: def setup(self, trainer: "pl.Trainer") -> None: self.accelerator.setup(trainer) - self.setup_optimizers(trainer) - optimizers_to_device(self.optimizers, self.root_device) - if trainer.state.fn == TrainerFn.FITTING and self._layer_sync: - self.model = self._layer_sync.apply(self.model) + if trainer.state.fn == TrainerFn.FITTING: + self.setup_optimizers(trainer) + optimizers_to_device(self.optimizers, self.root_device) + + if self._layer_sync: + self.model = self._layer_sync.apply(self.model) self.setup_precision_plugin() self.configure_ddp() diff --git a/pytorch_lightning/strategies/sharded_spawn.py b/pytorch_lightning/strategies/sharded_spawn.py index 8cb6ca8b62028..58ad47f464bfc 100644 --- a/pytorch_lightning/strategies/sharded_spawn.py +++ b/pytorch_lightning/strategies/sharded_spawn.py @@ -23,6 +23,7 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE +from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_only if _FAIRSCALE_AVAILABLE: @@ -38,9 +39,12 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy): strategy_name = "ddp_sharded_spawn" def configure_ddp(self) -> None: + # set up optimizers after the wrapped module has been moved to the device + self.setup_optimizers(self.lightning_module.trainer) self.model, self.optimizers = self._setup_model_and_optimizers( model=LightningShardedDataParallel(self.model), optimizers=self.optimizers ) + optimizers_to_device(self.optimizers, self.root_device) def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: """Wraps the model and optimizers with fairscale components. diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index 262e37360ad10..ee719954c7950 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -26,6 +26,7 @@ from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy from pytorch_lightning.strategies.launchers.xla_spawn import _XLASpawnLauncher from pytorch_lightning.trainer.connectors.data_connector import DataConnector +from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.distributed import ReduceOp @@ -138,8 +139,10 @@ def setup(self, trainer: "pl.Trainer") -> None: set_shared_parameters(self.model.module, shared_params) self.setup_precision_plugin() - self.setup_optimizers(trainer) - optimizers_to_device(self.optimizers, self.root_device) + + if trainer.state.fn == TrainerFn.FITTING: + self.setup_optimizers(trainer) + optimizers_to_device(self.optimizers, self.root_device) def _setup_model(self, model: Module) -> Module: return model From 7cd85450c22240b1c38b0e725e32353b719efe75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 31 May 2022 21:32:25 +0200 Subject: [PATCH 24/25] Bad merge --- CHANGELOG.md | 1 - tests/strategies/test_ddp_strategy.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 602292c5a662a..d22cb9df09d5d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -68,7 +68,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `teardown()` method to `Accelerator` ([#11935](https://github.com/PyTorchLightning/pytorch-lightning/pull/11935)) - ### Changed - Enable validation during overfitting ([#12527](https://github.com/PyTorchLightning/pytorch-lightning/pull/12527)) diff --git a/tests/strategies/test_ddp_strategy.py b/tests/strategies/test_ddp_strategy.py index 9c2127a4cc40a..5ee17e8a8d831 100644 --- a/tests/strategies/test_ddp_strategy.py +++ b/tests/strategies/test_ddp_strategy.py @@ -171,7 +171,7 @@ def configure_optimizers(self): super().configure_optimizers() -@RunIf(min_gpus=1) +@RunIf(min_cuda_gpus=1) @pytest.mark.parametrize("strategy", ("ddp", "ddp_spawn")) def test_model_parameters_on_device_for_optimizer(strategy): """Test that the strategy has moved the parameters to the device by the time the optimizer gets created.""" From 713d692456d15d5eea074817215096e1c221091a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 1 Jun 2022 01:34:45 +0200 Subject: [PATCH 25/25] Remove extra argument --- pytorch_lightning/strategies/ddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index 1ab2c2a981761..c431c2d790aec 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -183,7 +183,7 @@ def setup(self, trainer: "pl.Trainer") -> None: import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState): - self._enable_model_averaging(self._ddp_comm_state.start_localSGD_iter) + self._enable_model_averaging() def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""