1616^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1717"""
1818from copy import deepcopy
19- from typing import Callable , Optional , Union
19+ from typing import Callable , List , Optional , Union
2020
2121import torch
2222from torch import nn
@@ -35,7 +35,7 @@ class StochasticWeightAveraging(Callback):
3535 def __init__ (
3636 self ,
3737 swa_epoch_start : Union [int , float ] = 0.8 ,
38- swa_lrs : Optional [Union [float , list ]] = None ,
38+ swa_lrs : Optional [Union [float , List [ float ] ]] = None ,
3939 annealing_epochs : int = 10 ,
4040 annealing_strategy : str = "cos" ,
4141 avg_fn : Optional [_AVG_FN ] = None ,
@@ -62,19 +62,19 @@ def __init__(
6262
6363 .. warning:: ``StochasticWeightAveraging`` is currently only supported on every epoch.
6464
65- SWA can easily be activated directly from the Trainer as follow:
66-
67- .. code-block:: python
68-
69- Trainer(stochastic_weight_avg=True)
65+ See also how to :ref:`enable it directly on the Trainer <advanced/training_tricks:Stochastic Weight Averaging>`
7066
7167 Arguments:
7268
7369 swa_epoch_start: If provided as int, the procedure will start from
7470 the ``swa_epoch_start``-th epoch. If provided as float between 0 and 1,
7571 the procedure will start from ``int(swa_epoch_start * max_epochs)`` epoch
7672
77- swa_lrs: the learning rate value for all param groups together or separately for each group.
73+ swa_lrs: The SWA learning rate to use:
74+
75+ - ``None``. Use the current learning rate of the optimizer at the time the SWA procedure starts.
76+ - ``float``. Use this value for all parameter groups of the optimizer.
77+ - ``List[float]``. A list values for each parameter group of the optimizer.
7878
7979 annealing_epochs: number of epochs in the annealing phase (default: 10)
8080
@@ -105,7 +105,9 @@ def __init__(
105105 wrong_float = isinstance (swa_lrs , float ) and swa_lrs <= 0
106106 wrong_list = isinstance (swa_lrs , list ) and not all (lr > 0 and isinstance (lr , float ) for lr in swa_lrs )
107107 if swa_lrs is not None and (wrong_type or wrong_float or wrong_list ):
108- raise MisconfigurationException ("The `swa_lrs` should be a positive float or a list of positive float." )
108+ raise MisconfigurationException (
109+ "The `swa_lrs` should be `None`, a positive float, or a list of positive floats"
110+ )
109111
110112 if avg_fn is not None and not isinstance (avg_fn , Callable ):
111113 raise MisconfigurationException ("The `avg_fn` should be callable." )
0 commit comments