From 41d58031f94f5621ae9e7a126d6ff4d75be464b8 Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Mon, 8 Aug 2022 18:34:40 -0400 Subject: [PATCH 1/9] init modification --- pyproject.toml | 1 - src/pytorch_lightning/strategies/sharded_spawn.py | 14 +++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8db782df357d8..c442ed6359580 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,6 @@ module = [ "pytorch_lightning.profilers.pytorch", "pytorch_lightning.profilers.simple", "pytorch_lightning.strategies.sharded", - "pytorch_lightning.strategies.sharded_spawn", "pytorch_lightning.trainer.callback_hook", "pytorch_lightning.trainer.connectors.data_connector", "pytorch_lightning.trainer.supporters", diff --git a/src/pytorch_lightning/strategies/sharded_spawn.py b/src/pytorch_lightning/strategies/sharded_spawn.py index 4550e397ded80..8dfa613d0e707 100644 --- a/src/pytorch_lightning/strategies/sharded_spawn.py +++ b/src/pytorch_lightning/strategies/sharded_spawn.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Dict, Generator, List, Optional, Tuple +from typing import Any, Dict, Generator, List, Optional, Tuple from torch import Tensor from torch.nn import Module @@ -42,9 +42,9 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy): def configure_ddp(self) -> None: # set up optimizers after the wrapped module has been moved to the device - self.setup_optimizers(self.lightning_module.trainer) + self.setup_optimizers(self.lightning_module.trainer) # type: ignore self.model, self.optimizers = self._setup_model_and_optimizers( - model=LightningShardedDataParallel(self.model), optimizers=self.optimizers + model=LightningShardedDataParallel(self.model), optimizers=self.optimizers # type: ignore ) optimizers_to_device(self.optimizers, self.root_device) @@ -69,12 +69,12 @@ def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS" return optimizers def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]: - if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING: + if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING: # type: ignore return optimizers return self._reinit_optimizers_with_oss(optimizers) - def optimizer_state(self, optimizer: "OSS") -> Optional[dict]: + def optimizer_state(self, optimizer: "OSS") -> Dict[str, Tensor]: if isinstance(optimizer, OSS): optimizer.consolidate_state_dict() return self._optim_state_dict(optimizer) @@ -93,7 +93,7 @@ def block_backward_sync(self) -> Generator: yield None @rank_zero_only - def _optim_state_dict(self, optimizer): + def _optim_state_dict(self, optimizer: Any) -> Any: """ Retrieves state dict only on rank 0, which contains the entire optimizer state after calling :meth:`consolidate_state_dict`. @@ -112,7 +112,7 @@ def lightning_module(self) -> Optional["pl.LightningModule"]: def pre_backward(self, closure_loss: Tensor) -> None: pass - def post_training_step(self): + def post_training_step(self) -> None: pass @classmethod From 0f27fd03f1cb4a5262b3d16eb75dd2a39f02b3fb Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 9 Aug 2022 12:37:25 +0530 Subject: [PATCH 2/9] fix --- src/pytorch_lightning/overrides/base.py | 2 +- src/pytorch_lightning/strategies/sharded_spawn.py | 13 ++++++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index 26c2837bda7e3..1c5be085a770a 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -54,7 +54,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module): - def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None: + def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase, nn.Module]) -> None: """Wraps the user's LightningModule and redirects the forward call to the appropriate method, either ``training_step``, ``validation_step``, ``test_step``, or ``predict_step``. diff --git a/src/pytorch_lightning/strategies/sharded_spawn.py b/src/pytorch_lightning/strategies/sharded_spawn.py index 8dfa613d0e707..eabcd7d7ce333 100644 --- a/src/pytorch_lightning/strategies/sharded_spawn.py +++ b/src/pytorch_lightning/strategies/sharded_spawn.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Any, Dict, Generator, List, Optional, Tuple +from typing import Dict, Generator, List, Optional, Tuple from torch import Tensor from torch.nn import Module @@ -42,9 +42,11 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy): def configure_ddp(self) -> None: # set up optimizers after the wrapped module has been moved to the device - self.setup_optimizers(self.lightning_module.trainer) # type: ignore + assert self.lightning_module + assert self.model + self.setup_optimizers(self.lightning_module.trainer) self.model, self.optimizers = self._setup_model_and_optimizers( - model=LightningShardedDataParallel(self.model), optimizers=self.optimizers # type: ignore + model=LightningShardedDataParallel(self.model), optimizers=self.optimizers ) optimizers_to_device(self.optimizers, self.root_device) @@ -69,7 +71,8 @@ def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS" return optimizers def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]: - if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING: # type: ignore + assert self.lightning_module + if self.model is not None and self.lightning_module.trainer.state.fn != TrainerFn.FITTING: return optimizers return self._reinit_optimizers_with_oss(optimizers) @@ -93,7 +96,7 @@ def block_backward_sync(self) -> Generator: yield None @rank_zero_only - def _optim_state_dict(self, optimizer: Any) -> Any: + def _optim_state_dict(self, optimizer: Optimizer) -> Dict[str, Tensor]: """ Retrieves state dict only on rank 0, which contains the entire optimizer state after calling :meth:`consolidate_state_dict`. From a66d6928f27d1a94fcabc3379dff949fa73c91a0 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 9 Aug 2022 13:47:39 +0530 Subject: [PATCH 3/9] try fix --- src/pytorch_lightning/overrides/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index 1c5be085a770a..abcec7ea637e3 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -20,6 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin +from pytorch_lightning.core.module import LightningModule class _LightningPrecisionModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module): @@ -75,6 +76,7 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any: trainer = pl_module._trainer if trainer is not None: + assert isinstance(self.module, LightningModule) if trainer.training: output = self.module.training_step(*inputs, **kwargs) # In manual_optimization, we need to prevent DDP reducer as From a974d17799191dfddcfb45c5553ef75347587221 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 9 Aug 2022 14:53:00 +0530 Subject: [PATCH 4/9] fix --- src/pytorch_lightning/overrides/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index abcec7ea637e3..550d94d7b4556 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -20,7 +20,6 @@ import pytorch_lightning as pl from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin -from pytorch_lightning.core.module import LightningModule class _LightningPrecisionModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module): @@ -76,7 +75,7 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any: trainer = pl_module._trainer if trainer is not None: - assert isinstance(self.module, LightningModule) + assert isinstance(self.module, pl.LightningModule) if trainer.training: output = self.module.training_step(*inputs, **kwargs) # In manual_optimization, we need to prevent DDP reducer as From 5e22aef8212512d494d16758512a4c6c901ec1ec Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Tue, 9 Aug 2022 06:45:30 -0400 Subject: [PATCH 5/9] fix --- src/pytorch_lightning/overrides/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index 550d94d7b4556..660948b0d0082 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -75,7 +75,7 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any: trainer = pl_module._trainer if trainer is not None: - assert isinstance(self.module, pl.LightningModule) + assert isinstance(self.module, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) if trainer.training: output = self.module.training_step(*inputs, **kwargs) # In manual_optimization, we need to prevent DDP reducer as From 833d4ac9f3a69e85d3d6ccec2b32af8b6ee5c189 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 11 Aug 2022 02:14:17 +0200 Subject: [PATCH 6/9] avoid expanding types in wrapper --- src/pytorch_lightning/overrides/base.py | 2 +- src/pytorch_lightning/strategies/sharded_spawn.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index 660948b0d0082..3e9fda2f966f5 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -54,7 +54,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module): - def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase, nn.Module]) -> None: + def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None: """Wraps the user's LightningModule and redirects the forward call to the appropriate method, either ``training_step``, ``validation_step``, ``test_step``, or ``predict_step``. diff --git a/src/pytorch_lightning/strategies/sharded_spawn.py b/src/pytorch_lightning/strategies/sharded_spawn.py index eabcd7d7ce333..a45079b6e55f9 100644 --- a/src/pytorch_lightning/strategies/sharded_spawn.py +++ b/src/pytorch_lightning/strategies/sharded_spawn.py @@ -19,6 +19,7 @@ from torch.optim import Optimizer import pytorch_lightning as pl +from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy from pytorch_lightning.trainer.states import TrainerFn @@ -42,9 +43,9 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy): def configure_ddp(self) -> None: # set up optimizers after the wrapped module has been moved to the device - assert self.lightning_module - assert self.model + assert self.lightning_module is not None self.setup_optimizers(self.lightning_module.trainer) + assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) self.model, self.optimizers = self._setup_model_and_optimizers( model=LightningShardedDataParallel(self.model), optimizers=self.optimizers ) From e3ca7d34d40e8146275b168ab9a091ec0cc45b72 Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Thu, 11 Aug 2022 10:05:07 +0200 Subject: [PATCH 7/9] Update src/pytorch_lightning/strategies/sharded_spawn.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- src/pytorch_lightning/strategies/sharded_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/sharded_spawn.py b/src/pytorch_lightning/strategies/sharded_spawn.py index a45079b6e55f9..4a8bff4f6398f 100644 --- a/src/pytorch_lightning/strategies/sharded_spawn.py +++ b/src/pytorch_lightning/strategies/sharded_spawn.py @@ -78,7 +78,7 @@ def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]: return self._reinit_optimizers_with_oss(optimizers) - def optimizer_state(self, optimizer: "OSS") -> Dict[str, Tensor]: + def optimizer_state(self, optimizer: "OSS") -> Dict[str, Any]: if isinstance(optimizer, OSS): optimizer.consolidate_state_dict() return self._optim_state_dict(optimizer) From 775a8aab13f7a73d67eb72e6d10029efd0adec13 Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Thu, 11 Aug 2022 10:05:16 +0200 Subject: [PATCH 8/9] Update src/pytorch_lightning/strategies/sharded_spawn.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- src/pytorch_lightning/strategies/sharded_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/sharded_spawn.py b/src/pytorch_lightning/strategies/sharded_spawn.py index 4a8bff4f6398f..00e53b83ab0a1 100644 --- a/src/pytorch_lightning/strategies/sharded_spawn.py +++ b/src/pytorch_lightning/strategies/sharded_spawn.py @@ -97,7 +97,7 @@ def block_backward_sync(self) -> Generator: yield None @rank_zero_only - def _optim_state_dict(self, optimizer: Optimizer) -> Dict[str, Tensor]: + def _optim_state_dict(self, optimizer: Optimizer) -> Dict[str, Any]: """ Retrieves state dict only on rank 0, which contains the entire optimizer state after calling :meth:`consolidate_state_dict`. From 45de7db3d0311c166cdb85265472e69a65e1983d Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 11 Aug 2022 17:20:19 +0200 Subject: [PATCH 9/9] add missing import --- src/pytorch_lightning/strategies/sharded_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/sharded_spawn.py b/src/pytorch_lightning/strategies/sharded_spawn.py index 00e53b83ab0a1..882302e101cb6 100644 --- a/src/pytorch_lightning/strategies/sharded_spawn.py +++ b/src/pytorch_lightning/strategies/sharded_spawn.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Dict, Generator, List, Optional, Tuple +from typing import Any, Dict, Generator, List, Optional, Tuple from torch import Tensor from torch.nn import Module