2424from pytorch_lightning .utilities import _TORCH_GREATER_EQUAL_1_6
2525from pytorch_lightning .utilities .exceptions import MisconfigurationException
2626from tests .helpers import BoringModel , RandomDataset
27+ from tests .helpers .runif import RunIf
2728
2829if _TORCH_GREATER_EQUAL_1_6 :
2930 from pytorch_lightning .callbacks import StochasticWeightAveraging
31+ from torch .optim .swa_utils import SWALR
3032
3133 class SwaTestModel (BoringModel ):
3234
33- def __init__ (self , batchnorm : bool = True ):
35+ def __init__ (self , batchnorm : bool = True , interval : str = "epoch" ):
3436 super ().__init__ ()
3537 layers = [nn .Linear (32 , 32 )]
3638 if batchnorm :
3739 layers .append (nn .BatchNorm1d (32 ))
3840 layers += [nn .ReLU (), nn .Linear (32 , 2 )]
3941 self .layer = nn .Sequential (* layers )
42+ self .interval = interval
4043
4144 def training_step (self , batch , batch_idx ):
4245 output = self .forward (batch )
@@ -46,6 +49,14 @@ def training_step(self, batch, batch_idx):
4649 def train_dataloader (self ):
4750 return DataLoader (RandomDataset (32 , 64 ), batch_size = 2 )
4851
52+ def configure_optimizers (self ):
53+ optimizer = torch .optim .SGD (self .layer .parameters (), lr = 0.1 )
54+ return {
55+ "optimizer" : optimizer ,
56+ "scheduler" : torch .optim .lr_scheduler .StepLR (optimizer , step_size = 1 ),
57+ "interval" : self .interval ,
58+ }
59+
4960 class SwaTestCallback (StochasticWeightAveraging ):
5061 update_parameters_calls : int = 0
5162 transfer_weights_calls : int = 0
@@ -61,6 +72,10 @@ def transfer_weights(self, *args, **kwargs):
6172 def on_train_epoch_start (self , trainer , * args ):
6273 super ().on_train_epoch_start (trainer , * args )
6374 assert trainer .train_loop ._skip_backward == (trainer .current_epoch > self .swa_end )
75+ if self .swa_start <= trainer .current_epoch :
76+ assert isinstance (trainer .lr_schedulers [0 ]["scheduler" ], SWALR )
77+ assert trainer .lr_schedulers [0 ]["interval" ] == "epoch"
78+ assert trainer .lr_schedulers [0 ]["frequency" ] == 1
6479
6580 def on_train_epoch_end (self , trainer , * args ):
6681 super ().on_train_epoch_end (trainer , * args )
@@ -89,8 +104,8 @@ def on_train_end(self, trainer, pl_module):
89104
90105
91106@mock .patch .dict (os .environ , {"PL_DEV_DEBUG" : "1" })
92- def train_with_swa (tmpdir , batchnorm = True , accelerator = None , gpus = None , num_processes = 1 ):
93- model = SwaTestModel (batchnorm = batchnorm )
107+ def train_with_swa (tmpdir , batchnorm = True , accelerator = None , gpus = None , num_processes = 1 , interval = "epoch" ):
108+ model = SwaTestModel (batchnorm = batchnorm , interval = interval )
94109 swa_start = 2
95110 max_epochs = 5
96111 swa_callback = SwaTestCallback (swa_epoch_start = swa_start , swa_lrs = 0.1 )
@@ -147,7 +162,13 @@ def test_swa_callback(tmpdir, batchnorm):
147162 train_with_swa (tmpdir , batchnorm = batchnorm )
148163
149164
150- @pytest .mark .skipif (not _TORCH_GREATER_EQUAL_1_6 , reason = "SWA available from PyTorch 1.6.0" )
165+ @RunIf (min_torch = "1.6.0" )
166+ @pytest .mark .parametrize ("interval" , ("epoch" , "step" ))
167+ def test_swa_callback_scheduler_step (tmpdir , interval : bool ):
168+ train_with_swa (tmpdir , interval = interval )
169+
170+
171+ @RunIf (min_torch = "1.6.0" )
151172def test_swa_raises ():
152173 with pytest .raises (MisconfigurationException , match = ">0 integer or a float between 0 and 1" ):
153174 StochasticWeightAveraging (swa_epoch_start = 0 , swa_lrs = 0.1 )
0 commit comments