Skip to content

Commit 8bf5b71

Browse files
awaelchlilexierule
authored andcommitted
Increment the total batch idx before the accumulation early exit (#7692)
* Increment the total batch idx before the accumulation early exit * Update CHANGELOG
1 parent a1376ed commit 8bf5b71

File tree

3 files changed

+13
-19
lines changed

3 files changed

+13
-19
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1818
- Fixed dataloaders are not reset when tuning the model ([#7566](https://github.com/PyTorchLightning/pytorch-lightning/pull/7566))
1919
- Fixed print errors in `ProgressBar` when `trainer.fit` is not called ([#7674](https://github.com/PyTorchLightning/pytorch-lightning/pull/7674))
2020
- Fixed global step update when the epoch is skipped ([#7677](https://github.com/PyTorchLightning/pytorch-lightning/pull/7677))
21+
- Fixed training loop total batch counter when accumulate grad batches was enabled ([#7692](https://github.com/PyTorchLightning/pytorch-lightning/pull/7692))
2122

2223
## [1.3.2] - 2021-05-18
2324

pytorch_lightning/trainer/training_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,8 @@ def run_training_epoch(self):
526526
self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics)
527527
self.trainer.checkpoint_connector.has_trained = True
528528

529+
self.trainer.total_batch_idx += 1
530+
529531
# max steps reached, end training
530532
if (
531533
self.trainer.max_steps is not None and self.trainer.max_steps <= self.trainer.global_step + 1
@@ -539,8 +541,6 @@ def run_training_epoch(self):
539541
if self.trainer.should_stop:
540542
break
541543

542-
self.trainer.total_batch_idx += 1
543-
544544
# stop epoch if we limited the number of training batches
545545
if self._num_training_batches_reached(is_last_batch):
546546
break

tests/tuner/test_lr_finder.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -197,31 +197,24 @@ def test_datamodule_parameter(tmpdir):
197197

198198

199199
def test_accumulation_and_early_stopping(tmpdir):
200-
""" Test that early stopping of learning rate finder works, and that
201-
accumulation also works for this feature """
200+
""" Test that early stopping of learning rate finder works, and that accumulation also works for this feature """
202201

203-
hparams = EvalModelTemplate.get_default_hparams()
204-
model = EvalModelTemplate(**hparams)
202+
class TestModel(BoringModel):
205203

206-
before_lr = hparams.get('learning_rate')
207-
# logger file to get meta
204+
def __init__(self):
205+
super().__init__()
206+
self.lr = 1e-3
207+
208+
model = TestModel()
208209
trainer = Trainer(
209210
default_root_dir=tmpdir,
210211
accumulate_grad_batches=2,
211212
)
212-
213213
lrfinder = trainer.tuner.lr_find(model, early_stop_threshold=None)
214-
after_lr = lrfinder.suggestion()
215214

216-
expected_num_lrs = 100
217-
expected_batch_idx = 200 - 1
218-
219-
assert before_lr != after_lr, \
220-
'Learning rate was not altered after running learning rate finder'
221-
assert len(lrfinder.results['lr']) == expected_num_lrs, \
222-
'Early stopping for learning rate finder did not work'
223-
assert lrfinder._total_batch_idx == expected_batch_idx, \
224-
'Accumulation parameter did not work'
215+
assert lrfinder.suggestion() != 1e-3
216+
assert len(lrfinder.results['lr']) == 100
217+
assert lrfinder._total_batch_idx == 200
225218

226219

227220
def test_suggestion_parameters_work(tmpdir):

0 commit comments

Comments
 (0)