@@ -41,7 +41,7 @@ def test_lr_monitor_single_lr(tmpdir):
4141 assert lr_monitor .lrs , "No learning rates logged"
4242 assert all (v is None for v in lr_monitor .last_momentum_values .values ()), "Momentum should not be logged by default"
4343 assert len (lr_monitor .lrs ) == len (trainer .lr_schedulers )
44- assert lr_monitor . lr_sch_names == list (lr_monitor .lrs . keys () ) == ["lr-SGD" ]
44+ assert list (lr_monitor .lrs ) == ["lr-SGD" ]
4545
4646
4747@pytest .mark .parametrize ("opt" , ["SGD" , "Adam" ])
@@ -77,7 +77,7 @@ def configure_optimizers(self):
7777
7878 assert all (v is not None for v in lr_monitor .last_momentum_values .values ()), "Expected momentum to be logged"
7979 assert len (lr_monitor .last_momentum_values ) == len (trainer .lr_schedulers )
80- assert all (k == f"lr-{ opt } -momentum" for k in lr_monitor .last_momentum_values . keys () )
80+ assert all (k == f"lr-{ opt } -momentum" for k in lr_monitor .last_momentum_values )
8181
8282
8383def test_log_momentum_no_momentum_optimizer (tmpdir ):
@@ -104,7 +104,7 @@ def configure_optimizers(self):
104104
105105 assert all (v == 0 for v in lr_monitor .last_momentum_values .values ()), "Expected momentum to be logged"
106106 assert len (lr_monitor .last_momentum_values ) == len (trainer .lr_schedulers )
107- assert all (k == "lr-ASGD-momentum" for k in lr_monitor .last_momentum_values . keys () )
107+ assert all (k == "lr-ASGD-momentum" for k in lr_monitor .last_momentum_values )
108108
109109
110110def test_lr_monitor_no_lr_scheduler_single_lr (tmpdir ):
@@ -127,7 +127,7 @@ def configure_optimizers(self):
127127
128128 assert lr_monitor .lrs , "No learning rates logged"
129129 assert len (lr_monitor .lrs ) == len (trainer .optimizers )
130- assert lr_monitor .lr_sch_names == ["lr-SGD" ]
130+ assert list ( lr_monitor .lrs ) == ["lr-SGD" ]
131131
132132
133133@pytest .mark .parametrize ("opt" , ["SGD" , "Adam" ])
@@ -162,7 +162,7 @@ def configure_optimizers(self):
162162
163163 assert all (v is not None for v in lr_monitor .last_momentum_values .values ()), "Expected momentum to be logged"
164164 assert len (lr_monitor .last_momentum_values ) == len (trainer .optimizers )
165- assert all (k == f"lr-{ opt } -momentum" for k in lr_monitor .last_momentum_values . keys () )
165+ assert all (k == f"lr-{ opt } -momentum" for k in lr_monitor .last_momentum_values )
166166
167167
168168def test_log_momentum_no_momentum_optimizer_no_lr_scheduler (tmpdir ):
@@ -188,7 +188,7 @@ def configure_optimizers(self):
188188
189189 assert all (v == 0 for v in lr_monitor .last_momentum_values .values ()), "Expected momentum to be logged"
190190 assert len (lr_monitor .last_momentum_values ) == len (trainer .optimizers )
191- assert all (k == "lr-ASGD-momentum" for k in lr_monitor .last_momentum_values . keys () )
191+ assert all (k == "lr-ASGD-momentum" for k in lr_monitor .last_momentum_values )
192192
193193
194194def test_lr_monitor_no_logger (tmpdir ):
@@ -238,7 +238,7 @@ def configure_optimizers(self):
238238
239239 assert lr_monitor .lrs , "No learning rates logged"
240240 assert len (lr_monitor .lrs ) == len (trainer .lr_schedulers )
241- assert lr_monitor .lr_sch_names == ["lr-Adam" , "lr-Adam-1" ], "Names of learning rates not set correctly"
241+ assert list ( lr_monitor .lrs ) == ["lr-Adam" , "lr-Adam-1" ], "Names of learning rates not set correctly"
242242
243243 if logging_interval == "step" :
244244 expected_number_logged = trainer .global_step // log_every_n_steps
@@ -281,7 +281,7 @@ def configure_optimizers(self):
281281
282282 assert lr_monitor .lrs , "No learning rates logged"
283283 assert len (lr_monitor .lrs ) == len (trainer .optimizers )
284- assert lr_monitor .lr_sch_names == ["lr-Adam" , "lr-Adam-1" ], "Names of learning rates not set correctly"
284+ assert list ( lr_monitor .lrs ) == ["lr-Adam" , "lr-Adam-1" ], "Names of learning rates not set correctly"
285285
286286 if logging_interval == "step" :
287287 expected_number_logged = trainer .global_step // log_every_n_steps
@@ -317,8 +317,7 @@ def configure_optimizers(self):
317317
318318 assert lr_monitor .lrs , "No learning rates logged"
319319 assert len (lr_monitor .lrs ) == 2 * len (trainer .lr_schedulers )
320- assert lr_monitor .lr_sch_names == ["lr-Adam" ]
321- assert list (lr_monitor .lrs .keys ()) == ["lr-Adam/pg1" , "lr-Adam/pg2" ], "Names of learning rates not set correctly"
320+ assert list (lr_monitor .lrs ) == ["lr-Adam/pg1" , "lr-Adam/pg2" ], "Names of learning rates not set correctly"
322321
323322
324323def test_lr_monitor_custom_name (tmpdir ):
@@ -339,7 +338,7 @@ def configure_optimizers(self):
339338 enable_model_summary = False ,
340339 )
341340 trainer .fit (TestModel ())
342- assert lr_monitor . lr_sch_names == list (lr_monitor .lrs . keys () ) == ["my_logging_name" ]
341+ assert list (lr_monitor .lrs ) == ["my_logging_name" ]
343342
344343
345344def test_lr_monitor_custom_pg_name (tmpdir ):
@@ -360,7 +359,6 @@ def configure_optimizers(self):
360359 enable_model_summary = False ,
361360 )
362361 trainer .fit (TestModel ())
363- assert lr_monitor .lr_sch_names == ["lr-SGD" ]
364362 assert list (lr_monitor .lrs ) == ["lr-SGD/linear" ]
365363
366364
@@ -434,7 +432,7 @@ def configure_optimizers(self):
434432 class Check (Callback ):
435433 def on_train_epoch_start (self , trainer , pl_module ) -> None :
436434 num_param_groups = sum (len (opt .param_groups ) for opt in trainer .optimizers )
437- assert lr_monitor . lr_sch_names == [ "lr-Adam" , "lr-Adam-1" , "lr-Adam-2" ]
435+
438436 if trainer .current_epoch == 0 :
439437 assert num_param_groups == 3
440438 elif trainer .current_epoch == 1 :
@@ -512,7 +510,10 @@ def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int):
512510 assert lr_monitor .lrs ["lr-Adam-1/pg3" ] == expected
513511
514512
515- def test_lr_monitor_multiple_param_groups_no_scheduler (tmpdir ):
513+ def test_lr_monitor_multiple_param_groups_no_lr_scheduler (tmpdir ):
514+ """Test that the `LearningRateMonitor` is able to log correct keys with multiple param groups and no
515+ lr_scheduler."""
516+
516517 class TestModel (BoringModel ):
517518 def __init__ (self , lr , momentum ):
518519 super ().__init__ ()
@@ -550,8 +551,7 @@ def configure_optimizers(self):
550551 trainer .fit (model )
551552
552553 assert len (lr_monitor .lrs ) == len (trainer .optimizers [0 ].param_groups )
553- assert list (lr_monitor .lrs .keys ()) == ["lr-Adam/pg1" , "lr-Adam/pg2" ]
554- assert lr_monitor .lr_sch_names == ["lr-Adam" ]
555- assert list (lr_monitor .last_momentum_values .keys ()) == ["lr-Adam/pg1-momentum" , "lr-Adam/pg2-momentum" ]
554+ assert list (lr_monitor .lrs ) == ["lr-Adam/pg1" , "lr-Adam/pg2" ]
555+ assert list (lr_monitor .last_momentum_values ) == ["lr-Adam/pg1-momentum" , "lr-Adam/pg2-momentum" ]
556556 assert all (val == momentum for val in lr_monitor .last_momentum_values .values ())
557557 assert all (all (val == lr for val in lr_monitor .lrs [lr_key ]) for lr_key in lr_monitor .lrs )
0 commit comments