@@ -441,10 +441,6 @@ def fit(
441441 # hook
442442 self .data_connector .prepare_data (model )
443443
444- # bookkeeping
445- # we reuse fit in .test() but change its behavior using this flag
446- self .testing = os .environ .get ('PL_TESTING_MODE' , self .testing )
447-
448444 # ----------------------------
449445 # SET UP TRAINING
450446 # ----------------------------
@@ -720,33 +716,31 @@ def test(
720716 datamodule : Optional [LightningDataModule ] = None ,
721717 ):
722718 r"""
723-
724- Separates from fit to make sure you never run on your test set until you want to.
719+ Perform one evaluation epoch over the test set. It's separated from
720+ fit to make sure you never run on your test set until you want to.
725721
726722 Args:
727723 ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
728- If ``None``, use the weights from the last epoch to test. Default to ``best``.
729-
724+ If ``None``, use the current weights of the model. Default to ``best``.
730725 datamodule: A instance of :class:`LightningDataModule`.
731-
732- model: The model to test.
733-
734- test_dataloaders: Either a single
735- Pytorch Dataloader or a list of them, specifying validation samples.
736-
737- verbose: If True, prints the test results
726+ model: The model to evaluate.
727+ test_dataloaders: Either a single PyTorch DataLoader or a list of them,
728+ specifying test samples.
729+ verbose: If True, prints the test results.
738730
739731 Returns:
740- The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries
732+ The dictionary with final test results returned by test_epoch_end.
733+ If test_epoch_end is not defined, the output is a list of the dictionaries
734+ returned by test_step.
741735 """
742736 # --------------------
743737 # SETUP HOOK
744738 # --------------------
745- self .verbose_test = verbose
739+ self .verbose_evaluate = verbose
746740
747741 self .logger_connector .set_stage ("test" )
748742
749- # If you supply a datamodule you can't supply train_dataloader or val_dataloaders
743+ # If you supply a datamodule you can't supply test_dataloaders
750744 if test_dataloaders and datamodule :
751745 raise MisconfigurationException (
752746 'You cannot pass test_dataloaders to trainer.test if you supply a datamodule'
@@ -756,15 +750,15 @@ def test(
756750 self .data_connector .attach_datamodule (model or self .get_model (), datamodule , 'test' )
757751
758752 if model is not None :
759- results = self .__test_given_model (model , test_dataloaders )
753+ results = self .__evaluate_given_model (model , test_dataloaders , 'test' )
760754 else :
761- results = self .__test_using_best_weights (ckpt_path , test_dataloaders )
755+ results = self .__evaluate_using_best_weights (ckpt_path , test_dataloaders , 'test' )
762756
763757 self .teardown ('test' )
764758
765759 return results
766760
767- def __test_using_best_weights (self , ckpt_path , test_dataloaders ):
761+ def __evaluate_using_best_weights (self , ckpt_path , test_dataloaders , stage : str ):
768762 model = self .get_model ()
769763
770764 # if user requests the best checkpoint but we don't have it, error
@@ -796,40 +790,56 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders):
796790 self .data_connector .attach_dataloaders (model , test_dataloaders = test_dataloaders )
797791
798792 # run tests
799- self .tested_ckpt_path = ckpt_path
800- self .testing = True
801- os .environ ['PL_TESTING_MODE' ] = '1'
793+ self .evaluating = stage
794+ self .evaluated_ckpt_path = ckpt_path
802795 self .model = model
803796 results = self .fit (model )
804- self .testing = False
805- del os .environ ['PL_TESTING_MODE' ]
797+ self .evaluating = None
806798
807799 # teardown
808800 if self .is_function_implemented ('teardown' ):
809801 model_ref = self .get_model ()
810- model_ref .teardown ('test' )
802+ model_ref .teardown (stage )
811803
812804 return results
813805
814- def __test_given_model (self , model , test_dataloaders ):
806+ def __evaluate_given_model (self , model , test_dataloaders , stage : str ):
815807
816808 # attach data
817809 if test_dataloaders is not None :
818810 self .data_connector .attach_dataloaders (model , test_dataloaders = test_dataloaders )
819811
820812 # run test
821813 # sets up testing so we short circuit to eval
822- self .testing = True
814+ self .evaluating = stage
823815 self .model = model
824816 results = self .fit (model )
825- self .testing = False
817+ self .evaluating = None
826818
827819 # teardown
828820 if self .is_function_implemented ('teardown' ):
829- model .teardown ('test' )
821+ model .teardown (stage )
830822
831823 return results
832824
825+ @property
826+ def testing (self ):
827+ warnings .warn (
828+ 'Trainer.testing has been deprecated in v1.1 and will be removed '
829+ 'in v1.3, use Trainer.evaluating instead.' ,
830+ DeprecationWarning , stacklevel = 2
831+ )
832+ return bool (self .evaluating )
833+
834+ @property
835+ def tested_ckpt_path (self ):
836+ warnings .warn (
837+ 'Trainer.tested_ckpt_path has been renamed Trainer.evaluated_ckpt_path '
838+ 'in v1.1 and will be removed in v1.3.' ,
839+ DeprecationWarning , stacklevel = 2
840+ )
841+ return self .evaluated_ckpt_path
842+
833843 def tune (
834844 self ,
835845 model : LightningModule ,
@@ -856,11 +866,17 @@ def tune(
856866
857867 def call_setup_hook (self , model ):
858868 # call setup after the ddp process has connected
859- stage_name = 'test' if self .testing else 'fit'
869+ stage_name = self .evaluating or 'fit'
870+
860871 if self .datamodule is not None :
861- called = self .datamodule .has_setup_test if self .testing else self .datamodule .has_setup_fit
872+ called = {
873+ None : self .datamodule .has_setup_fit ,
874+ 'test' : self .datamodule .has_setup_test ,
875+ }[self .evaluating ]
876+
862877 if not called :
863878 self .datamodule .setup (stage_name )
879+
864880 self .setup (model , stage_name )
865881 model .setup (stage_name )
866882
0 commit comments