From 103917ca6a75921c34df1de6d5d5a480544458e4 Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Tue, 26 Jul 2022 22:00:28 +0700 Subject: [PATCH 1/3] fixes typing in stochastic_weight_avg.py (follow-up of #13685) This is a follow-up of #13685 which originated from #13445. There is a couple of suggestions @carmocca that improves the typing of `stochastic_weight_avg.py`. --- .../callbacks/stochastic_weight_avg.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index 093c8e47d07dd..1ca3bbe8358eb 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -169,7 +169,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo for lr, group in zip(self._swa_lrs, optimizer.param_groups): group["initial_lr"] = lr - self._swa_scheduler: _LRScheduler = cast( + self._swa_scheduler: = cast( _LRScheduler, SWALR( optimizer, @@ -244,19 +244,23 @@ def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> No for module in pl_module.modules(): if not isinstance(module, nn.modules.batchnorm._BatchNorm): continue + assert module.running_mean is not None module.running_mean = torch.zeros_like( - module.running_mean, # type: ignore[arg-type] + module.running_mean, device=pl_module.device, - dtype=module.running_mean.dtype, # type: ignore[union-attr] + dtype=module.running_mean.dtype, ) + assert module.running_var is not None module.running_var = torch.ones_like( - module.running_var, # type: ignore[arg-type] + module.running_var, device=pl_module.device, - dtype=module.running_var.dtype, # type: ignore[union-attr] + dtype=module.running_var.dtype, ) self.momenta[module] = module.momentum - module.momentum = None # type: ignore[assignment] - module.num_batches_tracked *= 0 # type: ignore[assignment, operator] + assert module.momentum is not None + module.momentum = None + assert module.num_batches_tracked is not None + module.num_batches_tracked *= 0 def reset_momenta(self) -> None: """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165.""" From 9d62d8c554f5782a6b04343c26d8ed22d7f94210 Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Tue, 26 Jul 2022 22:06:46 +0700 Subject: [PATCH 2/3] Minor --- src/pytorch_lightning/callbacks/stochastic_weight_avg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index 1ca3bbe8358eb..6a6f048ccc421 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -169,7 +169,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo for lr, group in zip(self._swa_lrs, optimizer.param_groups): group["initial_lr"] = lr - self._swa_scheduler: = cast( + self._swa_scheduler = cast( _LRScheduler, SWALR( optimizer, From 7438e9f4920c606d96d0673edddab1c68870feea Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Tue, 26 Jul 2022 22:36:23 +0700 Subject: [PATCH 3/3] Minor --- src/pytorch_lightning/callbacks/stochastic_weight_avg.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index 6a6f048ccc421..8c141b28bc93d 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -257,8 +257,7 @@ def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> No dtype=module.running_var.dtype, ) self.momenta[module] = module.momentum - assert module.momentum is not None - module.momentum = None + module.momentum = float() assert module.num_batches_tracked is not None module.num_batches_tracked *= 0