@@ -820,6 +820,69 @@ def run_sanity_check(self, ref_model):
820820
821821 self ._running_stage = stage
822822
823+ def validate (
824+ self ,
825+ model : Optional [LightningModule ] = None ,
826+ val_dataloaders : Optional [Union [DataLoader , List [DataLoader ]]] = None ,
827+ ckpt_path : Optional [str ] = 'best' ,
828+ verbose : bool = True ,
829+ datamodule : Optional [LightningDataModule ] = None ,
830+ ):
831+ r"""
832+ Perform one evaluation epoch over the validation set.
833+
834+ Args:
835+ model: The model to validate.
836+
837+ val_dataloaders: Either a single PyTorch DataLoader or a list of them,
838+ specifying validation samples.
839+
840+ ckpt_path: Either ``best`` or path to the checkpoint you wish to validate.
841+ If ``None``, use the current weights of the model.
842+ When the model is given as argument, this parameter will not apply.
843+
844+ verbose: If True, prints the validation results.
845+
846+ datamodule: A instance of :class:`LightningDataModule`.
847+
848+ Returns:
849+ The dictionary with final validation results returned by validation_epoch_end.
850+ If validation_epoch_end is not defined, the output is a list of the dictionaries
851+ returned by validation_step.
852+ """
853+ # --------------------
854+ # SETUP HOOK
855+ # --------------------
856+ self .verbose_evaluate = verbose
857+
858+ self .state = TrainerState .VALIDATING
859+ self .validating = True
860+
861+ # If you supply a datamodule you can't supply val_dataloaders
862+ if val_dataloaders and datamodule :
863+ raise MisconfigurationException (
864+ 'You cannot pass both `trainer.validate(val_dataloaders=..., datamodule=...)`'
865+ )
866+
867+ model_provided = model is not None
868+ model = model or self .lightning_module
869+
870+ # Attach datamodule to get setup/prepare_data added to model before the call to it below
871+ self .data_connector .attach_datamodule (model , datamodule )
872+ # Attach dataloaders (if given)
873+ self .data_connector .attach_dataloaders (model , val_dataloaders = val_dataloaders )
874+
875+ if not model_provided :
876+ self .validated_ckpt_path = self .__load_ckpt_weights (model , ckpt_path = ckpt_path )
877+
878+ # run validate
879+ results = self .fit (model )
880+
881+ assert self .state .stopped
882+ self .validating = False
883+
884+ return results
885+
823886 def test (
824887 self ,
825888 model : Optional [LightningModule ] = None ,
@@ -833,17 +896,19 @@ def test(
833896 fit to make sure you never run on your test set until you want to.
834897
835898 Args:
836- ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
837- If ``None``, use the current weights of the model. Default to ``best``.
838- datamodule: A instance of :class:`LightningDataModule`.
839-
840899 model: The model to test.
841900
842901 test_dataloaders: Either a single PyTorch DataLoader or a list of them,
843902 specifying test samples.
844903
904+ ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
905+ If ``None``, use the current weights of the model.
906+ When the model is given as argument, this parameter will not apply.
907+
845908 verbose: If True, prints the test results.
846909
910+ datamodule: A instance of :class:`LightningDataModule`.
911+
847912 Returns:
848913 Returns a list of dictionaries, one for each test dataloader containing their respective metrics.
849914 """
@@ -858,30 +923,33 @@ def test(
858923 # If you supply a datamodule you can't supply test_dataloaders
859924 if test_dataloaders and datamodule :
860925 raise MisconfigurationException (
861- 'You cannot pass test_dataloaders to trainer.test if you supply a datamodule'
926+ 'You cannot pass both ` trainer.test(test_dataloaders=..., datamodule=...)` '
862927 )
863928
864929 model_provided = model is not None
865930 model = model or self .lightning_module
866931
867932 # Attach datamodule to get setup/prepare_data added to model before the call to it below
868933 self .data_connector .attach_datamodule (model , datamodule )
869- results = (
870- self .__evaluate_given_model (model , dataloaders = test_dataloaders ) if model_provided else
871- self .__evaluate_using_weights (model , ckpt_path = ckpt_path , dataloaders = test_dataloaders )
872- )
934+ # Attach dataloaders (if given)
935+ self .data_connector .attach_dataloaders (model , test_dataloaders = test_dataloaders )
936+
937+ if not model_provided :
938+ self .tested_ckpt_path = self .__load_ckpt_weights (model , ckpt_path = ckpt_path )
939+
940+ # run test
941+ results = self .fit (model )
873942
874943 assert self .state .stopped
875944 self .testing = False
876945
877946 return results
878947
879- def __evaluate_using_weights (
948+ def __load_ckpt_weights (
880949 self ,
881950 model ,
882951 ckpt_path : Optional [str ] = None ,
883- dataloaders : Optional [Union [DataLoader , List [DataLoader ]]] = None
884- ):
952+ ) -> Optional [str ]:
885953 # if user requests the best checkpoint but we don't have it, error
886954 if ckpt_path == 'best' and not self .checkpoint_callback .best_model_path :
887955 raise MisconfigurationException (
@@ -894,42 +962,18 @@ def __evaluate_using_weights(
894962 if ckpt_path == 'best' :
895963 ckpt_path = self .checkpoint_callback .best_model_path
896964
897- if len (ckpt_path ) == 0 :
898- rank_zero_warn (
899- f'`.test()` found no path for the best weights, { ckpt_path } . Please'
900- ' specify a path for a checkpoint `.test(ckpt_path=PATH)`'
965+ if not ckpt_path :
966+ fn = self .state .value
967+ raise MisconfigurationException (
968+ f'`.{ fn } ()` found no path for the best weights: "{ ckpt_path } ". Please'
969+ ' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`'
901970 )
902- return {}
903971
904972 self .training_type_plugin .barrier ()
905973
906974 ckpt = pl_load (ckpt_path , map_location = lambda storage , loc : storage )
907975 model .load_state_dict (ckpt ['state_dict' ])
908-
909- # attach dataloaders
910- if dataloaders is not None :
911- self .data_connector .attach_dataloaders (model , test_dataloaders = dataloaders )
912-
913- if self .validating :
914- self .validated_ckpt_path = ckpt_path
915- else :
916- self .tested_ckpt_path = ckpt_path
917-
918- # run test
919- results = self .fit (model )
920-
921- return results
922-
923- def __evaluate_given_model (self , model , dataloaders : Optional [Union [DataLoader , List [DataLoader ]]] = None ):
924- # attach data
925- if dataloaders is not None :
926- self .data_connector .attach_dataloaders (model , test_dataloaders = dataloaders )
927-
928- # run test
929- # sets up testing so we short circuit to eval
930- results = self .fit (model )
931-
932- return results
976+ return ckpt_path
933977
934978 def predict (
935979 self ,
@@ -970,15 +1014,11 @@ def predict(
9701014 'You cannot pass dataloaders to trainer.predict if you supply a datamodule.'
9711015 )
9721016
973- if datamodule is not None :
974- # Attach datamodule to get setup/prepare_data added to model before the call to it below
975- self .data_connector .attach_datamodule (model , datamodule )
976-
977- # attach data
978- if dataloaders is not None :
979- self .data_connector .attach_dataloaders (model , predict_dataloaders = dataloaders )
1017+ # Attach datamodule to get setup/prepare_data added to model before the call to it below
1018+ self .data_connector .attach_datamodule (model , datamodule )
1019+ # Attach dataloaders (if given)
1020+ self .data_connector .attach_dataloaders (model , predict_dataloaders = dataloaders )
9801021
981- self .model = model
9821022 results = self .fit (model )
9831023
9841024 assert self .state .stopped
0 commit comments