Skip to content

Commit 833d4ac

Browse files
committed
avoid expanding types in wrapper
1 parent cb2ba76 commit 833d4ac

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

src/pytorch_lightning/overrides/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
5454

5555

5656
class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
57-
def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase, nn.Module]) -> None:
57+
def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None:
5858
"""Wraps the user's LightningModule and redirects the forward call to the appropriate method, either
5959
``training_step``, ``validation_step``, ``test_step``, or ``predict_step``.
6060

src/pytorch_lightning/strategies/sharded_spawn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch.optim import Optimizer
2020

2121
import pytorch_lightning as pl
22+
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
2223
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
2324
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
2425
from pytorch_lightning.trainer.states import TrainerFn
@@ -42,9 +43,9 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy):
4243

4344
def configure_ddp(self) -> None:
4445
# set up optimizers after the wrapped module has been moved to the device
45-
assert self.lightning_module
46-
assert self.model
46+
assert self.lightning_module is not None
4747
self.setup_optimizers(self.lightning_module.trainer)
48+
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
4849
self.model, self.optimizers = self._setup_model_and_optimizers(
4950
model=LightningShardedDataParallel(self.model), optimizers=self.optimizers
5051
)

0 commit comments

Comments
 (0)