Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support `strategy` argument being case insensitive ([#12528](https://github.com/PyTorchLightning/pytorch-lightning/pull/12528))


-
- Marked `swa_lrs` argument in `StochasticWeightAveraging` callback as required ([#12556](https://github.com/PyTorchLightning/pytorch-lightning/pull/12556))


-
Expand Down
2 changes: 1 addition & 1 deletion docs/source/advanced/training_tricks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ read `this post <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-we
.. testcode::

# Enable Stochastic Weight Averaging using the callback
trainer = Trainer(callbacks=[StochasticWeightAveraging(...)])
trainer = Trainer(callbacks=[StochasticWeightAveraging(swa_lrs=1e-2)])

----------

Expand Down
19 changes: 7 additions & 12 deletions pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
class StochasticWeightAveraging(Callback):
def __init__(
self,
swa_lrs: Union[float, List[float]],
swa_epoch_start: Union[int, float] = 0.8,
swa_lrs: Optional[Union[float, List[float]]] = None,
annealing_epochs: int = 10,
annealing_strategy: str = "cos",
avg_fn: Optional[_AVG_FN] = None,
Expand Down Expand Up @@ -66,16 +66,15 @@ def __init__(

Arguments:

swa_epoch_start: If provided as int, the procedure will start from
the ``swa_epoch_start``-th epoch. If provided as float between 0 and 1,
the procedure will start from ``int(swa_epoch_start * max_epochs)`` epoch

swa_lrs: The SWA learning rate to use:

- ``None``. Use the current learning rate of the optimizer at the time the SWA procedure starts.
- ``float``. Use this value for all parameter groups of the optimizer.
- ``List[float]``. A list values for each parameter group of the optimizer.

swa_epoch_start: If provided as int, the procedure will start from
the ``swa_epoch_start``-th epoch. If provided as float between 0 and 1,
the procedure will start from ``int(swa_epoch_start * max_epochs)`` epoch

annealing_epochs: number of epochs in the annealing phase (default: 10)

annealing_strategy: Specifies the annealing strategy (default: "cos"):
Expand Down Expand Up @@ -104,10 +103,8 @@ def __init__(
wrong_type = not isinstance(swa_lrs, (float, list))
wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0
wrong_list = isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs)
if swa_lrs is not None and (wrong_type or wrong_float or wrong_list):
raise MisconfigurationException(
"The `swa_lrs` should be `None`, a positive float, or a list of positive floats"
)
if wrong_type or wrong_float or wrong_list:
raise MisconfigurationException("The `swa_lrs` should a positive float, or a list of positive floats")

if avg_fn is not None and not isinstance(avg_fn, Callable):
raise MisconfigurationException("The `avg_fn` should be callable.")
Expand Down Expand Up @@ -164,8 +161,6 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
self._average_model = self._average_model.to(self._device or pl_module.device)

optimizer = trainer.optimizers[0]
if self._swa_lrs is None:
self._swa_lrs = [param_group["lr"] for param_group in optimizer.param_groups]
if isinstance(self._swa_lrs, float):
self._swa_lrs = [self._swa_lrs] * len(optimizer.param_groups)

Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def test_swa_callback_scheduler_step(tmpdir, interval: str):

def test_swa_warns(tmpdir, caplog):
model = SwaTestModel(interval="step")
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=StochasticWeightAveraging())
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=StochasticWeightAveraging(swa_lrs=1e-2))
with caplog.at_level(level=logging.INFO), pytest.warns(UserWarning, match="SWA is currently only supported"):
trainer.fit(model)
assert "Swapping scheduler `StepLR` for `SWALR`" in caplog.text
Expand Down Expand Up @@ -211,7 +211,7 @@ def setup(self, trainer, pl_module, stage) -> None:
self.setup_called = True

model = BoringModel()
swa = TestSWA()
swa = TestSWA(swa_lrs=1e-2)
trainer = Trainer(default_root_dir=tmpdir, callbacks=swa, fast_dev_run=True)
trainer.fit(model, train_dataloaders=DataLoader(RandomDataset(32, 2)))
assert swa.setup_called
Expand Down
5 changes: 4 additions & 1 deletion tests/profiler/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,10 @@ def test_advanced_profiler_cprofile_deepcopy(tmpdir):
"""Checks for pickle issue reported in #6522."""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir, fast_dev_run=True, profiler="advanced", callbacks=StochasticWeightAveraging()
default_root_dir=tmpdir,
fast_dev_run=True,
profiler="advanced",
callbacks=StochasticWeightAveraging(swa_lrs=1e-2),
)
trainer.fit(model)

Expand Down