@@ -33,11 +33,11 @@ class LearningRateMonitor(Callback):
3333 Automatically monitor and logs learning rate for learning rate schedulers during training.
3434
3535 Args:
36- logging_interval: set to `epoch` or `step` to log `lr ` of all optimizers
37- at the same interval, set to `None` to log at individual interval
38- according to the `interval` key of each scheduler. Defaults to ``None``.
36+ logging_interval: set to ``' epoch'`` or ``' step'`` to log ``lr` ` of all optimizers
37+ at the same interval, set to `` None` ` to log at individual interval
38+ according to the `` interval` ` key of each scheduler. Defaults to ``None``.
3939 log_momentum: option to also log the momentum values of the optimizer, if the optimizer
40- has the `momentum` attribute. Defaults to ``False``.
40+ has the `` momentum`` or ``betas` ` attribute. Defaults to ``False``.
4141
4242 Example::
4343
@@ -47,17 +47,19 @@ class LearningRateMonitor(Callback):
4747 >>> trainer = Trainer(callbacks=[lr_monitor])
4848
4949 Logging names are automatically determined based on optimizer class name.
50- In case of multiple optimizers of same type, they will be named `Adam`,
51- `Adam-1` etc. If a optimizer has multiple parameter groups they will
52- be named `Adam/pg1`, `Adam/pg2` etc. To control naming, pass in a
53- `name` keyword in the construction of the learning rate schdulers
50+ In case of multiple optimizers of same type, they will be named `` Adam` `,
51+ `` Adam-1` ` etc. If a optimizer has multiple parameter groups they will
52+ be named `` Adam/pg1`` , `` Adam/pg2` ` etc. To control naming, pass in a
53+ `` name` ` keyword in the construction of the learning rate schdulers
5454
5555 Example::
5656
5757 def configure_optimizer(self):
5858 optimizer = torch.optim.Adam(...)
59- lr_scheduler = {'scheduler': torch.optim.lr_scheduler.LambdaLR(optimizer, ...)
60- 'name': 'my_logging_name'}
59+ lr_scheduler = {
60+ 'scheduler': torch.optim.lr_scheduler.LambdaLR(optimizer, ...)
61+ 'name': 'my_logging_name'
62+ }
6163 return [optimizer], [lr_scheduler]
6264
6365 """
@@ -80,16 +82,28 @@ def on_train_start(self, trainer, *args, **kwargs):
8082 """
8183 if not trainer .logger :
8284 raise MisconfigurationException (
83- 'Cannot use LearningRateMonitor callback with Trainer that has no logger.'
85+ 'Cannot use ` LearningRateMonitor` callback with ` Trainer` that has no logger.'
8486 )
8587
8688 if not trainer .lr_schedulers :
8789 rank_zero_warn (
88- 'You are using LearningRateMonitor callback with models that'
90+ 'You are using ` LearningRateMonitor` callback with models that'
8991 ' have no learning rate schedulers. Please see documentation'
9092 ' for `configure_optimizers` method.' , RuntimeWarning
9193 )
9294
95+ if self .log_momentum :
96+ def _check_no_key (key ):
97+ return any (
98+ key not in sch ['scheduler' ].optimizer .defaults for sch in trainer .lr_schedulers
99+ )
100+
101+ if _check_no_key ('momentum' ) and _check_no_key ('betas' ):
102+ rank_zero_warn (
103+ "You have set log_momentum=True, but some optimizers do not"
104+ " have momentum. This will log a value 0 for the momentum." , RuntimeWarning
105+ )
106+
93107 # Find names for schedulers
94108 names = self ._find_names (trainer .lr_schedulers )
95109
@@ -105,35 +119,33 @@ def on_train_batch_start(self, trainer, *args, **kwargs):
105119 interval = 'step' if self .logging_interval is None else 'any'
106120 latest_stat = self ._extract_stats (trainer , interval )
107121
108- if trainer . logger is not None and latest_stat :
122+ if latest_stat :
109123 trainer .logger .log_metrics (latest_stat , step = trainer .global_step )
110124
111125 def on_train_epoch_start (self , trainer , * args , ** kwargs ):
112126 if self .logging_interval != 'step' :
113127 interval = 'epoch' if self .logging_interval is None else 'any'
114128 latest_stat = self ._extract_stats (trainer , interval )
115129
116- if trainer . logger is not None and latest_stat :
130+ if latest_stat :
117131 trainer .logger .log_metrics (latest_stat , step = trainer .global_step )
118132
119133 def _extract_stats (self , trainer , interval : str ) -> Dict [str , float ]:
120134 latest_stat = {}
121135
122136 for name , scheduler in zip (self .lr_sch_names , trainer .lr_schedulers ):
123137 if scheduler ['interval' ] == interval or interval == 'any' :
124- param_groups = scheduler ['scheduler' ].optimizer .param_groups
125- if len (param_groups ) != 1 :
126- for i , pg in enumerate (param_groups ):
127- lr = self ._extract_lr (param_group = pg , name = f'{ name } /pg{ i + 1 } ' )
128- latest_stat .update (lr )
129- momentum = self ._extract_momentum (param_group = pg , name = f'{ name } -momentum/pg{ i + 1 } ' )
130- latest_stat .update (momentum )
131-
132- else :
133- pg = param_groups [0 ]
134- lr = self ._extract_lr (param_group = pg , name = name )
138+ opt = scheduler ['scheduler' ].optimizer
139+ param_groups = opt .param_groups
140+ use_betas = 'betas' in opt .defaults
141+
142+ for i , pg in enumerate (param_groups ):
143+ suffix = f'/pg{ i + 1 } ' if len (param_groups ) > 1 else ''
144+ lr = self ._extract_lr (param_group = pg , name = f'{ name } { suffix } ' )
135145 latest_stat .update (lr )
136- momentum = self ._extract_momentum (param_group = pg , name = f'{ name } -momentum' )
146+ momentum = self ._extract_momentum (
147+ param_group = pg , name = f'{ name } -momentum{ suffix } ' , use_betas = use_betas
148+ )
137149 latest_stat .update (momentum )
138150
139151 return latest_stat
@@ -143,11 +155,11 @@ def _extract_lr(self, param_group, name: str) -> Dict[str, float]:
143155 self .lrs [name ].append (lr )
144156 return {name : lr }
145157
146- def _extract_momentum (self , param_group , name : str ) -> Dict [str , float ]:
158+ def _extract_momentum (self , param_group , name : str , use_betas : bool ) -> Dict [str , float ]:
147159 if not self .log_momentum :
148160 return {}
149161
150- momentum = param_group .get ('momentum' )
162+ momentum = param_group .get ('betas' )[ 0 ] if use_betas else param_group . get ( ' momentum', 0 )
151163 self .last_momentum_values [name ] = momentum
152164 return {name : momentum }
153165
@@ -190,5 +202,4 @@ def _should_log(trainer) -> bool:
190202 or trainer .should_stop
191203 )
192204
193- should_log = should_log and not trainer .fast_dev_run
194205 return should_log
0 commit comments