diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 1c3a3b9d5a1be..3171c25595fc6 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -376,6 +376,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed main progress bar counter when `val_check_interval=int` and `check_val_every_n_epoch=None` ([#12832](https://github.com/Lightning-AI/lightning/pull/12832) +- Used `global_step` while restoring logging step for old checkpoints ([#13645](https://github.com/Lightning-AI/lightning/pull/13645)) + + +- Fixed error handling in learning rate finder when not enough data points are available to give a good suggestion ([#13845](https://github.com/Lightning-AI/lightning/pull/13845)) + + +- Fixed an issue that caused the learning rate finder to set the model's learning rate to None when no suggestion was possible ([#13845](https://github.com/Lightning-AI/lightning/pull/13845)) + + + ## [1.6.5] - 2022-07-13 ### Fixed @@ -386,9 +396,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the restoration of log step during restart ([#13467](https://github.com/PyTorchLightning/pytorch-lightning/pull/13467)) -- Used `global_step` while restoring logging step for old checkpoints ([#13645](https://github.com/PyTorchLightning/pytorch-lightning/pull/13645)) - - ## [1.6.4] - 2022-06-01 ### Added diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index 71d96ef428f35..186dfb5ea7416 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -174,24 +174,33 @@ def plot(self, suggest: bool = False, show: bool = False) -> Optional["plt.Figur return fig def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]: - """This will propose a suggestion for choice of initial learning rate as the point with the steepest + """This will propose a suggestion for an initial learning rate based on the point with the steepest negative gradient. + Args: + skip_begin: how many samples to skip in the beginning; helps to avoid too naive estimates + skip_end: how many samples to skip in the end; helps to avoid too optimistic estimates + Returns: - lr: suggested initial learning rate to use - skip_begin: how many samples to skip in the beginning. Prevent too naive estimates - skip_end: how many samples to skip in the end. Prevent too optimistic estimates + The suggested initial learning rate to use, or `None` if a suggestion is not possible due to too few + loss samples. """ - try: - loss = np.array(self.results["loss"][skip_begin:-skip_end]) - loss = loss[np.isfinite(loss)] - min_grad = np.gradient(loss).argmin() - self._optimal_idx = min_grad + skip_begin - return self.results["lr"][self._optimal_idx] - # todo: specify the possible exception - except Exception: - log.exception("Failed to compute suggesting for `lr`. There might not be enough points.") + losses = np.array(self.results["loss"][skip_begin:-skip_end]) + losses = losses[np.isfinite(losses)] + if len(losses) < 2: + # computing np.gradient requires at least 2 points + log.error( + "Failed to compute suggestion for learning rate because there are not enough points. Increase the loop" + " iteration limits or the size of your dataset/dataloader." + ) self._optimal_idx = None + return None + + # TODO: When computing the argmin here, and some losses are non-finite, the expected indices could be + # incorrectly shifted by an offset + min_grad = np.gradient(losses).argmin() + self._optimal_idx = min_grad + skip_begin + return self.results["lr"][self._optimal_idx] def lr_find( @@ -252,8 +261,9 @@ def lr_find( lr = lr_finder.suggestion() # TODO: log lr.results to self.logger - lightning_setattr(model, lr_attr_name, lr) - log.info(f"Learning rate set to {lr}") + if lr is not None: + lightning_setattr(model, lr_attr_name, lr) + log.info(f"Learning rate set to {lr}") # Restore initial state of model trainer._checkpoint_connector.restore(ckpt_path) diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index 529ef1c4c08c1..9be115d2f8fda 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from copy import deepcopy @@ -19,6 +20,7 @@ from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.demos.boring_classes import BoringModel +from pytorch_lightning.tuner.lr_finder import _LRFinder from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.simple_models import ClassificationModel @@ -359,3 +361,55 @@ def test_multiple_lr_find_calls_gives_same_results(tmpdir): for curr_lr_finder in all_res[1:] for k in all_res[0].keys() ) + + +@pytest.mark.parametrize( + "skip_begin,skip_end,losses,expected_error", + [ + (0, 0, [], True), + (10, 1, [], True), + (0, 2, [0, 1, 2], True), + (0, 1, [0, 1, 2], False), + (1, 1, [0, 1, 2], True), + (1, 1, [0, 1, 2, 3], False), + (0, 1, [float("nan"), float("nan"), 0, float("inf"), 1, 2, 3, float("inf"), 2, float("nan"), 1], False), + (4, 1, [float("nan"), float("nan"), 0, float("inf"), 1, 2, 3, float("inf"), 2, float("nan"), 1], False), + ], +) +def test_suggestion_not_enough_finite_points(losses, skip_begin, skip_end, expected_error, caplog): + """Tests the error handling when not enough finite points are available to make a suggestion.""" + caplog.clear() + lr_finder = _LRFinder( + mode="exponential", + lr_min=1e-8, + lr_max=1, + num_training=100, + ) + lrs = list(torch.arange(len(losses))) + lr_finder.results = { + "lr": lrs, + "loss": losses, + } + with caplog.at_level(logging.ERROR, logger="root.tuner.lr_finder"): + lr = lr_finder.suggestion(skip_begin=skip_begin, skip_end=skip_end) + + if expected_error: + assert lr is None + assert "Failed to compute suggestion for learning rate" in caplog.text + else: + assert lr is not None + + +def test_lr_attribute_when_suggestion_invalid(tmpdir): + """Tests learning rate finder ends before `num_training` steps.""" + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.learning_rate = 0.123 + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir) + lr_finder = trainer.tuner.lr_find(model=model, update_attr=True, num_training=1) # force insufficient data points + assert lr_finder.suggestion() is None + assert model.learning_rate == 0.123 # must remain unchanged because suggestion is not possible