3232import pytorch_lightning as pl
3333from pytorch_lightning .core .optimizer import LightningOptimizer
3434from pytorch_lightning .overrides import LightningDistributedModule
35+ from pytorch_lightning .overrides .base import _LightningPrecisionModuleWrapperBase
3536from pytorch_lightning .overrides .distributed import prepare_for_backward
3637from pytorch_lightning .overrides .fairscale import _FAIRSCALE_AVAILABLE
3738from pytorch_lightning .plugins .environments .cluster_environment import ClusterEnvironment
3839from pytorch_lightning .plugins .io .checkpoint_plugin import CheckpointIO
3940from pytorch_lightning .plugins .precision import PrecisionPlugin
4041from pytorch_lightning .strategies .launchers .subprocess_script import _SubprocessScriptLauncher
4142from pytorch_lightning .strategies .parallel import ParallelStrategy
43+ from pytorch_lightning .strategies .strategy import TBroadcast
4244from pytorch_lightning .trainer .states import TrainerFn
4345from pytorch_lightning .utilities .distributed import (
4446 _get_process_group_backend_from_env ,
5759from pytorch_lightning .utilities .optimizer import optimizers_to_device
5860from pytorch_lightning .utilities .rank_zero import rank_zero_info , rank_zero_only , rank_zero_warn
5961from pytorch_lightning .utilities .seed import reset_seed
60- from pytorch_lightning .utilities .types import STEP_OUTPUT
62+ from pytorch_lightning .utilities .types import PredictStep , STEP_OUTPUT , TestStep , ValidationStep
6163
6264if _FAIRSCALE_AVAILABLE :
6365 from fairscale .optim import OSS
@@ -83,12 +85,12 @@ def __init__(
8385 checkpoint_io : Optional [CheckpointIO ] = None ,
8486 precision_plugin : Optional [PrecisionPlugin ] = None ,
8587 ddp_comm_state : Optional [object ] = None ,
86- ddp_comm_hook : Optional [callable ] = None ,
87- ddp_comm_wrapper : Optional [callable ] = None ,
88+ ddp_comm_hook : Optional [Callable ] = None ,
89+ ddp_comm_wrapper : Optional [Callable ] = None ,
8890 model_averaging_period : Optional [int ] = None ,
8991 process_group_backend : Optional [str ] = None ,
9092 timeout : Optional [timedelta ] = default_pg_timeout ,
91- ** kwargs : Union [ Any , Dict [ str , Any ]] ,
93+ ** kwargs : Any ,
9294 ) -> None :
9395 super ().__init__ (
9496 accelerator = accelerator ,
@@ -105,7 +107,7 @@ def __init__(
105107 self ._ddp_comm_wrapper = ddp_comm_wrapper
106108 self ._model_averaging_period = model_averaging_period
107109 self ._model_averager : Optional [ModelAverager ] = None
108- self ._pids : Optional [ List [int ]] = None
110+ self ._pids : List [int ] = []
109111 self ._sync_dir : Optional [str ] = None
110112 self ._rank_0_will_call_children_scripts : bool = False
111113 self ._process_group_backend : Optional [str ] = process_group_backend
@@ -117,6 +119,7 @@ def is_distributed(self) -> bool:
117119
118120 @property
119121 def root_device (self ) -> torch .device :
122+ assert self .parallel_devices is not None
120123 return self .parallel_devices [self .local_rank ]
121124
122125 @property
@@ -129,11 +132,11 @@ def num_nodes(self, num_nodes: int) -> None:
129132 self ._num_nodes = num_nodes
130133
131134 @property
132- def num_processes (self ):
135+ def num_processes (self ) -> int :
133136 return len (self .parallel_devices ) if self .parallel_devices is not None else 0
134137
135138 @property
136- def distributed_sampler_kwargs (self ):
139+ def distributed_sampler_kwargs (self ) -> Dict [ str , Any ] :
137140 distributed_sampler_kwargs = dict (num_replicas = (self .num_nodes * self .num_processes ), rank = self .global_rank )
138141 return distributed_sampler_kwargs
139142
@@ -146,6 +149,7 @@ def process_group_backend(self) -> Optional[str]:
146149 return self ._process_group_backend
147150
148151 def _configure_launcher (self ) -> None :
152+ assert self .cluster_environment is not None
149153 if not self .cluster_environment .creates_processes_externally :
150154 self ._launcher = _SubprocessScriptLauncher (self .cluster_environment , self .num_processes , self .num_nodes )
151155 self ._rank_0_will_call_children_scripts = True
@@ -156,10 +160,11 @@ def setup_environment(self) -> None:
156160
157161 def setup (self , trainer : "pl.Trainer" ) -> None :
158162 # share ddp pids to all processes
159- self ._rank_0_will_call_children_scripts = self .broadcast (self ._rank_0_will_call_children_scripts )
163+ self ._rank_0_will_call_children_scripts = bool ( self .broadcast (self ._rank_0_will_call_children_scripts ) )
160164 if self ._should_run_deadlock_detection ():
161165 self ._share_information_to_prevent_deadlock ()
162166
167+ assert self .accelerator is not None
163168 self .accelerator .setup (trainer )
164169
165170 # move the model to the correct device
@@ -170,6 +175,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
170175
171176 if trainer_fn == TrainerFn .FITTING :
172177 if self ._layer_sync :
178+ assert self .model is not None
173179 self .model = self ._layer_sync .apply (self .model )
174180
175181 self .setup_precision_plugin ()
@@ -193,7 +199,7 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
193199 log .detail (f"setting up DDP model with device ids: { device_ids } , kwargs: { self ._ddp_kwargs } " )
194200 return DistributedDataParallel (module = model , device_ids = device_ids , ** self ._ddp_kwargs )
195201
196- def setup_distributed (self ):
202+ def setup_distributed (self ) -> None :
197203 log .detail (f"{ self .__class__ .__name__ } : setting up distributed..." )
198204 reset_seed ()
199205
@@ -204,6 +210,7 @@ def setup_distributed(self):
204210 rank_zero_only .rank = self .global_rank
205211
206212 self ._process_group_backend = self ._get_process_group_backend ()
213+ assert self .cluster_environment is not None
207214 init_dist_connection (self .cluster_environment , self ._process_group_backend , timeout = self ._timeout )
208215
209216 def _get_process_group_backend (self ) -> str :
@@ -230,6 +237,7 @@ def pre_configure_ddp(self) -> None:
230237 def _register_ddp_hooks (self ) -> None :
231238 log .detail (f"{ self .__class__ .__name__ } : registering ddp hooks" )
232239 if self .root_device .type == "cuda" and self ._is_single_process_single_device :
240+ assert isinstance (self .model , DistributedDataParallel )
233241 register_ddp_comm_hook (
234242 model = self .model ,
235243 ddp_comm_state = self ._ddp_comm_state ,
@@ -262,6 +270,7 @@ def _enable_model_averaging(self) -> None:
262270 f"{ optimizer .__class__ .__name__ } ."
263271 )
264272
273+ assert self ._ddp_comm_state is not None
265274 self ._model_averager = torch .distributed .algorithms .model_averaging .averagers .PeriodicModelAverager (
266275 period = self ._model_averaging_period , warmup_steps = self ._ddp_comm_state .start_localSGD_iter
267276 )
@@ -296,39 +305,46 @@ def optimizer_step(
296305 def configure_ddp (self ) -> None :
297306 log .detail (f"{ self .__class__ .__name__ } : configuring DistributedDataParallel" )
298307 self .pre_configure_ddp ()
308+ assert isinstance (self .model , (pl .LightningModule , _LightningPrecisionModuleWrapperBase ))
299309 self .model = self ._setup_model (LightningDistributedModule (self .model ))
300310 self ._register_ddp_hooks ()
301311
302- def determine_ddp_device_ids (self ):
312+ def determine_ddp_device_ids (self ) -> Optional [ List [ int ]] :
303313 if self .root_device .type == "cpu" :
304314 return None
305315 return [self .root_device .index ]
306316
307- def barrier (self , * args , ** kwargs ) -> None :
317+ def barrier (self , * args : Any , ** kwargs : Any ) -> None :
308318 if not distributed_available ():
309319 return
310320 if torch .distributed .get_backend () == "nccl" :
311321 torch .distributed .barrier (device_ids = self .determine_ddp_device_ids ())
312322 else :
313323 torch .distributed .barrier ()
314324
315- def broadcast (self , obj : object , src : int = 0 ) -> object :
325+ def broadcast (self , obj : TBroadcast , src : int = 0 ) -> TBroadcast :
316326 obj = [obj ]
317327 if self .global_rank != src :
318- obj = [None ]
328+ obj = [None ] # type: ignore[list-item]
319329 torch .distributed .broadcast_object_list (obj , src , group = _group .WORLD )
320330 return obj [0 ]
321331
322332 def pre_backward (self , closure_loss : Tensor ) -> None :
323333 """Run before precision plugin executes backward."""
334+ if not isinstance (self .model , DistributedDataParallel ):
335+ return
336+ assert self .lightning_module is not None
324337 if not self .lightning_module .automatic_optimization :
325338 prepare_for_backward (self .model , closure_loss )
326339
327- def model_to_device (self ):
340+ def model_to_device (self ) -> None :
328341 log .detail (f"{ self .__class__ .__name__ } : moving model to device [{ self .root_device } ]..." )
342+ assert self .model is not None
329343 self .model .to (self .root_device )
330344
331- def reduce (self , tensor , group : Optional [Any ] = None , reduce_op : Union [ReduceOp , str ] = "mean" ) -> Tensor :
345+ def reduce (
346+ self , tensor : Tensor , group : Optional [Any ] = None , reduce_op : Optional [Union [ReduceOp , str ]] = "mean"
347+ ) -> Tensor :
332348 """Reduces a tensor from several distributed processes to one aggregated tensor.
333349
334350 Args:
@@ -344,30 +360,38 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp,
344360 tensor = sync_ddp_if_available (tensor , group , reduce_op = reduce_op )
345361 return tensor
346362
347- def training_step (self , * args , ** kwargs ) -> STEP_OUTPUT :
363+ def training_step (self , * args : Any , ** kwargs : Any ) -> STEP_OUTPUT :
364+ assert self .model is not None
348365 with self .precision_plugin .train_step_context ():
349366 return self .model (* args , ** kwargs )
350367
351- def validation_step (self , * args , ** kwargs ) -> Optional [STEP_OUTPUT ]:
368+ def validation_step (self , * args : Any , ** kwargs : Any ) -> Optional [STEP_OUTPUT ]:
352369 with self .precision_plugin .val_step_context ():
370+ assert self .lightning_module is not None
371+ assert self .model is not None
353372 if self .lightning_module .trainer .state .fn == TrainerFn .FITTING :
354373 # used when calling `trainer.fit`
355374 return self .model (* args , ** kwargs )
356375 else :
357376 # used when calling `trainer.validate`
377+ assert isinstance (self .model , ValidationStep )
358378 return self .model .validation_step (* args , ** kwargs )
359379
360- def test_step (self , * args , ** kwargs ) -> Optional [STEP_OUTPUT ]:
380+ def test_step (self , * args : Any , ** kwargs : Any ) -> Optional [STEP_OUTPUT ]:
361381 with self .precision_plugin .test_step_context ():
382+ assert isinstance (self .model , TestStep )
362383 return self .model .test_step (* args , ** kwargs )
363384
364- def predict_step (self , * args , ** kwargs ) -> STEP_OUTPUT :
385+ def predict_step (self , * args : Any , ** kwargs : Any ) -> STEP_OUTPUT :
365386 with self .precision_plugin .predict_step_context ():
387+ assert isinstance (self .model , PredictStep )
366388 return self .model .predict_step (* args , ** kwargs )
367389
368- def post_training_step (self ):
390+ def post_training_step (self ) -> None :
391+ assert self .lightning_module is not None
369392 if not self .lightning_module .automatic_optimization :
370- self .model .require_backward_grad_sync = True
393+ assert self .model is not None
394+ self .model .require_backward_grad_sync = True # type: ignore[assignment]
371395
372396 @classmethod
373397 def register_strategies (cls , strategy_registry : Dict ) -> None :
@@ -458,7 +482,7 @@ def teardown(self) -> None:
458482 if (
459483 _TORCH_GREATER_EQUAL_1_11
460484 and not self .model .static_graph
461- and self .model ._get_ddp_logging_data ().get ("can_set_static_graph" )
485+ and self .model ._get_ddp_logging_data ().get ("can_set_static_graph" ) # type: ignore[operator]
462486 ):
463487 rank_zero_info (
464488 "Your model can run with static graph optimizations. For future training runs, we suggest you"
@@ -475,6 +499,7 @@ def teardown(self) -> None:
475499 and pl_module ._trainer .state .fn == TrainerFn .FITTING
476500 and self ._layer_sync
477501 ):
502+ assert self .model is not None
478503 self .model = self ._layer_sync .revert (self .model )
479504
480505 super ().teardown ()
0 commit comments