From 204311ead8726cf80581b32e77928f4e15c9f020 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Wed, 17 Mar 2021 22:37:40 +0900 Subject: [PATCH 1/4] Add test for lr_schedulers() --- .../optimization/test_manual_optimization.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 8ad603a7677ea..38eff914fa15a 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -1147,3 +1147,41 @@ def dis_closure(): @RunIf(min_gpus=2, special=True) def test_step_with_optimizer_closure_with_different_frequencies_ddp_with_toggle_model(tmpdir): train_manual_optimization(tmpdir, "ddp", model_cls=TestManualOptimizationDDPModelToggleModel) + + +def test_lr_schedulers(tmpdir): + """ + Test `lr_schedulers()` return the same objects + in the correct order as defined in `configure_optimizers()`. + """ + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + scheduler_1, scheduler_2 = self.lr_schedulers() + assert scheduler_1 is self.scheduler_1 + assert scheduler_2 is self.scheduler_2 + + def configure_optimizers(self): + optimizer_1 = torch.optim.SGD(self.parameters(), lr=0.1) + optimizer_2 = torch.optim.SGD(self.parameters(), lr=0.1) + self.scheduler_1 = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1) + self.scheduler_2 = torch.optim.lr_scheduler.StepLR(optimizer_2, step_size=1) + return [optimizer_1, optimizer_2], [self.scheduler_1, self.scheduler_2] + + model = TestModel() + model.training_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + ) + + trainer.fit(model) From 576748724a0c2c187d31428d7d424e28efad3a95 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Wed, 17 Mar 2021 22:38:13 +0900 Subject: [PATCH 2/4] Add lr_schedulers to LightningModule --- pytorch_lightning/core/lightning.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 7efe88515b37e..637e9159e4fc1 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -119,6 +119,20 @@ def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Opt # multiple opts return opts + def lr_schedulers(self) -> Optional[Union[Any, List[Any]]]: + if not self.trainer.lr_schedulers: + return None + + # ignore other keys "interval", "frequency", etc. + lr_schedulers = [s["scheduler"] for s in self.trainer.lr_schedulers] + + # single scheduler + if len(lr_schedulers) == 1: + return lr_schedulers[0] + + # multiple schedulers + return lr_schedulers + @property def example_input_array(self) -> Any: return self._example_input_array From b72477b80d705984ca6381a6c6a02bba9c4df910 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 5 Apr 2021 05:04:52 +0900 Subject: [PATCH 3/4] Update test comment --- tests/trainer/optimization/test_manual_optimization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 38eff914fa15a..70db6208164aa 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -1151,8 +1151,8 @@ def test_step_with_optimizer_closure_with_different_frequencies_ddp_with_toggle_ def test_lr_schedulers(tmpdir): """ - Test `lr_schedulers()` return the same objects - in the correct order as defined in `configure_optimizers()`. + Test `lr_schedulers()` returns the same objects + in the same order as `configure_optimizers()` returns. """ class TestModel(BoringModel): From 07d25e91122f08c2287b2037019f865d766037e0 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 5 Apr 2021 05:34:09 +0900 Subject: [PATCH 4/4] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 81846809fbf85..c4672336c71b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -82,6 +82,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `model` parameter to precision plugins' `clip_gradients` signature ([#6764](https://github.com/PyTorchLightning/pytorch-lightning/pull/6764)) +- Added `LightningModule.lr_schedulers()` for manual optimization ([#6567](https://github.com/PyTorchLightning/pytorch-lightning/pull/6567)) + + ### Changed - Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))