5959from pytorch_lightning .utilities import DeviceType , rank_zero_warn
6060from pytorch_lightning .utilities .cloud_io import load as pl_load
6161from pytorch_lightning .utilities .debugging import InternalDebugger
62- from pytorch_lightning .utilities .enums import LightningEnum
6362from pytorch_lightning .utilities .exceptions import MisconfigurationException
6463from pytorch_lightning .utilities .memory import recursive_detach
6564from pytorch_lightning .utilities .model_helpers import is_overridden
@@ -450,7 +449,7 @@ def fit(
450449 # bookkeeping
451450 # we reuse fit in .test() and .predict(). When already set, it shouldn't be modified.
452451 if self ._running_stage is None :
453- self ._set_running_stage ( RunningStage .TRAINING , model )
452+ self ._running_stage = RunningStage .TRAINING
454453
455454 # set local properties on the model
456455 self .model_connector .copy_trainer_model_properties (model )
@@ -531,7 +530,7 @@ def fit(
531530 if self ._state != TrainerState .INTERRUPTED :
532531 self ._state = TrainerState .FINISHED
533532
534- self ._set_running_stage ( None , model )
533+ self ._running_stage = None
535534
536535 return self .accelerator .results or 1
537536
@@ -564,14 +563,6 @@ def train_or_test_or_predict(self):
564563
565564 return results
566565
567- def _set_running_stage (self , stage : LightningEnum , model_ref : LightningModule ):
568- """
569- This function is used to set the running_state on both
570- the trainer and the model
571- """
572- model_ref .running_stage = stage
573- self ._running_stage = stage
574-
575566 def _pre_training_routine (self ):
576567 # wait for all to join if on distributed
577568 self .accelerator .barrier ("setup_training" )
@@ -614,7 +605,7 @@ def run_train(self):
614605 self .run_sanity_check (self .lightning_module )
615606
616607 # set stage for logging
617- self ._set_running_stage ( RunningStage .TRAINING , self . lightning_module )
608+ self ._running_stage = RunningStage .TRAINING
618609
619610 self .checkpoint_connector .has_trained = False
620611
@@ -678,9 +669,7 @@ def run_train(self):
678669 def run_evaluation (self , max_batches = None , on_epoch = False ):
679670
680671 # used to know if we are logging for val, test + reset cached results
681- self ._set_running_stage (
682- RunningStage .TESTING if self .testing else RunningStage .EVALUATING , self .lightning_module
683- )
672+ self ._running_stage = RunningStage .TESTING if self .testing else RunningStage .EVALUATING
684673 self .logger_connector .reset ()
685674
686675 # bookkeeping
@@ -907,7 +896,7 @@ def test(
907896 # --------------------
908897 self .verbose_test = verbose
909898
910- self ._set_running_stage ( RunningStage . TESTING , model or self . lightning_module )
899+ self ._running_stage = RunningStage . TESTING
911900
912901 # If you supply a datamodule you can't supply train_dataloader or val_dataloaders
913902 if test_dataloaders and datamodule :
@@ -924,7 +913,7 @@ def test(
924913 results = self .__test_using_best_weights (ckpt_path , test_dataloaders )
925914
926915 self .teardown ('test' )
927- self ._set_running_stage ( None , model or self . lightning_module )
916+ self ._running_stage = None
928917 return results
929918
930919 def __test_using_best_weights (self , ckpt_path , test_dataloaders ):
@@ -1016,7 +1005,7 @@ def predict(
10161005
10171006 model = model or self .lightning_module
10181007
1019- self ._set_running_stage ( RunningStage .PREDICTING , model )
1008+ self ._running_stage = RunningStage .PREDICTING
10201009
10211010 if dataloaders and datamodule :
10221011 raise MisconfigurationException (
@@ -1033,7 +1022,7 @@ def predict(
10331022
10341023 self .model = model
10351024 results = self .fit (model )
1036- self ._set_running_stage ( None , model )
1025+ self ._running_stage = None
10371026
10381027 return results
10391028
0 commit comments