Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `model` parameter to precision plugins' `clip_gradients` signature ([#6764](https://github.com/PyTorchLightning/pytorch-lightning/pull/6764))


- Added `LightningModule.lr_schedulers()` for manual optimization ([#6567](https://github.com/PyTorchLightning/pytorch-lightning/pull/6567))


### Changed

- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))
Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,20 @@ def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Opt
# multiple opts
return opts

def lr_schedulers(self) -> Optional[Union[Any, List[Any]]]:
if not self.trainer.lr_schedulers:
return None

# ignore other keys "interval", "frequency", etc.
lr_schedulers = [s["scheduler"] for s in self.trainer.lr_schedulers]
Comment on lines +126 to +127
Copy link
Contributor Author

@akihironitta akihironitta Mar 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.lr_schedulers() is supposed to be used in manual optimization, so even when dict keys like "interval" and "monitor" are defined in configure_optimizers(), this line ignores all of the keys except "scheduler". Related docs: https://pytorch-lightning.readthedocs.io/en/latest/common/optimizers.html#learning-rate-scheduling


# single scheduler
if len(lr_schedulers) == 1:
return lr_schedulers[0]

# multiple schedulers
return lr_schedulers

@property
def example_input_array(self) -> Any:
return self._example_input_array
Expand Down
38 changes: 38 additions & 0 deletions tests/trainer/optimization/test_manual_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,3 +1147,41 @@ def dis_closure():
@RunIf(min_gpus=2, special=True)
def test_step_with_optimizer_closure_with_different_frequencies_ddp_with_toggle_model(tmpdir):
train_manual_optimization(tmpdir, "ddp", model_cls=TestManualOptimizationDDPModelToggleModel)


def test_lr_schedulers(tmpdir):
"""
Test `lr_schedulers()` returns the same objects
in the same order as `configure_optimizers()` returns.
"""

class TestModel(BoringModel):

def __init__(self):
super().__init__()
self.automatic_optimization = False

def training_step(self, batch, batch_idx):
scheduler_1, scheduler_2 = self.lr_schedulers()
assert scheduler_1 is self.scheduler_1
assert scheduler_2 is self.scheduler_2

def configure_optimizers(self):
optimizer_1 = torch.optim.SGD(self.parameters(), lr=0.1)
optimizer_2 = torch.optim.SGD(self.parameters(), lr=0.1)
self.scheduler_1 = torch.optim.lr_scheduler.StepLR(optimizer_1, step_size=1)
self.scheduler_2 = torch.optim.lr_scheduler.StepLR(optimizer_2, step_size=1)
return [optimizer_1, optimizer_2], [self.scheduler_1, self.scheduler_2]

model = TestModel()
model.training_epoch_end = None

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
)

trainer.fit(model)