Skip to content

Commit 4928dc5

Browse files
authored
Improve SWA docs (#8717)
1 parent 299e289 commit 4928dc5

File tree

4 files changed

+21
-15
lines changed

4 files changed

+21
-15
lines changed

docs/source/advanced/training_tricks.rst

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
.. testsetup:: *
22

3-
from pytorch_lightning.trainer.trainer import Trainer
3+
from pytorch_lightning import Trainer
4+
from pytorch_lightning.callbacks import StochasticWeightAveraging
45

56
.. _training_tricks:
67

@@ -57,15 +58,18 @@ This can be used with both non-trained and trained models. The SWA procedure smo
5758
it harder to end up in a local minimum during optimization.
5859

5960
For a more detailed explanation of SWA and how it works,
60-
read `this <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging>`__ post by the PyTorch team.
61+
read `this post <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging>`__ by the PyTorch team.
6162

62-
.. seealso:: :class:`~pytorch_lightning.callbacks.StochasticWeightAveraging` (Callback)
63+
.. seealso:: The :class:`~pytorch_lightning.callbacks.StochasticWeightAveraging` callback
6364

6465
.. testcode::
6566

66-
# Enable Stochastic Weight Averaging
67+
# Enable Stochastic Weight Averaging - uses the class defaults
6768
trainer = Trainer(stochastic_weight_avg=True)
6869

70+
# alternatively, if you need to pass custom arguments
71+
trainer = Trainer(callbacks=[StochasticWeightAveraging(...)])
72+
6973
----------
7074

7175
Auto scaling of batch size

pytorch_lightning/callbacks/stochastic_weight_avg.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1717
"""
1818
from copy import deepcopy
19-
from typing import Callable, Optional, Union
19+
from typing import Callable, List, Optional, Union
2020

2121
import torch
2222
from torch import nn
@@ -35,7 +35,7 @@ class StochasticWeightAveraging(Callback):
3535
def __init__(
3636
self,
3737
swa_epoch_start: Union[int, float] = 0.8,
38-
swa_lrs: Optional[Union[float, list]] = None,
38+
swa_lrs: Optional[Union[float, List[float]]] = None,
3939
annealing_epochs: int = 10,
4040
annealing_strategy: str = "cos",
4141
avg_fn: Optional[_AVG_FN] = None,
@@ -62,19 +62,19 @@ def __init__(
6262
6363
.. warning:: ``StochasticWeightAveraging`` is currently only supported on every epoch.
6464
65-
SWA can easily be activated directly from the Trainer as follow:
66-
67-
.. code-block:: python
68-
69-
Trainer(stochastic_weight_avg=True)
65+
See also how to :ref:`enable it directly on the Trainer <advanced/training_tricks:Stochastic Weight Averaging>`
7066
7167
Arguments:
7268
7369
swa_epoch_start: If provided as int, the procedure will start from
7470
the ``swa_epoch_start``-th epoch. If provided as float between 0 and 1,
7571
the procedure will start from ``int(swa_epoch_start * max_epochs)`` epoch
7672
77-
swa_lrs: the learning rate value for all param groups together or separately for each group.
73+
swa_lrs: The SWA learning rate to use:
74+
75+
- ``None``. Use the current learning rate of the optimizer at the time the SWA procedure starts.
76+
- ``float``. Use this value for all parameter groups of the optimizer.
77+
- ``List[float]``. A list values for each parameter group of the optimizer.
7878
7979
annealing_epochs: number of epochs in the annealing phase (default: 10)
8080
@@ -105,7 +105,9 @@ def __init__(
105105
wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0
106106
wrong_list = isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs)
107107
if swa_lrs is not None and (wrong_type or wrong_float or wrong_list):
108-
raise MisconfigurationException("The `swa_lrs` should be a positive float or a list of positive float.")
108+
raise MisconfigurationException(
109+
"The `swa_lrs` should be `None`, a positive float, or a list of positive floats"
110+
)
109111

110112
if avg_fn is not None and not isinstance(avg_fn, Callable):
111113
raise MisconfigurationException("The `avg_fn` should be callable.")

pytorch_lightning/utilities/distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import os
1717
from functools import wraps
1818
from platform import python_version
19-
from typing import Any, Callable, List, Optional, Tuple, Type, Union
19+
from typing import Any, Callable, List, Optional, Tuple, Union
2020

2121
import torch
2222
from torch.nn.parallel.distributed import DistributedDataParallel

tests/callbacks/test_stochastic_weight_avg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def test_swa_raises():
185185
StochasticWeightAveraging(swa_epoch_start=1.5, swa_lrs=0.1)
186186
with pytest.raises(MisconfigurationException, match=">0 integer or a float between 0 and 1"):
187187
StochasticWeightAveraging(swa_epoch_start=-1, swa_lrs=0.1)
188-
with pytest.raises(MisconfigurationException, match="positive float or a list of positive float"):
188+
with pytest.raises(MisconfigurationException, match="positive float, or a list of positive floats"):
189189
StochasticWeightAveraging(swa_epoch_start=5, swa_lrs=[0.2, 1])
190190

191191

0 commit comments

Comments
 (0)