Skip to content

Commit 0631827

Browse files
Palzertchatoncarmoccaawaelchli
authored andcommitted
fixed bug where tuner would not tune lr if also tuning batch_size (#4688)
* fixed bug where tuner would not tune lr if also tuning batch_size * added a '+1' to computing the smoothed loss. This maintains the behavior for the smoothed loss as before the bug fix * pep8 fix * add changelog Co-authored-by: chaton <[email protected]> Co-authored-by: Carlos Mocholi <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> (cherry picked from commit 523c59b)
1 parent e719d60 commit 0631827

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2323
- Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372))
2424

2525

26+
- Fixed an issue where the tuner would not tune the learning rate if also tuning the batch size ([#4688](https://github.com/PyTorchLightning/pytorch-lightning/pull/4688))
27+
28+
2629
## [1.2.2] - 2021-03-02
2730

2831
### Added

pytorch_lightning/tuner/lr_finder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,11 +412,11 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data
412412
self.progress_bar.update()
413413

414414
current_loss = trainer.train_loop.running_loss.last().item()
415-
current_step = trainer.global_step + 1 # remove the +1 in 1.0
415+
current_step = trainer.global_step
416416

417417
# Avg loss (loss with momentum) + smoothing
418418
self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * current_loss
419-
smoothed_loss = self.avg_loss / (1 - self.beta**current_step)
419+
smoothed_loss = self.avg_loss / (1 - self.beta**(current_step + 1))
420420

421421
# Check if we diverging
422422
if self.early_stop_threshold is not None:

0 commit comments

Comments
 (0)