From 4ad89d9b63be85845fec3312718e1bdfe91e3c2b Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 26 Jul 2022 06:27:48 +0200 Subject: [PATCH 01/10] Improve error handling in learning rate finder --- src/pytorch_lightning/tuner/lr_finder.py | 34 +++++++++++++++--------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index 71d96ef428f35..85ed6a3ee40b4 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -177,21 +177,28 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float] """This will propose a suggestion for choice of initial learning rate as the point with the steepest negative gradient. - Returns: - lr: suggested initial learning rate to use + Args: skip_begin: how many samples to skip in the beginning. Prevent too naive estimates skip_end: how many samples to skip in the end. Prevent too optimistic estimates + + Returns: + The suggested initial learning rate to use, or `None` if a suggestion is not possible due to too few + loss samples. """ - try: - loss = np.array(self.results["loss"][skip_begin:-skip_end]) - loss = loss[np.isfinite(loss)] - min_grad = np.gradient(loss).argmin() - self._optimal_idx = min_grad + skip_begin - return self.results["lr"][self._optimal_idx] - # todo: specify the possible exception - except Exception: - log.exception("Failed to compute suggesting for `lr`. There might not be enough points.") + losses = np.array(self.results["loss"][skip_begin:-skip_end]) + losses = losses[np.isfinite(losses)] + if len(losses) < 2: + # computing np.gradient requires at least 2 points + log.error( + "Failed to compute suggestion for learning rate because there are not enough points. Increase the loop" + " iteration limits or the size of your dataset/dataloader." + ) self._optimal_idx = None + return + + min_grad = np.gradient(losses).argmin() + self._optimal_idx = min_grad + skip_begin + return self.results["lr"][self._optimal_idx] def lr_find( @@ -252,8 +259,9 @@ def lr_find( lr = lr_finder.suggestion() # TODO: log lr.results to self.logger - lightning_setattr(model, lr_attr_name, lr) - log.info(f"Learning rate set to {lr}") + if lr is not None: + lightning_setattr(model, lr_attr_name, lr) + log.info(f"Learning rate set to {lr}") # Restore initial state of model trainer._checkpoint_connector.restore(ckpt_path) From d93fb7a780c4ddac591b89cdc3e4ae32f957d2ac Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 26 Jul 2022 07:13:40 +0200 Subject: [PATCH 02/10] docs --- src/pytorch_lightning/tuner/lr_finder.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index 85ed6a3ee40b4..6ac4a8aa5f1bb 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -174,12 +174,12 @@ def plot(self, suggest: bool = False, show: bool = False) -> Optional["plt.Figur return fig def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]: - """This will propose a suggestion for choice of initial learning rate as the point with the steepest - negative gradient. + """This will propose a suggestion for an initial learning rate based on the point with the steepest negative + gradient. Args: - skip_begin: how many samples to skip in the beginning. Prevent too naive estimates - skip_end: how many samples to skip in the end. Prevent too optimistic estimates + skip_begin: how many samples to skip in the beginning; helps to avoid too naive estimates + skip_end: how many samples to skip in the end; helps to avoid too optimistic estimates Returns: The suggested initial learning rate to use, or `None` if a suggestion is not possible due to too few From 64ad0963555cb2ba7ccf73eadf9f92b45009a79f Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 26 Jul 2022 07:49:26 +0200 Subject: [PATCH 03/10] add test --- src/pytorch_lightning/tuner/lr_finder.py | 2 ++ tests/tests_pytorch/tuner/test_lr_finder.py | 36 +++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index 6ac4a8aa5f1bb..f70553c16edcb 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -196,6 +196,8 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float] self._optimal_idx = None return + # TODO: When computing the argmin here, and some losses are non-finite, the expected indices could be + # incorrectly shifted by an offset min_grad = np.gradient(losses).argmin() self._optimal_idx = min_grad + skip_begin return self.results["lr"][self._optimal_idx] diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index 529ef1c4c08c1..8bf073772db33 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from copy import deepcopy @@ -19,6 +20,7 @@ from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.demos.boring_classes import BoringModel +from pytorch_lightning.tuner.lr_finder import _LRFinder from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.simple_models import ClassificationModel @@ -359,3 +361,37 @@ def test_multiple_lr_find_calls_gives_same_results(tmpdir): for curr_lr_finder in all_res[1:] for k in all_res[0].keys() ) + + +@pytest.mark.parametrize("skip_begin,skip_end,losses,expected_error", [ + (0, 0, [], True), + (10, 1, [], True), + (0, 2, [0, 1, 2], True), + (0, 1, [0, 1, 2], False), + (1, 1, [0, 1, 2], True), + (1, 1, [0, 1, 2, 3], False), + (0, 1, [float("nan"), float("nan"), 0, float("inf"), 1, 2, 3, float("inf"), 2, float("nan"), 1], False), + (4, 1, [float("nan"), float("nan"), 0, float("inf"), 1, 2, 3, float("inf"), 2, float("nan"), 1], False), +]) +def test_suggestion_not_enough_finite_points(losses, skip_begin, skip_end, expected_error, caplog): + """Tests the error handling when not enough finite points are available to make a suggestion.""" + caplog.clear() + lr_finder = _LRFinder( + mode="exponential", + lr_min=1e-8, + lr_max=1, + num_training=100, + ) + lrs = list(torch.arange(len(losses))) + lr_finder.results = { + "lr": lrs, + "loss": losses, + } + with caplog.at_level(logging.ERROR, logger="root.tuner.lr_finder"): + lr = lr_finder.suggestion(skip_begin=skip_begin, skip_end=skip_end) + + if expected_error: + assert lr is None + assert "Failed to compute suggestion for learning rate" in caplog.text + else: + assert lr is not None From d0549789ffd32cb3eba6f0b262233ec97604d600 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 26 Jul 2022 07:50:39 +0200 Subject: [PATCH 04/10] mypy --- src/pytorch_lightning/tuner/lr_finder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index f70553c16edcb..a04baf1cf9b2e 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -194,7 +194,7 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float] " iteration limits or the size of your dataset/dataloader." ) self._optimal_idx = None - return + return None # TODO: When computing the argmin here, and some losses are non-finite, the expected indices could be # incorrectly shifted by an offset From 8ea26189818fc7f122979080e5d947c600e52997 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 Jul 2022 05:52:17 +0000 Subject: [PATCH 05/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/tuner/lr_finder.py | 4 ++-- tests/tests_pytorch/tuner/test_lr_finder.py | 23 ++++++++++++--------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index a04baf1cf9b2e..186dfb5ea7416 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -174,8 +174,8 @@ def plot(self, suggest: bool = False, show: bool = False) -> Optional["plt.Figur return fig def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]: - """This will propose a suggestion for an initial learning rate based on the point with the steepest negative - gradient. + """This will propose a suggestion for an initial learning rate based on the point with the steepest + negative gradient. Args: skip_begin: how many samples to skip in the beginning; helps to avoid too naive estimates diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index 8bf073772db33..4e1db90777813 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -363,16 +363,19 @@ def test_multiple_lr_find_calls_gives_same_results(tmpdir): ) -@pytest.mark.parametrize("skip_begin,skip_end,losses,expected_error", [ - (0, 0, [], True), - (10, 1, [], True), - (0, 2, [0, 1, 2], True), - (0, 1, [0, 1, 2], False), - (1, 1, [0, 1, 2], True), - (1, 1, [0, 1, 2, 3], False), - (0, 1, [float("nan"), float("nan"), 0, float("inf"), 1, 2, 3, float("inf"), 2, float("nan"), 1], False), - (4, 1, [float("nan"), float("nan"), 0, float("inf"), 1, 2, 3, float("inf"), 2, float("nan"), 1], False), -]) +@pytest.mark.parametrize( + "skip_begin,skip_end,losses,expected_error", + [ + (0, 0, [], True), + (10, 1, [], True), + (0, 2, [0, 1, 2], True), + (0, 1, [0, 1, 2], False), + (1, 1, [0, 1, 2], True), + (1, 1, [0, 1, 2, 3], False), + (0, 1, [float("nan"), float("nan"), 0, float("inf"), 1, 2, 3, float("inf"), 2, float("nan"), 1], False), + (4, 1, [float("nan"), float("nan"), 0, float("inf"), 1, 2, 3, float("inf"), 2, float("nan"), 1], False), + ], +) def test_suggestion_not_enough_finite_points(losses, skip_begin, skip_end, expected_error, caplog): """Tests the error handling when not enough finite points are available to make a suggestion.""" caplog.clear() From caab74da2469eeec696b01dbc9cb97a532ca1035 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 26 Jul 2022 07:53:35 +0200 Subject: [PATCH 06/10] update changelog --- src/pytorch_lightning/CHANGELOG.md | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 1c3a3b9d5a1be..17dee646df7dc 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -376,6 +376,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed main progress bar counter when `val_check_interval=int` and `check_val_every_n_epoch=None` ([#12832](https://github.com/Lightning-AI/lightning/pull/12832) +- Used `global_step` while restoring logging step for old checkpoints ([#13645](https://github.com/PyTorchLightning/pytorch-lightning/pull/13645)) + + +- Fixed error handling in learning rate finder when not enough data points are available to give a good suggestion ([#13845](https://github.com/PyTorchLightning/pytorch-lightning/pull/13845)) + + + ## [1.6.5] - 2022-07-13 ### Fixed @@ -386,9 +393,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the restoration of log step during restart ([#13467](https://github.com/PyTorchLightning/pytorch-lightning/pull/13467)) -- Used `global_step` while restoring logging step for old checkpoints ([#13645](https://github.com/PyTorchLightning/pytorch-lightning/pull/13645)) - - ## [1.6.4] - 2022-06-01 ### Added From f322a4f9e467fa440342242d9b7bccb0b779d192 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 26 Jul 2022 08:02:22 +0200 Subject: [PATCH 07/10] add test --- tests/tests_pytorch/tuner/test_lr_finder.py | 38 +++++++++++++++------ 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index 8bf073772db33..ee1cee6ea1b4f 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -363,16 +363,19 @@ def test_multiple_lr_find_calls_gives_same_results(tmpdir): ) -@pytest.mark.parametrize("skip_begin,skip_end,losses,expected_error", [ - (0, 0, [], True), - (10, 1, [], True), - (0, 2, [0, 1, 2], True), - (0, 1, [0, 1, 2], False), - (1, 1, [0, 1, 2], True), - (1, 1, [0, 1, 2, 3], False), - (0, 1, [float("nan"), float("nan"), 0, float("inf"), 1, 2, 3, float("inf"), 2, float("nan"), 1], False), - (4, 1, [float("nan"), float("nan"), 0, float("inf"), 1, 2, 3, float("inf"), 2, float("nan"), 1], False), -]) +@pytest.mark.parametrize( + "skip_begin,skip_end,losses,expected_error", + [ + (0, 0, [], True), + (10, 1, [], True), + (0, 2, [0, 1, 2], True), + (0, 1, [0, 1, 2], False), + (1, 1, [0, 1, 2], True), + (1, 1, [0, 1, 2, 3], False), + (0, 1, [float("nan"), float("nan"), 0, float("inf"), 1, 2, 3, float("inf"), 2, float("nan"), 1], False), + (4, 1, [float("nan"), float("nan"), 0, float("inf"), 1, 2, 3, float("inf"), 2, float("nan"), 1], False), + ], +) def test_suggestion_not_enough_finite_points(losses, skip_begin, skip_end, expected_error, caplog): """Tests the error handling when not enough finite points are available to make a suggestion.""" caplog.clear() @@ -395,3 +398,18 @@ def test_suggestion_not_enough_finite_points(losses, skip_begin, skip_end, expec assert "Failed to compute suggestion for learning rate" in caplog.text else: assert lr is not None + + +def test_lr_attribute_when_suggestion_invalid(tmpdir): + """Tests learning rate finder ends before `num_training` steps.""" + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.learning_rate = 0.123 + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir) + lr_finder = trainer.tuner.lr_find(model=model, num_training=1) # force insufficient data points + assert lr_finder.suggestion() is None + assert model.learning_rate == 0.123 # must remain unchanged because suggestion is not possible From 72c82fbc9ca3b529bd723a6b33575d86620754f0 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 26 Jul 2022 08:04:02 +0200 Subject: [PATCH 08/10] changelog --- src/pytorch_lightning/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 17dee646df7dc..1da7bf0a8a108 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -382,6 +382,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed error handling in learning rate finder when not enough data points are available to give a good suggestion ([#13845](https://github.com/PyTorchLightning/pytorch-lightning/pull/13845)) +- Fixed an issue that caused the learning rate finder to set the model's learning rate to None when no suggestion was possible ([#13845](https://github.com/PyTorchLightning/pytorch-lightning/pull/13845)) + + ## [1.6.5] - 2022-07-13 From c5e426be2ba3e78025a67c06d293f544bd0ea4d4 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 26 Jul 2022 08:11:14 +0200 Subject: [PATCH 09/10] update test --- tests/tests_pytorch/tuner/test_lr_finder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index ee1cee6ea1b4f..9be115d2f8fda 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -410,6 +410,6 @@ def __init__(self): model = TestModel() trainer = Trainer(default_root_dir=tmpdir) - lr_finder = trainer.tuner.lr_find(model=model, num_training=1) # force insufficient data points + lr_finder = trainer.tuner.lr_find(model=model, update_attr=True, num_training=1) # force insufficient data points assert lr_finder.suggestion() is None assert model.learning_rate == 0.123 # must remain unchanged because suggestion is not possible From a083ddef99a05c563b83c459c8499aa4f3ac34d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 26 Jul 2022 10:22:25 -0400 Subject: [PATCH 10/10] Update src/pytorch_lightning/CHANGELOG.md Co-authored-by: Rohit Gupta --- src/pytorch_lightning/CHANGELOG.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 1da7bf0a8a108..3171c25595fc6 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -376,13 +376,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed main progress bar counter when `val_check_interval=int` and `check_val_every_n_epoch=None` ([#12832](https://github.com/Lightning-AI/lightning/pull/12832) -- Used `global_step` while restoring logging step for old checkpoints ([#13645](https://github.com/PyTorchLightning/pytorch-lightning/pull/13645)) +- Used `global_step` while restoring logging step for old checkpoints ([#13645](https://github.com/Lightning-AI/lightning/pull/13645)) -- Fixed error handling in learning rate finder when not enough data points are available to give a good suggestion ([#13845](https://github.com/PyTorchLightning/pytorch-lightning/pull/13845)) +- Fixed error handling in learning rate finder when not enough data points are available to give a good suggestion ([#13845](https://github.com/Lightning-AI/lightning/pull/13845)) -- Fixed an issue that caused the learning rate finder to set the model's learning rate to None when no suggestion was possible ([#13845](https://github.com/PyTorchLightning/pytorch-lightning/pull/13845)) +- Fixed an issue that caused the learning rate finder to set the model's learning rate to None when no suggestion was possible ([#13845](https://github.com/Lightning-AI/lightning/pull/13845))