Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ module = [
"pytorch_lightning.profilers.base",
"pytorch_lightning.profilers.pytorch",
"pytorch_lightning.strategies.sharded",
"pytorch_lightning.strategies.sharded_spawn",
"pytorch_lightning.trainer.callback_hook",
"pytorch_lightning.trainer.connectors.data_connector",
"pytorch_lightning.trainer.supporters",
Expand Down
1 change: 1 addition & 0 deletions src/pytorch_lightning/overrides/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any:
trainer = pl_module._trainer

if trainer is not None:
assert isinstance(self.module, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
if trainer.training:
output = self.module.training_step(*inputs, **kwargs)
# In manual_optimization, we need to prevent DDP reducer as
Expand Down
14 changes: 9 additions & 5 deletions src/pytorch_lightning/strategies/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Dict, Generator, List, Optional, Tuple
from typing import Any, Dict, Generator, List, Optional, Tuple

from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
from pytorch_lightning.trainer.states import TrainerFn
Expand All @@ -42,7 +43,9 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy):

def configure_ddp(self) -> None:
# set up optimizers after the wrapped module has been moved to the device
assert self.lightning_module is not None
self.setup_optimizers(self.lightning_module.trainer)
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
self.model, self.optimizers = self._setup_model_and_optimizers(
model=LightningShardedDataParallel(self.model), optimizers=self.optimizers
)
Expand All @@ -69,12 +72,13 @@ def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"
return optimizers

def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
assert self.lightning_module
if self.model is not None and self.lightning_module.trainer.state.fn != TrainerFn.FITTING:
return optimizers

return self._reinit_optimizers_with_oss(optimizers)

def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
def optimizer_state(self, optimizer: "OSS") -> Dict[str, Any]:
if isinstance(optimizer, OSS):
optimizer.consolidate_state_dict()
return self._optim_state_dict(optimizer)
Expand All @@ -93,7 +97,7 @@ def block_backward_sync(self) -> Generator:
yield None

@rank_zero_only
def _optim_state_dict(self, optimizer):
def _optim_state_dict(self, optimizer: Optimizer) -> Dict[str, Any]:
"""
Retrieves state dict only on rank 0, which contains the entire optimizer state after calling
:meth:`consolidate_state_dict`.
Expand All @@ -112,7 +116,7 @@ def lightning_module(self) -> Optional["pl.LightningModule"]:
def pre_backward(self, closure_loss: Tensor) -> None:
pass

def post_training_step(self):
def post_training_step(self) -> None:
pass

@classmethod
Expand Down