From e0f7deabb5073106e2ab150e7450b37ef514e4a2 Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Thu, 10 Nov 2022 12:51:54 +0100 Subject: [PATCH 1/4] Fix restarting attribute for lr finder --- src/pytorch_lightning/tuner/lr_finder.py | 1 + tests/tests_pytorch/tuner/test_lr_finder.py | 37 +++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index 63d7c09abb26e..888c8a6dee515 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -265,6 +265,7 @@ def lr_find( # Restore initial state of model trainer._checkpoint_connector.restore(ckpt_path) trainer.strategy.remove_checkpoint(ckpt_path) + trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True return lr_finder diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index ed4d9d33430f0..01722dca86438 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -441,6 +441,43 @@ def test_if_lr_finder_callback_already_configured(): trainer.tune(model) +def test_lr_finder_callback_restarting(tmpdir): + """Test that `LearningRateFinder` does not set restarting=True when loading checkpoint.""" + + class MyBoringModel(BoringModel): + def __init__(self): + super().__init__() + self.learning_rate = 0.123 + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=self.learning_rate) + + class CustomLearningRateFinder(LearningRateFinder): + milestones = (1,) + + def lr_find(self, trainer, pl_module) -> None: + super().lr_find(trainer, pl_module) + assert not trainer.fit_loop.restarting + + def on_train_epoch_start(self, trainer, pl_module): + if trainer.current_epoch in self.milestones or trainer.current_epoch == 0: + self.lr_find(trainer, pl_module) + + model = MyBoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=3, + callbacks=[CustomLearningRateFinder(early_stop_threshold=None, update_attr=True)], + limit_train_batches=10, + limit_val_batches=0, + limit_test_batches=00, + num_sanity_val_steps=0, + enable_model_summary=False, + ) + + trainer.fit(model) + + @mock.patch.dict(os.environ, os.environ.copy(), clear=True) @RunIf(standalone=True) def test_lr_finder_with_ddp(tmpdir): From 2b6afeb148289fb784132e71e53d0fb727f4b85e Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Thu, 10 Nov 2022 12:56:27 +0100 Subject: [PATCH 2/4] chlog --- src/pytorch_lightning/CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 2abb696ec3c41..6765a10c9762f 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -37,6 +37,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Enhanced `reduce_boolean_decision` to accommodate `any`-analogous semantics expected by the `EarlyStopping` callback ([#15253](https://github.com/Lightning-AI/lightning/pull/15253)) +- Fixed `fit_loop.restarting` to be `False` for lr finder ([#15620](https://github.com/Lightning-AI/lightning/pull/15620)) + ### Deprecated - Deprecated `pytorch_lightning.utilities.distributed.rank_zero_only` in favor of `pytorch_lightning.utilities.rank_zero_only` ([#15536](https://github.com/Lightning-AI/lightning/pull/15536)) From 4c2ff34737893dfcbde913d8ac88520f6dc0723e Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Thu, 10 Nov 2022 16:51:28 +0100 Subject: [PATCH 3/4] additional fixes --- src/pytorch_lightning/callbacks/lr_finder.py | 2 +- src/pytorch_lightning/tuner/lr_finder.py | 14 ++++++++------ tests/tests_pytorch/tuner/test_lr_finder.py | 14 ++++++++++++-- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/pytorch_lightning/callbacks/lr_finder.py b/src/pytorch_lightning/callbacks/lr_finder.py index 4d235751ca791..1c950e64086b9 100644 --- a/src/pytorch_lightning/callbacks/lr_finder.py +++ b/src/pytorch_lightning/callbacks/lr_finder.py @@ -85,7 +85,7 @@ def __init__( max_lr: float = 1, num_training_steps: int = 100, mode: str = "exponential", - early_stop_threshold: float = 4.0, + early_stop_threshold: Optional[float] = 4.0, update_attr: bool = False, ) -> None: mode = mode.lower() diff --git a/src/pytorch_lightning/tuner/lr_finder.py b/src/pytorch_lightning/tuner/lr_finder.py index 888c8a6dee515..846b4abdd8e76 100644 --- a/src/pytorch_lightning/tuner/lr_finder.py +++ b/src/pytorch_lightning/tuner/lr_finder.py @@ -203,7 +203,7 @@ def lr_find( max_lr: float = 1, num_training: int = 100, mode: str = "exponential", - early_stop_threshold: float = 4.0, + early_stop_threshold: Optional[float] = 4.0, update_attr: bool = False, ) -> Optional[_LRFinder]: """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`""" @@ -220,6 +220,8 @@ def lr_find( ckpt_path = trainer.strategy.broadcast(ckpt_path) trainer.save_checkpoint(ckpt_path) + start_steps = trainer.global_step + # Arguments we adjust during the lr finder, save for restoring params = __lr_finder_dump_params(trainer) @@ -240,7 +242,7 @@ def lr_find( _try_loop_run(trainer, params) # Prompt if we stopped early - if trainer.global_step != num_training: + if trainer.global_step != num_training + start_steps: log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.") # Transfer results from callback to lr finder object @@ -285,7 +287,7 @@ def __lr_finder_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: } -def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_stop_threshold: float) -> None: +def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_stop_threshold: Optional[float]) -> None: from pytorch_lightning.loggers.logger import DummyLogger trainer.strategy.lr_scheduler_configs = [] @@ -296,8 +298,8 @@ def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_sto trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)] # No logging trainer.logger = DummyLogger() if trainer.logger is not None else None - # Max step set to number of iterations - trainer.fit_loop.max_steps = num_training + # Max step set to number of iterations starting at current number of iterations + trainer.fit_loop.max_steps = num_training + trainer.global_step trainer.limit_val_batches = num_training @@ -336,7 +338,7 @@ class _LRCallback(Callback): def __init__( self, num_training: int, - early_stop_threshold: float = 4.0, + early_stop_threshold: Optional[float] = 4.0, progress_bar_refresh_rate: int = 0, beta: float = 0.98, ): diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index 01722dca86438..25fdcd35f31f7 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -444,11 +444,17 @@ def test_if_lr_finder_callback_already_configured(): def test_lr_finder_callback_restarting(tmpdir): """Test that `LearningRateFinder` does not set restarting=True when loading checkpoint.""" + num_lr_steps = 100 + class MyBoringModel(BoringModel): def __init__(self): super().__init__() self.learning_rate = 0.123 + def on_train_batch_start(self, batch, batch_idx): + if getattr(self, "_expected_max_steps", None) is not None: + assert self.trainer.fit_loop.max_steps == self._expected_max_steps + def configure_optimizers(self): return torch.optim.SGD(self.parameters(), lr=self.learning_rate) @@ -456,7 +462,9 @@ class CustomLearningRateFinder(LearningRateFinder): milestones = (1,) def lr_find(self, trainer, pl_module) -> None: + pl_module._expected_max_steps = trainer.global_step + self._num_training_steps super().lr_find(trainer, pl_module) + pl_module._expected_max_steps = None assert not trainer.fit_loop.restarting def on_train_epoch_start(self, trainer, pl_module): @@ -467,10 +475,12 @@ def on_train_epoch_start(self, trainer, pl_module): trainer = Trainer( default_root_dir=tmpdir, max_epochs=3, - callbacks=[CustomLearningRateFinder(early_stop_threshold=None, update_attr=True)], + callbacks=[ + CustomLearningRateFinder(early_stop_threshold=None, update_attr=True, num_training_steps=num_lr_steps) + ], limit_train_batches=10, limit_val_batches=0, - limit_test_batches=00, + limit_test_batches=0, num_sanity_val_steps=0, enable_model_summary=False, ) From 24dfb849303de6e63bd107e3b85106e2e85dca72 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Fri, 11 Nov 2022 00:21:59 +0100 Subject: [PATCH 4/4] chlog --- src/pytorch_lightning/CHANGELOG.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 6765a10c9762f..7395a61c384a5 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -15,11 +15,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added utilities to migrate checkpoints from one Lightning version to another ([#15237](https://github.com/Lightning-AI/lightning/pull/15237)) + - Added back the accidentally removed `pytorch_lightning.utilities.distributed.rank_zero_only` function ([#15536](https://github.com/Lightning-AI/lightning/pull/15536)) -- Added support to upgrade all checkpoints in a folder using the `pl.utilities.upgrade_checkpoint` script ([#15333](https://github.com/Lightning-AI/lightning/pull/15333)) -- +- Added support to upgrade all checkpoints in a folder using the `pl.utilities.upgrade_checkpoint` script ([#15333](https://github.com/Lightning-AI/lightning/pull/15333)) - Added a check to validate that wrapped FSDP models are used while initializing optimizers ([#15301](https://github.com/Lightning-AI/lightning/pull/15301)) @@ -29,15 +29,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - From now on, Lightning Trainer and `LightningModule.load_from_checkpoint` automatically upgrade the loaded checkpoint if it was produced in an old version of Lightning ([#15237](https://github.com/Lightning-AI/lightning/pull/15237)) + - `Trainer.{validate,test,predict}(ckpt_path=...)` no longer restores the `Trainer.global_step` and `trainer.current_epoch` value from the checkpoints - From now on, only `Trainer.fit` will restore this value ([#15532](https://github.com/Lightning-AI/lightning/pull/15532)) + - The `ModelCheckpoint.save_on_train_epoch_end` attribute is now computed dynamically every epoch, accounting for changes to the validation dataloaders ([#15300](https://github.com/Lightning-AI/lightning/pull/15300)) -### Fixed - Enhanced `reduce_boolean_decision` to accommodate `any`-analogous semantics expected by the `EarlyStopping` callback ([#15253](https://github.com/Lightning-AI/lightning/pull/15253)) -- Fixed `fit_loop.restarting` to be `False` for lr finder ([#15620](https://github.com/Lightning-AI/lightning/pull/15620)) ### Deprecated @@ -77,6 +77,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the import of `pytorch_lightning` causing a warning 'Redirects are currently not supported in Windows or MacOs' ([#15610](https://github.com/PyTorchLightning/pytorch-lightning/issues/15610)) +- Fixed `fit_loop.restarting` to be `False` for lr finder ([#15620](https://github.com/Lightning-AI/lightning/pull/15620)) + ## [1.8.0] - 2022-11-01