3232log = logging .getLogger (__name__ )
3333
3434
35- def multiplicative (epoch ) :
36- return 2
35+ def multiplicative (epoch : int ) -> float :
36+ return 2.0
3737
3838
3939class BaseFinetuning (Callback ):
@@ -79,7 +79,7 @@ class BaseFinetuning(Callback):
7979 ... )
8080 """
8181
82- def __init__ (self ):
82+ def __init__ (self ) -> None :
8383 self ._internal_optimizer_metadata : Dict [int , List [Dict [str , Any ]]] = {}
8484 self ._restarting = False
8585
@@ -94,7 +94,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
9494 self ._internal_optimizer_metadata = state_dict ["internal_optimizer_metadata" ]
9595 else :
9696 # compatibility to load from old checkpoints before PR #11887
97- self ._internal_optimizer_metadata = state_dict
97+ self ._internal_optimizer_metadata = state_dict # type: ignore[assignment]
9898
9999 def on_fit_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
100100 # restore the param_groups created during the previous training.
@@ -122,10 +122,11 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -
122122 modules = modules .values ()
123123
124124 if isinstance (modules , Iterable ):
125- _modules = []
125+ _flatten_modules = []
126126 for m in modules :
127- _modules .extend (BaseFinetuning .flatten_modules (m ))
127+ _flatten_modules .extend (BaseFinetuning .flatten_modules (m ))
128128
129+ _modules = iter (_flatten_modules )
129130 else :
130131 _modules = modules .modules ()
131132
0 commit comments