From 316b84d49f706a502e126053a84701c6348f318b Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 4 Jun 2021 19:28:15 +0100 Subject: [PATCH 01/11] add test + resolve bug --- pytorch_lightning/callbacks/lr_monitor.py | 63 ++++++++++++++++------- tests/callbacks/test_lr_monitor.py | 54 +++++++++++++++++++ 2 files changed, 98 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index 410f8b319c239..5d2d5eae046aa 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -19,7 +19,7 @@ Monitor and logs learning rate for lr schedulers during training. """ - +from collections import defaultdict from typing import Dict, List, Optional from pytorch_lightning.callbacks.base import Callback @@ -53,7 +53,7 @@ class LearningRateMonitor(Callback): In case of multiple optimizers of same type, they will be named ``Adam``, ``Adam-1`` etc. If a optimizer has multiple parameter groups they will be named ``Adam/pg1``, ``Adam/pg2`` etc. To control naming, pass in a - ``name`` keyword in the construction of the learning rate schdulers + ``name`` keyword in the construction of the learning rate schedulers Example:: @@ -146,7 +146,7 @@ def _extract_stats(self, trainer, interval: str) -> Dict[str, float]: for i, pg in enumerate(param_groups): suffix = f'/pg{i + 1}' if len(param_groups) > 1 else '' - lr = self._extract_lr(param_group=pg, name=f'{name}{suffix}') + lr = self._extract_lr(trainer, param_group=pg, name=f'{name}{suffix}') latest_stat.update(lr) momentum = self._extract_momentum( param_group=pg, name=f'{name}-momentum{suffix}', use_betas=use_betas @@ -155,11 +155,26 @@ def _extract_stats(self, trainer, interval: str) -> Dict[str, float]: return latest_stat - def _extract_lr(self, param_group, name: str) -> Dict[str, float]: + def _extract_lr(self, trainer, param_group, name: str) -> Dict[str, float]: lr = param_group.get('lr') - self.lrs[name].append(lr) + try: + self.lrs[name].append(lr) + except KeyError: + names = self._find_names(trainer.lr_schedulers, add_lr_sch_names=False) + self._remap_keys(names) + self.lrs[name].append(lr) return {name: lr} + def _remap_keys(self, names: List[str]) -> None: + token = '/pg1' + for new_name in names: + if token in new_name: + old_n = new_name.replace(token, '') + self.lrs[new_name] = self.lrs[old_n] + del self.lrs[old_n] + else: + self.lrs[new_name] = [] + def _extract_momentum(self, param_group, name: str, use_betas: bool) -> Dict[str, float]: if not self.log_momentum: return {} @@ -168,35 +183,45 @@ def _extract_momentum(self, param_group, name: str, use_betas: bool) -> Dict[str self.last_momentum_values[name] = momentum return {name: momentum} - def _find_names(self, lr_schedulers) -> List[str]: - # Create uniqe names in the case we have multiple of the same learning - # rate schduler + multiple parameter groups + def _add_prefix(self, name, optimizer_cls, seen_optimizer_types) -> str: + count = seen_optimizer_types[optimizer_cls] + return name + f'-{count}' if count else name + + def _find_names(self, lr_schedulers, add_lr_sch_names: bool = True) -> List[str]: + # Create unique names in the case we have multiple of the same learning + # rate scheduler + multiple parameter groups names = [] + seen_optimizers = [] + seen_optimizer_types = defaultdict(int) for scheduler in lr_schedulers: sch = scheduler['scheduler'] if scheduler['name'] is not None: name = scheduler['name'] else: - opt_name = 'lr-' + sch.optimizer.__class__.__name__ - i, name = 1, opt_name + name = 'lr-' + sch.optimizer.__class__.__name__ - # Multiple schduler of the same type - while True: - if name not in names: - break - i, name = i + 1, f'{opt_name}-{i}' + seen_optimizers.append(sch.optimizer) + optimizer_cls = type(sch.optimizer) + if scheduler['name'] is None: + if optimizer_cls not in seen_optimizer_types: + seen_optimizer_types[optimizer_cls] = 0 + else: + seen_optimizer_types[optimizer_cls] += 1 - # Multiple param groups for the same schduler + # Multiple param groups for the same scheduler param_groups = sch.optimizer.param_groups if len(param_groups) != 1: - for i, pg in enumerate(param_groups): - temp = f'{name}/pg{i + 1}' + for i in range(len(param_groups)): + temp = self._add_prefix(name, optimizer_cls, seen_optimizer_types) + temp = f'{temp}/pg{i + 1}' names.append(temp) else: + name = self._add_prefix(name, optimizer_cls, seen_optimizer_types) names.append(name) - self.lr_sch_names.append(name) + if add_lr_sch_names: + self.lr_sch_names.append(name) return names diff --git a/tests/callbacks/test_lr_monitor.py b/tests/callbacks/test_lr_monitor.py index bea6c45e95ced..ac3939fc362ce 100644 --- a/tests/callbacks/test_lr_monitor.py +++ b/tests/callbacks/test_lr_monitor.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest +import torch from torch import optim import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.callbacks import LearningRateMonitor +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.callbacks.finetuning import BackboneFinetuning from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel from tests.helpers.datamodules import ClassifDataModule @@ -278,3 +281,54 @@ def configure_optimizers(self): ) trainer.fit(TestModel()) assert lr_monitor.lr_sch_names == list(lr_monitor.lrs.keys()) == ['my_logging_name'] + + +def test_multiple_optimizers_basefinetuning(tmpdir): + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.backbone = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(True)) + self.layer = torch.nn.Linear(32, 2) + + def training_step(self, batch, batch_idx, optimizer_idx): + return super().training_step(batch, batch_idx) + + def forward(self, x): + return self.layer(self.backbone(x)) + + def configure_optimizers(self): + opt = optim.Adam(self.layer.parameters(), lr=0.1) + opt_2 = optim.Adam(self.layer.parameters(), lr=0.1) + opt_3 = optim.Adam(self.layer.parameters(), lr=0.1) + return [opt, opt_2, opt_3 + ], [optim.lr_scheduler.StepLR(opt, step_size=1), + optim.lr_scheduler.StepLR(opt_2, step_size=1)] + + class Check(Callback): + + def on_epoch_end(self, trainer, pl_module) -> None: + assert lr_monitor.lr_sch_names == ['lr-Adam', 'lr-Adam-1'] + if trainer.current_epoch > 2: + assert list(lr_monitor.lrs.keys()) == ['lr-Adam/pg1', 'lr-Adam/pg2', 'lr-Adam-1/pg1', 'lr-Adam-1/pg2'] + else: + assert list(lr_monitor.lrs.keys()) == ['lr-Adam', 'lr-Adam-1'] + + lr_monitor = LearningRateMonitor() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=4, + limit_val_batches=0, + limit_train_batches=2, + callbacks=[lr_monitor, BackboneFinetuning(unfreeze_backbone_at_epoch=2), + Check()], + progress_bar_refresh_rate=0, + weights_summary=None, + ) + model = TestModel() + model.training_epoch_end = None + trainer.fit(model) + + # 3 epoch difference + assert len(lr_monitor.lrs["lr-Adam/pg1"]) == len(lr_monitor.lrs["lr-Adam/pg2"]) + 3 From 485120e2e0c672e26e15e943ec876b7b8a89efb4 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 4 Jun 2021 19:30:42 +0100 Subject: [PATCH 02/11] update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f6d4692ec076..5fc69fb34a46e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -191,6 +191,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where checking `trainer.precision` changed to `'mixed'` when specifying 16 in trainer ([#7825](https://github.com/PyTorchLightning/pytorch-lightning/pull/7825)) +- Fixed `LearningRateMonitor` keys not properly setup when running with `BackboneFinetuning` Callback ([#7835](https://github.com/PyTorchLightning/pytorch-lightning/pull/7835)) + + ## [1.3.2] - 2021-05-18 ### Changed From a7e70910a111300f80edd5f4f07c6816234350ab Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 4 Jun 2021 20:44:01 +0100 Subject: [PATCH 03/11] resolve bug --- pytorch_lightning/callbacks/lr_monitor.py | 24 +++++-- tests/callbacks/test_lr_monitor.py | 78 ++++++++++++++++++----- 2 files changed, 80 insertions(+), 22 deletions(-) diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index 5d2d5eae046aa..f3f0903279ca2 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -19,7 +19,6 @@ Monitor and logs learning rate for lr schedulers during training. """ -from collections import defaultdict from typing import Dict, List, Optional from pytorch_lightning.callbacks.base import Callback @@ -138,6 +137,9 @@ def on_train_epoch_start(self, trainer, *args, **kwargs): def _extract_stats(self, trainer, interval: str) -> Dict[str, float]: latest_stat = {} + names = self._find_names(trainer.lr_schedulers, add_lr_sch_names=False) + self._remap_keys(names) + for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers): if scheduler['interval'] == interval or interval == 'any': opt = scheduler['scheduler'].optimizer @@ -153,13 +155,19 @@ def _extract_stats(self, trainer, interval: str) -> Dict[str, float]: ) latest_stat.update(momentum) + print() + print(self.lrs) + print() + return latest_stat def _extract_lr(self, trainer, param_group, name: str) -> Dict[str, float]: lr = param_group.get('lr') - try: + print(trainer.current_epoch, name, lr) + if name in self.lrs: self.lrs[name].append(lr) - except KeyError: + else: + # new params groups have been added and we need to refresh the names. names = self._find_names(trainer.lr_schedulers, add_lr_sch_names=False) self._remap_keys(names) self.lrs[name].append(lr) @@ -170,10 +178,12 @@ def _remap_keys(self, names: List[str]) -> None: for new_name in names: if token in new_name: old_n = new_name.replace(token, '') - self.lrs[new_name] = self.lrs[old_n] - del self.lrs[old_n] + if old_n in self.lrs: + self.lrs[new_name] = self.lrs[old_n] + del self.lrs[old_n] else: - self.lrs[new_name] = [] + if new_name not in self.lrs: + self.lrs[new_name] = [] def _extract_momentum(self, param_group, name: str, use_betas: bool) -> Dict[str, float]: if not self.log_momentum: @@ -192,7 +202,7 @@ def _find_names(self, lr_schedulers, add_lr_sch_names: bool = True) -> List[str] # rate scheduler + multiple parameter groups names = [] seen_optimizers = [] - seen_optimizer_types = defaultdict(int) + seen_optimizer_types = {} for scheduler in lr_schedulers: sch = scheduler['scheduler'] if scheduler['name'] is not None: diff --git a/tests/callbacks/test_lr_monitor.py b/tests/callbacks/test_lr_monitor.py index ac3939fc362ce..66504c254158c 100644 --- a/tests/callbacks/test_lr_monitor.py +++ b/tests/callbacks/test_lr_monitor.py @@ -289,7 +289,12 @@ class TestModel(BoringModel): def __init__(self): super().__init__() - self.backbone = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(True)) + self.backbone = torch.nn.Sequential( + torch.nn.Linear(32, 32), + torch.nn.Linear(32, 32), + torch.nn.Linear(32, 32), + torch.nn.ReLU(True), + ) self.layer = torch.nn.Linear(32, 2) def training_step(self, batch, batch_idx, optimizer_idx): @@ -299,36 +304,79 @@ def forward(self, x): return self.layer(self.backbone(x)) def configure_optimizers(self): - opt = optim.Adam(self.layer.parameters(), lr=0.1) - opt_2 = optim.Adam(self.layer.parameters(), lr=0.1) - opt_3 = optim.Adam(self.layer.parameters(), lr=0.1) - return [opt, opt_2, opt_3 - ], [optim.lr_scheduler.StepLR(opt, step_size=1), - optim.lr_scheduler.StepLR(opt_2, step_size=1)] + parameters = list(filter(lambda p: p.requires_grad, self.parameters())) + opt = optim.Adam(parameters, lr=0.1) + opt_2 = optim.Adam(parameters, lr=0.1) + opt_3 = optim.Adam(parameters, lr=0.1) + optimizers = [opt, opt_2, opt_3] + schedulers = [ + optim.lr_scheduler.StepLR(opt, step_size=1, gamma=0.9), + optim.lr_scheduler.StepLR(opt_2, step_size=1, gamma=0.9), + ] + return optimizers, schedulers class Check(Callback): - def on_epoch_end(self, trainer, pl_module) -> None: + def on_train_epoch_start(self, trainer, pl_module) -> None: + num_param_groups = sum([len(opt.param_groups) for opt in trainer.optimizers]) assert lr_monitor.lr_sch_names == ['lr-Adam', 'lr-Adam-1'] - if trainer.current_epoch > 2: + if trainer.current_epoch == 0: + assert num_param_groups == 3 + elif trainer.current_epoch == 1: + assert num_param_groups == 4 + assert list(lr_monitor.lrs.keys()) == ['lr-Adam-1', 'lr-Adam/pg1', 'lr-Adam/pg2'] + elif trainer.current_epoch == 2: + assert num_param_groups == 5 assert list(lr_monitor.lrs.keys()) == ['lr-Adam/pg1', 'lr-Adam/pg2', 'lr-Adam-1/pg1', 'lr-Adam-1/pg2'] else: - assert list(lr_monitor.lrs.keys()) == ['lr-Adam', 'lr-Adam-1'] + expected = ['lr-Adam/pg1', 'lr-Adam/pg2', 'lr-Adam-1/pg1', 'lr-Adam-1/pg2', 'lr-Adam-1/pg3'] + assert list(lr_monitor.lrs.keys()) == expected + + class TestFinetuning(BackboneFinetuning): + + def freeze_before_training(self, pl_module): + self.freeze(pl_module.backbone[0]) + self.freeze(pl_module.backbone[1]) + self.freeze(pl_module.layer) + + def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int): + """Called when the epoch begins.""" + if epoch == 1 and opt_idx == 0: + self.unfreeze_and_add_param_group(pl_module.backbone[0], optimizer, lr=0.1) + if epoch == 2 and opt_idx == 1: + self.unfreeze_and_add_param_group(pl_module.layer, optimizer, lr=0.1) + + if epoch == 3 and opt_idx == 1: + assert len(optimizer.param_groups) == 2 + self.unfreeze_and_add_param_group(pl_module.backbone[1], optimizer, lr=0.1) + assert len(optimizer.param_groups) == 3 lr_monitor = LearningRateMonitor() trainer = Trainer( default_root_dir=tmpdir, - max_epochs=4, + max_epochs=5, limit_val_batches=0, limit_train_batches=2, - callbacks=[lr_monitor, BackboneFinetuning(unfreeze_backbone_at_epoch=2), - Check()], + callbacks=[TestFinetuning(), lr_monitor, Check()], progress_bar_refresh_rate=0, weights_summary=None, + checkpoint_callback=False ) model = TestModel() model.training_epoch_end = None trainer.fit(model) - # 3 epoch difference - assert len(lr_monitor.lrs["lr-Adam/pg1"]) == len(lr_monitor.lrs["lr-Adam/pg2"]) + 3 + expected = [0.1, 0.09000000000000001, 0.08100000000000002, 0.07290000000000002, 0.06561000000000002] + assert lr_monitor.lrs['lr-Adam/pg1'] == expected + + expected = [0.1, 0.09000000000000001, 0.08100000000000002, 0.07290000000000002] + assert lr_monitor.lrs['lr-Adam/pg2'] == expected + + expected = [0.1, 0.09000000000000001, 0.08100000000000002, 0.07290000000000002, 0.06561000000000002] + assert lr_monitor.lrs['lr-Adam-1/pg1'] == expected + + expected = [0.1, 0.09000000000000001, 0.08100000000000002] + assert lr_monitor.lrs['lr-Adam-1/pg2'] == expected + + expected = [0.1, 0.09000000000000001] + assert lr_monitor.lrs['lr-Adam-1/pg3'] == expected From 37b11c2ff2c52550954e76f9b028a101355e9fe4 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 4 Jun 2021 20:47:23 +0100 Subject: [PATCH 04/11] resolve bug --- pytorch_lightning/callbacks/lr_monitor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index f3f0903279ca2..72b67d954f51a 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -194,6 +194,8 @@ def _extract_momentum(self, param_group, name: str, use_betas: bool) -> Dict[str return {name: momentum} def _add_prefix(self, name, optimizer_cls, seen_optimizer_types) -> str: + if optimizer_cls not in seen_optimizer_types: + return name count = seen_optimizer_types[optimizer_cls] return name + f'-{count}' if count else name From c8d56d56a7a3909ee15b4cd2873debf4b96ff3b3 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 4 Jun 2021 22:02:18 +0100 Subject: [PATCH 05/11] Update pytorch_lightning/callbacks/lr_monitor.py Co-authored-by: Jirka Borovec --- pytorch_lightning/callbacks/lr_monitor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index 72b67d954f51a..1522164220361 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -163,7 +163,6 @@ def _extract_stats(self, trainer, interval: str) -> Dict[str, float]: def _extract_lr(self, trainer, param_group, name: str) -> Dict[str, float]: lr = param_group.get('lr') - print(trainer.current_epoch, name, lr) if name in self.lrs: self.lrs[name].append(lr) else: From 7b527f6d23ce6da142444ec2204dd0a99aec3c61 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 4 Jun 2021 22:02:30 +0100 Subject: [PATCH 06/11] Update pytorch_lightning/callbacks/lr_monitor.py Co-authored-by: Jirka Borovec --- pytorch_lightning/callbacks/lr_monitor.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index 1522164220361..08abf8cf3b945 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -155,10 +155,6 @@ def _extract_stats(self, trainer, interval: str) -> Dict[str, float]: ) latest_stat.update(momentum) - print() - print(self.lrs) - print() - return latest_stat def _extract_lr(self, trainer, param_group, name: str) -> Dict[str, float]: From 0d2cfb54cc73ade5fb3fd633bdbf968912616c9f Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 4 Jun 2021 22:05:32 +0100 Subject: [PATCH 07/11] update on comments --- pytorch_lightning/callbacks/lr_monitor.py | 8 +------- tests/callbacks/test_lr_monitor.py | 14 +++++++------- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index 08abf8cf3b945..f0d5ba319c8f3 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -159,13 +159,7 @@ def _extract_stats(self, trainer, interval: str) -> Dict[str, float]: def _extract_lr(self, trainer, param_group, name: str) -> Dict[str, float]: lr = param_group.get('lr') - if name in self.lrs: - self.lrs[name].append(lr) - else: - # new params groups have been added and we need to refresh the names. - names = self._find_names(trainer.lr_schedulers, add_lr_sch_names=False) - self._remap_keys(names) - self.lrs[name].append(lr) + self.lrs[name].append(lr) return {name: lr} def _remap_keys(self, names: List[str]) -> None: diff --git a/tests/callbacks/test_lr_monitor.py b/tests/callbacks/test_lr_monitor.py index 66504c254158c..385309e5c733d 100644 --- a/tests/callbacks/test_lr_monitor.py +++ b/tests/callbacks/test_lr_monitor.py @@ -310,8 +310,8 @@ def configure_optimizers(self): opt_3 = optim.Adam(parameters, lr=0.1) optimizers = [opt, opt_2, opt_3] schedulers = [ - optim.lr_scheduler.StepLR(opt, step_size=1, gamma=0.9), - optim.lr_scheduler.StepLR(opt_2, step_size=1, gamma=0.9), + optim.lr_scheduler.StepLR(opt, step_size=1, gamma=0.5), + optim.lr_scheduler.StepLR(opt_2, step_size=1, gamma=0.5), ] return optimizers, schedulers @@ -366,17 +366,17 @@ def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int): model.training_epoch_end = None trainer.fit(model) - expected = [0.1, 0.09000000000000001, 0.08100000000000002, 0.07290000000000002, 0.06561000000000002] + expected = [0.1, 0.05, 0.025, 0.0125, 0.00625] assert lr_monitor.lrs['lr-Adam/pg1'] == expected - expected = [0.1, 0.09000000000000001, 0.08100000000000002, 0.07290000000000002] + expected = [0.1, 0.05, 0.025, 0.0125] assert lr_monitor.lrs['lr-Adam/pg2'] == expected - expected = [0.1, 0.09000000000000001, 0.08100000000000002, 0.07290000000000002, 0.06561000000000002] + expected = [0.1, 0.05, 0.025, 0.0125, 0.00625] assert lr_monitor.lrs['lr-Adam-1/pg1'] == expected - expected = [0.1, 0.09000000000000001, 0.08100000000000002] + expected = [0.1, 0.05, 0.025] assert lr_monitor.lrs['lr-Adam-1/pg2'] == expected - expected = [0.1, 0.09000000000000001] + expected = [0.1, 0.05] assert lr_monitor.lrs['lr-Adam-1/pg3'] == expected From b31024ba3b45581660b472805dd3f63bc4325646 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 7 Jun 2021 08:22:54 +0100 Subject: [PATCH 08/11] resolve comments --- pytorch_lightning/callbacks/lr_monitor.py | 54 ++++++++++++----------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index f0d5ba319c8f3..5e0cd7d56f7bf 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -19,7 +19,10 @@ Monitor and logs learning rate for lr schedulers during training. """ -from typing import Dict, List, Optional +from collections import defaultdict +from typing import Any, DefaultDict, Dict, List, Optional, Type + +from torch.optim.optimizer import Optimizer from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn @@ -148,7 +151,7 @@ def _extract_stats(self, trainer, interval: str) -> Dict[str, float]: for i, pg in enumerate(param_groups): suffix = f'/pg{i + 1}' if len(param_groups) > 1 else '' - lr = self._extract_lr(trainer, param_group=pg, name=f'{name}{suffix}') + lr = self._extract_lr(pg, name=f'{name}{suffix}') latest_stat.update(lr) momentum = self._extract_momentum( param_group=pg, name=f'{name}-momentum{suffix}', use_betas=use_betas @@ -157,24 +160,24 @@ def _extract_stats(self, trainer, interval: str) -> Dict[str, float]: return latest_stat - def _extract_lr(self, trainer, param_group, name: str) -> Dict[str, float]: + def _extract_lr(self, param_group: Dict[str, Any], name: str) -> Dict[str, Any]: lr = param_group.get('lr') self.lrs[name].append(lr) return {name: lr} - def _remap_keys(self, names: List[str]) -> None: - token = '/pg1' + def _remap_keys(self, names: List[str], token: str = '/pg1') -> None: + """ + This function is used the remap the keys if param groups for a given optimizer increased. + """ for new_name in names: - if token in new_name: - old_n = new_name.replace(token, '') - if old_n in self.lrs: - self.lrs[new_name] = self.lrs[old_n] - del self.lrs[old_n] - else: - if new_name not in self.lrs: - self.lrs[new_name] = [] - - def _extract_momentum(self, param_group, name: str, use_betas: bool) -> Dict[str, float]: + old_n = new_name.replace(token, '') + if token in new_name and old_n in self.lrs: + self.lrs[new_name] = self.lrs[old_n] + del self.lrs[old_n] + elif new_name not in self.lrs: + self.lrs[new_name] = [] + + def _extract_momentum(self, param_group: Dict[str, Any], name: str, use_betas: bool) -> Dict[str, float]: if not self.log_momentum: return {} @@ -182,18 +185,20 @@ def _extract_momentum(self, param_group, name: str, use_betas: bool) -> Dict[str self.last_momentum_values[name] = momentum return {name: momentum} - def _add_prefix(self, name, optimizer_cls, seen_optimizer_types) -> str: + def _add_prefix( + self, name: str, optimizer_cls: Type[Optimizer], seen_optimizer_types: DefaultDict[Type[Optimizer], int] + ) -> str: if optimizer_cls not in seen_optimizer_types: return name count = seen_optimizer_types[optimizer_cls] - return name + f'-{count}' if count else name + return name + f'-{count - 1}' if count > 1 else name - def _find_names(self, lr_schedulers, add_lr_sch_names: bool = True) -> List[str]: + def _find_names(self, lr_schedulers: List, add_lr_sch_names: bool = True) -> List[str]: # Create unique names in the case we have multiple of the same learning # rate scheduler + multiple parameter groups names = [] seen_optimizers = [] - seen_optimizer_types = {} + seen_optimizer_types = defaultdict(int) for scheduler in lr_schedulers: sch = scheduler['scheduler'] if scheduler['name'] is not None: @@ -204,21 +209,18 @@ def _find_names(self, lr_schedulers, add_lr_sch_names: bool = True) -> List[str] seen_optimizers.append(sch.optimizer) optimizer_cls = type(sch.optimizer) if scheduler['name'] is None: - if optimizer_cls not in seen_optimizer_types: - seen_optimizer_types[optimizer_cls] = 0 - else: - seen_optimizer_types[optimizer_cls] += 1 + seen_optimizer_types[optimizer_cls] += 1 # Multiple param groups for the same scheduler param_groups = sch.optimizer.param_groups + name = self._add_prefix(name, optimizer_cls, seen_optimizer_types) + if len(param_groups) != 1: for i in range(len(param_groups)): - temp = self._add_prefix(name, optimizer_cls, seen_optimizer_types) - temp = f'{temp}/pg{i + 1}' + temp = f'{name}/pg{i + 1}' names.append(temp) else: - name = self._add_prefix(name, optimizer_cls, seen_optimizer_types) names.append(name) if add_lr_sch_names: From cbcf86fd59ecb03724c991a5be64944cf8a4c93a Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 7 Jun 2021 08:24:37 +0100 Subject: [PATCH 09/11] update --- pytorch_lightning/callbacks/lr_monitor.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index 5e0cd7d56f7bf..62b3df1b07118 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -170,10 +170,9 @@ def _remap_keys(self, names: List[str], token: str = '/pg1') -> None: This function is used the remap the keys if param groups for a given optimizer increased. """ for new_name in names: - old_n = new_name.replace(token, '') - if token in new_name and old_n in self.lrs: - self.lrs[new_name] = self.lrs[old_n] - del self.lrs[old_n] + old_name = new_name.replace(token, '') + if token in new_name and old_name in self.lrs: + self.lrs[new_name] = self.lrs.pop(old_name) elif new_name not in self.lrs: self.lrs[new_name] = [] From af8e171e4b630a32d24d8e4aea9e0b5cfe15a60a Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Mon, 7 Jun 2021 10:54:38 +0100 Subject: [PATCH 10/11] Update tests/callbacks/test_lr_monitor.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- tests/callbacks/test_lr_monitor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/callbacks/test_lr_monitor.py b/tests/callbacks/test_lr_monitor.py index 385309e5c733d..808165d61b053 100644 --- a/tests/callbacks/test_lr_monitor.py +++ b/tests/callbacks/test_lr_monitor.py @@ -324,13 +324,13 @@ def on_train_epoch_start(self, trainer, pl_module) -> None: assert num_param_groups == 3 elif trainer.current_epoch == 1: assert num_param_groups == 4 - assert list(lr_monitor.lrs.keys()) == ['lr-Adam-1', 'lr-Adam/pg1', 'lr-Adam/pg2'] + assert list(lr_monitor.lrs) == ['lr-Adam-1', 'lr-Adam/pg1', 'lr-Adam/pg2'] elif trainer.current_epoch == 2: assert num_param_groups == 5 - assert list(lr_monitor.lrs.keys()) == ['lr-Adam/pg1', 'lr-Adam/pg2', 'lr-Adam-1/pg1', 'lr-Adam-1/pg2'] + assert list(lr_monitor.lrs) == ['lr-Adam/pg1', 'lr-Adam/pg2', 'lr-Adam-1/pg1', 'lr-Adam-1/pg2'] else: expected = ['lr-Adam/pg1', 'lr-Adam/pg2', 'lr-Adam-1/pg1', 'lr-Adam-1/pg2', 'lr-Adam-1/pg3'] - assert list(lr_monitor.lrs.keys()) == expected + assert list(lr_monitor.lrs) == expected class TestFinetuning(BackboneFinetuning): From 472ba3b1862070076c023077134697ccfbb28d70 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Mon, 7 Jun 2021 10:54:46 +0100 Subject: [PATCH 11/11] Update pytorch_lightning/callbacks/lr_monitor.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/callbacks/lr_monitor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index 62b3df1b07118..5a8e8be5138dd 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -151,7 +151,7 @@ def _extract_stats(self, trainer, interval: str) -> Dict[str, float]: for i, pg in enumerate(param_groups): suffix = f'/pg{i + 1}' if len(param_groups) > 1 else '' - lr = self._extract_lr(pg, name=f'{name}{suffix}') + lr = self._extract_lr(pg, f'{name}{suffix}') latest_stat.update(lr) momentum = self._extract_momentum( param_group=pg, name=f'{name}-momentum{suffix}', use_betas=use_betas