diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ccb83f41ca84..14dbee08920be 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -127,6 +127,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed global step update when the epoch is skipped ([#7677](https://github.com/PyTorchLightning/pytorch-lightning/pull/7677)) +- Fixed training loop total batch counter when accumulate grad batches was enabled ([#7692](https://github.com/PyTorchLightning/pytorch-lightning/pull/7692)) + + - Fixed broadcasting in multi-node, multi-gpu DDP using torch 1.7 ([#7592](https://github.com/PyTorchLightning/pytorch-lightning/pull/7592)) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ca50b088c665b..10b727edbc93a 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -529,6 +529,8 @@ def run_training_epoch(self): self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics) self.trainer.checkpoint_connector.has_trained = True + self.total_batch_idx += 1 + # max steps reached, end training if ( self.max_steps is not None and self.max_steps <= self.global_step + 1 @@ -542,8 +544,6 @@ def run_training_epoch(self): if self.trainer.should_stop: break - self.total_batch_idx += 1 - # stop epoch if we limited the number of training batches if self._num_training_batches_reached(is_last_batch): break diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py index 608cb8c6778bf..a74af3862c473 100644 --- a/tests/tuner/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -197,31 +197,24 @@ def test_datamodule_parameter(tmpdir): def test_accumulation_and_early_stopping(tmpdir): - """ Test that early stopping of learning rate finder works, and that - accumulation also works for this feature """ + """ Test that early stopping of learning rate finder works, and that accumulation also works for this feature """ - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) + class TestModel(BoringModel): - before_lr = hparams.get('learning_rate') - # logger file to get meta + def __init__(self): + super().__init__() + self.lr = 1e-3 + + model = TestModel() trainer = Trainer( default_root_dir=tmpdir, accumulate_grad_batches=2, ) - lrfinder = trainer.tuner.lr_find(model, early_stop_threshold=None) - after_lr = lrfinder.suggestion() - expected_num_lrs = 100 - expected_batch_idx = 200 - 1 - - assert before_lr != after_lr, \ - 'Learning rate was not altered after running learning rate finder' - assert len(lrfinder.results['lr']) == expected_num_lrs, \ - 'Early stopping for learning rate finder did not work' - assert lrfinder._total_batch_idx == expected_batch_idx, \ - 'Accumulation parameter did not work' + assert lrfinder.suggestion() != 1e-3 + assert len(lrfinder.results['lr']) == 100 + assert lrfinder._total_batch_idx == 200 def test_suggestion_parameters_work(tmpdir):