diff --git a/CHANGELOG.md b/CHANGELOG.md index 30f0d9b1a6e2b..3bdfa4f3b3d80 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,6 +55,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed `on_train_batch_start` hook to end epoch early ([#3700](https://github.com/PyTorchLightning/pytorch-lightning/pull/3700)) + - Fixed `num_sanity_val_steps` is clipped to `limit_val_batches` ([#2917](https://github.com/PyTorchLightning/pytorch-lightning/pull/2917)) - Fixed RMSLE metric ([#3188](https://github.com/PyTorchLightning/pytorch-lightning/pull/3188)) diff --git a/docs/source/early_stopping.rst b/docs/source/early_stopping.rst index 3ee903b7a3179..4d5d5e692bb87 100644 --- a/docs/source/early_stopping.rst +++ b/docs/source/early_stopping.rst @@ -10,7 +10,7 @@ Early stopping Stopping an epoch early ----------------------- -You can stop an epoch early by overriding :meth:`~pytorch_lightning.core.lightning.LightningModule.on_batch_start` to return ``-1`` when some condition is met. +You can stop an epoch early by overriding :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_train_batch_start` to return ``-1`` when some condition is met. If you do this repeatedly, for every epoch you had originally requested, then this will stop your entire run. diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 428e081a5cab1..e4443b0efd174 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -517,6 +517,10 @@ def run_training_epoch(self): # ------------------------------------ batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx) + # when returning -1 from train_step, we end epoch early + if batch_output.signal == -1: + break + # only track outputs when user implements training_epoch_end # otherwise we will build up unnecessary memory epoch_end_outputs = self.process_train_step_outputs( @@ -529,9 +533,6 @@ def run_training_epoch(self): # TODO: add outputs to batches self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx) - # when returning -1 from train_step, we end epoch early - self.trainer.should_stop = batch_output.signal == -1 - # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 0edd1b9350c60..3681484a8bcc8 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -108,3 +108,24 @@ def transfer_batch_to_device(self, data, device): expected = torch.device('cuda', 0) assert model.hook_called assert batch_gpu.samples.device == batch_gpu.targets.device == expected + + +@pytest.mark.parametrize( + 'max_epochs,batch_idx_', + [(2, 5), (3, 8), (4, 12)] +) +def test_on_train_batch_start_hook(max_epochs, batch_idx_): + class CurrentModel(EvalModelTemplate): + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + if batch_idx == batch_idx_: + return -1 + + model = CurrentModel() + trainer = Trainer(max_epochs=max_epochs) + trainer.fit(model) + if batch_idx_ > len(model.val_dataloader()) - 1: + assert trainer.batch_idx == len(model.val_dataloader()) - 1 + assert trainer.global_step == len(model.val_dataloader()) * max_epochs + else: + assert trainer.batch_idx == batch_idx_ + assert trainer.global_step == (batch_idx_ + 1) * max_epochs