Skip to content

Commit f4cc745

Browse files
EliaCeredacarmoccatchatonawaelchli
authored
Add Trainer.validate(…) method to run one validation epoch (#4948)
Co-authored-by: Carlos Mocholi <[email protected]> Co-authored-by: chaton <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent d1db604 commit f4cc745

File tree

20 files changed

+446
-483
lines changed

20 files changed

+446
-483
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2121
- Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
2222

2323

24+
- Added `Trainer.validate()` method to perform one evaluation epoch over the validation set ([#4948](https://github.com/PyTorchLightning/pytorch-lightning/pull/4948))
25+
26+
2427
- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915))
2528

2629

docs/source/common/trainer.rst

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,27 @@ So you can run it like so:
151151

152152
------------
153153

154+
Validation
155+
----------
156+
You can perform an evaluation epoch over the validation set, outside of the training loop,
157+
using :meth:`pytorch_lightning.trainer.trainer.Trainer.validate`. This might be
158+
useful if you want to collect new metrics from a model right at its initialization
159+
or after it has already been trained.
160+
161+
.. code-block:: python
162+
163+
trainer.validate(val_dataloaders=val_dataloaders)
164+
165+
------------
166+
154167
Testing
155168
-------
156169
Once you're done training, feel free to run the test set!
157170
(Only right before publishing your paper or pushing to production)
158171

159172
.. code-block:: python
160173
161-
trainer.test(test_dataloaders=test_dataloader)
174+
trainer.test(test_dataloaders=test_dataloaders)
162175
163176
------------
164177

pytorch_lightning/callbacks/progress.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,9 +355,11 @@ def init_predict_tqdm(self) -> tqdm:
355355

356356
def init_validation_tqdm(self) -> tqdm:
357357
""" Override this to customize the tqdm bar for validation. """
358+
# The main progress bar doesn't exist in `trainer.validate()`
359+
has_main_bar = self.main_progress_bar is not None
358360
bar = tqdm(
359361
desc='Validating',
360-
position=(2 * self.process_position + 1),
362+
position=(2 * self.process_position + has_main_bar),
361363
disable=self.is_disabled,
362364
leave=False,
363365
dynamic_ncols=True,
@@ -426,7 +428,8 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,
426428

427429
def on_validation_end(self, trainer, pl_module):
428430
super().on_validation_end(trainer, pl_module)
429-
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
431+
if self.main_progress_bar is not None:
432+
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
430433
self.val_progress_bar.close()
431434

432435
def on_train_end(self, trainer, pl_module):
@@ -479,8 +482,10 @@ def print(
479482
def _should_update(self, current, total):
480483
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
481484

482-
def _update_bar(self, bar):
485+
def _update_bar(self, bar: Optional[tqdm]) -> None:
483486
""" Updates the bar by the refresh rate without overshooting. """
487+
if bar is None:
488+
return
484489
if bar.total is not None:
485490
delta = min(self.refresh_rate, bar.total - bar.n)
486491
else:

pytorch_lightning/trainer/configuration_validator.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from pytorch_lightning.core.lightning import LightningModule
15+
from pytorch_lightning.trainer.states import TrainerState
1516
from pytorch_lightning.utilities import rank_zero_warn
1617
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1718
from pytorch_lightning.utilities.model_helpers import is_overridden
@@ -22,18 +23,24 @@ class ConfigValidator(object):
2223
def __init__(self, trainer):
2324
self.trainer = trainer
2425

25-
def verify_loop_configurations(self, model: LightningModule):
26+
def verify_loop_configurations(self, model: LightningModule) -> None:
2627
r"""
2728
Checks that the model is configured correctly before the run is started.
2829
2930
Args:
3031
model: The model to check the configuration.
3132
3233
"""
33-
if self.trainer.training:
34+
if self.trainer.state == TrainerState.FITTING:
3435
self.__verify_train_loop_configuration(model)
35-
elif self.trainer.evaluating:
36-
self.__verify_eval_loop_configuration(model)
36+
self.__verify_eval_loop_configuration(model, 'val')
37+
elif self.trainer.state == TrainerState.TUNING:
38+
self.__verify_train_loop_configuration(model)
39+
elif self.trainer.state == TrainerState.VALIDATING:
40+
self.__verify_eval_loop_configuration(model, 'val')
41+
elif self.trainer.state == TrainerState.TESTING:
42+
self.__verify_eval_loop_configuration(model, 'test')
43+
# TODO: add predict
3744

3845
def __verify_train_loop_configuration(self, model):
3946
# -----------------------------------
@@ -81,11 +88,9 @@ def __verify_train_loop_configuration(self, model):
8188
' It ensures optimizer_step or optimizer_zero_grad are called on every batch.'
8289
)
8390

84-
def __verify_eval_loop_configuration(self, model):
85-
stage = "val" if self.trainer.validating else "test"
86-
91+
def __verify_eval_loop_configuration(self, model: LightningModule, stage: str) -> None:
8792
loader_name = f'{stage}_dataloader'
88-
step_name = f'{stage}_step'
93+
step_name = 'validation_step' if stage == 'val' else 'test_step'
8994

9095
has_loader = is_overridden(loader_name, model)
9196
has_step = is_overridden(step_name, model)

pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,10 @@ def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloa
100100
def attach_dataloaders(
101101
self,
102102
model,
103-
train_dataloader=None,
104-
val_dataloaders=None,
105-
test_dataloaders=None,
106-
predict_dataloaders=None,
103+
train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None,
104+
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
105+
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
106+
predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
107107
):
108108
# when dataloader is passed via fit, patch the train_dataloader
109109
# functions to overwrite with these implementations
@@ -119,7 +119,7 @@ def attach_dataloaders(
119119
if predict_dataloaders is not None:
120120
model.predict_dataloader = _PatchDataLoader(predict_dataloaders)
121121

122-
def attach_datamodule(self, model, datamodule: Optional[LightningDataModule]) -> None:
122+
def attach_datamodule(self, model, datamodule: Optional[LightningDataModule] = None) -> None:
123123
# We use datamodule if it's been provided, otherwise we check model for it
124124
datamodule = datamodule or getattr(model, 'datamodule', None)
125125

pytorch_lightning/trainer/states.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class RunningStage(LightningEnum):
5858
"""
5959
TRAINING = 'train'
6060
SANITY_CHECKING = 'sanity_check'
61-
VALIDATING = 'validation'
61+
VALIDATING = 'validate'
6262
TESTING = 'test'
6363
PREDICTING = 'predict'
6464
TUNING = 'tune'

pytorch_lightning/trainer/trainer.py

Lines changed: 90 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,69 @@ def run_sanity_check(self, ref_model):
820820

821821
self._running_stage = stage
822822

823+
def validate(
824+
self,
825+
model: Optional[LightningModule] = None,
826+
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
827+
ckpt_path: Optional[str] = 'best',
828+
verbose: bool = True,
829+
datamodule: Optional[LightningDataModule] = None,
830+
):
831+
r"""
832+
Perform one evaluation epoch over the validation set.
833+
834+
Args:
835+
model: The model to validate.
836+
837+
val_dataloaders: Either a single PyTorch DataLoader or a list of them,
838+
specifying validation samples.
839+
840+
ckpt_path: Either ``best`` or path to the checkpoint you wish to validate.
841+
If ``None``, use the current weights of the model.
842+
When the model is given as argument, this parameter will not apply.
843+
844+
verbose: If True, prints the validation results.
845+
846+
datamodule: A instance of :class:`LightningDataModule`.
847+
848+
Returns:
849+
The dictionary with final validation results returned by validation_epoch_end.
850+
If validation_epoch_end is not defined, the output is a list of the dictionaries
851+
returned by validation_step.
852+
"""
853+
# --------------------
854+
# SETUP HOOK
855+
# --------------------
856+
self.verbose_evaluate = verbose
857+
858+
self.state = TrainerState.VALIDATING
859+
self.validating = True
860+
861+
# If you supply a datamodule you can't supply val_dataloaders
862+
if val_dataloaders and datamodule:
863+
raise MisconfigurationException(
864+
'You cannot pass both `trainer.validate(val_dataloaders=..., datamodule=...)`'
865+
)
866+
867+
model_provided = model is not None
868+
model = model or self.lightning_module
869+
870+
# Attach datamodule to get setup/prepare_data added to model before the call to it below
871+
self.data_connector.attach_datamodule(model, datamodule)
872+
# Attach dataloaders (if given)
873+
self.data_connector.attach_dataloaders(model, val_dataloaders=val_dataloaders)
874+
875+
if not model_provided:
876+
self.validated_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path)
877+
878+
# run validate
879+
results = self.fit(model)
880+
881+
assert self.state.stopped
882+
self.validating = False
883+
884+
return results
885+
823886
def test(
824887
self,
825888
model: Optional[LightningModule] = None,
@@ -833,17 +896,19 @@ def test(
833896
fit to make sure you never run on your test set until you want to.
834897
835898
Args:
836-
ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
837-
If ``None``, use the current weights of the model. Default to ``best``.
838-
datamodule: A instance of :class:`LightningDataModule`.
839-
840899
model: The model to test.
841900
842901
test_dataloaders: Either a single PyTorch DataLoader or a list of them,
843902
specifying test samples.
844903
904+
ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
905+
If ``None``, use the current weights of the model.
906+
When the model is given as argument, this parameter will not apply.
907+
845908
verbose: If True, prints the test results.
846909
910+
datamodule: A instance of :class:`LightningDataModule`.
911+
847912
Returns:
848913
Returns a list of dictionaries, one for each test dataloader containing their respective metrics.
849914
"""
@@ -858,30 +923,33 @@ def test(
858923
# If you supply a datamodule you can't supply test_dataloaders
859924
if test_dataloaders and datamodule:
860925
raise MisconfigurationException(
861-
'You cannot pass test_dataloaders to trainer.test if you supply a datamodule'
926+
'You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`'
862927
)
863928

864929
model_provided = model is not None
865930
model = model or self.lightning_module
866931

867932
# Attach datamodule to get setup/prepare_data added to model before the call to it below
868933
self.data_connector.attach_datamodule(model, datamodule)
869-
results = (
870-
self.__evaluate_given_model(model, dataloaders=test_dataloaders) if model_provided else
871-
self.__evaluate_using_weights(model, ckpt_path=ckpt_path, dataloaders=test_dataloaders)
872-
)
934+
# Attach dataloaders (if given)
935+
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)
936+
937+
if not model_provided:
938+
self.tested_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path)
939+
940+
# run test
941+
results = self.fit(model)
873942

874943
assert self.state.stopped
875944
self.testing = False
876945

877946
return results
878947

879-
def __evaluate_using_weights(
948+
def __load_ckpt_weights(
880949
self,
881950
model,
882951
ckpt_path: Optional[str] = None,
883-
dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None
884-
):
952+
) -> Optional[str]:
885953
# if user requests the best checkpoint but we don't have it, error
886954
if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path:
887955
raise MisconfigurationException(
@@ -894,42 +962,18 @@ def __evaluate_using_weights(
894962
if ckpt_path == 'best':
895963
ckpt_path = self.checkpoint_callback.best_model_path
896964

897-
if len(ckpt_path) == 0:
898-
rank_zero_warn(
899-
f'`.test()` found no path for the best weights, {ckpt_path}. Please'
900-
' specify a path for a checkpoint `.test(ckpt_path=PATH)`'
965+
if not ckpt_path:
966+
fn = self.state.value
967+
raise MisconfigurationException(
968+
f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please'
969+
' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`'
901970
)
902-
return {}
903971

904972
self.training_type_plugin.barrier()
905973

906974
ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
907975
model.load_state_dict(ckpt['state_dict'])
908-
909-
# attach dataloaders
910-
if dataloaders is not None:
911-
self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders)
912-
913-
if self.validating:
914-
self.validated_ckpt_path = ckpt_path
915-
else:
916-
self.tested_ckpt_path = ckpt_path
917-
918-
# run test
919-
results = self.fit(model)
920-
921-
return results
922-
923-
def __evaluate_given_model(self, model, dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None):
924-
# attach data
925-
if dataloaders is not None:
926-
self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders)
927-
928-
# run test
929-
# sets up testing so we short circuit to eval
930-
results = self.fit(model)
931-
932-
return results
976+
return ckpt_path
933977

934978
def predict(
935979
self,
@@ -970,15 +1014,11 @@ def predict(
9701014
'You cannot pass dataloaders to trainer.predict if you supply a datamodule.'
9711015
)
9721016

973-
if datamodule is not None:
974-
# Attach datamodule to get setup/prepare_data added to model before the call to it below
975-
self.data_connector.attach_datamodule(model, datamodule)
976-
977-
# attach data
978-
if dataloaders is not None:
979-
self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders)
1017+
# Attach datamodule to get setup/prepare_data added to model before the call to it below
1018+
self.data_connector.attach_datamodule(model, datamodule)
1019+
# Attach dataloaders (if given)
1020+
self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders)
9801021

981-
self.model = model
9821022
results = self.fit(model)
9831023

9841024
assert self.state.stopped

0 commit comments

Comments
 (0)