Skip to content

Commit c391170

Browse files
awaelchlirohitgr7
andauthored
Fix error handling in learning rate finder (#13845)
Co-authored-by: Rohit Gupta <[email protected]>
1 parent a90ef3b commit c391170

File tree

3 files changed

+89
-18
lines changed

3 files changed

+89
-18
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
379379
- Fixed main progress bar counter when `val_check_interval=int` and `check_val_every_n_epoch=None` ([#12832](https://github.com/Lightning-AI/lightning/pull/12832)
380380

381381

382+
- Used `global_step` while restoring logging step for old checkpoints ([#13645](https://github.com/Lightning-AI/lightning/pull/13645))
383+
384+
385+
- Fixed error handling in learning rate finder when not enough data points are available to give a good suggestion ([#13845](https://github.com/Lightning-AI/lightning/pull/13845))
386+
387+
388+
- Fixed an issue that caused the learning rate finder to set the model's learning rate to None when no suggestion was possible ([#13845](https://github.com/Lightning-AI/lightning/pull/13845))
389+
390+
391+
382392
## [1.6.5] - 2022-07-13
383393

384394
### Fixed
@@ -389,9 +399,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
389399
- Fixed the restoration of log step during restart ([#13467](https://github.com/PyTorchLightning/pytorch-lightning/pull/13467))
390400

391401

392-
- Used `global_step` while restoring logging step for old checkpoints ([#13645](https://github.com/PyTorchLightning/pytorch-lightning/pull/13645))
393-
394-
395402
## [1.6.4] - 2022-06-01
396403

397404
### Added

src/pytorch_lightning/tuner/lr_finder.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -174,24 +174,33 @@ def plot(self, suggest: bool = False, show: bool = False) -> Optional["plt.Figur
174174
return fig
175175

176176
def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]:
177-
"""This will propose a suggestion for choice of initial learning rate as the point with the steepest
177+
"""This will propose a suggestion for an initial learning rate based on the point with the steepest
178178
negative gradient.
179179
180+
Args:
181+
skip_begin: how many samples to skip in the beginning; helps to avoid too naive estimates
182+
skip_end: how many samples to skip in the end; helps to avoid too optimistic estimates
183+
180184
Returns:
181-
lr: suggested initial learning rate to use
182-
skip_begin: how many samples to skip in the beginning. Prevent too naive estimates
183-
skip_end: how many samples to skip in the end. Prevent too optimistic estimates
185+
The suggested initial learning rate to use, or `None` if a suggestion is not possible due to too few
186+
loss samples.
184187
"""
185-
try:
186-
loss = np.array(self.results["loss"][skip_begin:-skip_end])
187-
loss = loss[np.isfinite(loss)]
188-
min_grad = np.gradient(loss).argmin()
189-
self._optimal_idx = min_grad + skip_begin
190-
return self.results["lr"][self._optimal_idx]
191-
# todo: specify the possible exception
192-
except Exception:
193-
log.exception("Failed to compute suggesting for `lr`. There might not be enough points.")
188+
losses = np.array(self.results["loss"][skip_begin:-skip_end])
189+
losses = losses[np.isfinite(losses)]
190+
if len(losses) < 2:
191+
# computing np.gradient requires at least 2 points
192+
log.error(
193+
"Failed to compute suggestion for learning rate because there are not enough points. Increase the loop"
194+
" iteration limits or the size of your dataset/dataloader."
195+
)
194196
self._optimal_idx = None
197+
return None
198+
199+
# TODO: When computing the argmin here, and some losses are non-finite, the expected indices could be
200+
# incorrectly shifted by an offset
201+
min_grad = np.gradient(losses).argmin()
202+
self._optimal_idx = min_grad + skip_begin
203+
return self.results["lr"][self._optimal_idx]
195204

196205

197206
def lr_find(
@@ -252,8 +261,9 @@ def lr_find(
252261
lr = lr_finder.suggestion()
253262

254263
# TODO: log lr.results to self.logger
255-
lightning_setattr(model, lr_attr_name, lr)
256-
log.info(f"Learning rate set to {lr}")
264+
if lr is not None:
265+
lightning_setattr(model, lr_attr_name, lr)
266+
log.info(f"Learning rate set to {lr}")
257267

258268
# Restore initial state of model
259269
trainer._checkpoint_connector.restore(ckpt_path)

tests/tests_pytorch/tuner/test_lr_finder.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import logging
1415
import os
1516
from copy import deepcopy
1617

@@ -19,6 +20,7 @@
1920

2021
from pytorch_lightning import seed_everything, Trainer
2122
from pytorch_lightning.demos.boring_classes import BoringModel
23+
from pytorch_lightning.tuner.lr_finder import _LRFinder
2224
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2325
from tests_pytorch.helpers.datamodules import ClassifDataModule
2426
from tests_pytorch.helpers.simple_models import ClassificationModel
@@ -359,3 +361,55 @@ def test_multiple_lr_find_calls_gives_same_results(tmpdir):
359361
for curr_lr_finder in all_res[1:]
360362
for k in all_res[0].keys()
361363
)
364+
365+
366+
@pytest.mark.parametrize(
367+
"skip_begin,skip_end,losses,expected_error",
368+
[
369+
(0, 0, [], True),
370+
(10, 1, [], True),
371+
(0, 2, [0, 1, 2], True),
372+
(0, 1, [0, 1, 2], False),
373+
(1, 1, [0, 1, 2], True),
374+
(1, 1, [0, 1, 2, 3], False),
375+
(0, 1, [float("nan"), float("nan"), 0, float("inf"), 1, 2, 3, float("inf"), 2, float("nan"), 1], False),
376+
(4, 1, [float("nan"), float("nan"), 0, float("inf"), 1, 2, 3, float("inf"), 2, float("nan"), 1], False),
377+
],
378+
)
379+
def test_suggestion_not_enough_finite_points(losses, skip_begin, skip_end, expected_error, caplog):
380+
"""Tests the error handling when not enough finite points are available to make a suggestion."""
381+
caplog.clear()
382+
lr_finder = _LRFinder(
383+
mode="exponential",
384+
lr_min=1e-8,
385+
lr_max=1,
386+
num_training=100,
387+
)
388+
lrs = list(torch.arange(len(losses)))
389+
lr_finder.results = {
390+
"lr": lrs,
391+
"loss": losses,
392+
}
393+
with caplog.at_level(logging.ERROR, logger="root.tuner.lr_finder"):
394+
lr = lr_finder.suggestion(skip_begin=skip_begin, skip_end=skip_end)
395+
396+
if expected_error:
397+
assert lr is None
398+
assert "Failed to compute suggestion for learning rate" in caplog.text
399+
else:
400+
assert lr is not None
401+
402+
403+
def test_lr_attribute_when_suggestion_invalid(tmpdir):
404+
"""Tests learning rate finder ends before `num_training` steps."""
405+
406+
class TestModel(BoringModel):
407+
def __init__(self):
408+
super().__init__()
409+
self.learning_rate = 0.123
410+
411+
model = TestModel()
412+
trainer = Trainer(default_root_dir=tmpdir)
413+
lr_finder = trainer.tuner.lr_find(model=model, update_attr=True, num_training=1) # force insufficient data points
414+
assert lr_finder.suggestion() is None
415+
assert model.learning_rate == 0.123 # must remain unchanged because suggestion is not possible

0 commit comments

Comments
 (0)