From 31aa03f3a08a2f5956ee47cb9a41464edeca5506 Mon Sep 17 00:00:00 2001 From: donlapark Date: Sat, 16 Jul 2022 15:19:16 +0700 Subject: [PATCH 1/6] fixes typing in stochastic_weight_avg.py --- pyproject.toml | 1 - .../callbacks/stochastic_weight_avg.py | 56 ++++++++++--------- 2 files changed, 30 insertions(+), 27 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0ddadd2b29bfc..f761c2808bbbc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,6 @@ module = [ "pytorch_lightning.callbacks.model_checkpoint", "pytorch_lightning.callbacks.progress.rich_progress", "pytorch_lightning.callbacks.quantization", - "pytorch_lightning.callbacks.stochastic_weight_avg", "pytorch_lightning.core.datamodule", "pytorch_lightning.core.decorators", "pytorch_lightning.core.mixins.device_dtype_mixin", diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index 83fb1cf169794..93055d6d92c93 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -16,7 +16,7 @@ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ """ from copy import deepcopy -from typing import Callable, List, Optional, Union +from typing import Any, Callable, cast, List, Optional, Union import torch from torch import FloatTensor, nn, Tensor @@ -26,9 +26,9 @@ from pytorch_lightning.callbacks.callback import Callback from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn -from pytorch_lightning.utilities.types import LRSchedulerConfig +from pytorch_lightning.utilities.types import _LRScheduler, LRSchedulerConfig -_AVG_FN = Callable[[Tensor, Tensor, torch.LongTensor], FloatTensor] +_AVG_FN = Callable[[Tensor, Tensor, Tensor], Tensor] class StochasticWeightAveraging(Callback): @@ -106,7 +106,7 @@ def __init__( if wrong_type or wrong_float or wrong_list: raise MisconfigurationException("The `swa_lrs` should a positive float, or a list of positive floats") - if avg_fn is not None and not isinstance(avg_fn, Callable): + if avg_fn is not None and not isinstance(avg_fn, Callable): # type: ignore[arg-type] raise MisconfigurationException("The `avg_fn` should be callable.") if device is not None and not isinstance(device, (torch.device, str)): @@ -118,19 +118,20 @@ def __init__( self._annealing_strategy = annealing_strategy self._avg_fn = avg_fn or self.avg_fn self._device = device - self._model_contains_batch_norm = None - self._average_model = None + self._max_epochs: int + self._model_contains_batch_norm: bool + self._average_model = pl.LightningModule() @property def swa_start(self) -> int: - return max(self._swa_epoch_start - 1, 0) # 0-based + return max(self._swa_epoch_start - 1, 0) # type: ignore[return-value] @property def swa_end(self) -> int: return self._max_epochs - 1 # 0-based @staticmethod - def pl_module_contains_batch_norm(pl_module: "pl.LightningModule"): + def pl_module_contains_batch_norm(pl_module: "pl.LightningModule") -> bool: return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules()) def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: @@ -138,7 +139,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O with pl_module._prevent_trainer_and_dataloaders_deepcopy(): self._average_model = deepcopy(pl_module) - def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if len(trainer.optimizers) != 1: raise MisconfigurationException("SWA currently works with 1 `optimizer`.") @@ -155,7 +156,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): # virtually increase max_epochs to perform batch norm update on latest epoch. trainer.fit_loop.max_epochs += 1 - def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): + def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if trainer.current_epoch == self.swa_start: # move average model to request device. self._average_model = self._average_model.to(self._device or pl_module.device) @@ -167,13 +168,13 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo for lr, group in zip(self._swa_lrs, optimizer.param_groups): group["initial_lr"] = lr - self._swa_scheduler = SWALR( + self._swa_scheduler: _LRScheduler = cast(_LRScheduler, SWALR( optimizer, - swa_lr=self._swa_lrs, + swa_lr=self._swa_lrs, # type: ignore[arg-type] anneal_epochs=self._annealing_epochs, anneal_strategy=self._annealing_strategy, last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1, - ) + )) # We assert that there is only one optimizer on fit start, so know opt_idx is always 0 default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler, opt_idx=0) assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1 @@ -213,10 +214,10 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo trainer.accumulate_grad_batches = trainer.num_training_batches - def on_train_epoch_end(self, trainer: "pl.Trainer", *args): + def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any) -> None: trainer.fit_loop._skip_backward = False - def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): + def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: # the trainer increases the current epoch before this hook is called if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1: # BatchNorm epoch update. Reset state @@ -229,35 +230,39 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): self.transfer_weights(self._average_model, pl_module) @staticmethod - def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule"): + def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule") -> None: for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()): dst_param.detach().copy_(src_param.to(dst_param.device)) - def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule"): + def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> None: """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154.""" self.momenta = {} for module in pl_module.modules(): if not isinstance(module, nn.modules.batchnorm._BatchNorm): continue module.running_mean = torch.zeros_like( - module.running_mean, device=pl_module.device, dtype=module.running_mean.dtype + module.running_mean, # type: ignore[arg-type] + device=pl_module.device, + dtype=module.running_mean.dtype # type: ignore[union-attr] ) module.running_var = torch.ones_like( - module.running_var, device=pl_module.device, dtype=module.running_var.dtype + module.running_var, # type: ignore[arg-type] + device=pl_module.device, + dtype=module.running_var.dtype # type: ignore[union-attr] ) self.momenta[module] = module.momentum - module.momentum = None - module.num_batches_tracked *= 0 + module.momentum = None # type: ignore[assignment] + module.num_batches_tracked *= 0 # type: ignore[assignment, operator] - def reset_momenta(self): + def reset_momenta(self) -> None: """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165.""" for bn_module in self.momenta: bn_module.momentum = self.momenta[bn_module] @staticmethod def update_parameters( - average_model: "pl.LightningModule", model: "pl.LightningModule", n_averaged: torch.LongTensor, avg_fn: _AVG_FN - ): + average_model: "pl.LightningModule", model: "pl.LightningModule", n_averaged: Tensor, avg_fn: _AVG_FN + ) -> None: """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112.""" for p_swa, p_model in zip(average_model.parameters(), model.parameters()): device = p_swa.device @@ -269,7 +274,6 @@ def update_parameters( @staticmethod def avg_fn( - averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: torch.LongTensor - ) -> FloatTensor: + averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor) -> Tensor: """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97.""" return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1) From 39319e19444e273a3c148ca6e59a36d7fc2076db Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 16 Jul 2022 09:03:50 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../callbacks/stochastic_weight_avg.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index 93055d6d92c93..6d27d3e6c8831 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -168,13 +168,16 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo for lr, group in zip(self._swa_lrs, optimizer.param_groups): group["initial_lr"] = lr - self._swa_scheduler: _LRScheduler = cast(_LRScheduler, SWALR( - optimizer, - swa_lr=self._swa_lrs, # type: ignore[arg-type] - anneal_epochs=self._annealing_epochs, - anneal_strategy=self._annealing_strategy, - last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1, - )) + self._swa_scheduler: _LRScheduler = cast( + _LRScheduler, + SWALR( + optimizer, + swa_lr=self._swa_lrs, # type: ignore[arg-type] + anneal_epochs=self._annealing_epochs, + anneal_strategy=self._annealing_strategy, + last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1, + ), + ) # We assert that there is only one optimizer on fit start, so know opt_idx is always 0 default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler, opt_idx=0) assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1 @@ -243,12 +246,12 @@ def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> No module.running_mean = torch.zeros_like( module.running_mean, # type: ignore[arg-type] device=pl_module.device, - dtype=module.running_mean.dtype # type: ignore[union-attr] + dtype=module.running_mean.dtype, # type: ignore[union-attr] ) module.running_var = torch.ones_like( module.running_var, # type: ignore[arg-type] device=pl_module.device, - dtype=module.running_var.dtype # type: ignore[union-attr] + dtype=module.running_var.dtype, # type: ignore[union-attr] ) self.momenta[module] = module.momentum module.momentum = None # type: ignore[assignment] @@ -273,7 +276,6 @@ def update_parameters( n_averaged += 1 @staticmethod - def avg_fn( - averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor) -> Tensor: + def avg_fn(averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor) -> Tensor: """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97.""" return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1) From 8d66cbcf31c93234bf7dddb07ce0046024cd2ff9 Mon Sep 17 00:00:00 2001 From: donlapark Date: Mon, 18 Jul 2022 16:13:30 +0700 Subject: [PATCH 3/6] fixes callable and comment --- src/pytorch_lightning/callbacks/stochastic_weight_avg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index 93055d6d92c93..07f8d9bcdcbfc 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -106,7 +106,7 @@ def __init__( if wrong_type or wrong_float or wrong_list: raise MisconfigurationException("The `swa_lrs` should a positive float, or a list of positive floats") - if avg_fn is not None and not isinstance(avg_fn, Callable): # type: ignore[arg-type] + if avg_fn is not None and not callable(avg_fn): raise MisconfigurationException("The `avg_fn` should be callable.") if device is not None and not isinstance(device, (torch.device, str)): @@ -124,7 +124,7 @@ def __init__( @property def swa_start(self) -> int: - return max(self._swa_epoch_start - 1, 0) # type: ignore[return-value] + return max(self._swa_epoch_start - 1, 0) # type: ignore[return-value] # 0-based @property def swa_end(self) -> int: From 26361adce107234a858de7bd3a163bef43832ffb Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Tue, 19 Jul 2022 10:38:51 +0700 Subject: [PATCH 4/6] Remove FloatTensor import --- src/pytorch_lightning/callbacks/stochastic_weight_avg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index 8490e3d887e0a..ce1d8a3659ebf 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -19,7 +19,7 @@ from typing import Any, Callable, cast, List, Optional, Union import torch -from torch import FloatTensor, nn, Tensor +from torch import nn, Tensor from torch.optim.swa_utils import SWALR import pytorch_lightning as pl From e0b7489a117d6f8f56f3d3e2de7d8daf1c51ce75 Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Wed, 20 Jul 2022 14:53:26 +0700 Subject: [PATCH 5/6] Update src/pytorch_lightning/callbacks/stochastic_weight_avg.py --- src/pytorch_lightning/callbacks/stochastic_weight_avg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index ce1d8a3659ebf..77bb0d7cdc934 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -120,7 +120,7 @@ def __init__( self._device = device self._max_epochs: int self._model_contains_batch_norm: bool - self._average_model = pl.LightningModule() + self._average_model: "pl.LightningModule" @property def swa_start(self) -> int: From 676b768e5f70823783868cc9ba7dfaf614b6566e Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Sat, 23 Jul 2022 20:00:38 +0700 Subject: [PATCH 6/6] Add `assert` that `_swa_epoch_start` is `int` --- src/pytorch_lightning/callbacks/stochastic_weight_avg.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index 77bb0d7cdc934..093c8e47d07dd 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -124,7 +124,8 @@ def __init__( @property def swa_start(self) -> int: - return max(self._swa_epoch_start - 1, 0) # type: ignore[return-value] # 0-based + assert isinstance(self._swa_epoch_start, int) + return max(self._swa_epoch_start - 1, 0) # 0-based @property def swa_end(self) -> int: