diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index 093c8e47d07dd..8c141b28bc93d 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -169,7 +169,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo for lr, group in zip(self._swa_lrs, optimizer.param_groups): group["initial_lr"] = lr - self._swa_scheduler: _LRScheduler = cast( + self._swa_scheduler = cast( _LRScheduler, SWALR( optimizer, @@ -244,19 +244,22 @@ def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> No for module in pl_module.modules(): if not isinstance(module, nn.modules.batchnorm._BatchNorm): continue + assert module.running_mean is not None module.running_mean = torch.zeros_like( - module.running_mean, # type: ignore[arg-type] + module.running_mean, device=pl_module.device, - dtype=module.running_mean.dtype, # type: ignore[union-attr] + dtype=module.running_mean.dtype, ) + assert module.running_var is not None module.running_var = torch.ones_like( - module.running_var, # type: ignore[arg-type] + module.running_var, device=pl_module.device, - dtype=module.running_var.dtype, # type: ignore[union-attr] + dtype=module.running_var.dtype, ) self.momenta[module] = module.momentum - module.momentum = None # type: ignore[assignment] - module.num_batches_tracked *= 0 # type: ignore[assignment, operator] + module.momentum = float() + assert module.num_batches_tracked is not None + module.num_batches_tracked *= 0 def reset_momenta(self) -> None: """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165."""