Skip to content

Commit 03f2f32

Browse files
lijm1358otaj
andauthored
Fix mypy errors in pytorch_lightning/strategies/sharded.py (#14184)
Co-authored-by: otaj <[email protected]>
1 parent af688de commit 03f2f32

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ module = [
5252
"pytorch_lightning.callbacks.progress.rich_progress",
5353
"pytorch_lightning.profilers.base",
5454
"pytorch_lightning.profilers.pytorch",
55-
"pytorch_lightning.strategies.sharded",
5655
"pytorch_lightning.trainer.callback_hook",
5756
"pytorch_lightning.trainer.supporters",
5857
"pytorch_lightning.trainer.trainer",

src/pytorch_lightning/strategies/sharded.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from contextlib import contextmanager
15-
from typing import Dict, Generator, List, Tuple, Union
15+
from typing import Dict, Generator, List, Tuple
1616

1717
from torch import Tensor
1818
from torch.nn import Module
1919
from torch.optim import Optimizer
2020

2121
import pytorch_lightning as pl
2222
from pytorch_lightning.core.optimizer import LightningOptimizer
23-
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
23+
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
2424
from pytorch_lightning.strategies.ddp import DDPStrategy
2525
from pytorch_lightning.trainer.states import TrainerFn
2626
from pytorch_lightning.utilities.enums import PrecisionType
@@ -51,10 +51,11 @@ def connect(self, model: "pl.LightningModule") -> None:
5151

5252
def setup(self, trainer: "pl.Trainer") -> None:
5353
# share ddp pids to all processes
54-
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
54+
self._rank_0_will_call_children_scripts: bool = self.broadcast(self._rank_0_will_call_children_scripts)
5555
if self._should_run_deadlock_detection():
5656
self._share_information_to_prevent_deadlock()
5757

58+
assert self.accelerator is not None
5859
self.accelerator.setup(trainer)
5960

6061
# move the model to the correct device
@@ -64,6 +65,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
6465
trainer_fn = trainer.state.fn
6566
if trainer_fn == TrainerFn.FITTING:
6667
if self._layer_sync:
68+
assert self.model is not None
6769
self.model = self._layer_sync.apply(self.model)
6870

6971
self.setup_precision_plugin()
@@ -73,7 +75,9 @@ def setup(self, trainer: "pl.Trainer") -> None:
7375

7476
def configure_ddp(self) -> None:
7577
self._set_ddp_kwargs()
76-
self.setup_optimizers(self.model.trainer)
78+
assert self.lightning_module is not None
79+
self.setup_optimizers(self.lightning_module.trainer)
80+
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
7781
self.model, self.optimizers = self._setup_model_and_optimizers(
7882
model=_LightningModuleWrapperBase(self.model),
7983
optimizers=self.optimizers,
@@ -97,12 +101,13 @@ def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]
97101
return model, optimizers
98102

99103
def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
100-
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
104+
assert self.lightning_module is not None
105+
if self.model is not None and self.lightning_module.trainer.state.fn != TrainerFn.FITTING:
101106
return optimizers
102107

103108
return self._reinit_optimizers_with_oss(optimizers)
104109

105-
def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]:
110+
def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]:
106111
for x, optimizer in enumerate(optimizers):
107112
if isinstance(optimizer, LightningOptimizer):
108113
optimizer = optimizer._optimizer
@@ -135,7 +140,7 @@ def block_backward_sync(self) -> Generator:
135140
else:
136141
yield None
137142

138-
def post_training_step(self):
143+
def post_training_step(self) -> None:
139144
pass
140145

141146
@classmethod

0 commit comments

Comments
 (0)