diff --git a/CHANGELOG.md b/CHANGELOG.md index 06a91bf973790..b3edb21a7104e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the order of `backward`, `step`, `zero_grad` to `zero_grad`, `backward`, `step` ([#6147](https://github.com/PyTorchLightning/pytorch-lightning/pull/6147)) +- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259)) + + ### Deprecated diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index f3787c1cb2f7f..fb61ad81aee28 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -22,7 +22,7 @@ from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase from pytorch_lightning.callbacks.pruning import ModelPruning from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining -from pytorch_lightning.callbacks.swa import StochasticWeightAveraging +from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging __all__ = [ 'BackboneFinetuning', diff --git a/pytorch_lightning/callbacks/swa.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py similarity index 97% rename from pytorch_lightning/callbacks/swa.py rename to pytorch_lightning/callbacks/stochastic_weight_avg.py index c8cf367cb4d5e..bece2ffe9f1b2 100644 --- a/pytorch_lightning/callbacks/swa.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -102,12 +102,10 @@ def __init__( if isinstance(swa_epoch_start, float) and not (0 <= swa_epoch_start <= 1): raise MisconfigurationException(err_msg) - if ( - swa_lrs is not None and ( - not isinstance(swa_lrs, (float, list)) or isinstance(swa_lrs, float) and swa_lrs <= 0 - or isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs) - ) - ): + wrong_type = not isinstance(swa_lrs, (float, list)) + wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0 + wrong_list = isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs) + if (swa_lrs is not None and (wrong_type or wrong_float or wrong_list)): raise MisconfigurationException("The `swa_lrs` should be a positive float or a list of positive float.") if avg_fn is not None and not isinstance(avg_fn, Callable): diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 40ac8f3e69870..8a5289e608c94 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -76,7 +76,7 @@ def _configure_swa_callbacks(self): if not self.trainer._stochastic_weight_avg: return - from pytorch_lightning.callbacks.swa import StochasticWeightAveraging + from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging existing_swa = [cb for cb in self.trainer.callbacks if isinstance(cb, StochasticWeightAveraging)] if not existing_swa: self.trainer.callbacks = [StochasticWeightAveraging()] + self.trainer.callbacks diff --git a/tests/callbacks/test_swa.py b/tests/callbacks/test_stochastic_weight_avg.py similarity index 100% rename from tests/callbacks/test_swa.py rename to tests/callbacks/test_stochastic_weight_avg.py