Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed main progress bar counter when `val_check_interval=int` and `check_val_every_n_epoch=None` ([#12832](https://github.com/Lightning-AI/lightning/pull/12832)


- Used `global_step` while restoring logging step for old checkpoints ([#13645](https://github.com/Lightning-AI/lightning/pull/13645))


- Fixed error handling in learning rate finder when not enough data points are available to give a good suggestion ([#13845](https://github.com/Lightning-AI/lightning/pull/13845))


- Fixed an issue that caused the learning rate finder to set the model's learning rate to None when no suggestion was possible ([#13845](https://github.com/Lightning-AI/lightning/pull/13845))



## [1.6.5] - 2022-07-13

### Fixed
Expand All @@ -386,9 +396,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed the restoration of log step during restart ([#13467](https://github.com/PyTorchLightning/pytorch-lightning/pull/13467))


- Used `global_step` while restoring logging step for old checkpoints ([#13645](https://github.com/PyTorchLightning/pytorch-lightning/pull/13645))


## [1.6.4] - 2022-06-01

### Added
Expand Down
40 changes: 25 additions & 15 deletions src/pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,24 +174,33 @@ def plot(self, suggest: bool = False, show: bool = False) -> Optional["plt.Figur
return fig

def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]:
"""This will propose a suggestion for choice of initial learning rate as the point with the steepest
"""This will propose a suggestion for an initial learning rate based on the point with the steepest
negative gradient.

Args:
skip_begin: how many samples to skip in the beginning; helps to avoid too naive estimates
skip_end: how many samples to skip in the end; helps to avoid too optimistic estimates

Returns:
lr: suggested initial learning rate to use
skip_begin: how many samples to skip in the beginning. Prevent too naive estimates
skip_end: how many samples to skip in the end. Prevent too optimistic estimates
The suggested initial learning rate to use, or `None` if a suggestion is not possible due to too few
loss samples.
"""
try:
loss = np.array(self.results["loss"][skip_begin:-skip_end])
loss = loss[np.isfinite(loss)]
min_grad = np.gradient(loss).argmin()
self._optimal_idx = min_grad + skip_begin
return self.results["lr"][self._optimal_idx]
# todo: specify the possible exception
except Exception:
log.exception("Failed to compute suggesting for `lr`. There might not be enough points.")
losses = np.array(self.results["loss"][skip_begin:-skip_end])
losses = losses[np.isfinite(losses)]
if len(losses) < 2:
# computing np.gradient requires at least 2 points
log.error(
"Failed to compute suggestion for learning rate because there are not enough points. Increase the loop"
" iteration limits or the size of your dataset/dataloader."
)
self._optimal_idx = None
return None

# TODO: When computing the argmin here, and some losses are non-finite, the expected indices could be
# incorrectly shifted by an offset
min_grad = np.gradient(losses).argmin()
self._optimal_idx = min_grad + skip_begin
return self.results["lr"][self._optimal_idx]


def lr_find(
Expand Down Expand Up @@ -252,8 +261,9 @@ def lr_find(
lr = lr_finder.suggestion()

# TODO: log lr.results to self.logger
lightning_setattr(model, lr_attr_name, lr)
log.info(f"Learning rate set to {lr}")
if lr is not None:
lightning_setattr(model, lr_attr_name, lr)
log.info(f"Learning rate set to {lr}")

# Restore initial state of model
trainer._checkpoint_connector.restore(ckpt_path)
Expand Down
54 changes: 54 additions & 0 deletions tests/tests_pytorch/tuner/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from copy import deepcopy

Expand All @@ -19,6 +20,7 @@

from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.tuner.lr_finder import _LRFinder
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.simple_models import ClassificationModel
Expand Down Expand Up @@ -359,3 +361,55 @@ def test_multiple_lr_find_calls_gives_same_results(tmpdir):
for curr_lr_finder in all_res[1:]
for k in all_res[0].keys()
)


@pytest.mark.parametrize(
"skip_begin,skip_end,losses,expected_error",
[
(0, 0, [], True),
(10, 1, [], True),
(0, 2, [0, 1, 2], True),
(0, 1, [0, 1, 2], False),
(1, 1, [0, 1, 2], True),
(1, 1, [0, 1, 2, 3], False),
(0, 1, [float("nan"), float("nan"), 0, float("inf"), 1, 2, 3, float("inf"), 2, float("nan"), 1], False),
(4, 1, [float("nan"), float("nan"), 0, float("inf"), 1, 2, 3, float("inf"), 2, float("nan"), 1], False),
],
)
def test_suggestion_not_enough_finite_points(losses, skip_begin, skip_end, expected_error, caplog):
"""Tests the error handling when not enough finite points are available to make a suggestion."""
caplog.clear()
lr_finder = _LRFinder(
mode="exponential",
lr_min=1e-8,
lr_max=1,
num_training=100,
)
lrs = list(torch.arange(len(losses)))
lr_finder.results = {
"lr": lrs,
"loss": losses,
}
with caplog.at_level(logging.ERROR, logger="root.tuner.lr_finder"):
lr = lr_finder.suggestion(skip_begin=skip_begin, skip_end=skip_end)

if expected_error:
assert lr is None
assert "Failed to compute suggestion for learning rate" in caplog.text
else:
assert lr is not None


def test_lr_attribute_when_suggestion_invalid(tmpdir):
"""Tests learning rate finder ends before `num_training` steps."""

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.learning_rate = 0.123

model = TestModel()
trainer = Trainer(default_root_dir=tmpdir)
lr_finder = trainer.tuner.lr_find(model=model, update_attr=True, num_training=1) # force insufficient data points
assert lr_finder.suggestion() is None
assert model.learning_rate == 0.123 # must remain unchanged because suggestion is not possible