From bed1fe3a80e3089d7a5059b4efc7cc32a45ab277 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Sat, 27 Nov 2021 04:43:18 +0100 Subject: [PATCH 1/3] Do not sanity check on reload --- pytorch_lightning/trainer/trainer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b417d40484028..64abc423c838d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1307,7 +1307,13 @@ def _run_sanity_check(self, ref_model): using_val_step = self._data_connector._val_dataloader_source.is_defined() and is_overridden( "validation_step", ref_model ) - should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 + should_sanity_check = ( + using_val_step + and self.num_sanity_val_steps > 0 + and self.limit_val_batches > 0 + # do not sanity check if restarting because it would mess up the loaded state + and not self._evaluation_loop.restarting + ) # run tiny validation (if validation defined) # to make sure program won't crash during val From 3fb05158d07cb8c9bc830c4d76ad79d964f3b7d8 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 29 Nov 2021 14:54:12 +0100 Subject: [PATCH 2/3] Updaate and use enable_validation --- pytorch_lightning/trainer/trainer.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 64abc423c838d..be3edc05fb792 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1304,13 +1304,9 @@ def _run_predict(self) -> Optional[_PREDICT_OUTPUT]: return self.predict_loop.run() def _run_sanity_check(self, ref_model): - using_val_step = self._data_connector._val_dataloader_source.is_defined() and is_overridden( - "validation_step", ref_model - ) should_sanity_check = ( - using_val_step + self.enable_validation and self.num_sanity_val_steps > 0 - and self.limit_val_batches > 0 # do not sanity check if restarting because it would mess up the loaded state and not self._evaluation_loop.restarting ) @@ -1787,8 +1783,9 @@ def _should_reload_dl_epoch(self) -> bool: def enable_validation(self) -> bool: """Check if we should run validation during training.""" model_ref = self.lightning_module - val_loop_enabled = is_overridden("validation_step", model_ref) and self.limit_val_batches > 0 - return val_loop_enabled + val_dataloader_defined = self._data_connector._val_dataloader_source.is_defined() + val_step_overridden = is_overridden("validation_step", model_ref) + return val_dataloader_defined and val_step_overridden and self.limit_val_batches > 0 @property def default_root_dir(self) -> str: From b5a5a759b954705883cf6cdb3e9fbbc0799cb3cc Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 29 Nov 2021 14:55:28 +0100 Subject: [PATCH 3/3] Refactor --- pytorch_lightning/trainer/trainer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index be3edc05fb792..0376a0f745f6f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1782,10 +1782,11 @@ def _should_reload_dl_epoch(self) -> bool: @property def enable_validation(self) -> bool: """Check if we should run validation during training.""" - model_ref = self.lightning_module - val_dataloader_defined = self._data_connector._val_dataloader_source.is_defined() - val_step_overridden = is_overridden("validation_step", model_ref) - return val_dataloader_defined and val_step_overridden and self.limit_val_batches > 0 + return ( + self._data_connector._val_dataloader_source.is_defined() + and is_overridden("validation_step", self.lightning_module) + and self.limit_val_batches > 0 + ) @property def default_root_dir(self) -> str: