From b15a75f729a0885c0380b6f270eb8bcdc12f5fe5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 25 May 2021 14:24:00 +0200 Subject: [PATCH 1/2] Remove on epoch guard from the should stop validation check --- .../callbacks/gpu_stats_monitor.py | 4 +-- pytorch_lightning/callbacks/lr_monitor.py | 4 +-- pytorch_lightning/trainer/training_loop.py | 20 +++--------- tests/trainer/loops/test_training_loop.py | 32 +++++++++++++++++++ 4 files changed, 38 insertions(+), 22 deletions(-) diff --git a/pytorch_lightning/callbacks/gpu_stats_monitor.py b/pytorch_lightning/callbacks/gpu_stats_monitor.py index ffd39e9af4c16..794165fe60812 100644 --- a/pytorch_lightning/callbacks/gpu_stats_monitor.py +++ b/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -211,6 +211,4 @@ def _get_gpu_device_stat_keys(self) -> List[Tuple[str, str]]: @staticmethod def _should_log(trainer) -> bool: - should_log = ((trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop) - - return should_log + return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index 7530bfaa9d21e..410f8b319c239 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -202,6 +202,4 @@ def _find_names(self, lr_schedulers) -> List[str]: @staticmethod def _should_log(trainer) -> bool: - should_log = ((trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop) - - return should_log + return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 1906679a2be8b..ef2245de908e1 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -529,21 +529,9 @@ def run_training_epoch(self): self.total_batch_idx += 1 - # max steps reached, end training - if ( - self.max_steps is not None and self.max_steps <= self.global_step + 1 - and self._accumulated_batches_reached() - ): - break - - # end epoch early - # stop when the flag is changed or we've gone past the amount - # requested in the batches - if self.trainer.should_stop: - break - - # stop epoch if we limited the number of training batches - if self._num_training_batches_reached(is_last_batch): + max_steps_reached = self.max_steps is not None and self.max_steps <= self.global_step + 1 and self._accumulated_batches_reached( + ) + if max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(is_last_batch): break # progress global step according to grads progress @@ -906,7 +894,7 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bo if on_epoch and is_last_batch and is_infinite_dataset: return True - if on_epoch and self.trainer.should_stop: + if self.trainer.should_stop: return True # TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch diff --git a/tests/trainer/loops/test_training_loop.py b/tests/trainer/loops/test_training_loop.py index b89909a40fdf3..2e17f57ec9479 100644 --- a/tests/trainer/loops/test_training_loop.py +++ b/tests/trainer/loops/test_training_loop.py @@ -110,3 +110,35 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx): else: assert trainer.train_loop.batch_idx == batch_idx_ assert trainer.global_step == batch_idx_ * max_epochs + + +def test_should_stop_mid_epoch(tmpdir): + """Test that training correctly stops mid epoch and that validation is still called at the right time""" + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.validation_called_at = None + + def training_step(self, batch, batch_idx): + if batch_idx == 4: + self.trainer.should_stop = True + return super().training_step(batch, batch_idx) + + def validation_step(self, *args): + self.validation_called_at = (self.trainer.current_epoch, self.trainer.global_step) + return super().validation_step(*args) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=10, + limit_val_batches=1, + ) + trainer.fit(model) + + assert trainer.current_epoch == 0 + assert trainer.global_step == 5 + assert model.validation_called_at == (0, 4) # TODO(@carmocca): should be 5 - will be fixed in next PR From f4b305c48c7d1b80d85c01d92d14ae53ad10c6d0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 25 May 2021 14:25:37 +0200 Subject: [PATCH 2/2] Formatting --- pytorch_lightning/trainer/training_loop.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ef2245de908e1..62138790138ee 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -529,7 +529,9 @@ def run_training_epoch(self): self.total_batch_idx += 1 - max_steps_reached = self.max_steps is not None and self.max_steps <= self.global_step + 1 and self._accumulated_batches_reached( + max_steps_reached = ( + self.max_steps is not None and self.max_steps <= self.global_step + 1 + and self._accumulated_batches_reached() ) if max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(is_last_batch): break