From 62bd29ebbf908c90d0e14c927aeb8294de0fd9d0 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Tue, 17 Nov 2020 12:18:46 +0100 Subject: [PATCH 01/35] =?UTF-8?q?Add=20Trainer.validate(=E2=80=A6)=20to=20?= =?UTF-8?q?run=20one=20validation=20epoch?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `Trainer.validate` follows the same semantics as `Trainer.test` and shares part of the implementation --- pytorch_lightning/accelerators/accelerator.py | 1 + .../trainer/configuration_validator.py | 6 +- pytorch_lightning/trainer/evaluation_loop.py | 1 + pytorch_lightning/trainer/trainer.py | 106 +++++++++++++++--- 4 files changed, 97 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index e765c2ab626df..396d9936bab50 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -59,6 +59,7 @@ def barrier(self, name: Optional[str] = None): def broadcast(self, obj, src=0): return obj + # TODO: rename train_or_evaluate def train_or_test(self): if self.trainer.testing: results = self.trainer.run_test() diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 01c0119e857ec..23967dc1bc2a9 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -31,12 +31,12 @@ def verify_loop_configurations(self, model: LightningModule): model: The model to check the configuration. """ - if not self.trainer.testing: + if not self.trainer.evaluating: self.__verify_train_loop_configuration(model) self.__verify_eval_loop_configuration(model, 'validation') else: - # check test loop configuration - self.__verify_eval_loop_configuration(model, 'test') + # check evaluation loop configurations + self.__verify_eval_loop_configuration(model, self.trainer.evaluating) def __verify_train_loop_configuration(self, model): # ----------------------------------- diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index e3a0f1108f1f9..a0605fd860913 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -39,6 +39,7 @@ def on_trainer_init(self): self.trainer.val_dataloaders = None self.trainer.running_sanity_check = False self.trainer.testing = False + self.trainer.evaluating = False # when .test() is called, it sets this self.trainer.tested_ckpt_path = None diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 46e4abbe584ae..db44dd6719ecb 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -659,10 +659,12 @@ def track_output_for_epoch_end(self, outputs, output): outputs.append(output) return outputs + # TODO: rename run_test_or_validate? def run_test(self): # only load test dataloader for testing # self.reset_test_dataloader(ref_model) - eval_loop_results, _ = self.run_evaluation(test_mode=True) + test_mode = True if self.evaluating == 'test' else False + eval_loop_results, _ = self.run_evaluation(test_mode=test_mode) if len(eval_loop_results) == 0: return 1 @@ -710,6 +712,60 @@ def run_sanity_check(self, ref_model): self.on_sanity_check_end() self.running_sanity_check = False + def validate( + self, + model: Optional[LightningModule] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + ckpt_path: Optional[str] = 'best', + verbose: bool = True, + datamodule: Optional[LightningDataModule] = None, + ): + # TODO: docstring + r""" + + Separates from fit to make sure you never run on your test set until you want to. + + Args: + ckpt_path: Either ``best`` or path to the checkpoint you wish to test. + If ``None``, use the weights from the last epoch to test. Default to ``best``. + + datamodule: A instance of :class:`LightningDataModule`. + + model: The model to test. + + test_dataloaders: Either a single + Pytorch Dataloader or a list of them, specifying validation samples. + + verbose: If True, prints the test results + + Returns: + The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries + """ + # -------------------- + # SETUP HOOK + # -------------------- + self.verbose_test = verbose # TODO: rename / else? + + self.logger_connector.set_stage("validation") + + # If you supply a datamodule you can't supply val_dataloaders + if val_dataloaders and datamodule: + raise MisconfigurationException( + 'You cannot pass val_dataloaders to trainer.validate if you supply a datamodule' + ) + + # Attach datamodule to get setup/prepare_data added to model before the call to it below + self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'validation') + + if model is not None: + results = self.__evaluate_given_model(model, val_dataloaders, 'validation') + else: + results = self.__evaluate_using_best_weights(ckpt_path, val_dataloaders, 'validation') + + self.teardown('validation') + + return results + def test( self, model: Optional[LightningModule] = None, @@ -745,7 +801,7 @@ def test( self.logger_connector.set_stage("test") - # If you supply a datamodule you can't supply train_dataloader or val_dataloaders + # If you supply a datamodule you can't supply test_dataloaders if test_dataloaders and datamodule: raise MisconfigurationException( 'You cannot pass test_dataloaders to trainer.test if you supply a datamodule' @@ -755,15 +811,15 @@ def test( self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'test') if model is not None: - results = self.__test_given_model(model, test_dataloaders) + results = self.__evaluate_given_model(model, test_dataloaders, 'test') else: - results = self.__test_using_best_weights(ckpt_path, test_dataloaders) + results = self.__evaluate_using_best_weights(ckpt_path, test_dataloaders, 'test') self.teardown('test') return results - def __test_using_best_weights(self, ckpt_path, test_dataloaders): + def __evaluate_using_best_weights(self, ckpt_path, dataloaders, stage: str): model = self.get_model() # if user requests the best checkpoint but we don't have it, error @@ -791,41 +847,63 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): model.load_state_dict(ckpt['state_dict']) # attach dataloaders - if test_dataloaders is not None: - self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) + if dataloaders is not None: + loaders_arg = { + 'validation': 'val_dataloaders', + 'test': 'test_dataloaders' + }[stage] + + kwargs = { + loaders_arg: dataloaders + } + + self.data_connector.attach_dataloaders(model, **kwargs) # run tests + self.evaluating = stage self.tested_ckpt_path = ckpt_path self.testing = True os.environ['PL_TESTING_MODE'] = '1' self.model = model results = self.fit(model) - self.testing = False del os.environ['PL_TESTING_MODE'] + self.testing = False + self.evaluating = False # teardown if self.is_function_implemented('teardown'): model_ref = self.get_model() - model_ref.teardown('test') + model_ref.teardown(stage) return results - def __test_given_model(self, model, test_dataloaders): + def __evaluate_given_model(self, model, dataloaders, stage: str): # attach data - if test_dataloaders is not None: - self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) + if dataloaders is not None: + loaders_arg = { + 'validation': 'val_dataloaders', + 'test': 'test_dataloaders' + }[stage] + + kwargs = { + loaders_arg: dataloaders + } + + self.data_connector.attach_dataloaders(model, **kwargs) # run test # sets up testing so we short circuit to eval - self.testing = True + self.evaluating = stage + self.testing = True # TODO: remove, keep only evaluating self.model = model results = self.fit(model) self.testing = False + self.evaluating = False # teardown if self.is_function_implemented('teardown'): - model.teardown('test') + model.teardown(stage) return results From 055e1ba7b5ee60a6ccd4cf3c5e361127a216a2af Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Tue, 17 Nov 2020 12:21:33 +0100 Subject: [PATCH 02/35] Support val_progress_bar without main_progress_bar in ProgressBar --- pytorch_lightning/callbacks/progress.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index de0c91f6983bd..2ccea5276ccac 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -282,9 +282,13 @@ def init_train_tqdm(self) -> tqdm: def init_validation_tqdm(self) -> tqdm: """ Override this to customize the tqdm bar for validation. """ + + # The main progress bar doesn't exist in trainer.validate(...) + has_main_bar = 1 if self.main_progress_bar else 0 + bar = tqdm( desc='Validating', - position=(2 * self.process_position + 1), + position=(2 * self.process_position + has_main_bar), disable=self.is_disabled, leave=False, dynamic_ncols=True, @@ -348,11 +352,18 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if self.is_enabled and self.val_batch_idx % self.refresh_rate == 0: self.val_progress_bar.update(self.refresh_rate) - self.main_progress_bar.update(self.refresh_rate) + + # The main progress bar doesn't exist in trainer.validate(...) + if self.main_progress_bar: + self.main_progress_bar.update(self.refresh_rate) def on_validation_end(self, trainer, pl_module): super().on_validation_end(trainer, pl_module) - self.main_progress_bar.set_postfix(trainer.progress_bar_dict) + + # The main progress bar doesn't exist in trainer.validate(...) + if self.main_progress_bar: + self.main_progress_bar.set_postfix(trainer.progress_bar_dict) + self.val_progress_bar.close() def on_train_end(self, trainer, pl_module): From 156b6698bbf2d50d52cf7e5a1f3b03200478bc9d Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Tue, 17 Nov 2020 12:39:19 +0100 Subject: [PATCH 03/35] Fix PEP 8 issue --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index db44dd6719ecb..5944498d63ed5 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -744,7 +744,7 @@ def validate( # -------------------- # SETUP HOOK # -------------------- - self.verbose_test = verbose # TODO: rename / else? + self.verbose_test = verbose # TODO: rename / else? self.logger_connector.set_stage("validation") From 1429548e083a2a61a931df91400263ebe1b746fd Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Tue, 17 Nov 2020 13:11:17 +0100 Subject: [PATCH 04/35] Use `main_progress_bar is not None` to test if the bar is present in ProgressBar MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit It seems that tqdm doesn’t support `__bool__` on its instances, so it was raising an exception. --- pytorch_lightning/callbacks/progress.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 2ccea5276ccac..9ee973c67e4ea 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -284,7 +284,7 @@ def init_validation_tqdm(self) -> tqdm: """ Override this to customize the tqdm bar for validation. """ # The main progress bar doesn't exist in trainer.validate(...) - has_main_bar = 1 if self.main_progress_bar else 0 + has_main_bar = 1 if self.main_progress_bar is not None else 0 bar = tqdm( desc='Validating', @@ -354,14 +354,14 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, self.val_progress_bar.update(self.refresh_rate) # The main progress bar doesn't exist in trainer.validate(...) - if self.main_progress_bar: + if self.main_progress_bar is not None: self.main_progress_bar.update(self.refresh_rate) def on_validation_end(self, trainer, pl_module): super().on_validation_end(trainer, pl_module) # The main progress bar doesn't exist in trainer.validate(...) - if self.main_progress_bar: + if self.main_progress_bar is not None: self.main_progress_bar.set_postfix(trainer.progress_bar_dict) self.val_progress_bar.close() From 50427e76e4aa37b2f6064fa2b18ada9519007c90 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Tue, 17 Nov 2020 16:58:23 +0100 Subject: [PATCH 05/35] Simplify selection of dataloaders arg to be set --- pytorch_lightning/trainer/trainer.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5944498d63ed5..70334e0fdebcb 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -848,15 +848,7 @@ def __evaluate_using_best_weights(self, ckpt_path, dataloaders, stage: str): # attach dataloaders if dataloaders is not None: - loaders_arg = { - 'validation': 'val_dataloaders', - 'test': 'test_dataloaders' - }[stage] - - kwargs = { - loaders_arg: dataloaders - } - + kwargs = {'test_dataloaders' if stage == 'test' else 'val_dataloaders': dataloaders} self.data_connector.attach_dataloaders(model, **kwargs) # run tests @@ -881,15 +873,7 @@ def __evaluate_given_model(self, model, dataloaders, stage: str): # attach data if dataloaders is not None: - loaders_arg = { - 'validation': 'val_dataloaders', - 'test': 'test_dataloaders' - }[stage] - - kwargs = { - loaders_arg: dataloaders - } - + kwargs = {'test_dataloaders' if stage == 'test' else 'val_dataloaders': dataloaders} self.data_connector.attach_dataloaders(model, **kwargs) # run test From d1988e03ebefb65c049f6fc527627673c3955d6d Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Tue, 17 Nov 2020 18:54:51 +0100 Subject: [PATCH 06/35] =?UTF-8?q?Call=20setup(=E2=80=A6)=20with=20stage=20?= =?UTF-8?q?=E2=80=98validation=E2=80=99=20when=20running=20Trainer.validat?= =?UTF-8?q?e(=E2=80=A6)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytorch_lightning/core/datamodule.py | 10 ++++++++++ pytorch_lightning/trainer/trainer.py | 11 +++++++++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index fe81d641c86d6..cf224eb7ea1f8 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -155,6 +155,7 @@ def __init__( # Private attrs to keep track of whether or not data hooks have been called yet self._has_prepared_data = False self._has_setup_fit = False + self._has_setup_validation = False self._has_setup_test = False @property @@ -230,6 +231,15 @@ def has_setup_fit(self): """ return self._has_setup_fit + @property + def has_setup_validation(self): + """Return bool letting you know if datamodule.setup('validation') has been called or not. + + Returns: + bool: True if datamodule.setup('validation') has been called. False by default. + """ + return self._has_setup_validation + @property def has_setup_test(self): """Return bool letting you know if datamodule.setup('test') has been called or not. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 70334e0fdebcb..4bbf01de86381 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -917,11 +917,18 @@ def tune( def call_setup_hook(self, model): # call setup after the ddp process has connected - stage_name = 'test' if self.testing else 'fit' + stage_name = self.evaluating or 'fit' + if self.datamodule is not None: - called = self.datamodule.has_setup_test if self.testing else self.datamodule.has_setup_fit + called = { + False: self.datamodule.has_setup_fit, + 'validation': self.datamodule.has_setup_validation, + 'test': self.datamodule.has_setup_test, + }[self.evaluating] + if not called: self.datamodule.setup(stage_name) + self.setup(model, stage_name) model.setup(stage_name) From ae03c6b6bb955e4da343398bae7a2a3859ea88be Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Tue, 17 Nov 2020 19:11:05 +0100 Subject: [PATCH 07/35] Check self.trainer.evaluating instead of self.trainer.testing in Accelerator, in view of its future deprecation --- pytorch_lightning/accelerators/accelerator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 396d9936bab50..4a22d275fd89f 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -61,7 +61,7 @@ def broadcast(self, obj, src=0): # TODO: rename train_or_evaluate def train_or_test(self): - if self.trainer.testing: + if self.trainer.evaluating: results = self.trainer.run_test() else: results = self.trainer.train() @@ -161,7 +161,7 @@ def early_stopping_should_stop(self, pl_module): return self.trainer.should_stop def setup_optimizers(self, model): - if self.trainer.testing is True: + if self.trainer.evaluating: return optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model) From 5493a5b29b7efb21a4ef454e2667cbff4cf24a09 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Tue, 17 Nov 2020 21:32:32 +0100 Subject: [PATCH 08/35] Set Trainer.evaluating to None by default Co-authored-by: Rohit Gupta --- pytorch_lightning/trainer/evaluation_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index a0605fd860913..3dfa7dc87d514 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -39,7 +39,7 @@ def on_trainer_init(self): self.trainer.val_dataloaders = None self.trainer.running_sanity_check = False self.trainer.testing = False - self.trainer.evaluating = False + self.trainer.evaluating = None # when .test() is called, it sets this self.trainer.tested_ckpt_path = None From 860fef5fdd3883510cc6fbf4ecdb7bf89df83761 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Wed, 18 Nov 2020 09:16:32 +0100 Subject: [PATCH 09/35] Replace the remaining instances of self.evaluating = False with None --- pytorch_lightning/trainer/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4bbf01de86381..7935fea55a96e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -860,7 +860,7 @@ def __evaluate_using_best_weights(self, ckpt_path, dataloaders, stage: str): results = self.fit(model) del os.environ['PL_TESTING_MODE'] self.testing = False - self.evaluating = False + self.evaluating = None # teardown if self.is_function_implemented('teardown'): @@ -883,7 +883,7 @@ def __evaluate_given_model(self, model, dataloaders, stage: str): self.model = model results = self.fit(model) self.testing = False - self.evaluating = False + self.evaluating = None # teardown if self.is_function_implemented('teardown'): @@ -921,7 +921,7 @@ def call_setup_hook(self, model): if self.datamodule is not None: called = { - False: self.datamodule.has_setup_fit, + None: self.datamodule.has_setup_fit, 'validation': self.datamodule.has_setup_validation, 'test': self.datamodule.has_setup_test, }[self.evaluating] From 99a61612df6484c32bffe82aa6e8ac9f0ee1a119 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Wed, 18 Nov 2020 09:54:11 +0100 Subject: [PATCH 10/35] =?UTF-8?q?Add=20a=20first=20batch=20of=20tests=20fo?= =?UTF-8?q?r=20Trainer.validate(=E2=80=A6)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/trainer/test_config_validator.py | 20 ++++++ tests/trainer/test_dataloaders.py | 42 ++++++++++++ tests/trainer/test_optimizers.py | 18 +++++ tests/trainer/test_states.py | 40 ++++++++++- tests/trainer/test_trainer.py | 45 ++++++++++++ tests/trainer/test_trainer_validate_loop.py | 76 +++++++++++++++++++++ 6 files changed, 240 insertions(+), 1 deletion(-) create mode 100644 tests/trainer/test_trainer_validate_loop.py diff --git a/tests/trainer/test_config_validator.py b/tests/trainer/test_config_validator.py index 1ab97304f2338..d11bdf1d64e5a 100755 --- a/tests/trainer/test_config_validator.py +++ b/tests/trainer/test_config_validator.py @@ -92,3 +92,23 @@ def test_test_loop_config(tmpdir): model = EvalModelTemplate(**hparams) model.test_step = None trainer.test(model, test_dataloaders=model.dataloader(train=False)) + + +def test_validation_loop_config(tmpdir): + """" + When either validation loop or validation data are missing + """ + hparams = EvalModelTemplate.get_default_hparams() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + + # has val loop but no val data + with pytest.warns(UserWarning): + model = EvalModelTemplate(**hparams) + model.val_dataloader = None + trainer.validate(model) + + # has val data but no val loop + with pytest.warns(UserWarning): + model = EvalModelTemplate(**hparams) + model.validation_step = None + trainer.validate(model, val_dataloaders=model.dataloader(train=False)) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 1ca34101a9141..ced93d3649060 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -170,6 +170,48 @@ def test_step(self, batch, batch_idx, *args, **kwargs): trainer.test(ckpt_path=ckpt_path) +@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) +def test_multiple_validate_dataloader(tmpdir, ckpt_path): + """Verify multiple val_dataloaders.""" + + model_template = EvalModelTemplate() + + class MultipleValDataloaderModel(EvalModelTemplate): + def val_dataloader(self): + return model_template.val_dataloader__multiple() + + def validation_step(self, batch, batch_idx, *args, **kwargs): + return model_template.validation_step__multiple_dataloaders(batch, batch_idx, *args, **kwargs) + + def validation_epoch_end(self, outputs): + return model_template.validation_epoch_end__multiple_dataloaders(outputs) + + model = MultipleValDataloaderModel() + + # fit model + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=0.1, + limit_train_batches=0.2, + ) + trainer.fit(model) + if ckpt_path == 'specific': + ckpt_path = trainer.checkpoint_callback.best_model_path + trainer.validate(ckpt_path=ckpt_path) + + # verify there are 2 test loaders + assert len(trainer.val_dataloaders) == 2, \ + 'Multiple val_dataloaders not initiated properly' + + # make sure predictions are good for each test set + for dataloader in trainer.val_dataloaders: + tpipes.run_prediction(dataloader, trainer.model) + + # run the test method + trainer.validate(ckpt_path=ckpt_path) + + def test_train_dataloader_passed_to_fit(tmpdir): """Verify that train dataloader can be passed to fit """ diff --git a/tests/trainer/test_optimizers.py b/tests/trainer/test_optimizers.py index c28e626f2eec0..52dc14eb97994 100644 --- a/tests/trainer/test_optimizers.py +++ b/tests/trainer/test_optimizers.py @@ -335,6 +335,24 @@ def test_init_optimizers_during_testing(tmpdir): assert len(trainer.optimizer_frequencies) == 0 +def test_init_optimizers_during_validation(tmpdir): + """ + Test that optimizers is an empty list during validation. + """ + model = EvalModelTemplate() + model.configure_optimizers = model.configure_optimizers__multiple_schedulers + + trainer = Trainer( + default_root_dir=tmpdir, + limit_test_batches=10 + ) + trainer.validate(model, ckpt_path=None) + + assert len(trainer.lr_schedulers) == 0 + assert len(trainer.optimizers) == 0 + assert len(trainer.optimizer_frequencies) == 0 + + def test_multiple_optimizers_callbacks(tmpdir): """ Tests that multiple optimizers can be used with callbacks diff --git a/tests/trainer/test_states.py b/tests/trainer/test_states.py index 0244f654227a2..f6e29b7187d61 100644 --- a/tests/trainer/test_states.py +++ b/tests/trainer/test_states.py @@ -23,7 +23,7 @@ class StateSnapshotCallback(Callback): def __init__(self, snapshot_method: str): super().__init__() - assert snapshot_method in ['on_batch_start', 'on_test_batch_start'] + assert snapshot_method in ['on_batch_start', 'on_validation_batch_start', 'on_test_batch_start'] self.snapshot_method = snapshot_method self.trainer_state = None @@ -31,6 +31,10 @@ def on_batch_start(self, trainer, pl_module): if self.snapshot_method == 'on_batch_start': self.trainer_state = trainer.state + def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + if self.snapshot_method == 'on_validation_batch_start': + self.trainer_state = trainer.state + def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): if self.snapshot_method == 'on_test_batch_start': self.trainer_state = trainer.state @@ -191,6 +195,40 @@ def test_finished_state_after_test(tmpdir): assert trainer.state == TrainerState.FINISHED +def test_running_state_during_validation(tmpdir): + """ Tests that state is set to RUNNING during test """ + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + snapshot_callback = StateSnapshotCallback(snapshot_method='on_validation_batch_start') + + trainer = Trainer( + callbacks=[snapshot_callback], + default_root_dir=tmpdir, + fast_dev_run=True, + ) + + trainer.validate(model) + + assert snapshot_callback.trainer_state == TrainerState.RUNNING + + +def test_finished_state_after_validation(tmpdir): + """ Tests that state is FINISHED after fit """ + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + ) + + trainer.validate(model) + + assert trainer.state == TrainerState.FINISHED + + @pytest.mark.parametrize("extra_params", [ pytest.param(dict(fast_dev_run=True), id='Fast-Run'), pytest.param(dict(max_steps=1), id='Single-Step'), diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 1fbfbb7eb137a..94c26a9c23462 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -747,6 +747,47 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k): assert trainer.tested_ckpt_path == ckpt_path +@pytest.mark.parametrize("ckpt_path", [None, "best", "specific"]) +@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2]) +def test_validate_checkpoint_path(tmpdir, ckpt_path, save_top_k): + hparams = EvalModelTemplate.get_default_hparams() + + model = EvalModelTemplate(**hparams) + trainer = Trainer( + max_epochs=2, + progress_bar_refresh_rate=0, + default_root_dir=tmpdir, + checkpoint_callback=ModelCheckpoint(monitor="early_stop_on", save_top_k=save_top_k), + ) + trainer.fit(model) + if ckpt_path == "best": + # ckpt_path is 'best', meaning we load the best weights + if save_top_k == 0: + with pytest.raises(MisconfigurationException, match=".*is not configured to save the best.*"): + trainer.validate(ckpt_path=ckpt_path) + else: + trainer.validate(ckpt_path=ckpt_path) + assert trainer.tested_ckpt_path == trainer.checkpoint_callback.best_model_path + elif ckpt_path is None: + # ckpt_path is None, meaning we don't load any checkpoints and + # use the weights from the end of training + trainer.validate(ckpt_path=ckpt_path) + assert trainer.tested_ckpt_path is None + else: + # specific checkpoint, pick one from saved ones + if save_top_k == 0: + with pytest.raises(FileNotFoundError): + trainer.validate(ckpt_path="random.ckpt") + else: + ckpt_path = str( + list((Path(tmpdir) / f"lightning_logs/version_{trainer.logger.version}/checkpoints").iterdir())[ + 0 + ].absolute() + ) + trainer.validate(ckpt_path=ckpt_path) + assert trainer.tested_ckpt_path == ckpt_path + + def test_disabled_training(tmpdir): """Verify that `limit_train_batches=0` disables the training loop unless `fast_dev_run=True`.""" @@ -1448,6 +1489,10 @@ def setup(self, model, stage): assert trainer.stage == "test" assert trainer.get_model().stage == "test" + trainer.validate(ckpt_path=None) + assert trainer.stage == "validation" + assert trainer.get_model().stage == "validation" + @pytest.mark.parametrize( "train_batches, max_steps, log_interval", diff --git a/tests/trainer/test_trainer_validate_loop.py b/tests/trainer/test_trainer_validate_loop.py new file mode 100644 index 0000000000000..a2205a4b50dc2 --- /dev/null +++ b/tests/trainer/test_trainer_validate_loop.py @@ -0,0 +1,76 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch + +import pytorch_lightning as pl +import tests.base.develop_utils as tutils +from tests.base import EvalModelTemplate + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_single_gpu_validate(tmpdir): + tutils.set_random_master_port() + + model = EvalModelTemplate() + trainer = pl.Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=10, + limit_val_batches=10, + gpus=[0], + ) + trainer.fit(model) + assert 'ckpt' in trainer.checkpoint_callback.best_model_path + results = trainer.validate() + assert 'val_acc' in results[0] + + old_weights = model.c_d1.weight.clone().detach().cpu() + + results = trainer.validate(model) + assert 'val_acc' in results[0] + + # make sure weights didn't change + new_weights = model.c_d1.weight.clone().detach().cpu() + + assert torch.all(torch.eq(old_weights, new_weights)) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_ddp_spawn_validate(tmpdir): + tutils.set_random_master_port() + + model = EvalModelTemplate() + trainer = pl.Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=10, + limit_val_batches=10, + gpus=[0, 1], + distributed_backend='ddp_spawn', + ) + trainer.fit(model) + assert 'ckpt' in trainer.checkpoint_callback.best_model_path + results = trainer.validate() + assert 'val_acc' in results[0] + + old_weights = model.c_d1.weight.clone().detach().cpu() + + results = trainer.validate(model) + assert 'val_acc' in results[0] + + # make sure weights didn't change + new_weights = model.c_d1.weight.clone().detach().cpu() + + assert torch.all(torch.eq(old_weights, new_weights)) From 307c89a2b358c70a6ecdc59187f4de8b5235b8a9 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Wed, 18 Nov 2020 17:56:02 +0100 Subject: [PATCH 11/35] Avoid an if/else in ProgressBar MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- pytorch_lightning/callbacks/progress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 9ee973c67e4ea..5e99b38520759 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -284,7 +284,7 @@ def init_validation_tqdm(self) -> tqdm: """ Override this to customize the tqdm bar for validation. """ # The main progress bar doesn't exist in trainer.validate(...) - has_main_bar = 1 if self.main_progress_bar is not None else 0 + has_main_bar = int(self.main_progress_bar is not None) bar = tqdm( desc='Validating', From 9e59e6d893cf70ec203601df13c129665e55dd37 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Wed, 18 Nov 2020 18:21:18 +0100 Subject: [PATCH 12/35] Modify ModelCheckpoint to never save a checkpoint automatically when evaluating Without this, ModelCheckpoint might delete the very checkpoint being evaluated. Furthermore, the model will not change during evaluation anyway. --- pytorch_lightning/callbacks/model_checkpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index d257e1ea7cc0d..bbc295591b69a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -208,6 +208,7 @@ def save_checkpoint(self, trainer, pl_module): or self.period < 1 # no models are saved or (epoch + 1) % self.period # skip epoch or trainer.running_sanity_check # don't save anything during sanity check + or trainer.evaluating # don't save anything during evaluation: might delete the checkpoint being evaluated or self.last_global_step_saved == global_step # already saved at the last step ): return From a844f409aea58b9e252827b24fe3337ad6808ca4 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Wed, 18 Nov 2020 18:35:10 +0100 Subject: [PATCH 13/35] Update test_config_validator.py to match the messages of expected errors and warnings --- tests/trainer/test_config_validator.py | 21 +++++++++------------ tests/trainer/test_dataloaders.py | 2 +- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/tests/trainer/test_config_validator.py b/tests/trainer/test_config_validator.py index d11bdf1d64e5a..b724fc8587e24 100755 --- a/tests/trainer/test_config_validator.py +++ b/tests/trainer/test_config_validator.py @@ -19,9 +19,6 @@ from tests.base import EvalModelTemplate -# TODO: add matching messages - - def test_wrong_train_setting(tmpdir): """ * Test that an error is thrown when no `train_dataloader()` is defined @@ -31,12 +28,12 @@ def test_wrong_train_setting(tmpdir): hparams = EvalModelTemplate.get_default_hparams() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - with pytest.raises(MisconfigurationException): + with pytest.raises(MisconfigurationException, match=r'No `train_dataloader\(\)` method defined.'): model = EvalModelTemplate(**hparams) model.train_dataloader = None trainer.fit(model) - with pytest.raises(MisconfigurationException): + with pytest.raises(MisconfigurationException, match=r'No `training_step\(\)` method defined.'): model = EvalModelTemplate(**hparams) model.training_step = None trainer.fit(model) @@ -47,7 +44,7 @@ def test_wrong_configure_optimizers(tmpdir): tutils.reset_seed() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - with pytest.raises(MisconfigurationException): + with pytest.raises(MisconfigurationException, match=r'No `configure_optimizers\(\)` method defined.'): model = EvalModelTemplate() model.configure_optimizers = None trainer.fit(model) @@ -62,13 +59,13 @@ def test_val_loop_config(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) # no val data has val loop - with pytest.warns(UserWarning): + with pytest.warns(UserWarning, match=r'you passed in a val_dataloader but have no validation_step'): model = EvalModelTemplate(**hparams) model.validation_step = None trainer.fit(model) # has val loop but no val data - with pytest.warns(UserWarning): + with pytest.warns(UserWarning, match=r'you defined a validation_step but have no val_dataloader'): model = EvalModelTemplate(**hparams) model.val_dataloader = None trainer.fit(model) @@ -82,13 +79,13 @@ def test_test_loop_config(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) # has test loop but no test data - with pytest.warns(UserWarning): + with pytest.warns(UserWarning, match=r'you defined a test_step but have no test_dataloader'): model = EvalModelTemplate(**hparams) model.test_dataloader = None trainer.test(model) # has test data but no test loop - with pytest.warns(UserWarning): + with pytest.warns(UserWarning, match=r'you passed in a test_dataloader but have no test_step'): model = EvalModelTemplate(**hparams) model.test_step = None trainer.test(model, test_dataloaders=model.dataloader(train=False)) @@ -102,13 +99,13 @@ def test_validation_loop_config(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) # has val loop but no val data - with pytest.warns(UserWarning): + with pytest.warns(UserWarning, match=r'you defined a validation_step but have no val_dataloader'): model = EvalModelTemplate(**hparams) model.val_dataloader = None trainer.validate(model) # has val data but no val loop - with pytest.warns(UserWarning): + with pytest.warns(UserWarning, match=r'you passed in a val_dataloader but have no validation_step'): model = EvalModelTemplate(**hparams) model.validation_step = None trainer.validate(model, val_dataloaders=model.dataloader(train=False)) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index ced93d3649060..06da361f6c2eb 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -208,7 +208,7 @@ def validation_epoch_end(self, outputs): for dataloader in trainer.val_dataloaders: tpipes.run_prediction(dataloader, trainer.model) - # run the test method + # run the validate method trainer.validate(ckpt_path=ckpt_path) From 3f9f9279c7d8506a15a0a9c2946e3f9283fefe55 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Thu, 19 Nov 2020 11:25:31 +0100 Subject: [PATCH 14/35] =?UTF-8?q?Fix=20Trainer.validate(=E2=80=A6,=20verbo?= =?UTF-8?q?se=3DTrue)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../trainer/connectors/logger_connector/logger_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 9386d428b1f07..5e4361affacf0 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -270,7 +270,7 @@ def get_evaluate_epoch_results(self, test_mode): self.prepare_eval_loop_results() # log results of test - if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test: + if self.trainer.evaluating and self.trainer.is_global_zero and self.trainer.verbose_test: print('-' * 80) for result_idx, results in enumerate(self.eval_loop_results): print(f'DATALOADER:{result_idx} TEST RESULTS') From db22f2bcbd79d1abdbb0276faf049fb8ced4f3ee Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Thu, 19 Nov 2020 11:21:34 +0100 Subject: [PATCH 15/35] Transform Trainer.testing to a read-only deprecated property, remove PL_TESTING_MODE env variable It appears that CI tests pass successfully even without PL_TESTING_MODE --- pytorch_lightning/trainer/evaluation_loop.py | 10 ++++++---- pytorch_lightning/trainer/trainer.py | 15 +++++---------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 3dfa7dc87d514..867ea0c9c601f 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +import pytorch_lightning as pl from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.core.step_result import Result, EvalResult from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -21,7 +22,7 @@ class EvaluationLoop(object): - def __init__(self, trainer): + def __init__(self, trainer: 'pl.Trainer'): self.trainer = trainer self.testing = False self.outputs = [] @@ -38,13 +39,14 @@ def on_trainer_init(self): self.trainer.test_dataloaders = None self.trainer.val_dataloaders = None self.trainer.running_sanity_check = False - self.trainer.testing = False + + # .validate() sets this to 'validation' and .test() sets this to 'test' self.trainer.evaluating = None - # when .test() is called, it sets this + # .validate() and .test() set this when they load a checkpoint self.trainer.tested_ckpt_path = None - # when true, prints test results + # when true, print evaluation results in .validate() and .test() self.trainer.verbose_test = True def get_evaluation_dataloaders(self, max_batches): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7935fea55a96e..dd96db2310492 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -442,10 +442,6 @@ def fit( # hook self.data_connector.prepare_data(model) - # bookkeeping - # we reuse fit in .test() but change its behavior using this flag - self.testing = os.environ.get('PL_TESTING_MODE', self.testing) - # ---------------------------- # SET UP TRAINING # ---------------------------- @@ -854,12 +850,8 @@ def __evaluate_using_best_weights(self, ckpt_path, dataloaders, stage: str): # run tests self.evaluating = stage self.tested_ckpt_path = ckpt_path - self.testing = True - os.environ['PL_TESTING_MODE'] = '1' self.model = model results = self.fit(model) - del os.environ['PL_TESTING_MODE'] - self.testing = False self.evaluating = None # teardown @@ -879,10 +871,8 @@ def __evaluate_given_model(self, model, dataloaders, stage: str): # run test # sets up testing so we short circuit to eval self.evaluating = stage - self.testing = True # TODO: remove, keep only evaluating self.model = model results = self.fit(model) - self.testing = False self.evaluating = None # teardown @@ -891,6 +881,11 @@ def __evaluate_given_model(self, model, dataloaders, stage: str): return results + @property + def testing(self): + warnings.warn('Trainer.testing is deprecated, use Trainer.evaluating instead.', FutureWarning, stacklevel=2) + return bool(self.evaluating) + def tune( self, model: LightningModule, From f8647c549572702f7a2028194d0e62b82ac23254 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Thu, 19 Nov 2020 11:49:15 +0100 Subject: [PATCH 16/35] Update docs for Trainer.validate and Trainer.test --- docs/source/trainer.rst | 16 +++++++++++-- pytorch_lightning/trainer/trainer.py | 34 ++++++++++++++-------------- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst index e82b0871ef85b..83f699616f2e1 100644 --- a/docs/source/trainer.rst +++ b/docs/source/trainer.rst @@ -54,7 +54,7 @@ Under the hood -------------- Under the hood, the Lightning Trainer handles the training loop details for you, some examples include: -- Automatically eenabling/disabling grads +- Automatically enabling/disabling grads - Running the training, validation and test dataloaders - Calling the Callbacks at the appropriate times - Putting batches and computations on the correct devices @@ -148,6 +148,18 @@ So you can run it like so: ------------ +Validation +---------- +After training, you can perform a new evaluation epoch over the validation set +with :meth:`pytorch_lightning.trainer.trainer.Trainer.validate`. This might be +useful if you want to collect new metrics from a model you already trained. + +.. code-block:: python + + trainer.validate(val_dataloaders=val_dataloaders) + +------------ + Testing ------- Once you're done training, feel free to run the test set! @@ -155,7 +167,7 @@ Once you're done training, feel free to run the test set! .. code-block:: python - trainer.test(test_dataloader=test_dataloader) + trainer.test(test_dataloaders=test_dataloaders) ------------ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index dd96db2310492..81e58e51b78f0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -716,26 +716,25 @@ def validate( verbose: bool = True, datamodule: Optional[LightningDataModule] = None, ): - # TODO: docstring r""" - Separates from fit to make sure you never run on your test set until you want to. + Perform one evaluation epoch over the validation set. Args: - ckpt_path: Either ``best`` or path to the checkpoint you wish to test. - If ``None``, use the weights from the last epoch to test. Default to ``best``. + ckpt_path: Either ``best`` or path to the checkpoint you wish to evaluate. + If ``None``, use the current weights of the model. Default to ``best``. datamodule: A instance of :class:`LightningDataModule`. - model: The model to test. + model: The model to evaluate. - test_dataloaders: Either a single - Pytorch Dataloader or a list of them, specifying validation samples. + val_dataloaders: Either a single PyTorch DataLoader or a list of them, + specifying validation samples. - verbose: If True, prints the test results + verbose: If True, prints the evaluation results Returns: - The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries + The final validation results dictionary. If no validation_epoch_end is defined, returns a list of dictionaries """ # -------------------- # SETUP HOOK @@ -772,23 +771,24 @@ def test( ): r""" - Separates from fit to make sure you never run on your test set until you want to. + Perform one evaluation epoch over the test set. It's separated from + fit to make sure you never run on your test set until you want to. Args: - ckpt_path: Either ``best`` or path to the checkpoint you wish to test. - If ``None``, use the weights from the last epoch to test. Default to ``best``. + ckpt_path: Either ``best`` or path to the checkpoint you wish to evaluate. + If ``None``, use the current weights of the model. Default to ``best``. datamodule: A instance of :class:`LightningDataModule`. - model: The model to test. + model: The model to evaluate. - test_dataloaders: Either a single - Pytorch Dataloader or a list of them, specifying validation samples. + test_dataloaders: Either a single PyTorch DataLoader or a list of them, + specifying test samples. - verbose: If True, prints the test results + verbose: If True, prints the evaluation results Returns: - The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries + The final test results dictionary. If no test_epoch_end is defined, returns a list of dictionaries """ # -------------------- # SETUP HOOK From 99281a06d94efbdce4841b3a74c305dd0a4cc4f2 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Fri, 20 Nov 2020 12:58:24 +0100 Subject: [PATCH 17/35] Remove usages of deprecated Trainer.testing --- pytorch_lightning/trainer/connectors/model_connector.py | 5 ++++- pytorch_lightning/trainer/training_loop.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index dbdceb1532288..caef10c67bf2e 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -45,7 +45,10 @@ def copy_trainer_model_properties(self, model): m.use_ddp2 = self.trainer.use_ddp2 m.use_ddp = self.trainer.use_ddp m.use_amp = self.trainer.amp_backend is not None - m.testing = self.trainer.testing + # TODO: I only find usages of m.testing in DDP, where it's used to + # discriminate test from validation, as opposed to test from fit in + # Trainer. Still need to fully determine if it's correct. + m.testing = self.trainer.evaluating == 'test' m.use_single_gpu = self.trainer.use_single_gpu m.use_tpu = self.trainer.use_tpu m.tpu_local_core_rank = self.trainer.tpu_local_core_rank diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 8af55f64715f2..a4b6f9767b999 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -160,7 +160,7 @@ def setup_training(self, model: LightningModule): ref_model.on_pretrain_routine_start() # print model summary - if self.trainer.is_global_zero and self.trainer.weights_summary is not None and not self.trainer.testing: + if self.trainer.is_global_zero and self.trainer.weights_summary is not None and not self.trainer.evaluating: if self.trainer.weights_summary in ModelSummary.MODES: ref_model.summarize(mode=self.trainer.weights_summary) else: From 58d1c368dc3ec61bb732caa869a8f57b394f160f Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Fri, 20 Nov 2020 13:01:55 +0100 Subject: [PATCH 18/35] Rename methods and attributes to reflect their new behavior --- pytorch_lightning/accelerators/accelerator.py | 5 ++--- pytorch_lightning/accelerators/cpu_accelerator.py | 4 ++-- pytorch_lightning/accelerators/ddp2_accelerator.py | 4 ++-- pytorch_lightning/accelerators/ddp_accelerator.py | 4 ++-- .../accelerators/ddp_cpu_spawn_accelerator.py | 4 ++-- pytorch_lightning/accelerators/ddp_hpc_accelerator.py | 4 ++-- pytorch_lightning/accelerators/ddp_spawn_accelerator.py | 4 ++-- pytorch_lightning/accelerators/dp_accelerator.py | 4 ++-- pytorch_lightning/accelerators/gpu_accelerator.py | 5 +++-- pytorch_lightning/accelerators/horovod_accelerator.py | 4 ++-- pytorch_lightning/accelerators/tpu_accelerator.py | 4 ++-- .../connectors/logger_connector/logger_connector.py | 2 +- pytorch_lightning/trainer/evaluation_loop.py | 2 +- pytorch_lightning/trainer/trainer.py | 7 +++---- 14 files changed, 28 insertions(+), 29 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 4a22d275fd89f..fb654fe31836f 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -59,10 +59,9 @@ def barrier(self, name: Optional[str] = None): def broadcast(self, obj, src=0): return obj - # TODO: rename train_or_evaluate - def train_or_test(self): + def train_or_evaluate(self): if self.trainer.evaluating: - results = self.trainer.run_test() + results = self.trainer.run_test_or_validate() else: results = self.trainer.train() return results diff --git a/pytorch_lightning/accelerators/cpu_accelerator.py b/pytorch_lightning/accelerators/cpu_accelerator.py index 66f9e4f0201b2..3c44fbc981ca7 100644 --- a/pytorch_lightning/accelerators/cpu_accelerator.py +++ b/pytorch_lightning/accelerators/cpu_accelerator.py @@ -55,8 +55,8 @@ def train(self): # set up training routine self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() return results def training_step(self, args): diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index 2da9747a9be92..9b2be72329bea 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -181,8 +181,8 @@ def ddp_train(self, process_idx, mp_queue, model): # set up training routine self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() # clean up memory torch.cuda.empty_cache() diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index f99cd1149e5ae..3ef3ddc60b0b8 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -278,8 +278,8 @@ def ddp_train(self, process_idx, model): self.barrier('ddp_setup') self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() # clean up memory torch.cuda.empty_cache() diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 221ed5769c35e..7f92c5455db08 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -143,8 +143,8 @@ def ddp_train(self, process_idx, mp_queue, model): # set up training routine self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() # get original model model = self.trainer.get_model() diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index b6d813f978943..2bbff3cfb4739 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -175,8 +175,8 @@ def ddp_train(self, process_idx, model): # set up training routine self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() # clean up memory torch.cuda.empty_cache() diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index a30d266ec1b2f..5e828a178ce09 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -154,8 +154,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 # set up training routine self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() # get original model model = self.trainer.get_model() diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index 2f6c5dce97c46..62c1a1cbcec7d 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -102,8 +102,8 @@ def train(self): # set up training routine self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() return results diff --git a/pytorch_lightning/accelerators/gpu_accelerator.py b/pytorch_lightning/accelerators/gpu_accelerator.py index 1a52c4037c8d3..2e8221119e429 100644 --- a/pytorch_lightning/accelerators/gpu_accelerator.py +++ b/pytorch_lightning/accelerators/gpu_accelerator.py @@ -60,8 +60,9 @@ def train(self): # set up training routine self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() + return results def training_step(self, args): diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index 3d9191914566d..c80acde9acdea 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -109,8 +109,8 @@ def train(self): # set up training routine self.trainer.train_loop.setup_training(self.trainer.model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() # Make sure all workers have finished training before returning to the user hvd.join() diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index cc7da4dc10781..c053c2d927cbb 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -133,8 +133,8 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine # set up training routine self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() # save weights at the end of training self.__save_end_of_training_weights(model, trainer) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 5e4361affacf0..0e9e1664d1023 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -270,7 +270,7 @@ def get_evaluate_epoch_results(self, test_mode): self.prepare_eval_loop_results() # log results of test - if self.trainer.evaluating and self.trainer.is_global_zero and self.trainer.verbose_test: + if self.trainer.evaluating and self.trainer.is_global_zero and self.trainer.verbose_evaluate: print('-' * 80) for result_idx, results in enumerate(self.eval_loop_results): print(f'DATALOADER:{result_idx} TEST RESULTS') diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 867ea0c9c601f..7117b4b74cacc 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -47,7 +47,7 @@ def on_trainer_init(self): self.trainer.tested_ckpt_path = None # when true, print evaluation results in .validate() and .test() - self.trainer.verbose_test = True + self.trainer.verbose_evaluate = True def get_evaluation_dataloaders(self, max_batches): # select dataloaders diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 81e58e51b78f0..ead4dd978c055 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -655,8 +655,7 @@ def track_output_for_epoch_end(self, outputs, output): outputs.append(output) return outputs - # TODO: rename run_test_or_validate? - def run_test(self): + def run_test_or_validate(self): # only load test dataloader for testing # self.reset_test_dataloader(ref_model) test_mode = True if self.evaluating == 'test' else False @@ -739,7 +738,7 @@ def validate( # -------------------- # SETUP HOOK # -------------------- - self.verbose_test = verbose # TODO: rename / else? + self.verbose_evaluate = verbose self.logger_connector.set_stage("validation") @@ -793,7 +792,7 @@ def test( # -------------------- # SETUP HOOK # -------------------- - self.verbose_test = verbose + self.verbose_evaluate = verbose self.logger_connector.set_stage("test") From 7330ad4e6c63a4f359a0fe894b55cc7a49be4543 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Fri, 20 Nov 2020 15:59:16 +0100 Subject: [PATCH 19/35] =?UTF-8?q?Rename=20Trainer.tested=5Fckpt=5Fpath=20t?= =?UTF-8?q?o=20Trainer.evaluated=5Fckpt=5Fpath=20since=20it=E2=80=99s=20us?= =?UTF-8?q?ed=20by=20both=20.validate()=20and=20.test()?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytorch_lightning/trainer/evaluation_loop.py | 2 +- pytorch_lightning/trainer/trainer.py | 7 ++++++- tests/trainer/test_trainer.py | 12 ++++++------ 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 7117b4b74cacc..2a9167b06d83a 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -44,7 +44,7 @@ def on_trainer_init(self): self.trainer.evaluating = None # .validate() and .test() set this when they load a checkpoint - self.trainer.tested_ckpt_path = None + self.trainer.evaluated_ckpt_path = None # when true, print evaluation results in .validate() and .test() self.trainer.verbose_evaluate = True diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ead4dd978c055..03d210506845e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -848,7 +848,7 @@ def __evaluate_using_best_weights(self, ckpt_path, dataloaders, stage: str): # run tests self.evaluating = stage - self.tested_ckpt_path = ckpt_path + self.evaluated_ckpt_path = ckpt_path self.model = model results = self.fit(model) self.evaluating = None @@ -885,6 +885,11 @@ def testing(self): warnings.warn('Trainer.testing is deprecated, use Trainer.evaluating instead.', FutureWarning, stacklevel=2) return bool(self.evaluating) + @property + def tested_ckpt_path(self): + warnings.warn('Trainer.tested_ckpt_path is deprecated and has been replaced by Trainer.evaluated_ckpt_path.', FutureWarning, stacklevel=2) + return self.evaluated_ckpt_path + def tune( self, model: LightningModule, diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 94c26a9c23462..3647b6ecbd2f4 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -726,12 +726,12 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k): trainer.test(ckpt_path=ckpt_path) else: trainer.test(ckpt_path=ckpt_path) - assert trainer.tested_ckpt_path == trainer.checkpoint_callback.best_model_path + assert trainer.evaluated_ckpt_path == trainer.checkpoint_callback.best_model_path elif ckpt_path is None: # ckpt_path is None, meaning we don't load any checkpoints and # use the weights from the end of training trainer.test(ckpt_path=ckpt_path) - assert trainer.tested_ckpt_path is None + assert trainer.evaluated_ckpt_path is None else: # specific checkpoint, pick one from saved ones if save_top_k == 0: @@ -744,7 +744,7 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k): ].absolute() ) trainer.test(ckpt_path=ckpt_path) - assert trainer.tested_ckpt_path == ckpt_path + assert trainer.evaluated_ckpt_path == ckpt_path @pytest.mark.parametrize("ckpt_path", [None, "best", "specific"]) @@ -767,12 +767,12 @@ def test_validate_checkpoint_path(tmpdir, ckpt_path, save_top_k): trainer.validate(ckpt_path=ckpt_path) else: trainer.validate(ckpt_path=ckpt_path) - assert trainer.tested_ckpt_path == trainer.checkpoint_callback.best_model_path + assert trainer.evaluated_ckpt_path == trainer.checkpoint_callback.best_model_path elif ckpt_path is None: # ckpt_path is None, meaning we don't load any checkpoints and # use the weights from the end of training trainer.validate(ckpt_path=ckpt_path) - assert trainer.tested_ckpt_path is None + assert trainer.evaluated_ckpt_path is None else: # specific checkpoint, pick one from saved ones if save_top_k == 0: @@ -785,7 +785,7 @@ def test_validate_checkpoint_path(tmpdir, ckpt_path, save_top_k): ].absolute() ) trainer.validate(ckpt_path=ckpt_path) - assert trainer.tested_ckpt_path == ckpt_path + assert trainer.evaluated_ckpt_path == ckpt_path def test_disabled_training(tmpdir): From 7abc67deb24d534026803274ec7d890b0c05bc46 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Fri, 20 Nov 2020 16:05:52 +0100 Subject: [PATCH 20/35] Update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 236784afec84e..eacfb531450cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). [#4495](https://github.com/PyTorchLightning/pytorch-lightning/pull/4495), [#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439)) +- Added `Trainer.validate()` method to perform one evaluation epoch over the validation set ( + [#4707](https://github.com/PyTorchLightning/pytorch-lightning/pull/4707)) ### Changed From 14799da5283575b58202ce180d919e0587def851 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Fri, 20 Nov 2020 16:12:58 +0100 Subject: [PATCH 21/35] Fix PEP 8 issues --- pytorch_lightning/trainer/trainer.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 03d210506845e..8e15c17374c66 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -733,7 +733,8 @@ def validate( verbose: If True, prints the evaluation results Returns: - The final validation results dictionary. If no validation_epoch_end is defined, returns a list of dictionaries + The final validation results dictionary. If no validation_epoch_end + is defined, returns a list of dictionaries. """ # -------------------- # SETUP HOOK @@ -787,7 +788,8 @@ def test( verbose: If True, prints the evaluation results Returns: - The final test results dictionary. If no test_epoch_end is defined, returns a list of dictionaries + The final test results dictionary. If no test_epoch_end is defined, + returns a list of dictionaries. """ # -------------------- # SETUP HOOK @@ -882,12 +884,18 @@ def __evaluate_given_model(self, model, dataloaders, stage: str): @property def testing(self): - warnings.warn('Trainer.testing is deprecated, use Trainer.evaluating instead.', FutureWarning, stacklevel=2) + warnings.warn( + 'Trainer.testing is deprecated, use Trainer.evaluating instead.', + FutureWarning, stacklevel=2 + ) return bool(self.evaluating) @property def tested_ckpt_path(self): - warnings.warn('Trainer.tested_ckpt_path is deprecated and has been replaced by Trainer.evaluated_ckpt_path.', FutureWarning, stacklevel=2) + warnings.warn( + 'Trainer.tested_ckpt_path is deprecated and has been replaced by Trainer.evaluated_ckpt_path.', + FutureWarning, stacklevel=2 + ) return self.evaluated_ckpt_path def tune( From f8f4d3b235da1a95b62e96eac90e6a301d71dfbb Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Fri, 20 Nov 2020 16:25:26 +0100 Subject: [PATCH 22/35] =?UTF-8?q?Update=20documentation=20of=20.setup(stag?= =?UTF-8?q?e)=20methods=20to=20mention=20the=20new=20=E2=80=98validation?= =?UTF-8?q?=E2=80=99=20stage?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytorch_lightning/callbacks/base.py | 4 ++-- pytorch_lightning/core/hooks.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 3f6b4ffe9622a..4528409595207 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -28,11 +28,11 @@ class Callback(abc.ABC): """ def setup(self, trainer, pl_module, stage: str): - """Called when fit or test begins""" + """Called when fit, validate or test begins""" pass def teardown(self, trainer, pl_module, stage: str): - """Called when fit or test ends""" + """Called when fit, validate or test ends""" pass def on_init_start(self, trainer): diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 4272c0823bb19..38db4bdfc1a7e 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -33,12 +33,12 @@ class ModelHooks: """Hooks to be used in LightningModule.""" def setup(self, stage: str): """ - Called at the beginning of fit and test. + Called at the beginning of fit, validate and test. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP. Args: - stage: either 'fit' or 'test' + stage: either 'fit', 'validation' or 'test' Example:: @@ -61,10 +61,10 @@ def setup(stage): def teardown(self, stage: str): """ - Called at the end of fit and test. + Called at the end of fit, validate and test. Args: - stage: either 'fit' or 'test' + stage: either 'fit', 'validation' or 'test' """ def on_fit_start(self): From 1818f22f62ff8842a427678bc85b72fb97eba9fd Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Fri, 20 Nov 2020 16:37:41 +0100 Subject: [PATCH 23/35] Added more tests for Trainer.validate --- tests/callbacks/test_callbacks.py | 26 +++++++++++++++-- tests/callbacks/test_progress_bar.py | 30 ++++++++++++++++++-- tests/checkpointing/test_model_checkpoint.py | 8 ++++++ 3 files changed, 59 insertions(+), 5 deletions(-) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index cf88f52436576..81a34880e9f27 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -60,7 +60,7 @@ def __init__(self): def setup(self, trainer, pl_module, stage: str): assert isinstance(trainer, Trainer) - self.setup_called = True + self.setup_called = stage def teardown(self, trainer, pl_module, step: str): assert isinstance(trainer, Trainer) @@ -245,7 +245,7 @@ def on_before_zero_grad(self, trainer, pl_module, optimizer): trainer.fit(model) - assert test_callback.setup_called + assert test_callback.setup_called == 'fit' assert test_callback.teardown_called assert test_callback.on_init_start_called assert test_callback.on_init_end_called @@ -278,12 +278,32 @@ def on_before_zero_grad(self, trainer, pl_module, optimizer): test_callback.teardown_called = False test_callback.setup_called = False + # validate model + test_callback = TestCallback() + trainer_options.update(callbacks=[test_callback]) + trainer = Trainer(**trainer_options) + trainer.validate(model) + + assert test_callback.setup_called == 'validation' + assert test_callback.teardown_called + assert test_callback.on_validation_start_called + assert test_callback.on_validation_end_called + assert test_callback.on_validation_batch_end_called + assert test_callback.on_validation_batch_start_called + assert not test_callback.on_test_batch_start_called + assert not test_callback.on_test_batch_end_called + assert not test_callback.on_test_start_called + assert not test_callback.on_test_end_called + assert not test_callback.on_after_backward_called + assert not test_callback.on_before_zero_grad_called + + # test model test_callback = TestCallback() trainer_options.update(callbacks=[test_callback]) trainer = Trainer(**trainer_options) trainer.test(model) - assert test_callback.setup_called + assert test_callback.setup_called == 'test' assert test_callback.teardown_called assert test_callback.on_test_batch_start_called assert test_callback.on_test_batch_end_called diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 221844244ad75..8235384eed6b2 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -80,7 +80,7 @@ def test_progress_bar_totals(tmpdir): limit_val_batches=1.0, max_epochs=1, ) - bar = trainer.progress_bar_callback + bar: ProgressBar = trainer.progress_bar_callback assert 0 == bar.total_train_batches assert 0 == bar.total_val_batches assert 0 == bar.total_test_batches @@ -109,6 +109,17 @@ def test_progress_bar_totals(tmpdir): assert 0 == bar.total_test_batches assert bar.test_progress_bar is None + trainer.validate(model) + + # check validation progress bar total + k = bar.total_val_batches + assert sum(len(loader) for loader in trainer.val_dataloaders) == k + assert bar.val_progress_bar.total == k + + # validation progress bar should have reached the end + assert bar.val_progress_bar.n == k + assert bar.val_batch_idx == k + trainer.test(model) # check test progress bar total @@ -131,7 +142,7 @@ def test_progress_bar_fast_dev_run(tmpdir): trainer.fit(model) - progress_bar = trainer.progress_bar_callback + progress_bar: ProgressBar = trainer.progress_bar_callback assert 1 == progress_bar.total_train_batches # total val batches are known only after val dataloaders have reloaded @@ -146,6 +157,13 @@ def test_progress_bar_fast_dev_run(tmpdir): assert 2 == progress_bar.main_progress_bar.total assert 2 == progress_bar.main_progress_bar.n + trainer.validate(model) + + # the validation progress bar should display 1 batch + assert 1 == progress_bar.val_batch_idx + assert 1 == progress_bar.val_progress_bar.total + assert 1 == progress_bar.val_progress_bar.n + trainer.test(model) # the test progress bar should display 1 batch @@ -203,8 +221,16 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal trainer.fit(model) assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches assert progress_bar.val_batches_seen == 3 * progress_bar.total_val_batches + trainer.num_sanity_val_steps + assert progress_bar.test_batches_seen == 0 + + trainer.validate(model) + assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches + assert progress_bar.val_batches_seen == 4 * progress_bar.total_val_batches + trainer.num_sanity_val_steps + assert progress_bar.test_batches_seen == 0 trainer.test(model) + assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches + assert progress_bar.val_batches_seen == 4 * progress_bar.total_val_batches + trainer.num_sanity_val_steps assert progress_bar.test_batches_seen == progress_bar.total_test_batches diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 1ac81feaef6de..c909f1e931333 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -743,6 +743,9 @@ def get_model(): assert trainer.current_epoch == epochs - 1 assert_checkpoint_log_dir(0) + trainer.validate(model) + assert trainer.current_epoch == epochs - 1 + trainer.test(model) assert trainer.current_epoch == epochs - 1 @@ -763,6 +766,11 @@ def get_model(): ) assert_trainer_init(trainer) + trainer.validate(model) + assert not trainer.checkpoint_connector.has_trained + assert trainer.global_step == epochs * limit_train_batches + assert trainer.current_epoch == epochs + trainer.test(model) assert not trainer.checkpoint_connector.has_trained assert trainer.global_step == epochs * limit_train_batches From d0cd34a4a035c1be1174edac0368946b1b2c755e Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Fri, 20 Nov 2020 17:08:39 +0100 Subject: [PATCH 24/35] =?UTF-8?q?Fix=20hook=20that=20tracks=20LightningDat?= =?UTF-8?q?aModule.setup(=E2=80=98validation=E2=80=99)=20calls,=20add=20mo?= =?UTF-8?q?re=20tests=20on=20DMs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytorch_lightning/core/datamodule.py | 5 ++- tests/base/datamodules.py | 4 +- tests/core/test_datamodules.py | 56 ++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index cf224eb7ea1f8..f72ab422e66b5 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -76,13 +76,16 @@ def wrapped_fn(*args, **kwargs): if fn.__name__ == "setup": # Get stage either by grabbing from args or checking kwargs. - # If not provided, set call status of 'fit' and 'test' to True. + # If not provided, set call status of 'fit', 'validation' and 'test' to True. # We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test() stage = args[1] if len(args) > 1 else kwargs.get("stage", None) if stage == "fit" or stage is None: obj._has_setup_fit = True + if stage == "validation" or stage is None: + obj._has_setup_validation = True + if stage == "test" or stage is None: obj._has_setup_test = True diff --git a/tests/base/datamodules.py b/tests/base/datamodules.py index e4d0b4bff89d7..df583a193de23 100644 --- a/tests/base/datamodules.py +++ b/tests/base/datamodules.py @@ -33,7 +33,7 @@ def prepare_data(self): def setup(self, stage: Optional[str] = None): - if stage == "fit" or stage is None: + if stage in ["fit", "validation"] or stage is None: mnist_full = TrialMNIST( root=self.data_dir, train=True, num_samples=64, download=True ) @@ -88,7 +88,7 @@ def setup(self, stage: Optional[str] = None): # Assign train/val datasets for use in dataloaders # TODO: need to split using random_split once updated to torch >= 1.6 - if stage == "fit" or stage is None: + if stage in ["fit", "validate"] or stage is None: self.mnist_train = MNIST( self.data_dir, train=True, normalize=(0.1307, 0.3081) ) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 3e683025e8867..32f4aebe445d4 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -111,6 +111,7 @@ def test_base_datamodule_with_verbose_setup(tmpdir): dm = TrialMNISTDataModule() dm.prepare_data() dm.setup('fit') + dm.setup('validation') dm.setup('test') @@ -118,16 +119,19 @@ def test_data_hooks_called(tmpdir): dm = TrialMNISTDataModule() assert dm.has_prepared_data is False assert dm.has_setup_fit is False + assert dm.has_setup_validation is False assert dm.has_setup_test is False dm.prepare_data() assert dm.has_prepared_data is True assert dm.has_setup_fit is False + assert dm.has_setup_validation is False assert dm.has_setup_test is False dm.setup() assert dm.has_prepared_data is True assert dm.has_setup_fit is True + assert dm.has_setup_validation is True assert dm.has_setup_test is True @@ -135,21 +139,31 @@ def test_data_hooks_called_verbose(tmpdir): dm = TrialMNISTDataModule() assert dm.has_prepared_data is False assert dm.has_setup_fit is False + assert dm.has_setup_validation is False assert dm.has_setup_test is False dm.prepare_data() assert dm.has_prepared_data is True assert dm.has_setup_fit is False + assert dm.has_setup_validation is False assert dm.has_setup_test is False dm.setup('fit') assert dm.has_prepared_data is True assert dm.has_setup_fit is True + assert dm.has_setup_validation is False + assert dm.has_setup_test is False + + dm.setup('validation') + assert dm.has_prepared_data is True + assert dm.has_setup_fit is True + assert dm.has_setup_validation is True assert dm.has_setup_test is False dm.setup('test') assert dm.has_prepared_data is True assert dm.has_setup_fit is True + assert dm.has_setup_validation is True assert dm.has_setup_test is True @@ -160,10 +174,17 @@ def test_data_hooks_called_with_stage_kwarg(tmpdir): dm.setup(stage='fit') assert dm.has_setup_fit is True + assert dm.has_setup_validation is False + assert dm.has_setup_test is False + + dm.setup(stage='validation') + assert dm.has_setup_fit is True + assert dm.has_setup_validation is True assert dm.has_setup_test is False dm.setup(stage='test') assert dm.has_setup_fit is True + assert dm.has_setup_validation is True assert dm.has_setup_test is True @@ -254,6 +275,21 @@ def test_dm_checkpoint_save(tmpdir): assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__ +def test_validate_loop_only(tmpdir): + reset_seed() + + dm = TrialMNISTDataModule(tmpdir) + + model = EvalModelTemplate() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=3, + weights_summary=None, + ) + trainer.validate(model, datamodule=dm) + + def test_test_loop_only(tmpdir): reset_seed() @@ -287,6 +323,11 @@ def test_full_loop(tmpdir): result = trainer.fit(model, dm) assert result == 1 + # validate + result = trainer.validate(datamodule=dm) + result = result[0] + assert result['val_acc'] > 0.8 + # test result = trainer.test(datamodule=dm) result = result[0] @@ -312,6 +353,11 @@ def test_trainer_attached_to_dm(tmpdir): assert result == 1 assert dm.trainer is not None + # validate + result = trainer.validate(datamodule=dm) + result = result[0] + assert dm.trainer is not None + # test result = trainer.test(datamodule=dm) result = result[0] @@ -338,6 +384,11 @@ def test_full_loop_single_gpu(tmpdir): result = trainer.fit(model, dm) assert result == 1 + # validate + result = trainer.validate(datamodule=dm) + result = result[0] + assert result['val_acc'] > 0.8 + # test result = trainer.test(datamodule=dm) result = result[0] @@ -365,6 +416,11 @@ def test_full_loop_dp(tmpdir): result = trainer.fit(model, dm) assert result == 1 + # validate + result = trainer.validate(datamodule=dm) + result = result[0] + assert result['val_acc'] > 0.8 + # test result = trainer.test(datamodule=dm) result = result[0] From 0209cfc2f20534330546d054870d8ddbc824a483 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Fri, 20 Nov 2020 17:57:15 +0100 Subject: [PATCH 25/35] Add a test for Trainer.validate on DataParallel --- tests/backends/test_dp.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/backends/test_dp.py b/tests/backends/test_dp.py index c051b442cb7a7..b697440280f80 100644 --- a/tests/backends/test_dp.py +++ b/tests/backends/test_dp.py @@ -67,7 +67,7 @@ def test_multi_gpu_model_dp(tmpdir): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -def test_dp_test(tmpdir): +def test_dp_evaluate(tmpdir): tutils.set_random_master_port() import os @@ -84,6 +84,22 @@ def test_dp_test(tmpdir): ) trainer.fit(model) assert 'ckpt' in trainer.checkpoint_callback.best_model_path + + # validate + results = trainer.validate() + assert 'val_acc' in results[0] + + old_weights = model.c_d1.weight.clone().detach().cpu() + + results = trainer.validate(model) + assert 'val_acc' in results[0] + + # make sure weights didn't change + new_weights = model.c_d1.weight.clone().detach().cpu() + + assert torch.all(torch.eq(old_weights, new_weights)) + + # test results = trainer.test() assert 'test_acc' in results[0] From 6a0428002b9bf15a133ae98c2415abe3c2d9aa81 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Sat, 21 Nov 2020 16:19:02 +0100 Subject: [PATCH 26/35] Disable EarlyStopping in evaluation mode --- pytorch_lightning/callbacks/early_stopping.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 972d16fd705a8..562aefe9b7ec9 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -138,13 +138,13 @@ def on_load_checkpoint(self, checkpointed_state): self.patience = checkpointed_state['patience'] def on_validation_end(self, trainer, pl_module): - if trainer.running_sanity_check: + if trainer.running_sanity_check or trainer.evaluating: return self._run_early_stopping_check(trainer, pl_module) def on_validation_epoch_end(self, trainer, pl_module): - if trainer.running_sanity_check: + if trainer.running_sanity_check or trainer.evaluating: return if self._validate_condition_metric(trainer.logger_connector.callback_metrics): From 211535071ddc5e2f8f35dc643e7f42dfd025d7d4 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Sat, 21 Nov 2020 16:29:25 +0100 Subject: [PATCH 27/35] Clean up LoggerConnector.get_evaluate_epoch_results * Remove unused test_mode parameter * Differentiate validation/test results in the printed message --- .../trainer/connectors/logger_connector/logger_connector.py | 6 +++--- pytorch_lightning/trainer/evaluation_loop.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 0e9e1664d1023..240d55229ca11 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -260,7 +260,7 @@ def prepare_eval_loop_results(self): for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders): self.add_to_eval_loop_results(dl_idx, has_been_initialized) - def get_evaluate_epoch_results(self, test_mode): + def get_evaluate_epoch_results(self): if not self.trainer.running_sanity_check: # log all the metrics as a single dict metrics_to_log = self.cached_results.get_epoch_log_metrics() @@ -269,11 +269,11 @@ def get_evaluate_epoch_results(self, test_mode): self.prepare_eval_loop_results() - # log results of test + # log results of evaluation if self.trainer.evaluating and self.trainer.is_global_zero and self.trainer.verbose_evaluate: print('-' * 80) for result_idx, results in enumerate(self.eval_loop_results): - print(f'DATALOADER:{result_idx} TEST RESULTS') + print(f'DATALOADER:{result_idx} {self.trainer.evaluating.upper()} RESULTS') pprint(results) print('-' * 80) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 2a9167b06d83a..5a93035ff49fa 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -217,7 +217,7 @@ def evaluation_epoch_end(self): def log_epoch_metrics_on_evaluation_end(self): # get the final loop results - eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results(self.testing) + eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results() return eval_loop_results def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): From 92acb1256469fadd59e3a09f5d3ae78ee9ca575f Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Sat, 21 Nov 2020 16:34:38 +0100 Subject: [PATCH 28/35] Improve description of Trainer.validate in docs/source/trainer.rst --- docs/source/trainer.rst | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst index 9890c9bab119e..b935b742340d7 100644 --- a/docs/source/trainer.rst +++ b/docs/source/trainer.rst @@ -150,9 +150,10 @@ So you can run it like so: Validation ---------- -After training, you can perform a new evaluation epoch over the validation set -with :meth:`pytorch_lightning.trainer.trainer.Trainer.validate`. This might be -useful if you want to collect new metrics from a model you already trained. +You can perform an evaluation epoch over the validation set, outside of the training loop, +using :meth:`pytorch_lightning.trainer.trainer.Trainer.validate`. This might be +useful if you want to collect new metrics from a model right at its initialization +or that has already been trained. .. code-block:: python From 80901932a0590a3633f6b304146f1fd0e610c1d5 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Sat, 21 Nov 2020 16:42:03 +0100 Subject: [PATCH 29/35] Clean up setup() methods in tests/base/datamodules.py Co-authored-by: Rohit Gupta --- tests/base/datamodules.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/base/datamodules.py b/tests/base/datamodules.py index df583a193de23..94e4ba9c1efe9 100644 --- a/tests/base/datamodules.py +++ b/tests/base/datamodules.py @@ -33,7 +33,7 @@ def prepare_data(self): def setup(self, stage: Optional[str] = None): - if stage in ["fit", "validation"] or stage is None: + if stage != 'test': mnist_full = TrialMNIST( root=self.data_dir, train=True, num_samples=64, download=True ) @@ -88,7 +88,7 @@ def setup(self, stage: Optional[str] = None): # Assign train/val datasets for use in dataloaders # TODO: need to split using random_split once updated to torch >= 1.6 - if stage in ["fit", "validate"] or stage is None: + if stage != 'test': self.mnist_train = MNIST( self.data_dir, train=True, normalize=(0.1307, 0.3081) ) From a0984891df8f76f8e6533bb62c9dd446f27948e6 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Sat, 21 Nov 2020 17:50:36 +0100 Subject: [PATCH 30/35] Update deprecation warnings --- pytorch_lightning/trainer/trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8e15c17374c66..f88f8f8639bdd 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -885,16 +885,16 @@ def __evaluate_given_model(self, model, dataloaders, stage: str): @property def testing(self): warnings.warn( - 'Trainer.testing is deprecated, use Trainer.evaluating instead.', - FutureWarning, stacklevel=2 + 'Trainer.testing has been deprecated in v1.1 and will be removed in v1.3, use Trainer.evaluating instead.', + DeprecationWarning, stacklevel=2 ) return bool(self.evaluating) @property def tested_ckpt_path(self): warnings.warn( - 'Trainer.tested_ckpt_path is deprecated and has been replaced by Trainer.evaluated_ckpt_path.', - FutureWarning, stacklevel=2 + 'Trainer.tested_ckpt_path has been renamed Trainer.evaluated_ckpt_path in v1.1 and will be removed in v1.3.', + DeprecationWarning, stacklevel=2 ) return self.evaluated_ckpt_path From f8ab3910a7c7f684b74d31aa87d6d77931f96441 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Sat, 21 Nov 2020 17:54:39 +0100 Subject: [PATCH 31/35] Update Trainer.{validate, test} docstrings --- pytorch_lightning/trainer/trainer.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f88f8f8639bdd..2116d624ba36a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -716,25 +716,21 @@ def validate( datamodule: Optional[LightningDataModule] = None, ): r""" - Perform one evaluation epoch over the validation set. Args: - ckpt_path: Either ``best`` or path to the checkpoint you wish to evaluate. + ckpt_path: Either ``best`` or path to the checkpoint you wish to validate. If ``None``, use the current weights of the model. Default to ``best``. - datamodule: A instance of :class:`LightningDataModule`. - model: The model to evaluate. - val_dataloaders: Either a single PyTorch DataLoader or a list of them, specifying validation samples. - - verbose: If True, prints the evaluation results + verbose: If True, prints the validation results. Returns: - The final validation results dictionary. If no validation_epoch_end - is defined, returns a list of dictionaries. + The dictionary with final validation results returned by validation_epoch_end. + If validation_epoch_end is not defined, the output is a list of the dictionaries + returned by validation_step. """ # -------------------- # SETUP HOOK @@ -770,26 +766,22 @@ def test( datamodule: Optional[LightningDataModule] = None, ): r""" - Perform one evaluation epoch over the test set. It's separated from fit to make sure you never run on your test set until you want to. Args: - ckpt_path: Either ``best`` or path to the checkpoint you wish to evaluate. + ckpt_path: Either ``best`` or path to the checkpoint you wish to test. If ``None``, use the current weights of the model. Default to ``best``. - datamodule: A instance of :class:`LightningDataModule`. - model: The model to evaluate. - test_dataloaders: Either a single PyTorch DataLoader or a list of them, specifying test samples. - - verbose: If True, prints the evaluation results + verbose: If True, prints the test results. Returns: - The final test results dictionary. If no test_epoch_end is defined, - returns a list of dictionaries. + The dictionary with final test results returned by test_epoch_end. + If test_epoch_end is not defined, the output is a list of the dictionaries + returned by test_step. """ # -------------------- # SETUP HOOK From 605e7b07535239f3054069425558798e2913e504 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Sat, 21 Nov 2020 17:56:22 +0100 Subject: [PATCH 32/35] Fix PEP 8 issue --- pytorch_lightning/trainer/trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2116d624ba36a..9cc7069a77063 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -877,7 +877,8 @@ def __evaluate_given_model(self, model, dataloaders, stage: str): @property def testing(self): warnings.warn( - 'Trainer.testing has been deprecated in v1.1 and will be removed in v1.3, use Trainer.evaluating instead.', + 'Trainer.testing has been deprecated in v1.1 and will be removed ' + 'in v1.3, use Trainer.evaluating instead.', DeprecationWarning, stacklevel=2 ) return bool(self.evaluating) @@ -885,7 +886,8 @@ def testing(self): @property def tested_ckpt_path(self): warnings.warn( - 'Trainer.tested_ckpt_path has been renamed Trainer.evaluated_ckpt_path in v1.1 and will be removed in v1.3.', + 'Trainer.tested_ckpt_path has been renamed Trainer.evaluated_ckpt_path ' + 'in v1.1 and will be removed in v1.3.', DeprecationWarning, stacklevel=2 ) return self.evaluated_ckpt_path From 14a77670242f72a14edfa738f8b0e79cd07cf582 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Mon, 23 Nov 2020 10:57:25 +0100 Subject: [PATCH 33/35] Consistently use the serial comma in docstrings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- pytorch_lightning/callbacks/base.py | 4 ++-- pytorch_lightning/core/datamodule.py | 2 +- pytorch_lightning/core/hooks.py | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 4528409595207..8ca0ef301c260 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -28,11 +28,11 @@ class Callback(abc.ABC): """ def setup(self, trainer, pl_module, stage: str): - """Called when fit, validate or test begins""" + """Called when fit, validate, or test begins""" pass def teardown(self, trainer, pl_module, stage: str): - """Called when fit, validate or test ends""" + """Called when fit, validate, or test ends""" pass def on_init_start(self, trainer): diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index f72ab422e66b5..3ff9f4cf889d4 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -76,7 +76,7 @@ def wrapped_fn(*args, **kwargs): if fn.__name__ == "setup": # Get stage either by grabbing from args or checking kwargs. - # If not provided, set call status of 'fit', 'validation' and 'test' to True. + # If not provided, set call status of 'fit', 'validation', and 'test' to True. # We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test() stage = args[1] if len(args) > 1 else kwargs.get("stage", None) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 38db4bdfc1a7e..7dc6402c316b1 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -33,12 +33,12 @@ class ModelHooks: """Hooks to be used in LightningModule.""" def setup(self, stage: str): """ - Called at the beginning of fit, validate and test. + Called at the beginning of fit (training + validation), validation, and test. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP. Args: - stage: either 'fit', 'validation' or 'test' + stage: either 'fit', 'validation', or 'test' Example:: @@ -61,10 +61,10 @@ def setup(stage): def teardown(self, stage: str): """ - Called at the end of fit, validate and test. + Called at the end of fit (training + validation), validation, and test. Args: - stage: either 'fit', 'validation' or 'test' + stage: either 'fit', 'validation', or 'test' """ def on_fit_start(self): From 6f2ce28d48edc8eddc5df634bc2a7f379e5b8114 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Wed, 25 Nov 2020 12:05:06 +0100 Subject: [PATCH 34/35] Fix PEP 8 issue --- pytorch_lightning/callbacks/progress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 40ed147fc8bda..b00dca548671f 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -356,7 +356,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if self._should_update(self.val_batch_idx, self.total_val_batches): self._update_bar(self.val_progress_bar) - + # The main progress bar doesn't exist in trainer.validate(...) if self.main_progress_bar is not None: self._update_bar(self.main_progress_bar) From d4cb1b052bed11ef223478396a2a2b4938a635b6 Mon Sep 17 00:00:00 2001 From: Elia Cereda Date: Wed, 2 Dec 2020 11:52:32 +0100 Subject: [PATCH 35/35] Rewrite assertions for Trainer.validate in test_callbacks.py using MagicMock --- tests/callbacks/test_callbacks.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index bb740b1dcbb1c..6f427afef7728 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -101,6 +101,28 @@ def test_trainer_callback_system(torch_save): call.teardown(trainer, model, 'fit'), ] + callback_mock.reset_mock() + trainer = Trainer(**trainer_options) + trainer.validate(model) + + assert callback_mock.method_calls == [ + call.on_init_start(trainer), + call.on_init_end(trainer), + call.setup(trainer, model, 'validation'), + call.on_fit_start(trainer, model), + call.on_pretrain_routine_start(trainer, model), + call.on_pretrain_routine_end(trainer, model), + call.on_validation_start(trainer, model), + call.on_validation_epoch_start(trainer, model), + call.on_validation_batch_start(trainer, model, ANY, 0, 0), + call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), + call.on_validation_epoch_end(trainer, model), + call.on_validation_end(trainer, model), + call.on_fit_end(trainer, model), + call.teardown(trainer, model, 'fit'), + call.teardown(trainer, model, 'validation'), + ] + callback_mock.reset_mock() trainer = Trainer(**trainer_options) trainer.test(model)