Skip to content

Commit 890156a

Browse files
lijm1358awaelchli
andauthored
Fix mypy errors in pytorch_lightning/strategies/ddp.py (#13885)
Co-authored-by: awaelchli <[email protected]>
1 parent 61a9f3a commit 890156a

File tree

5 files changed

+51
-28
lines changed

5 files changed

+51
-28
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ module = [
5959
"pytorch_lightning.profilers.base",
6060
"pytorch_lightning.profilers.pytorch",
6161
"pytorch_lightning.profilers.simple",
62-
"pytorch_lightning.strategies.ddp",
6362
"pytorch_lightning.strategies.sharded",
6463
"pytorch_lightning.strategies.sharded_spawn",
6564
"pytorch_lightning.trainer.callback_hook",

src/pytorch_lightning/overrides/distributed.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ def _find_tensors(
4545
# https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638
4646
def prepare_for_backward(model: DistributedDataParallel, output: Any) -> None:
4747
# `prepare_for_backward` is `DistributedDataParallel` specific.
48-
if not isinstance(model, DistributedDataParallel):
49-
return
5048
if torch.is_grad_enabled() and model.require_backward_grad_sync:
5149
model.require_forward_param_sync = True # type: ignore[assignment]
5250
# We'll return the output object verbatim since it is a freeform

src/pytorch_lightning/strategies/ddp.py

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,15 @@
3232
import pytorch_lightning as pl
3333
from pytorch_lightning.core.optimizer import LightningOptimizer
3434
from pytorch_lightning.overrides import LightningDistributedModule
35+
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
3536
from pytorch_lightning.overrides.distributed import prepare_for_backward
3637
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
3738
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
3839
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
3940
from pytorch_lightning.plugins.precision import PrecisionPlugin
4041
from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
4142
from pytorch_lightning.strategies.parallel import ParallelStrategy
43+
from pytorch_lightning.strategies.strategy import TBroadcast
4244
from pytorch_lightning.trainer.states import TrainerFn
4345
from pytorch_lightning.utilities.distributed import (
4446
_get_process_group_backend_from_env,
@@ -57,7 +59,7 @@
5759
from pytorch_lightning.utilities.optimizer import optimizers_to_device
5860
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
5961
from pytorch_lightning.utilities.seed import reset_seed
60-
from pytorch_lightning.utilities.types import STEP_OUTPUT
62+
from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep
6163

6264
if _FAIRSCALE_AVAILABLE:
6365
from fairscale.optim import OSS
@@ -83,12 +85,12 @@ def __init__(
8385
checkpoint_io: Optional[CheckpointIO] = None,
8486
precision_plugin: Optional[PrecisionPlugin] = None,
8587
ddp_comm_state: Optional[object] = None,
86-
ddp_comm_hook: Optional[callable] = None,
87-
ddp_comm_wrapper: Optional[callable] = None,
88+
ddp_comm_hook: Optional[Callable] = None,
89+
ddp_comm_wrapper: Optional[Callable] = None,
8890
model_averaging_period: Optional[int] = None,
8991
process_group_backend: Optional[str] = None,
9092
timeout: Optional[timedelta] = default_pg_timeout,
91-
**kwargs: Union[Any, Dict[str, Any]],
93+
**kwargs: Any,
9294
) -> None:
9395
super().__init__(
9496
accelerator=accelerator,
@@ -105,7 +107,7 @@ def __init__(
105107
self._ddp_comm_wrapper = ddp_comm_wrapper
106108
self._model_averaging_period = model_averaging_period
107109
self._model_averager: Optional[ModelAverager] = None
108-
self._pids: Optional[List[int]] = None
110+
self._pids: List[int] = []
109111
self._sync_dir: Optional[str] = None
110112
self._rank_0_will_call_children_scripts: bool = False
111113
self._process_group_backend: Optional[str] = process_group_backend
@@ -117,6 +119,7 @@ def is_distributed(self) -> bool:
117119

118120
@property
119121
def root_device(self) -> torch.device:
122+
assert self.parallel_devices is not None
120123
return self.parallel_devices[self.local_rank]
121124

122125
@property
@@ -129,11 +132,11 @@ def num_nodes(self, num_nodes: int) -> None:
129132
self._num_nodes = num_nodes
130133

131134
@property
132-
def num_processes(self):
135+
def num_processes(self) -> int:
133136
return len(self.parallel_devices) if self.parallel_devices is not None else 0
134137

135138
@property
136-
def distributed_sampler_kwargs(self):
139+
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
137140
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
138141
return distributed_sampler_kwargs
139142

@@ -146,6 +149,7 @@ def process_group_backend(self) -> Optional[str]:
146149
return self._process_group_backend
147150

148151
def _configure_launcher(self) -> None:
152+
assert self.cluster_environment is not None
149153
if not self.cluster_environment.creates_processes_externally:
150154
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
151155
self._rank_0_will_call_children_scripts = True
@@ -156,10 +160,11 @@ def setup_environment(self) -> None:
156160

157161
def setup(self, trainer: "pl.Trainer") -> None:
158162
# share ddp pids to all processes
159-
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
163+
self._rank_0_will_call_children_scripts = bool(self.broadcast(self._rank_0_will_call_children_scripts))
160164
if self._should_run_deadlock_detection():
161165
self._share_information_to_prevent_deadlock()
162166

167+
assert self.accelerator is not None
163168
self.accelerator.setup(trainer)
164169

165170
# move the model to the correct device
@@ -170,6 +175,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
170175

171176
if trainer_fn == TrainerFn.FITTING:
172177
if self._layer_sync:
178+
assert self.model is not None
173179
self.model = self._layer_sync.apply(self.model)
174180

175181
self.setup_precision_plugin()
@@ -193,7 +199,7 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
193199
log.detail(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
194200
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)
195201

196-
def setup_distributed(self):
202+
def setup_distributed(self) -> None:
197203
log.detail(f"{self.__class__.__name__}: setting up distributed...")
198204
reset_seed()
199205

@@ -204,6 +210,7 @@ def setup_distributed(self):
204210
rank_zero_only.rank = self.global_rank
205211

206212
self._process_group_backend = self._get_process_group_backend()
213+
assert self.cluster_environment is not None
207214
init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
208215

209216
def _get_process_group_backend(self) -> str:
@@ -230,6 +237,7 @@ def pre_configure_ddp(self) -> None:
230237
def _register_ddp_hooks(self) -> None:
231238
log.detail(f"{self.__class__.__name__}: registering ddp hooks")
232239
if self.root_device.type == "cuda" and self._is_single_process_single_device:
240+
assert isinstance(self.model, DistributedDataParallel)
233241
register_ddp_comm_hook(
234242
model=self.model,
235243
ddp_comm_state=self._ddp_comm_state,
@@ -262,6 +270,7 @@ def _enable_model_averaging(self) -> None:
262270
f"{optimizer.__class__.__name__}."
263271
)
264272

273+
assert self._ddp_comm_state is not None
265274
self._model_averager = torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager(
266275
period=self._model_averaging_period, warmup_steps=self._ddp_comm_state.start_localSGD_iter
267276
)
@@ -296,39 +305,46 @@ def optimizer_step(
296305
def configure_ddp(self) -> None:
297306
log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel")
298307
self.pre_configure_ddp()
308+
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
299309
self.model = self._setup_model(LightningDistributedModule(self.model))
300310
self._register_ddp_hooks()
301311

302-
def determine_ddp_device_ids(self):
312+
def determine_ddp_device_ids(self) -> Optional[List[int]]:
303313
if self.root_device.type == "cpu":
304314
return None
305315
return [self.root_device.index]
306316

307-
def barrier(self, *args, **kwargs) -> None:
317+
def barrier(self, *args: Any, **kwargs: Any) -> None:
308318
if not distributed_available():
309319
return
310320
if torch.distributed.get_backend() == "nccl":
311321
torch.distributed.barrier(device_ids=self.determine_ddp_device_ids())
312322
else:
313323
torch.distributed.barrier()
314324

315-
def broadcast(self, obj: object, src: int = 0) -> object:
325+
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
316326
obj = [obj]
317327
if self.global_rank != src:
318-
obj = [None]
328+
obj = [None] # type: ignore[list-item]
319329
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
320330
return obj[0]
321331

322332
def pre_backward(self, closure_loss: Tensor) -> None:
323333
"""Run before precision plugin executes backward."""
334+
if not isinstance(self.model, DistributedDataParallel):
335+
return
336+
assert self.lightning_module is not None
324337
if not self.lightning_module.automatic_optimization:
325338
prepare_for_backward(self.model, closure_loss)
326339

327-
def model_to_device(self):
340+
def model_to_device(self) -> None:
328341
log.detail(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")
342+
assert self.model is not None
329343
self.model.to(self.root_device)
330344

331-
def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> Tensor:
345+
def reduce(
346+
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
347+
) -> Tensor:
332348
"""Reduces a tensor from several distributed processes to one aggregated tensor.
333349
334350
Args:
@@ -344,30 +360,38 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp,
344360
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
345361
return tensor
346362

347-
def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
363+
def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
364+
assert self.model is not None
348365
with self.precision_plugin.train_step_context():
349366
return self.model(*args, **kwargs)
350367

351-
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
368+
def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
352369
with self.precision_plugin.val_step_context():
370+
assert self.lightning_module is not None
371+
assert self.model is not None
353372
if self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
354373
# used when calling `trainer.fit`
355374
return self.model(*args, **kwargs)
356375
else:
357376
# used when calling `trainer.validate`
377+
assert isinstance(self.model, ValidationStep)
358378
return self.model.validation_step(*args, **kwargs)
359379

360-
def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
380+
def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
361381
with self.precision_plugin.test_step_context():
382+
assert isinstance(self.model, TestStep)
362383
return self.model.test_step(*args, **kwargs)
363384

364-
def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
385+
def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
365386
with self.precision_plugin.predict_step_context():
387+
assert isinstance(self.model, PredictStep)
366388
return self.model.predict_step(*args, **kwargs)
367389

368-
def post_training_step(self):
390+
def post_training_step(self) -> None:
391+
assert self.lightning_module is not None
369392
if not self.lightning_module.automatic_optimization:
370-
self.model.require_backward_grad_sync = True
393+
assert self.model is not None
394+
self.model.require_backward_grad_sync = True # type: ignore[assignment]
371395

372396
@classmethod
373397
def register_strategies(cls, strategy_registry: Dict) -> None:
@@ -458,7 +482,7 @@ def teardown(self) -> None:
458482
if (
459483
_TORCH_GREATER_EQUAL_1_11
460484
and not self.model.static_graph
461-
and self.model._get_ddp_logging_data().get("can_set_static_graph")
485+
and self.model._get_ddp_logging_data().get("can_set_static_graph") # type: ignore[operator]
462486
):
463487
rank_zero_info(
464488
"Your model can run with static graph optimizations. For future training runs, we suggest you"
@@ -475,6 +499,7 @@ def teardown(self) -> None:
475499
and pl_module._trainer.state.fn == TrainerFn.FITTING
476500
and self._layer_sync
477501
):
502+
assert self.model is not None
478503
self.model = self._layer_sync.revert(self.model)
479504

480505
super().teardown()

src/pytorch_lightning/strategies/ddp_spawn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,10 @@ def model_to_device(self) -> None:
254254

255255
def pre_backward(self, closure_loss: Tensor) -> None:
256256
"""Run before precision plugin executes backward."""
257+
if not isinstance(self.model, DistributedDataParallel):
258+
return
257259
assert self.lightning_module is not None
258260
if not self.lightning_module.automatic_optimization:
259-
assert isinstance(self.model, DistributedDataParallel)
260261
prepare_for_backward(self.model, closure_loss)
261262

262263
def reduce(

src/pytorch_lightning/strategies/deepspeed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import platform
2020
from collections import OrderedDict
2121
from pathlib import Path
22-
from typing import Any, cast, Dict, Generator, List, Mapping, Optional, Tuple, Union
22+
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
2323

2424
import torch
2525
from torch import Tensor
@@ -831,7 +831,7 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
831831
if self.load_full_weights and self.zero_stage_3:
832832
# Broadcast to ensure we load from the rank 0 checkpoint
833833
# This doesn't have to be the case when using deepspeed sharded checkpointing
834-
checkpoint_path = cast(_PATH, self.broadcast(checkpoint_path))
834+
checkpoint_path = self.broadcast(checkpoint_path)
835835
return super().load_checkpoint(checkpoint_path)
836836

837837
# Rely on deepspeed to load the checkpoint and necessary information

0 commit comments

Comments
 (0)