From ae5dd981bec130016b046bebffa174197a959026 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 30 Mar 2022 20:49:31 +0530 Subject: [PATCH 1/5] Make swa_lrs required in StochasticWeightAveraging callback --- .../callbacks/stochastic_weight_avg.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index ad9e8b8fc396b..425484567d521 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -34,8 +34,8 @@ class StochasticWeightAveraging(Callback): def __init__( self, + swa_lrs: Union[float, List[float]], swa_epoch_start: Union[int, float] = 0.8, - swa_lrs: Optional[Union[float, List[float]]] = None, annealing_epochs: int = 10, annealing_strategy: str = "cos", avg_fn: Optional[_AVG_FN] = None, @@ -66,16 +66,15 @@ def __init__( Arguments: - swa_epoch_start: If provided as int, the procedure will start from - the ``swa_epoch_start``-th epoch. If provided as float between 0 and 1, - the procedure will start from ``int(swa_epoch_start * max_epochs)`` epoch - swa_lrs: The SWA learning rate to use: - - ``None``. Use the current learning rate of the optimizer at the time the SWA procedure starts. - ``float``. Use this value for all parameter groups of the optimizer. - ``List[float]``. A list values for each parameter group of the optimizer. + swa_epoch_start: If provided as int, the procedure will start from + the ``swa_epoch_start``-th epoch. If provided as float between 0 and 1, + the procedure will start from ``int(swa_epoch_start * max_epochs)`` epoch + annealing_epochs: number of epochs in the annealing phase (default: 10) annealing_strategy: Specifies the annealing strategy (default: "cos"): @@ -104,10 +103,8 @@ def __init__( wrong_type = not isinstance(swa_lrs, (float, list)) wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0 wrong_list = isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs) - if swa_lrs is not None and (wrong_type or wrong_float or wrong_list): - raise MisconfigurationException( - "The `swa_lrs` should be `None`, a positive float, or a list of positive floats" - ) + if wrong_type or wrong_float or wrong_list: + raise MisconfigurationException("The `swa_lrs` should a positive float, or a list of positive floats") if avg_fn is not None and not isinstance(avg_fn, Callable): raise MisconfigurationException("The `avg_fn` should be callable.") @@ -164,8 +161,6 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo self._average_model = self._average_model.to(self._device or pl_module.device) optimizer = trainer.optimizers[0] - if self._swa_lrs is None: - self._swa_lrs = [param_group["lr"] for param_group in optimizer.param_groups] if isinstance(self._swa_lrs, float): self._swa_lrs = [self._swa_lrs] * len(optimizer.param_groups) From fba45076fcb477581392b899c1cccd11c409dde1 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 31 Mar 2022 23:23:43 +0530 Subject: [PATCH 2/5] update test --- tests/callbacks/test_stochastic_weight_avg.py | 4 ++-- tests/profiler/test_profiler.py | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 4a79c462f2780..f62bcb19432b5 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -178,7 +178,7 @@ def test_swa_callback_scheduler_step(tmpdir, interval: str): def test_swa_warns(tmpdir, caplog): model = SwaTestModel(interval="step") - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=StochasticWeightAveraging()) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=StochasticWeightAveraging(swa_lrs=1e-2)) with caplog.at_level(level=logging.INFO), pytest.warns(UserWarning, match="SWA is currently only supported"): trainer.fit(model) assert "Swapping scheduler `StepLR` for `SWALR`" in caplog.text @@ -211,7 +211,7 @@ def setup(self, trainer, pl_module, stage) -> None: self.setup_called = True model = BoringModel() - swa = TestSWA() + swa = TestSWA(swa_lrs=1e-2) trainer = Trainer(default_root_dir=tmpdir, callbacks=swa, fast_dev_run=True) trainer.fit(model, train_dataloaders=DataLoader(RandomDataset(32, 2))) assert swa.setup_called diff --git a/tests/profiler/test_profiler.py b/tests/profiler/test_profiler.py index 6cb71be2fc125..0dfa90ce42147 100644 --- a/tests/profiler/test_profiler.py +++ b/tests/profiler/test_profiler.py @@ -338,7 +338,10 @@ def test_advanced_profiler_cprofile_deepcopy(tmpdir): """Checks for pickle issue reported in #6522.""" model = BoringModel() trainer = Trainer( - default_root_dir=tmpdir, fast_dev_run=True, profiler="advanced", callbacks=StochasticWeightAveraging() + default_root_dir=tmpdir, + fast_dev_run=True, + profiler="advanced", + callbacks=StochasticWeightAveraging(swa_lrs=1e-2), ) trainer.fit(model) From 0aef7dd1983d8bc4d1055946fda9d5025aa14327 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 31 Mar 2022 23:28:12 +0530 Subject: [PATCH 3/5] chlog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e1ada8db632c..980f8e253c0e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,7 +28,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Support `strategy` argument being case insensitive ([#12528](https://github.com/PyTorchLightning/pytorch-lightning/pull/12528)) -- +- Marked `swa_lrs` argument in `StochasticWeightAveraging` callback as required ([#12556](https://github.com/PyTorchLightning/pytorch-lightning/pull/12556)) - From 6ec287c327ba55ea605e05d476d3246333adcf23 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Fri, 1 Apr 2022 13:21:48 +0530 Subject: [PATCH 4/5] doc --- docs/source/advanced/training_tricks.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/advanced/training_tricks.rst b/docs/source/advanced/training_tricks.rst index 103eb34c3fe8e..14530af4d5828 100644 --- a/docs/source/advanced/training_tricks.rst +++ b/docs/source/advanced/training_tricks.rst @@ -67,7 +67,7 @@ read `this post Date: Fri, 1 Apr 2022 13:56:31 +0530 Subject: [PATCH 5/5] Apply suggestions from code review --- docs/source/advanced/training_tricks.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/advanced/training_tricks.rst b/docs/source/advanced/training_tricks.rst index 14530af4d5828..3c46567da9326 100644 --- a/docs/source/advanced/training_tricks.rst +++ b/docs/source/advanced/training_tricks.rst @@ -67,7 +67,7 @@ read `this post