@@ -169,7 +169,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
169169 for lr , group in zip (self ._swa_lrs , optimizer .param_groups ):
170170 group ["initial_lr" ] = lr
171171
172- self ._swa_scheduler : _LRScheduler = cast (
172+ self ._swa_scheduler = cast (
173173 _LRScheduler ,
174174 SWALR (
175175 optimizer ,
@@ -244,19 +244,22 @@ def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> No
244244 for module in pl_module .modules ():
245245 if not isinstance (module , nn .modules .batchnorm ._BatchNorm ):
246246 continue
247+ assert module .running_mean is not None
247248 module .running_mean = torch .zeros_like (
248- module .running_mean , # type: ignore[arg-type]
249+ module .running_mean ,
249250 device = pl_module .device ,
250- dtype = module .running_mean .dtype , # type: ignore[union-attr]
251+ dtype = module .running_mean .dtype ,
251252 )
253+ assert module .running_var is not None
252254 module .running_var = torch .ones_like (
253- module .running_var , # type: ignore[arg-type]
255+ module .running_var ,
254256 device = pl_module .device ,
255- dtype = module .running_var .dtype , # type: ignore[union-attr]
257+ dtype = module .running_var .dtype ,
256258 )
257259 self .momenta [module ] = module .momentum
258- module .momentum = None # type: ignore[assignment]
259- module .num_batches_tracked *= 0 # type: ignore[assignment, operator]
260+ module .momentum = float ()
261+ assert module .num_batches_tracked is not None
262+ module .num_batches_tracked *= 0
260263
261264 def reset_momenta (self ) -> None :
262265 """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165."""
0 commit comments