From fa31597073a56e6d86cd928cb19dd4d2d54566f2 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 25 May 2021 01:47:00 +0200 Subject: [PATCH 1/2] Increment the total batch idx before the accumulation early exit --- pytorch_lightning/trainer/training_loop.py | 4 ++-- tests/tuner/test_lr_finder.py | 27 ++++++++-------------- 2 files changed, 12 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ca50b088c665b..10b727edbc93a 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -529,6 +529,8 @@ def run_training_epoch(self): self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics) self.trainer.checkpoint_connector.has_trained = True + self.total_batch_idx += 1 + # max steps reached, end training if ( self.max_steps is not None and self.max_steps <= self.global_step + 1 @@ -542,8 +544,6 @@ def run_training_epoch(self): if self.trainer.should_stop: break - self.total_batch_idx += 1 - # stop epoch if we limited the number of training batches if self._num_training_batches_reached(is_last_batch): break diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py index 608cb8c6778bf..d4e5ef1862020 100644 --- a/tests/tuner/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -197,31 +197,24 @@ def test_datamodule_parameter(tmpdir): def test_accumulation_and_early_stopping(tmpdir): - """ Test that early stopping of learning rate finder works, and that - accumulation also works for this feature """ + """ Test that early stopping of learning rate finder works, and that accumulation also works for this feature """ - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) + class TestModel(BoringModel): - before_lr = hparams.get('learning_rate') - # logger file to get meta + def __init__(self): + super().__init__() + self.lr = 1e-3 + + model = TestModel() trainer = Trainer( default_root_dir=tmpdir, accumulate_grad_batches=2, ) - lrfinder = trainer.tuner.lr_find(model, early_stop_threshold=None) - after_lr = lrfinder.suggestion() - expected_num_lrs = 100 - expected_batch_idx = 200 - 1 - - assert before_lr != after_lr, \ - 'Learning rate was not altered after running learning rate finder' - assert len(lrfinder.results['lr']) == expected_num_lrs, \ - 'Early stopping for learning rate finder did not work' - assert lrfinder._total_batch_idx == expected_batch_idx, \ - 'Accumulation parameter did not work' + assert 1e-3 != lrfinder.suggestion() + assert len(lrfinder.results['lr']) == 100 + assert lrfinder._total_batch_idx == 200 def test_suggestion_parameters_work(tmpdir): From 64c49c11346cc383af36c70eca8ba9424e45e663 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 25 May 2021 01:52:12 +0200 Subject: [PATCH 2/2] Update CHANGELOG --- CHANGELOG.md | 3 +++ tests/tuner/test_lr_finder.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ccb83f41ca84..14dbee08920be 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -127,6 +127,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed global step update when the epoch is skipped ([#7677](https://github.com/PyTorchLightning/pytorch-lightning/pull/7677)) +- Fixed training loop total batch counter when accumulate grad batches was enabled ([#7692](https://github.com/PyTorchLightning/pytorch-lightning/pull/7692)) + + - Fixed broadcasting in multi-node, multi-gpu DDP using torch 1.7 ([#7592](https://github.com/PyTorchLightning/pytorch-lightning/pull/7592)) diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py index d4e5ef1862020..a74af3862c473 100644 --- a/tests/tuner/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -212,7 +212,7 @@ def __init__(self): ) lrfinder = trainer.tuner.lr_find(model, early_stop_threshold=None) - assert 1e-3 != lrfinder.suggestion() + assert lrfinder.suggestion() != 1e-3 assert len(lrfinder.results['lr']) == 100 assert lrfinder._total_batch_idx == 200