From 66449860a17be013f14887fd096dd80b985e1b5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 Oct 2021 15:04:04 +0200 Subject: [PATCH 01/17] setup --- pytorch_lightning/plugins/training_type/ddp.py | 10 ++++++++++ pytorch_lightning/plugins/training_type/dp.py | 10 +++++++--- .../training_type/training_type_plugin.py | 16 +++++++++++++++- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index a26b63151f5a8..7fc53b2b05b3d 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -27,7 +27,9 @@ import numpy as np import torch import torch.distributed +from torch.nn import Module from torch.nn.parallel.distributed import DistributedDataParallel +from torch.utils.data import DataLoader, DistributedSampler import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer @@ -181,6 +183,14 @@ def setup_environment(self) -> None: self.setup_distributed() + def setup_model(self, model: Module) -> Module: + model = DistributedDataParallel( + module=model.to(self.root_device), + device_ids=self.determine_ddp_device_ids(), + **self._ddp_kwargs, + ) + return model + def _call_children_scripts(self): # bookkeeping of spawned processes self._check_can_spawn_children() diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index fe970bb5a3bbc..a5a346f82698c 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import List, Optional, Sequence import torch -from torch.nn import DataParallel +from torch.nn import DataParallel, Module +from torch.optim import Optimizer from pytorch_lightning.overrides.data_parallel import LightningParallelModule from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO @@ -54,7 +55,10 @@ def world_size(self) -> int: def setup(self) -> None: # model needs to be moved to the device before it is wrapped self.model_to_device() - self._model = DataParallel(LightningParallelModule(self._model), self.parallel_devices) + self._model = self.setup_model(LightningParallelModule(self._model)) + + def setup_model(self, model: Module) -> Module: + return DataParallel(module=model, device_ids=self.parallel_devices) def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION: """Reduces a collection of tensors from all processes. It can be applied to just a single tensor. diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 9c53069063a52..e0bc057cf4b41 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -13,11 +13,12 @@ # limitations under the License. import contextlib from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Sequence, Tuple, Union import torch from torch import Tensor from torch.nn import Module +from torch.optim import Optimizer from torch.utils.data import DataLoader import pytorch_lightning as pl @@ -60,6 +61,19 @@ def setup_environment(self) -> None: def setup(self) -> None: """Called by the accelerator to finish setup.""" + def setup_models_and_optimizers( + self, models: List[Module], optimizers: List[Optimizer] + ) -> Tuple[List[Module], List[Optimizer]]: + models = [self.setup_model(model) for model in models] + optimizers = [self.setup_optimizer(optimizer) for optimizer in optimizers] + return models, optimizers + + def setup_model(self, model: Module) -> Module: + return model + + def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: + return optimizer + @property @abstractmethod def on_gpu(self) -> bool: From ca8c6c1051866baa50e4bb200029cac0f6ecfb4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 Oct 2021 15:55:52 +0200 Subject: [PATCH 02/17] update ddp and spawn plugins --- pytorch_lightning/plugins/training_type/ddp.py | 11 ++--------- pytorch_lightning/plugins/training_type/ddp_spawn.py | 8 +++++--- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 7fc53b2b05b3d..2bcb032e3adcf 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -184,12 +184,7 @@ def setup_environment(self) -> None: self.setup_distributed() def setup_model(self, model: Module) -> Module: - model = DistributedDataParallel( - module=model.to(self.root_device), - device_ids=self.determine_ddp_device_ids(), - **self._ddp_kwargs, - ) - return model + return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs) def _call_children_scripts(self): # bookkeeping of spawned processes @@ -365,9 +360,7 @@ def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int): def configure_ddp(self) -> None: self.pre_configure_ddp() - self._model = DistributedDataParallel( - LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs - ) + self._model = self.setup_model(LightningDistributedModule(self.model)) self._register_ddp_hooks() def determine_ddp_device_ids(self): diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index eb1acaec4100b..32ca38d07c27b 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -21,6 +21,7 @@ import torch import torch.distributed import torch.multiprocessing as mp +from torch.nn import Module from torch.nn.parallel.distributed import DistributedDataParallel import pytorch_lightning as pl @@ -147,6 +148,9 @@ def setup(self) -> None: smp = mp.get_context("spawn") self.mp_queue = smp.SimpleQueue() + def setup_model(self, model: Module) -> Module: + return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs) + def set_world_ranks(self, process_idx: int = 0) -> None: self._local_rank = process_idx if self.cluster_environment is None: @@ -256,9 +260,7 @@ def _register_ddp_hooks(self) -> None: def configure_ddp(self) -> None: self.pre_configure_ddp() - self._model = DistributedDataParallel( - LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs - ) + self._model = self.setup_model(LightningDistributedModule(self.model)) self._register_ddp_hooks() def determine_ddp_device_ids(self): From ce19737a98a8d92dafb98fda8658ead825d24905 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 Oct 2021 16:01:35 +0200 Subject: [PATCH 03/17] remove unused imports --- pytorch_lightning/plugins/training_type/dp.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index a5a346f82698c..a86bdcf8665b1 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -11,11 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Sequence +from typing import List, Optional import torch from torch.nn import DataParallel, Module -from torch.optim import Optimizer from pytorch_lightning.overrides.data_parallel import LightningParallelModule from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO From a6ef0a11b0c3d1970749eec849ad0b8f7f78ec3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 Oct 2021 16:35:44 +0200 Subject: [PATCH 04/17] Update pytorch_lightning/plugins/training_type/ddp.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/plugins/training_type/ddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 2bcb032e3adcf..0f28c1a0c0765 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -183,7 +183,7 @@ def setup_environment(self) -> None: self.setup_distributed() - def setup_model(self, model: Module) -> Module: + def setup_model(self, model: Module) -> DistributedDataParallel: return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs) def _call_children_scripts(self): From 4db0e5b9272647831ee6d783acaac30354a5831c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 Oct 2021 16:36:28 +0200 Subject: [PATCH 05/17] update typehint for ddp spawn --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 32ca38d07c27b..bc868559a73e3 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -148,7 +148,7 @@ def setup(self) -> None: smp = mp.get_context("spawn") self.mp_queue = smp.SimpleQueue() - def setup_model(self, model: Module) -> Module: + def setup_model(self, model: Module) -> DistributedDataParallel: return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs) def set_world_ranks(self, process_idx: int = 0) -> None: From 8d11d321578e4de15da7b3bd875e6bda43fc2cca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 Oct 2021 17:44:40 +0200 Subject: [PATCH 06/17] update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 888d22a520f75..6c804d01193bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -201,7 +201,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - LightningLite: * Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988)) - + * Added `TrainingTypePlugin.{setup_model, setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994)) ### Changed From 6d8562519cb34c92aba3c5222cb67e523db24562 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 18 Oct 2021 17:45:22 +0200 Subject: [PATCH 07/17] unused imports --- pytorch_lightning/plugins/training_type/ddp.py | 1 - pytorch_lightning/plugins/training_type/training_type_plugin.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 0f28c1a0c0765..3a384f0abf7f6 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -29,7 +29,6 @@ import torch.distributed from torch.nn import Module from torch.nn.parallel.distributed import DistributedDataParallel -from torch.utils.data import DataLoader, DistributedSampler import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index e0bc057cf4b41..d20ee3212f662 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, Union import torch from torch import Tensor From f3e5044ea3362365059f4ba537c2315ae7d946a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 03:21:00 +0200 Subject: [PATCH 08/17] update comment --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ac20014d4986e..361974ce97edf 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1538,7 +1538,7 @@ def local_rank(self) -> int: @property def node_rank(self) -> int: - # some training types define a local rank + # some training types define a node rank return getattr(self.training_type_plugin, "node_rank", 0) @property From c4b3d258c3fb4c3d5ea2a422e00d7a7f1e8c1867 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 10:51:12 +0200 Subject: [PATCH 09/17] split --- pytorch_lightning/plugins/training_type/dp.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index a86bdcf8665b1..fe970bb5a3bbc 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -14,7 +14,7 @@ from typing import List, Optional import torch -from torch.nn import DataParallel, Module +from torch.nn import DataParallel from pytorch_lightning.overrides.data_parallel import LightningParallelModule from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO @@ -54,10 +54,7 @@ def world_size(self) -> int: def setup(self) -> None: # model needs to be moved to the device before it is wrapped self.model_to_device() - self._model = self.setup_model(LightningParallelModule(self._model)) - - def setup_model(self, model: Module) -> Module: - return DataParallel(module=model, device_ids=self.parallel_devices) + self._model = DataParallel(LightningParallelModule(self._model), self.parallel_devices) def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION: """Reduces a collection of tensors from all processes. It can be applied to just a single tensor. From 3024abb831eb8ebaec6c1a05d54ac84200a3d336 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 11:12:18 +0200 Subject: [PATCH 10/17] add todo and docs --- .../plugins/training_type/training_type_plugin.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index d20ee3212f662..4fc4e6d123ad7 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -64,14 +64,24 @@ def setup(self) -> None: def setup_models_and_optimizers( self, models: List[Module], optimizers: List[Optimizer] ) -> Tuple[List[Module], List[Optimizer]]: + """Setup multiple models and multiple optimizers together. + + The returned objects are expected to be in the same order they were passed in. + The default implementation will call :meth:`setup_model` and :meth:`setup_optimizer` on the input lists. + """ + # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite models = [self.setup_model(model) for model in models] optimizers = [self.setup_optimizer(optimizer) for optimizer in optimizers] return models, optimizers def setup_model(self, model: Module) -> Module: + """Performs setup for the model, e.g., by wrapping it by another class.""" + # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite return model def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: + """Performs setup for the optimizer, e.g., by wrapping it by another class.""" + # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite return optimizer @property From 71d8e35aba893629116aae6ec0cb05ebc561630e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 11:13:59 +0200 Subject: [PATCH 11/17] update docs --- pytorch_lightning/plugins/training_type/ddp.py | 1 + pytorch_lightning/plugins/training_type/ddp_spawn.py | 1 + 2 files changed, 2 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 3a384f0abf7f6..e9442bc15ad3a 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -183,6 +183,7 @@ def setup_environment(self) -> None: self.setup_distributed() def setup_model(self, model: Module) -> DistributedDataParallel: + """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs) def _call_children_scripts(self): diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index bc868559a73e3..dbd479a4bd220 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -149,6 +149,7 @@ def setup(self) -> None: self.mp_queue = smp.SimpleQueue() def setup_model(self, model: Module) -> DistributedDataParallel: + """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs) def set_world_ranks(self, process_idx: int = 0) -> None: From 7cd1ba2ab5d5be9db1d06abc9e8921801c8277dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 11:30:36 +0200 Subject: [PATCH 12/17] update comments --- .../plugins/training_type/training_type_plugin.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 4fc4e6d123ad7..e33277970507e 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -69,19 +69,19 @@ def setup_models_and_optimizers( The returned objects are expected to be in the same order they were passed in. The default implementation will call :meth:`setup_model` and :meth:`setup_optimizer` on the input lists. """ - # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite + # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324 models = [self.setup_model(model) for model in models] optimizers = [self.setup_optimizer(optimizer) for optimizer in optimizers] return models, optimizers def setup_model(self, model: Module) -> Module: """Performs setup for the model, e.g., by wrapping it by another class.""" - # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite + # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324 return model def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: """Performs setup for the optimizer, e.g., by wrapping it by another class.""" - # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite + # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324 return optimizer @property From ce1bfdee16a113ac904991e03643dc8f6443d9ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 11:32:10 +0200 Subject: [PATCH 13/17] mark setup methods protected --- CHANGELOG.md | 2 +- pytorch_lightning/plugins/training_type/ddp.py | 4 ++-- pytorch_lightning/plugins/training_type/ddp_spawn.py | 4 ++-- .../plugins/training_type/training_type_plugin.py | 12 ++++++------ 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c804d01193bb..3f5783c14a16f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -201,7 +201,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - LightningLite: * Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988)) - * Added `TrainingTypePlugin.{setup_model, setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994)) + * Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994)) ### Changed diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index e9442bc15ad3a..64fc1a5a97277 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -182,7 +182,7 @@ def setup_environment(self) -> None: self.setup_distributed() - def setup_model(self, model: Module) -> DistributedDataParallel: + def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs) @@ -360,7 +360,7 @@ def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int): def configure_ddp(self) -> None: self.pre_configure_ddp() - self._model = self.setup_model(LightningDistributedModule(self.model)) + self._model = self._setup_model(LightningDistributedModule(self.model)) self._register_ddp_hooks() def determine_ddp_device_ids(self): diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index dbd479a4bd220..b9122aa062ade 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -148,7 +148,7 @@ def setup(self) -> None: smp = mp.get_context("spawn") self.mp_queue = smp.SimpleQueue() - def setup_model(self, model: Module) -> DistributedDataParallel: + def _setup_model(self, model: Module) -> DistributedDataParallel: """Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.""" return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs) @@ -261,7 +261,7 @@ def _register_ddp_hooks(self) -> None: def configure_ddp(self) -> None: self.pre_configure_ddp() - self._model = self.setup_model(LightningDistributedModule(self.model)) + self._model = self._setup_model(LightningDistributedModule(self.model)) self._register_ddp_hooks() def determine_ddp_device_ids(self): diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index e33277970507e..b0b9312e9927b 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -61,25 +61,25 @@ def setup_environment(self) -> None: def setup(self) -> None: """Called by the accelerator to finish setup.""" - def setup_models_and_optimizers( + def _setup_models_and_optimizers( self, models: List[Module], optimizers: List[Optimizer] ) -> Tuple[List[Module], List[Optimizer]]: """Setup multiple models and multiple optimizers together. The returned objects are expected to be in the same order they were passed in. - The default implementation will call :meth:`setup_model` and :meth:`setup_optimizer` on the input lists. + The default implementation will call :meth:`_setup_model` and :meth:`_setup_optimizer` on the input lists. """ # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324 - models = [self.setup_model(model) for model in models] - optimizers = [self.setup_optimizer(optimizer) for optimizer in optimizers] + models = [self._setup_model(model) for model in models] + optimizers = [self._setup_optimizer(optimizer) for optimizer in optimizers] return models, optimizers - def setup_model(self, model: Module) -> Module: + def _setup_model(self, model: Module) -> Module: """Performs setup for the model, e.g., by wrapping it by another class.""" # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324 return model - def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: + def _setup_optimizer(self, optimizer: Optimizer) -> Optimizer: """Performs setup for the optimizer, e.g., by wrapping it by another class.""" # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324 return optimizer From 6a13cf70083e340434ab602b19bc80d0fe8c52a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 11:34:46 +0200 Subject: [PATCH 14/17] add comment mentioning use by Lite --- .../plugins/training_type/training_type_plugin.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index b0b9312e9927b..fc8c3a94af49f 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -66,6 +66,8 @@ def _setup_models_and_optimizers( ) -> Tuple[List[Module], List[Optimizer]]: """Setup multiple models and multiple optimizers together. + Primarily used by Lightning Lite. + The returned objects are expected to be in the same order they were passed in. The default implementation will call :meth:`_setup_model` and :meth:`_setup_optimizer` on the input lists. """ @@ -75,12 +77,12 @@ def _setup_models_and_optimizers( return models, optimizers def _setup_model(self, model: Module) -> Module: - """Performs setup for the model, e.g., by wrapping it by another class.""" + """Performs setup for the model, e.g., by wrapping it by another class. Primarily used by Lightning Lite.""" # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324 return model def _setup_optimizer(self, optimizer: Optimizer) -> Optimizer: - """Performs setup for the optimizer, e.g., by wrapping it by another class.""" + """Performs setup for the optimizer, e.g., by wrapping it by another class. Primarily used by Lightning Lite.""" # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324 return optimizer From 35785f824a03a211d1ee3c89a5c7d1260cd0db96 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Oct 2021 09:36:31 +0000 Subject: [PATCH 15/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../plugins/training_type/training_type_plugin.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index fc8c3a94af49f..d1f48dcb13591 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -77,12 +77,18 @@ def _setup_models_and_optimizers( return models, optimizers def _setup_model(self, model: Module) -> Module: - """Performs setup for the model, e.g., by wrapping it by another class. Primarily used by Lightning Lite.""" + """Performs setup for the model, e.g., by wrapping it by another class. + + Primarily used by Lightning Lite. + """ # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324 return model def _setup_optimizer(self, optimizer: Optimizer) -> Optimizer: - """Performs setup for the optimizer, e.g., by wrapping it by another class. Primarily used by Lightning Lite.""" + """Performs setup for the optimizer, e.g., by wrapping it by another class. + + Primarily used by Lightning Lite. + """ # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324 return optimizer From 4e2fa4e93bae5a59ebed062e64be748e755f0b12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 15:40:47 +0200 Subject: [PATCH 16/17] remove a comment While currently the methods are primarily used by Lite, they will become more important ingredients later on for all Lightning and all plugins --- .../plugins/training_type/training_type_plugin.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index d1f48dcb13591..b0b9312e9927b 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -66,8 +66,6 @@ def _setup_models_and_optimizers( ) -> Tuple[List[Module], List[Optimizer]]: """Setup multiple models and multiple optimizers together. - Primarily used by Lightning Lite. - The returned objects are expected to be in the same order they were passed in. The default implementation will call :meth:`_setup_model` and :meth:`_setup_optimizer` on the input lists. """ @@ -77,18 +75,12 @@ def _setup_models_and_optimizers( return models, optimizers def _setup_model(self, model: Module) -> Module: - """Performs setup for the model, e.g., by wrapping it by another class. - - Primarily used by Lightning Lite. - """ + """Performs setup for the model, e.g., by wrapping it by another class.""" # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324 return model def _setup_optimizer(self, optimizer: Optimizer) -> Optimizer: - """Performs setup for the optimizer, e.g., by wrapping it by another class. - - Primarily used by Lightning Lite. - """ + """Performs setup for the optimizer, e.g., by wrapping it by another class.""" # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324 return optimizer From 2661cbf3f1eb627c840d6d2199fc6211f95d5507 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Oct 2021 13:42:52 +0000 Subject: [PATCH 17/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../plugins/training_type/training_type_plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index b0b9312e9927b..481b9ee1c4087 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -66,8 +66,8 @@ def _setup_models_and_optimizers( ) -> Tuple[List[Module], List[Optimizer]]: """Setup multiple models and multiple optimizers together. - The returned objects are expected to be in the same order they were passed in. - The default implementation will call :meth:`_setup_model` and :meth:`_setup_optimizer` on the input lists. + The returned objects are expected to be in the same order they were passed in. The default implementation will + call :meth:`_setup_model` and :meth:`_setup_optimizer` on the input lists. """ # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324 models = [self._setup_model(model) for model in models]