Skip to content

Commit 9af1dd7

Browse files
authored
Deprecate lr_sch_names from LearningRateMonitor (#10066)
1 parent b8ac176 commit 9af1dd7

File tree

4 files changed

+44
-20
lines changed

4 files changed

+44
-20
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,10 +445,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
445445
- Deprecated `ClusterEnvironment.creates_children()` in favor of `ClusterEnvironment.creates_processes_externally` (property) ([#10106](https://github.com/PyTorchLightning/pytorch-lightning/pull/10106))
446446

447447

448-
449448
- Deprecated `PrecisionPlugin.master_params()` in favor of `PrecisionPlugin.main_params()` ([#10105](https://github.com/PyTorchLightning/pytorch-lightning/pull/10105))
450449

451450

451+
- Deprecated `lr_sch_names` from `LearningRateMonitor` ([#10066](https://github.com/PyTorchLightning/pytorch-lightning/pull/10066))
452+
453+
452454
### Removed
453455

454456
- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))

pytorch_lightning/callbacks/lr_monitor.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import pytorch_lightning as pl
2929
from pytorch_lightning.callbacks.base import Callback
3030
from pytorch_lightning.utilities import rank_zero_warn
31+
from pytorch_lightning.utilities.distributed import rank_zero_deprecation
3132
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3233

3334

@@ -93,7 +94,7 @@ def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool =
9394
self.logging_interval = logging_interval
9495
self.log_momentum = log_momentum
9596
self.lrs: Dict[str, List[float]] = {}
96-
self.lr_sch_names: List[str] = []
97+
self._lr_sch_names: List[str] = []
9798

9899
def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
99100
"""Called before training, determines unique names for all lr schedulers in the case of multiple of the
@@ -334,6 +335,16 @@ def _check_duplicates_and_update_name(
334335
name_list = [self._add_suffix(name, param_groups, i) for i in range(len(param_groups))]
335336

336337
if add_lr_sch_names:
337-
self.lr_sch_names.append(name)
338+
self._lr_sch_names.append(name)
338339

339340
return name_list
341+
342+
@property
343+
def lr_sch_names(self) -> List[str]:
344+
# TODO remove `lr_sch_names` and `add_lr_sch_names` argument in v1.7.0
345+
rank_zero_deprecation(
346+
"`LearningRateMonitor.lr_sch_names` has been deprecated in v1.5 and will be removed in 1.7."
347+
" Consider accessing them using `LearningRateMonitor.lrs.keys()` which will return"
348+
" the names of all the optimizers, even those without a scheduler."
349+
)
350+
return self._lr_sch_names

tests/callbacks/test_lr_monitor.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_lr_monitor_single_lr(tmpdir):
4141
assert lr_monitor.lrs, "No learning rates logged"
4242
assert all(v is None for v in lr_monitor.last_momentum_values.values()), "Momentum should not be logged by default"
4343
assert len(lr_monitor.lrs) == len(trainer.lr_schedulers)
44-
assert lr_monitor.lr_sch_names == list(lr_monitor.lrs.keys()) == ["lr-SGD"]
44+
assert list(lr_monitor.lrs) == ["lr-SGD"]
4545

4646

4747
@pytest.mark.parametrize("opt", ["SGD", "Adam"])
@@ -77,7 +77,7 @@ def configure_optimizers(self):
7777

7878
assert all(v is not None for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged"
7979
assert len(lr_monitor.last_momentum_values) == len(trainer.lr_schedulers)
80-
assert all(k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values.keys())
80+
assert all(k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values)
8181

8282

8383
def test_log_momentum_no_momentum_optimizer(tmpdir):
@@ -104,7 +104,7 @@ def configure_optimizers(self):
104104

105105
assert all(v == 0 for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged"
106106
assert len(lr_monitor.last_momentum_values) == len(trainer.lr_schedulers)
107-
assert all(k == "lr-ASGD-momentum" for k in lr_monitor.last_momentum_values.keys())
107+
assert all(k == "lr-ASGD-momentum" for k in lr_monitor.last_momentum_values)
108108

109109

110110
def test_lr_monitor_no_lr_scheduler_single_lr(tmpdir):
@@ -127,7 +127,7 @@ def configure_optimizers(self):
127127

128128
assert lr_monitor.lrs, "No learning rates logged"
129129
assert len(lr_monitor.lrs) == len(trainer.optimizers)
130-
assert lr_monitor.lr_sch_names == ["lr-SGD"]
130+
assert list(lr_monitor.lrs) == ["lr-SGD"]
131131

132132

133133
@pytest.mark.parametrize("opt", ["SGD", "Adam"])
@@ -162,7 +162,7 @@ def configure_optimizers(self):
162162

163163
assert all(v is not None for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged"
164164
assert len(lr_monitor.last_momentum_values) == len(trainer.optimizers)
165-
assert all(k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values.keys())
165+
assert all(k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values)
166166

167167

168168
def test_log_momentum_no_momentum_optimizer_no_lr_scheduler(tmpdir):
@@ -188,7 +188,7 @@ def configure_optimizers(self):
188188

189189
assert all(v == 0 for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged"
190190
assert len(lr_monitor.last_momentum_values) == len(trainer.optimizers)
191-
assert all(k == "lr-ASGD-momentum" for k in lr_monitor.last_momentum_values.keys())
191+
assert all(k == "lr-ASGD-momentum" for k in lr_monitor.last_momentum_values)
192192

193193

194194
def test_lr_monitor_no_logger(tmpdir):
@@ -238,7 +238,7 @@ def configure_optimizers(self):
238238

239239
assert lr_monitor.lrs, "No learning rates logged"
240240
assert len(lr_monitor.lrs) == len(trainer.lr_schedulers)
241-
assert lr_monitor.lr_sch_names == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly"
241+
assert list(lr_monitor.lrs) == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly"
242242

243243
if logging_interval == "step":
244244
expected_number_logged = trainer.global_step // log_every_n_steps
@@ -281,7 +281,7 @@ def configure_optimizers(self):
281281

282282
assert lr_monitor.lrs, "No learning rates logged"
283283
assert len(lr_monitor.lrs) == len(trainer.optimizers)
284-
assert lr_monitor.lr_sch_names == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly"
284+
assert list(lr_monitor.lrs) == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly"
285285

286286
if logging_interval == "step":
287287
expected_number_logged = trainer.global_step // log_every_n_steps
@@ -317,8 +317,7 @@ def configure_optimizers(self):
317317

318318
assert lr_monitor.lrs, "No learning rates logged"
319319
assert len(lr_monitor.lrs) == 2 * len(trainer.lr_schedulers)
320-
assert lr_monitor.lr_sch_names == ["lr-Adam"]
321-
assert list(lr_monitor.lrs.keys()) == ["lr-Adam/pg1", "lr-Adam/pg2"], "Names of learning rates not set correctly"
320+
assert list(lr_monitor.lrs) == ["lr-Adam/pg1", "lr-Adam/pg2"], "Names of learning rates not set correctly"
322321

323322

324323
def test_lr_monitor_custom_name(tmpdir):
@@ -339,7 +338,7 @@ def configure_optimizers(self):
339338
enable_model_summary=False,
340339
)
341340
trainer.fit(TestModel())
342-
assert lr_monitor.lr_sch_names == list(lr_monitor.lrs.keys()) == ["my_logging_name"]
341+
assert list(lr_monitor.lrs) == ["my_logging_name"]
343342

344343

345344
def test_lr_monitor_custom_pg_name(tmpdir):
@@ -360,7 +359,6 @@ def configure_optimizers(self):
360359
enable_model_summary=False,
361360
)
362361
trainer.fit(TestModel())
363-
assert lr_monitor.lr_sch_names == ["lr-SGD"]
364362
assert list(lr_monitor.lrs) == ["lr-SGD/linear"]
365363

366364

@@ -434,7 +432,7 @@ def configure_optimizers(self):
434432
class Check(Callback):
435433
def on_train_epoch_start(self, trainer, pl_module) -> None:
436434
num_param_groups = sum(len(opt.param_groups) for opt in trainer.optimizers)
437-
assert lr_monitor.lr_sch_names == ["lr-Adam", "lr-Adam-1", "lr-Adam-2"]
435+
438436
if trainer.current_epoch == 0:
439437
assert num_param_groups == 3
440438
elif trainer.current_epoch == 1:
@@ -512,7 +510,10 @@ def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int):
512510
assert lr_monitor.lrs["lr-Adam-1/pg3"] == expected
513511

514512

515-
def test_lr_monitor_multiple_param_groups_no_scheduler(tmpdir):
513+
def test_lr_monitor_multiple_param_groups_no_lr_scheduler(tmpdir):
514+
"""Test that the `LearningRateMonitor` is able to log correct keys with multiple param groups and no
515+
lr_scheduler."""
516+
516517
class TestModel(BoringModel):
517518
def __init__(self, lr, momentum):
518519
super().__init__()
@@ -550,8 +551,7 @@ def configure_optimizers(self):
550551
trainer.fit(model)
551552

552553
assert len(lr_monitor.lrs) == len(trainer.optimizers[0].param_groups)
553-
assert list(lr_monitor.lrs.keys()) == ["lr-Adam/pg1", "lr-Adam/pg2"]
554-
assert lr_monitor.lr_sch_names == ["lr-Adam"]
555-
assert list(lr_monitor.last_momentum_values.keys()) == ["lr-Adam/pg1-momentum", "lr-Adam/pg2-momentum"]
554+
assert list(lr_monitor.lrs) == ["lr-Adam/pg1", "lr-Adam/pg2"]
555+
assert list(lr_monitor.last_momentum_values) == ["lr-Adam/pg1-momentum", "lr-Adam/pg2-momentum"]
556556
assert all(val == momentum for val in lr_monitor.last_momentum_values.values())
557557
assert all(all(val == lr for val in lr_monitor.lrs[lr_key]) for lr_key in lr_monitor.lrs)

tests/deprecated_api/test_remove_1-7.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from pytorch_lightning import Callback, LightningDataModule, Trainer
2121
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
22+
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
2223
from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor
2324
from pytorch_lightning.loggers import LoggerCollection, TestTubeLogger
2425
from tests.callbacks.test_callbacks import OldStatefulCallback
@@ -438,3 +439,13 @@ def test_v1_7_0_resume_from_checkpoint_trainer_constructor(tmpdir):
438439
trainer = Trainer(resume_from_checkpoint="trainer_arg_path")
439440
with pytest.raises(FileNotFoundError, match="Checkpoint at fit_arg_ckpt_path not found. Aborting training."):
440441
trainer.fit(model, ckpt_path="fit_arg_ckpt_path")
442+
443+
444+
def test_v1_7_0_deprecate_lr_sch_names(tmpdir):
445+
model = BoringModel()
446+
lr_monitor = LearningRateMonitor()
447+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=[lr_monitor])
448+
trainer.fit(model)
449+
450+
with pytest.deprecated_call(match="`LearningRateMonitor.lr_sch_names` has been deprecated in v1.5"):
451+
assert lr_monitor.lr_sch_names == ["lr-SGD"]

0 commit comments

Comments
 (0)