diff --git a/CHANGELOG.md b/CHANGELOG.md index 400fe58b4eaee..bf85d06213ccd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,7 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Support `strategy` argument being case insensitive ([#12528](https://github.com/PyTorchLightning/pytorch-lightning/pull/12528)) -- +- Marked `swa_lrs` argument in `StochasticWeightAveraging` callback as required ([#12556](https://github.com/PyTorchLightning/pytorch-lightning/pull/12556)) - diff --git a/docs/source/advanced/training_tricks.rst b/docs/source/advanced/training_tricks.rst index 103eb34c3fe8e..3c46567da9326 100644 --- a/docs/source/advanced/training_tricks.rst +++ b/docs/source/advanced/training_tricks.rst @@ -67,7 +67,7 @@ read `this post 0 and isinstance(lr, float) for lr in swa_lrs) - if swa_lrs is not None and (wrong_type or wrong_float or wrong_list): - raise MisconfigurationException( - "The `swa_lrs` should be `None`, a positive float, or a list of positive floats" - ) + if wrong_type or wrong_float or wrong_list: + raise MisconfigurationException("The `swa_lrs` should a positive float, or a list of positive floats") if avg_fn is not None and not isinstance(avg_fn, Callable): raise MisconfigurationException("The `avg_fn` should be callable.") @@ -164,8 +161,6 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo self._average_model = self._average_model.to(self._device or pl_module.device) optimizer = trainer.optimizers[0] - if self._swa_lrs is None: - self._swa_lrs = [param_group["lr"] for param_group in optimizer.param_groups] if isinstance(self._swa_lrs, float): self._swa_lrs = [self._swa_lrs] * len(optimizer.param_groups) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 4a79c462f2780..f62bcb19432b5 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -178,7 +178,7 @@ def test_swa_callback_scheduler_step(tmpdir, interval: str): def test_swa_warns(tmpdir, caplog): model = SwaTestModel(interval="step") - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=StochasticWeightAveraging()) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=StochasticWeightAveraging(swa_lrs=1e-2)) with caplog.at_level(level=logging.INFO), pytest.warns(UserWarning, match="SWA is currently only supported"): trainer.fit(model) assert "Swapping scheduler `StepLR` for `SWALR`" in caplog.text @@ -211,7 +211,7 @@ def setup(self, trainer, pl_module, stage) -> None: self.setup_called = True model = BoringModel() - swa = TestSWA() + swa = TestSWA(swa_lrs=1e-2) trainer = Trainer(default_root_dir=tmpdir, callbacks=swa, fast_dev_run=True) trainer.fit(model, train_dataloaders=DataLoader(RandomDataset(32, 2))) assert swa.setup_called diff --git a/tests/profiler/test_profiler.py b/tests/profiler/test_profiler.py index 6cb71be2fc125..0dfa90ce42147 100644 --- a/tests/profiler/test_profiler.py +++ b/tests/profiler/test_profiler.py @@ -338,7 +338,10 @@ def test_advanced_profiler_cprofile_deepcopy(tmpdir): """Checks for pickle issue reported in #6522.""" model = BoringModel() trainer = Trainer( - default_root_dir=tmpdir, fast_dev_run=True, profiler="advanced", callbacks=StochasticWeightAveraging() + default_root_dir=tmpdir, + fast_dev_run=True, + profiler="advanced", + callbacks=StochasticWeightAveraging(swa_lrs=1e-2), ) trainer.fit(model)