From 386ce9440826e3fefd71b7a71bf018165d163e81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 21:40:02 +0200 Subject: [PATCH 1/7] sharded --- CHANGELOG.md | 2 + .../plugins/training_type/sharded.py | 54 ++++++++++++------- 2 files changed, 37 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b3bd21cbaca1b..870ba8754ef69 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -208,6 +208,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994)) * Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010)) * Implemented `DeepSpeedPlugin._setup_models_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009)) + * Implemented `{DDPShardedPlugin,DDPShardedSpawnPlugin}._setup_models_and_optimizers` ([#XXXXX](https://github.com/PyTorchLightning/pytorch-lightning/pull/XXXXX)) + ### Changed diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index d684a34784f4c..b749e7ca9f5fc 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Dict, Generator, Optional +from typing import Dict, Generator, List, Optional, Tuple, Union import torch +from torch.nn import Module +from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer @@ -33,24 +35,39 @@ class DDPShardedPlugin(DDPPlugin): """Optimizer and gradient sharded training provided by FairScale.""" - _REDUCE_BUFFER_SIZE_DEFAULT = 2 ** 23 # 8M + _REDUCE_BUFFER_SIZE_DEFAULT: int = 2 ** 23 # 8M - def configure_ddp(self) -> None: - self._wrap_optimizers() + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._precision = None + + def _setup_models_and_optimizers( + self, models: List[Module], optimizers: List[Optimizer] + ) -> Tuple[List[Module], List[Optimizer]]: + if len(models) > 1: + raise ValueError( + f"DDPSharded only supports a single model with one or several optimizers. Got {len(models)} models." + ) + optimizers = self._wrap_optimizers(optimizers) + model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs) + setattr(model, "require_backward_grad_sync", False) # TODO: needed? + return [model], optimizers + + def configure_ddp(self) -> None: if "reduce_buffer_size" not in self._ddp_kwargs: # For multi-node training, enabling bucketing will improve performance. self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0 - self._model = ShardedDataParallel( - LightningShardedDataParallel(self.model), - sharded_optimizer=self.lightning_module.trainer.optimizers, - **self._ddp_kwargs + [self._model], optimizers = self._setup_models_and_optimizers( + models=[LightningShardedDataParallel(self.model)], + optimizers=self.lightning_module.trainer.optimizers, ) - setattr(self._model, "require_backward_grad_sync", False) + trainer = self.lightning_module.trainer + trainer.optimizers = optimizers + trainer.convert_to_lightning_optimizers() - def _reinit_optimizers_with_oss(self): - optimizers = self.lightning_module.trainer.optimizers + def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]: for x, optimizer in enumerate(optimizers): if isinstance(optimizer, LightningOptimizer): optimizer = optimizer._optimizer @@ -58,7 +75,7 @@ def _reinit_optimizers_with_oss(self): optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE: - precision = self.lightning_module.trainer.precision + precision = self._precision or self.lightning_module.trainer.precision is_fp16 = precision in ("mixed", 16) # For multi-node training, compressing the model shards in fp16 before broadcasting # improves performance. When using PyTorch AMP, it will not degrade @@ -66,14 +83,13 @@ def _reinit_optimizers_with_oss(self): zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1 optimizers[x] = zero_optimizer del optimizer - trainer = self.lightning_module.trainer - trainer.optimizers = optimizers - trainer.convert_to_lightning_optimizers() + return optimizers + + def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]: + if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING: + return optimizers - def _wrap_optimizers(self): - if self.model.trainer.state.fn != TrainerFn.FITTING: - return - self._reinit_optimizers_with_oss() + return self._reinit_optimizers_with_oss(optimizers) def optimizer_state(self, optimizer: "OSS") -> Optional[dict]: if isinstance(optimizer, LightningOptimizer): From 45806cffcc76b069cde49c6f1116d08ee550e1df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 21:40:28 +0200 Subject: [PATCH 2/7] sharded spawn --- .../plugins/training_type/sharded.py | 27 ++++++------ .../plugins/training_type/sharded_spawn.py | 43 ++++++++++++------- 2 files changed, 42 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index b749e7ca9f5fc..2640eacf4fb63 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -41,19 +41,6 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._precision = None - def _setup_models_and_optimizers( - self, models: List[Module], optimizers: List[Optimizer] - ) -> Tuple[List[Module], List[Optimizer]]: - if len(models) > 1: - raise ValueError( - f"DDPSharded only supports a single model with one or several optimizers. Got {len(models)} models." - ) - - optimizers = self._wrap_optimizers(optimizers) - model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs) - setattr(model, "require_backward_grad_sync", False) # TODO: needed? - return [model], optimizers - def configure_ddp(self) -> None: if "reduce_buffer_size" not in self._ddp_kwargs: # For multi-node training, enabling bucketing will improve performance. @@ -67,6 +54,20 @@ def configure_ddp(self) -> None: trainer.optimizers = optimizers trainer.convert_to_lightning_optimizers() + def _setup_models_and_optimizers( + self, models: List[Module], optimizers: List[Optimizer] + ) -> Tuple[List[Module], List[Optimizer]]: + if len(models) > 1: + raise ValueError( + f"DDPSharded only supports setting up a single model with one or several optimizers." + f" Got {len(models)} models." + ) + + optimizers = self._wrap_optimizers(optimizers) + model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs) + setattr(model, "require_backward_grad_sync", False) # TODO: needed? + return [model], optimizers + def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]: for x, optimizer in enumerate(optimizers): if isinstance(optimizer, LightningOptimizer): diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 78b54d029a5f6..6b49dd80bcb82 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -13,9 +13,11 @@ # limitations under the License. from contextlib import contextmanager from multiprocessing.queues import SimpleQueue -from typing import Dict, Generator, Optional +from typing import Dict, Generator, Optional, List, Tuple import torch +from torch.nn import Module +from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin @@ -36,29 +38,40 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin): """Optimizer sharded training provided by FairScale.""" def configure_ddp(self) -> None: - self._wrap_optimizers() - self._model = ShardedDataParallel( - LightningShardedDataParallel(self.model), - sharded_optimizer=self.lightning_module.trainer.optimizers, - **self._ddp_kwargs + [self._model], optimizers = self._setup_models_and_optimizers( + models=[LightningShardedDataParallel(self.model)], + optimizers=self.lightning_module.trainer.optimizers, ) - setattr(self._model, "require_backward_grad_sync", False) + self.lightning_module.trainer.optimizers = optimizers + + def _setup_models_and_optimizers( + self, models: List[Module], optimizers: List[Optimizer] + ) -> Tuple[List[Module], List[Optimizer]]: + if len(models) > 1: + raise ValueError( + f"DDPShardedSpawn only supports setting up a single model with one or several optimizers." + f" Got {len(models)} models." + ) + + optimizers = self._wrap_optimizers(optimizers) + model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs) + setattr(model, "require_backward_grad_sync", False) # TODO: needed? + return [model], optimizers - def _reinit_optimizers_with_oss(self): - optimizers = self.lightning_module.trainer.optimizers + def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]: for x, optimizer in enumerate(optimizers): if not isinstance(optimizer, OSS): optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) optimizers[x] = zero_optimizer del optimizer - trainer = self.lightning_module.trainer - trainer.optimizers = optimizers + return optimizers + + def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]: + if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING: + return optimizers - def _wrap_optimizers(self): - if self.model.trainer.state.fn != TrainerFn.FITTING: - return - self._reinit_optimizers_with_oss() + return self._reinit_optimizers_with_oss(optimizers) def optimizer_state(self, optimizer: "OSS") -> Optional[dict]: if isinstance(optimizer, OSS): From fc159bd6c43ef332f5efdff6c4fe4f923538de8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 21:50:22 +0200 Subject: [PATCH 3/7] add docs --- pytorch_lightning/plugins/training_type/sharded.py | 8 ++++++++ pytorch_lightning/plugins/training_type/sharded_spawn.py | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index 2640eacf4fb63..ef1273d16b391 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -57,6 +57,14 @@ def configure_ddp(self) -> None: def _setup_models_and_optimizers( self, models: List[Module], optimizers: List[Optimizer] ) -> Tuple[List[Module], List[Optimizer]]: + """Wraps the model and optimizers with fairscale components. + + Currently only one model can be setup at once. + + Return: + A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module + and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`. + """ if len(models) > 1: raise ValueError( f"DDPSharded only supports setting up a single model with one or several optimizers." diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 6b49dd80bcb82..04c97ea9abe3c 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -47,6 +47,14 @@ def configure_ddp(self) -> None: def _setup_models_and_optimizers( self, models: List[Module], optimizers: List[Optimizer] ) -> Tuple[List[Module], List[Optimizer]]: + """Wraps the model and optimizers with fairscale components. + + Currently only one model can be setup at once. + + Return: + A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module + and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`. + """ if len(models) > 1: raise ValueError( f"DDPShardedSpawn only supports setting up a single model with one or several optimizers." From 5505abb2d0fb47a0a7ce26bd83fdeea8e805c57b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 21:53:32 +0200 Subject: [PATCH 4/7] update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 870ba8754ef69..b9353b3d1176e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -208,7 +208,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994)) * Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010)) * Implemented `DeepSpeedPlugin._setup_models_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009)) - * Implemented `{DDPShardedPlugin,DDPShardedSpawnPlugin}._setup_models_and_optimizers` ([#XXXXX](https://github.com/PyTorchLightning/pytorch-lightning/pull/XXXXX)) + * Implemented `{DDPShardedPlugin,DDPShardedSpawnPlugin}._setup_models_and_optimizers` ([#10028](https://github.com/PyTorchLightning/pytorch-lightning/pull/10028)) ### Changed From f2b59de15242aec628ef21912bb7eb200a85f89c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Oct 2021 19:53:52 +0000 Subject: [PATCH 5/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/sharded_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 04c97ea9abe3c..f4f434de429a1 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -13,7 +13,7 @@ # limitations under the License. from contextlib import contextmanager from multiprocessing.queues import SimpleQueue -from typing import Dict, Generator, Optional, List, Tuple +from typing import Dict, Generator, List, Optional, Tuple import torch from torch.nn import Module From c0d1079141253dffccb90a3600c4bdbc16777365 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 20 Oct 2021 00:05:39 +0200 Subject: [PATCH 6/7] fix access to trainer, mocking in tests --- pytorch_lightning/plugins/training_type/sharded.py | 4 ++-- pytorch_lightning/plugins/training_type/sharded_spawn.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index ef1273d16b391..0422eb7c762d6 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -42,15 +42,15 @@ def __init__(self, *args, **kwargs): self._precision = None def configure_ddp(self) -> None: + trainer = self.lightning_module.trainer if "reduce_buffer_size" not in self._ddp_kwargs: # For multi-node training, enabling bucketing will improve performance. self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0 [self._model], optimizers = self._setup_models_and_optimizers( models=[LightningShardedDataParallel(self.model)], - optimizers=self.lightning_module.trainer.optimizers, + optimizers=trainer.optimizers, ) - trainer = self.lightning_module.trainer trainer.optimizers = optimizers trainer.convert_to_lightning_optimizers() diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index f4f434de429a1..5d48c489a37e8 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -38,11 +38,12 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin): """Optimizer sharded training provided by FairScale.""" def configure_ddp(self) -> None: + trainer = self.lightning_module.trainer [self._model], optimizers = self._setup_models_and_optimizers( models=[LightningShardedDataParallel(self.model)], - optimizers=self.lightning_module.trainer.optimizers, + optimizers=trainer.optimizers, ) - self.lightning_module.trainer.optimizers = optimizers + trainer.optimizers = optimizers def _setup_models_and_optimizers( self, models: List[Module], optimizers: List[Optimizer] From c91fb567ceb36c89a50c14c60dee25e6ecf912d6 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Wed, 20 Oct 2021 13:24:42 +0530 Subject: [PATCH 7/7] Update pytorch_lightning/plugins/training_type/sharded.py --- pytorch_lightning/plugins/training_type/sharded.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index 0422eb7c762d6..63ac7f5105945 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -67,7 +67,7 @@ def _setup_models_and_optimizers( """ if len(models) > 1: raise ValueError( - f"DDPSharded only supports setting up a single model with one or several optimizers." + "DDPSharded only supports setting up a single model with one or several optimizers." f" Got {len(models)} models." )