1919Monitor and logs learning rate for lr schedulers during training.
2020
2121"""
22+ import itertools
2223from collections import defaultdict
2324from typing import Any , DefaultDict , Dict , List , Optional , Set , Tuple , Type
2425
@@ -123,7 +124,7 @@ def _check_no_key(key: str) -> bool:
123124 )
124125
125126 # Find names for schedulers
126- names : List [str ] = []
127+ names : List [List [ str ] ] = []
127128 (
128129 sched_hparam_keys ,
129130 optimizers_with_scheduler ,
@@ -140,8 +141,9 @@ def _check_no_key(key: str) -> bool:
140141 names .extend (optimizer_hparam_keys )
141142
142143 # Initialize for storing values
143- self .lrs = {name : [] for name in names }
144- self .last_momentum_values = {name + "-momentum" : None for name in names }
144+ names_flatten = list (itertools .chain .from_iterable (names ))
145+ self .lrs = {name : [] for name in names_flatten }
146+ self .last_momentum_values = {name + "-momentum" : None for name in names_flatten }
145147
146148 def on_train_batch_start (self , trainer : "pl.Trainer" , * args : Any , ** kwargs : Any ) -> None :
147149 if not trainer .logger_connector .should_update_logs :
@@ -172,7 +174,7 @@ def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, floa
172174 ) = self ._find_names_from_schedulers (trainer .lr_schedulers , add_lr_sch_names = False )
173175 self ._remap_keys (scheduler_hparam_keys )
174176
175- for name , scheduler in zip (self . lr_sch_names , trainer .lr_schedulers ):
177+ for name , scheduler in zip (scheduler_hparam_keys , trainer .lr_schedulers ):
176178 if interval in [scheduler ["interval" ], "any" ]:
177179 opt = scheduler ["scheduler" ].optimizer
178180 current_stat = self ._get_lr_momentum_stat (opt , name )
@@ -186,23 +188,22 @@ def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, floa
186188 )
187189 self ._remap_keys (optimizer_hparam_keys )
188190
189- for opt , name in zip (optimizers_without_scheduler , optimizer_hparam_keys ):
190- current_stat = self ._get_lr_momentum_stat (opt , name )
191+ for opt , names in zip (optimizers_without_scheduler , optimizer_hparam_keys ):
192+ current_stat = self ._get_lr_momentum_stat (opt , names )
191193 latest_stat .update (current_stat )
192194
193195 return latest_stat
194196
195- def _get_lr_momentum_stat (self , optimizer : Optimizer , name : str ) -> Dict [str , float ]:
197+ def _get_lr_momentum_stat (self , optimizer : Optimizer , names : List [ str ] ) -> Dict [str , float ]:
196198 lr_momentum_stat = {}
197199 param_groups = optimizer .param_groups
198200 use_betas = "betas" in optimizer .defaults
199201
200- for i , pg in enumerate (param_groups ):
201- name_and_suffix = self ._add_suffix (name , param_groups , i )
202- lr = self ._extract_lr (pg , name_and_suffix )
202+ for pg , name in zip (param_groups , names ):
203+ lr = self ._extract_lr (pg , name )
203204 lr_momentum_stat .update (lr )
204205 momentum = self ._extract_momentum (
205- param_group = pg , name = name_and_suffix .replace (name , f"{ name } -momentum" ), use_betas = use_betas
206+ param_group = pg , name = name .replace (name , f"{ name } -momentum" ), use_betas = use_betas
206207 )
207208 lr_momentum_stat .update (momentum )
208209
@@ -213,14 +214,15 @@ def _extract_lr(self, param_group: Dict[str, Any], name: str) -> Dict[str, Any]:
213214 self .lrs [name ].append (lr )
214215 return {name : lr }
215216
216- def _remap_keys (self , names : List [str ], token : str = "/pg1" ) -> None :
217+ def _remap_keys (self , names : List [List [ str ] ], token : str = "/pg1" ) -> None :
217218 """This function is used the remap the keys if param groups for a given optimizer increased."""
218- for new_name in names :
219- old_name = new_name .replace (token , "" )
220- if token in new_name and old_name in self .lrs :
221- self .lrs [new_name ] = self .lrs .pop (old_name )
222- elif new_name not in self .lrs :
223- self .lrs [new_name ] = []
219+ for group_new_names in names :
220+ for new_name in group_new_names :
221+ old_name = new_name .replace (token , "" )
222+ if token in new_name and old_name in self .lrs :
223+ self .lrs [new_name ] = self .lrs .pop (old_name )
224+ elif new_name not in self .lrs :
225+ self .lrs [new_name ] = []
224226
225227 def _extract_momentum (self , param_group : Dict [str , List ], name : str , use_betas : bool ) -> Dict [str , float ]:
226228 if not self .log_momentum :
@@ -258,7 +260,7 @@ def _duplicate_param_group_names(self, param_groups: List[Dict]) -> Set[str]:
258260
259261 def _find_names_from_schedulers (
260262 self , lr_schedulers : List , add_lr_sch_names : bool = True
261- ) -> Tuple [List [str ], List [Optimizer ], DefaultDict [Type [Optimizer ], int ]]:
263+ ) -> Tuple [List [List [ str ] ], List [Optimizer ], DefaultDict [Type [Optimizer ], int ]]:
262264 # Create unique names in the case we have multiple of the same learning
263265 # rate scheduler + multiple parameter groups
264266 names = []
@@ -271,10 +273,11 @@ def _find_names_from_schedulers(
271273 else :
272274 name = "lr-" + sch .optimizer .__class__ .__name__
273275
274- updated_name = self ._check_duplicates_and_update_name (
276+ updated_names = self ._check_duplicates_and_update_name (
275277 sch .optimizer , name , seen_optimizers , seen_optimizer_types , scheduler , add_lr_sch_names
276278 )
277- names .extend (updated_name )
279+ names .append (updated_names )
280+
278281 return names , seen_optimizers , seen_optimizer_types
279282
280283 def _find_names_from_optimizers (
@@ -283,7 +286,7 @@ def _find_names_from_optimizers(
283286 seen_optimizers : List [Optimizer ],
284287 seen_optimizer_types : DefaultDict [Type [Optimizer ], int ],
285288 add_lr_sch_names : bool = True ,
286- ) -> Tuple [List [str ], List [Optimizer ]]:
289+ ) -> Tuple [List [List [ str ] ], List [Optimizer ]]:
287290 names = []
288291 optimizers_without_scheduler = []
289292
@@ -294,11 +297,12 @@ def _find_names_from_optimizers(
294297 continue
295298
296299 name = "lr-" + optimizer .__class__ .__name__
297- updated_name = self ._check_duplicates_and_update_name (
300+ updated_names = self ._check_duplicates_and_update_name (
298301 optimizer , name , seen_optimizers , seen_optimizer_types , None , add_lr_sch_names
299302 )
300- names .extend ( updated_name )
303+ names .append ( updated_names )
301304 optimizers_without_scheduler .append (optimizer )
305+
302306 return names , optimizers_without_scheduler
303307
304308 def _check_duplicates_and_update_name (
0 commit comments