5959from pytorch_lightning .utilities .optimizer import optimizers_to_device
6060from pytorch_lightning .utilities .rank_zero import rank_zero_info , rank_zero_only , rank_zero_warn
6161from pytorch_lightning .utilities .seed import reset_seed
62- from pytorch_lightning .utilities .types import STEP_OUTPUT
62+ from pytorch_lightning .utilities .types import STEP_OUTPUT , TestStep , TrainingStep , ValidationStep
6363
6464if _FAIRSCALE_AVAILABLE :
6565 from fairscale .optim import OSS
@@ -333,13 +333,15 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
333333 def pre_backward (self , closure_loss : Tensor ) -> None :
334334 """Run before precision plugin executes backward."""
335335 if isinstance (self .lightning_module , LightningModule ) and not self .lightning_module .automatic_optimization :
336+ assert isinstance (self .model , DistributedDataParallel )
336337 prepare_for_backward (self .model , closure_loss )
337338
338- def model_to_device (self ):
339+ def model_to_device (self ) -> None :
339340 log .detail (f"{ self .__class__ .__name__ } : moving model to device [{ self .root_device } ]..." )
341+ assert self .model is not None
340342 self .model .to (self .root_device )
341343
342- def reduce (self , tensor , group : Optional [Any ] = None , reduce_op : Union [ReduceOp , str ] = "mean" ) -> Tensor :
344+ def reduce (self , tensor : Tensor , group : Optional [Any ] = None , reduce_op : Optional [ Union [ReduceOp , str ] ] = "mean" ) -> Tensor :
343345 """Reduces a tensor from several distributed processes to one aggregated tensor.
344346
345347 Args:
@@ -355,30 +357,35 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp,
355357 tensor = sync_ddp_if_available (tensor , group , reduce_op = reduce_op )
356358 return tensor
357359
358- def training_step (self , * args , ** kwargs ) -> STEP_OUTPUT :
360+ def training_step (self , * args : Any , ** kwargs : Any ) -> STEP_OUTPUT :
359361 with self .precision_plugin .train_step_context ():
362+ assert isinstance (self .model , TrainingStep )
360363 return self .model (* args , ** kwargs )
361364
362- def validation_step (self , * args , ** kwargs ) -> Optional [STEP_OUTPUT ]:
365+ def validation_step (self , * args : Any , ** kwargs : Any ) -> Optional [STEP_OUTPUT ]:
363366 with self .precision_plugin .val_step_context ():
367+ assert isinstance (self .model , ValidationStep )
364368 if self .lightning_module .trainer .state .fn == TrainerFn .FITTING :
365369 # used when calling `trainer.fit`
366370 return self .model (* args , ** kwargs )
367371 else :
368372 # used when calling `trainer.validate`
369373 return self .model .validation_step (* args , ** kwargs )
370374
371- def test_step (self , * args , ** kwargs ) -> Optional [STEP_OUTPUT ]:
375+ def test_step (self , * args : Any , ** kwargs : Any ) -> Optional [STEP_OUTPUT ]:
372376 with self .precision_plugin .test_step_context ():
377+ assert isinstance (self .model , TestStep )
373378 return self .model .test_step (* args , ** kwargs )
374379
375- def predict_step (self , * args , ** kwargs ) -> STEP_OUTPUT :
380+ def predict_step (self , * args : Any , ** kwargs : Any ) -> STEP_OUTPUT :
376381 with self .precision_plugin .predict_step_context ():
382+ assert isinstance (self .model , TestStep )
377383 return self .model .predict_step (* args , ** kwargs )
378384
379- def post_training_step (self ):
380- if not self .lightning_module .automatic_optimization :
381- self .model .require_backward_grad_sync = True
385+ def post_training_step (self ) -> None :
386+ if isinstance (self .lightning_module , LightningModule ) and not self .lightning_module .automatic_optimization :
387+ assert self .model is not None
388+ self .model .require_backward_grad_sync = True # type: ignore[assignment]
382389
383390 @classmethod
384391 def register_strategies (cls , strategy_registry : Dict ) -> None :
@@ -455,6 +462,7 @@ def reconciliate_processes(self, trace: str) -> None:
455462 if len (os .listdir (sync_dir )) == (self .world_size // self .num_nodes ):
456463 return
457464
465+ assert self ._pids is not None
458466 for pid in self ._pids :
459467 if pid != os .getpid ():
460468 os .kill (pid , signal .SIGKILL )
@@ -469,7 +477,7 @@ def teardown(self) -> None:
469477 if (
470478 _TORCH_GREATER_EQUAL_1_11
471479 and not self .model .static_graph
472- and self .model ._get_ddp_logging_data ().get ("can_set_static_graph" )
480+ and self .model ._get_ddp_logging_data ().get ("can_set_static_graph" ) # type: ignore[operator]
473481 ):
474482 rank_zero_info (
475483 "Your model can run with static graph optimizations. For future training runs, we suggest you"
@@ -486,6 +494,7 @@ def teardown(self) -> None:
486494 and pl_module ._trainer .state .fn == TrainerFn .FITTING
487495 and self ._layer_sync
488496 ):
497+ assert self .model is not None
489498 self .model = self ._layer_sync .revert (self .model )
490499
491500 super ().teardown ()
0 commit comments