Skip to content

Commit 523c59b

Browse files
Palzertchatoncarmoccaawaelchli
authored
fixed bug where tuner would not tune lr if also tuning batch_size (#4688)
* fixed bug where tuner would not tune lr if also tuning batch_size * added a '+1' to computing the smoothed loss. This maintains the behavior for the smoothed loss as before the bug fix * pep8 fix * add changelog Co-authored-by: chaton <[email protected]> Co-authored-by: Carlos Mocholi <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 9eded7f commit 523c59b

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
131131
- Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372))
132132

133133

134+
- Fixed an issue where the tuner would not tune the learning rate if also tuning the batch size ([#4688](https://github.com/PyTorchLightning/pytorch-lightning/pull/4688))
135+
136+
134137
## [1.2.2] - 2021-03-02
135138

136139
### Added

pytorch_lightning/tuner/lr_finder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,11 +418,11 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data
418418
self.progress_bar.update()
419419

420420
current_loss = trainer.train_loop.running_loss.last().item()
421-
current_step = trainer.global_step + 1 # remove the +1 in 1.0
421+
current_step = trainer.global_step
422422

423423
# Avg loss (loss with momentum) + smoothing
424424
self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * current_loss
425-
smoothed_loss = self.avg_loss / (1 - self.beta**current_step)
425+
smoothed_loss = self.avg_loss / (1 - self.beta**(current_step + 1))
426426

427427
# Check if we diverging
428428
if self.early_stop_threshold is not None:

0 commit comments

Comments
 (0)