Skip to content

Commit ff39cb7

Browse files
committed
fix mypy errors
1 parent 9fb3b46 commit ff39cb7

File tree

1 file changed

+20
-11
lines changed
  • src/pytorch_lightning/strategies

1 file changed

+20
-11
lines changed

src/pytorch_lightning/strategies/ddp.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
from pytorch_lightning.utilities.optimizer import optimizers_to_device
6060
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
6161
from pytorch_lightning.utilities.seed import reset_seed
62-
from pytorch_lightning.utilities.types import STEP_OUTPUT
62+
from pytorch_lightning.utilities.types import STEP_OUTPUT, TestStep, TrainingStep, ValidationStep
6363

6464
if _FAIRSCALE_AVAILABLE:
6565
from fairscale.optim import OSS
@@ -333,13 +333,15 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
333333
def pre_backward(self, closure_loss: Tensor) -> None:
334334
"""Run before precision plugin executes backward."""
335335
if isinstance(self.lightning_module, LightningModule) and not self.lightning_module.automatic_optimization:
336+
assert isinstance(self.model, DistributedDataParallel)
336337
prepare_for_backward(self.model, closure_loss)
337338

338-
def model_to_device(self):
339+
def model_to_device(self) -> None:
339340
log.detail(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")
341+
assert self.model is not None
340342
self.model.to(self.root_device)
341343

342-
def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> Tensor:
344+
def reduce(self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean") -> Tensor:
343345
"""Reduces a tensor from several distributed processes to one aggregated tensor.
344346
345347
Args:
@@ -355,30 +357,35 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp,
355357
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
356358
return tensor
357359

358-
def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
360+
def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
359361
with self.precision_plugin.train_step_context():
362+
assert isinstance(self.model, TrainingStep)
360363
return self.model(*args, **kwargs)
361364

362-
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
365+
def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
363366
with self.precision_plugin.val_step_context():
367+
assert isinstance(self.model, ValidationStep)
364368
if self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
365369
# used when calling `trainer.fit`
366370
return self.model(*args, **kwargs)
367371
else:
368372
# used when calling `trainer.validate`
369373
return self.model.validation_step(*args, **kwargs)
370374

371-
def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
375+
def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
372376
with self.precision_plugin.test_step_context():
377+
assert isinstance(self.model, TestStep)
373378
return self.model.test_step(*args, **kwargs)
374379

375-
def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
380+
def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
376381
with self.precision_plugin.predict_step_context():
382+
assert isinstance(self.model, TestStep)
377383
return self.model.predict_step(*args, **kwargs)
378384

379-
def post_training_step(self):
380-
if not self.lightning_module.automatic_optimization:
381-
self.model.require_backward_grad_sync = True
385+
def post_training_step(self) -> None:
386+
if isinstance(self.lightning_module, LightningModule) and not self.lightning_module.automatic_optimization:
387+
assert self.model is not None
388+
self.model.require_backward_grad_sync = True # type: ignore[assignment]
382389

383390
@classmethod
384391
def register_strategies(cls, strategy_registry: Dict) -> None:
@@ -455,6 +462,7 @@ def reconciliate_processes(self, trace: str) -> None:
455462
if len(os.listdir(sync_dir)) == (self.world_size // self.num_nodes):
456463
return
457464

465+
assert self._pids is not None
458466
for pid in self._pids:
459467
if pid != os.getpid():
460468
os.kill(pid, signal.SIGKILL)
@@ -469,7 +477,7 @@ def teardown(self) -> None:
469477
if (
470478
_TORCH_GREATER_EQUAL_1_11
471479
and not self.model.static_graph
472-
and self.model._get_ddp_logging_data().get("can_set_static_graph")
480+
and self.model._get_ddp_logging_data().get("can_set_static_graph") # type: ignore[operator]
473481
):
474482
rank_zero_info(
475483
"Your model can run with static graph optimizations. For future training runs, we suggest you"
@@ -486,6 +494,7 @@ def teardown(self) -> None:
486494
and pl_module._trainer.state.fn == TrainerFn.FITTING
487495
and self._layer_sync
488496
):
497+
assert self.model is not None
489498
self.model = self._layer_sync.revert(self.model)
490499

491500
super().teardown()

0 commit comments

Comments
 (0)