Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug where checking `trainer.precision` changed to `'mixed'` when specifying 16 in trainer ([#7825](https://github.com/PyTorchLightning/pytorch-lightning/pull/7825))


- Fixed `LearningRateMonitor` keys not properly setup when running with `BackboneFinetuning` Callback ([#7835](https://github.com/PyTorchLightning/pytorch-lightning/pull/7835))


## [1.3.2] - 2021-05-18

### Changed
Expand Down
63 changes: 45 additions & 18 deletions pytorch_lightning/callbacks/lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
Monitor and logs learning rate for lr schedulers during training.

"""
from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Optional, Type

from typing import Dict, List, Optional
from torch.optim.optimizer import Optimizer

from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn
Expand Down Expand Up @@ -53,7 +55,7 @@ class LearningRateMonitor(Callback):
In case of multiple optimizers of same type, they will be named ``Adam``,
``Adam-1`` etc. If a optimizer has multiple parameter groups they will
be named ``Adam/pg1``, ``Adam/pg2`` etc. To control naming, pass in a
``name`` keyword in the construction of the learning rate schdulers
``name`` keyword in the construction of the learning rate schedulers

Example::

Expand Down Expand Up @@ -138,6 +140,9 @@ def on_train_epoch_start(self, trainer, *args, **kwargs):
def _extract_stats(self, trainer, interval: str) -> Dict[str, float]:
latest_stat = {}

names = self._find_names(trainer.lr_schedulers, add_lr_sch_names=False)
self._remap_keys(names)

for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers):
if scheduler['interval'] == interval or interval == 'any':
opt = scheduler['scheduler'].optimizer
Expand All @@ -146,7 +151,7 @@ def _extract_stats(self, trainer, interval: str) -> Dict[str, float]:

for i, pg in enumerate(param_groups):
suffix = f'/pg{i + 1}' if len(param_groups) > 1 else ''
lr = self._extract_lr(param_group=pg, name=f'{name}{suffix}')
lr = self._extract_lr(pg, f'{name}{suffix}')
latest_stat.update(lr)
momentum = self._extract_momentum(
param_group=pg, name=f'{name}-momentum{suffix}', use_betas=use_betas
Expand All @@ -155,48 +160,70 @@ def _extract_stats(self, trainer, interval: str) -> Dict[str, float]:

return latest_stat

def _extract_lr(self, param_group, name: str) -> Dict[str, float]:
def _extract_lr(self, param_group: Dict[str, Any], name: str) -> Dict[str, Any]:
lr = param_group.get('lr')
self.lrs[name].append(lr)
return {name: lr}

def _extract_momentum(self, param_group, name: str, use_betas: bool) -> Dict[str, float]:
def _remap_keys(self, names: List[str], token: str = '/pg1') -> None:
"""
This function is used the remap the keys if param groups for a given optimizer increased.
"""
for new_name in names:
old_name = new_name.replace(token, '')
if token in new_name and old_name in self.lrs:
self.lrs[new_name] = self.lrs.pop(old_name)
elif new_name not in self.lrs:
self.lrs[new_name] = []

def _extract_momentum(self, param_group: Dict[str, Any], name: str, use_betas: bool) -> Dict[str, float]:
if not self.log_momentum:
return {}

momentum = param_group.get('betas')[0] if use_betas else param_group.get('momentum', 0)
self.last_momentum_values[name] = momentum
return {name: momentum}

def _find_names(self, lr_schedulers) -> List[str]:
# Create uniqe names in the case we have multiple of the same learning
# rate schduler + multiple parameter groups
def _add_prefix(
self, name: str, optimizer_cls: Type[Optimizer], seen_optimizer_types: DefaultDict[Type[Optimizer], int]
) -> str:
if optimizer_cls not in seen_optimizer_types:
return name
count = seen_optimizer_types[optimizer_cls]
return name + f'-{count - 1}' if count > 1 else name

def _find_names(self, lr_schedulers: List, add_lr_sch_names: bool = True) -> List[str]:
# Create unique names in the case we have multiple of the same learning
# rate scheduler + multiple parameter groups
names = []
seen_optimizers = []
seen_optimizer_types = defaultdict(int)
for scheduler in lr_schedulers:
sch = scheduler['scheduler']
if scheduler['name'] is not None:
name = scheduler['name']
else:
opt_name = 'lr-' + sch.optimizer.__class__.__name__
i, name = 1, opt_name
name = 'lr-' + sch.optimizer.__class__.__name__

# Multiple schduler of the same type
while True:
if name not in names:
break
i, name = i + 1, f'{opt_name}-{i}'
seen_optimizers.append(sch.optimizer)
optimizer_cls = type(sch.optimizer)
if scheduler['name'] is None:
seen_optimizer_types[optimizer_cls] += 1

# Multiple param groups for the same schduler
# Multiple param groups for the same scheduler
param_groups = sch.optimizer.param_groups

name = self._add_prefix(name, optimizer_cls, seen_optimizer_types)

if len(param_groups) != 1:
for i, pg in enumerate(param_groups):
for i in range(len(param_groups)):
temp = f'{name}/pg{i + 1}'
names.append(temp)
else:
names.append(name)

self.lr_sch_names.append(name)
if add_lr_sch_names:
self.lr_sch_names.append(name)

return names

Expand Down
102 changes: 102 additions & 0 deletions tests/callbacks/test_lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from torch import optim

import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.callbacks.finetuning import BackboneFinetuning
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.datamodules import ClassifDataModule
Expand Down Expand Up @@ -278,3 +281,102 @@ def configure_optimizers(self):
)
trainer.fit(TestModel())
assert lr_monitor.lr_sch_names == list(lr_monitor.lrs.keys()) == ['my_logging_name']


def test_multiple_optimizers_basefinetuning(tmpdir):

class TestModel(BoringModel):

def __init__(self):
super().__init__()
self.backbone = torch.nn.Sequential(
torch.nn.Linear(32, 32),
torch.nn.Linear(32, 32),
torch.nn.Linear(32, 32),
torch.nn.ReLU(True),
)
self.layer = torch.nn.Linear(32, 2)

def training_step(self, batch, batch_idx, optimizer_idx):
return super().training_step(batch, batch_idx)

def forward(self, x):
return self.layer(self.backbone(x))

def configure_optimizers(self):
parameters = list(filter(lambda p: p.requires_grad, self.parameters()))
opt = optim.Adam(parameters, lr=0.1)
opt_2 = optim.Adam(parameters, lr=0.1)
opt_3 = optim.Adam(parameters, lr=0.1)
optimizers = [opt, opt_2, opt_3]
schedulers = [
optim.lr_scheduler.StepLR(opt, step_size=1, gamma=0.5),
optim.lr_scheduler.StepLR(opt_2, step_size=1, gamma=0.5),
]
return optimizers, schedulers

class Check(Callback):

def on_train_epoch_start(self, trainer, pl_module) -> None:
num_param_groups = sum([len(opt.param_groups) for opt in trainer.optimizers])
assert lr_monitor.lr_sch_names == ['lr-Adam', 'lr-Adam-1']
if trainer.current_epoch == 0:
assert num_param_groups == 3
elif trainer.current_epoch == 1:
assert num_param_groups == 4
assert list(lr_monitor.lrs) == ['lr-Adam-1', 'lr-Adam/pg1', 'lr-Adam/pg2']
elif trainer.current_epoch == 2:
assert num_param_groups == 5
assert list(lr_monitor.lrs) == ['lr-Adam/pg1', 'lr-Adam/pg2', 'lr-Adam-1/pg1', 'lr-Adam-1/pg2']
else:
expected = ['lr-Adam/pg1', 'lr-Adam/pg2', 'lr-Adam-1/pg1', 'lr-Adam-1/pg2', 'lr-Adam-1/pg3']
assert list(lr_monitor.lrs) == expected

class TestFinetuning(BackboneFinetuning):

def freeze_before_training(self, pl_module):
self.freeze(pl_module.backbone[0])
self.freeze(pl_module.backbone[1])
self.freeze(pl_module.layer)

def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int):
"""Called when the epoch begins."""
if epoch == 1 and opt_idx == 0:
self.unfreeze_and_add_param_group(pl_module.backbone[0], optimizer, lr=0.1)
if epoch == 2 and opt_idx == 1:
self.unfreeze_and_add_param_group(pl_module.layer, optimizer, lr=0.1)

if epoch == 3 and opt_idx == 1:
assert len(optimizer.param_groups) == 2
self.unfreeze_and_add_param_group(pl_module.backbone[1], optimizer, lr=0.1)
assert len(optimizer.param_groups) == 3

lr_monitor = LearningRateMonitor()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
limit_val_batches=0,
limit_train_batches=2,
callbacks=[TestFinetuning(), lr_monitor, Check()],
progress_bar_refresh_rate=0,
weights_summary=None,
checkpoint_callback=False
)
model = TestModel()
model.training_epoch_end = None
trainer.fit(model)

expected = [0.1, 0.05, 0.025, 0.0125, 0.00625]
assert lr_monitor.lrs['lr-Adam/pg1'] == expected

expected = [0.1, 0.05, 0.025, 0.0125]
assert lr_monitor.lrs['lr-Adam/pg2'] == expected

expected = [0.1, 0.05, 0.025, 0.0125, 0.00625]
assert lr_monitor.lrs['lr-Adam-1/pg1'] == expected

expected = [0.1, 0.05, 0.025]
assert lr_monitor.lrs['lr-Adam-1/pg2'] == expected

expected = [0.1, 0.05]
assert lr_monitor.lrs['lr-Adam-1/pg3'] == expected