5959from pytorch_lightning .utilities .optimizer import optimizers_to_device
6060from pytorch_lightning .utilities .rank_zero import rank_zero_info , rank_zero_only , rank_zero_warn
6161from pytorch_lightning .utilities .seed import reset_seed
62- from pytorch_lightning .utilities .types import PredictStep , STEP_OUTPUT , TestStep , TrainingStep , ValidationStep
62+ from pytorch_lightning .utilities .types import PredictStep , STEP_OUTPUT , TestStep , ValidationStep
6363
6464if _FAIRSCALE_AVAILABLE :
6565 from fairscale .optim import OSS
@@ -333,7 +333,7 @@ def pre_backward(self, closure_loss: Tensor) -> None:
333333 """Run before precision plugin executes backward."""
334334 assert self .lightning_module is not None
335335 if not self .lightning_module .automatic_optimization :
336- assert isinstance ( self .model , DistributedDataParallel )
336+ assert self .model is not None
337337 prepare_for_backward (self .model , closure_loss )
338338
339339 def model_to_device (self ) -> None :
@@ -360,18 +360,20 @@ def reduce(
360360 return tensor
361361
362362 def training_step (self , * args : Any , ** kwargs : Any ) -> STEP_OUTPUT :
363+ assert self .model is not None
363364 with self .precision_plugin .train_step_context ():
364- assert isinstance (self .model , TrainingStep )
365365 return self .model (* args , ** kwargs )
366366
367367 def validation_step (self , * args : Any , ** kwargs : Any ) -> Optional [STEP_OUTPUT ]:
368368 with self .precision_plugin .val_step_context ():
369- assert isinstance (self .model , ValidationStep )
369+ assert self .lightning_module is not None
370+ assert self .model is not None
370371 if self .lightning_module .trainer .state .fn == TrainerFn .FITTING :
371372 # used when calling `trainer.fit`
372373 return self .model (* args , ** kwargs )
373374 else :
374375 # used when calling `trainer.validate`
376+ assert isinstance (self .model , ValidationStep )
375377 return self .model .validation_step (* args , ** kwargs )
376378
377379 def test_step (self , * args : Any , ** kwargs : Any ) -> Optional [STEP_OUTPUT ]:
0 commit comments