From 6be90fda5335d27cb26358c702b344a21b4c710f Mon Sep 17 00:00:00 2001 From: ydcjeff Date: Mon, 28 Sep 2020 16:29:19 +0630 Subject: [PATCH 1/5] init --- pytorch_lightning/trainer/training_loop.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 3efb51953d026..035743156a70b 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -517,6 +517,10 @@ def run_training_epoch(self): # ------------------------------------ batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx) + self.trainer.should_stop = batch_output.signal == -1 + + if self.trainer.should_stop: + break # only track outputs when user implements training_epoch_end # otherwise we will build up unnecessary memory epoch_end_outputs = self.process_train_step_outputs( @@ -530,7 +534,6 @@ def run_training_epoch(self): self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx) # when returning -1 from train_step, we end epoch early - self.trainer.should_stop = batch_output.signal == -1 # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK @@ -561,8 +564,6 @@ def run_training_epoch(self): # end epoch early # stop when the flag is changed or we've gone past the amount # requested in the batches - if self.trainer.should_stop: - break self.trainer.total_batch_idx += 1 From a0702c7bfa7e47ca8f48ab941af33e2033ac5fe9 Mon Sep 17 00:00:00 2001 From: ydcjeff Date: Mon, 28 Sep 2020 19:06:52 +0630 Subject: [PATCH 2/5] add test --- pytorch_lightning/trainer/training_loop.py | 10 ++++----- tests/models/test_hooks.py | 25 ++++++++++++++++++++++ 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 035743156a70b..1d1cb3ad93e03 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -517,10 +517,10 @@ def run_training_epoch(self): # ------------------------------------ batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx) - self.trainer.should_stop = batch_output.signal == -1 - - if self.trainer.should_stop: + # when returning -1 from train_step, we end epoch early + if batch_output.signal == -1: break + # only track outputs when user implements training_epoch_end # otherwise we will build up unnecessary memory epoch_end_outputs = self.process_train_step_outputs( @@ -533,8 +533,6 @@ def run_training_epoch(self): # TODO: add outputs to batches self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx) - # when returning -1 from train_step, we end epoch early - # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- @@ -564,6 +562,8 @@ def run_training_epoch(self): # end epoch early # stop when the flag is changed or we've gone past the amount # requested in the batches + if self.trainer.should_stop: + break self.trainer.total_batch_idx += 1 diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 0edd1b9350c60..4441c2014035e 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -108,3 +108,28 @@ def transfer_batch_to_device(self, data, device): expected = torch.device('cuda', 0) assert model.hook_called assert batch_gpu.samples.device == batch_gpu.targets.device == expected + + +@pytest.mark.parametrize( + ['max_epochs', 'batch_idx_'], + [ + pytest.param(2, 5), + pytest.param(3, 8), + pytest.param(4, 12) + ] +) +def test_on_train_batch_start_hook(max_epochs, batch_idx_): + class CurrentModel(EvalModelTemplate): + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + if batch_idx == batch_idx_: + return -1 + + model = CurrentModel() + trainer = Trainer(max_epochs=max_epochs) + trainer.fit(model) + if batch_idx_ > 9: + assert trainer.batch_idx == len(model.val_dataloader()) - 1 + assert trainer.global_step == len(model.val_dataloader()) * max_epochs + else: + assert trainer.batch_idx == batch_idx_ + assert trainer.global_step == (batch_idx_ + 1) * max_epochs From e3462c50e33d1bfd61d152f2f9c47931932aa4e3 Mon Sep 17 00:00:00 2001 From: ydcjeff Date: Mon, 28 Sep 2020 19:12:28 +0630 Subject: [PATCH 3/5] changelog and docs --- CHANGELOG.md | 2 ++ docs/source/early_stopping.rst | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 10022219857ee..c0850e24d5c99 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed `on_train_batch_start` hook to end epoch early ([#3700](https://github.com/PyTorchLightning/pytorch-lightning/pull/3700)) + - Fixed `num_sanity_val_steps` is clipped to `limit_val_batches` ([#2917](https://github.com/PyTorchLightning/pytorch-lightning/pull/2917)) - Fixed RMSLE metric ([#3188](https://github.com/PyTorchLightning/pytorch-lightning/pull/3188)) diff --git a/docs/source/early_stopping.rst b/docs/source/early_stopping.rst index 3ee903b7a3179..4d5d5e692bb87 100644 --- a/docs/source/early_stopping.rst +++ b/docs/source/early_stopping.rst @@ -10,7 +10,7 @@ Early stopping Stopping an epoch early ----------------------- -You can stop an epoch early by overriding :meth:`~pytorch_lightning.core.lightning.LightningModule.on_batch_start` to return ``-1`` when some condition is met. +You can stop an epoch early by overriding :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_train_batch_start` to return ``-1`` when some condition is met. If you do this repeatedly, for every epoch you had originally requested, then this will stop your entire run. From 9d2b0a7728c3c8f51fd38ffca121c1edbf7d8c5b Mon Sep 17 00:00:00 2001 From: ydcjeff Date: Mon, 28 Sep 2020 19:14:12 +0630 Subject: [PATCH 4/5] fix test --- tests/models/test_hooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 4441c2014035e..65914a8ad7f75 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -127,7 +127,7 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx): model = CurrentModel() trainer = Trainer(max_epochs=max_epochs) trainer.fit(model) - if batch_idx_ > 9: + if batch_idx_ > len(model.val_dataloader()) - 1: assert trainer.batch_idx == len(model.val_dataloader()) - 1 assert trainer.global_step == len(model.val_dataloader()) * max_epochs else: From dcdbaee878cdebef7ae5bb6fd084ec19f7cebf8a Mon Sep 17 00:00:00 2001 From: Jeff Yang Date: Sat, 3 Oct 2020 00:17:22 +0630 Subject: [PATCH 5/5] Apply suggestion from code review Co-authored-by: Jirka Borovec --- tests/models/test_hooks.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 65914a8ad7f75..3681484a8bcc8 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -111,12 +111,8 @@ def transfer_batch_to_device(self, data, device): @pytest.mark.parametrize( - ['max_epochs', 'batch_idx_'], - [ - pytest.param(2, 5), - pytest.param(3, 8), - pytest.param(4, 12) - ] + 'max_epochs,batch_idx_', + [(2, 5), (3, 8), (4, 12)] ) def test_on_train_batch_start_hook(max_epochs, batch_idx_): class CurrentModel(EvalModelTemplate):