From ee8b1a298ced142718139f0681ce9a8a20b3b7a3 Mon Sep 17 00:00:00 2001 From: Michael Baumgartner Date: Thu, 18 Mar 2021 22:06:35 +0100 Subject: [PATCH 1/5] Update stochastic_weight_avg.py exchange whole scheduler dict in SWA --- pytorch_lightning/callbacks/stochastic_weight_avg.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index bece2ffe9f1b2..4f0c819432601 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -187,14 +187,15 @@ def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningMo anneal_strategy=self._annealing_strategy, last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1 ) + _scheduler_config = _get_default_scheduler_config() + assert _scheduler_config["interval"] == "epoch" and _scheduler_config["frequency"] == 1 + _scheduler_config["scheduler"] = self._swa_scheduler if trainer.lr_schedulers: lr_scheduler = trainer.lr_schedulers[0]["scheduler"] rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}") - trainer.lr_schedulers[0]["scheduler"] = self._swa_scheduler + trainer.lr_schedulers[0] = _scheduler_config else: - _scheduler_config = _get_default_scheduler_config() - _scheduler_config["scheduler"] = self._swa_scheduler trainer.lr_schedulers.append(_scheduler_config) self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device) From 73ee6468420cacec4f7defef3606df7ddb0ddd41 Mon Sep 17 00:00:00 2001 From: Michael Baumgartner Date: Wed, 24 Mar 2021 00:00:43 +0100 Subject: [PATCH 2/5] add test --- tests/callbacks/test_stochastic_weight_avg.py | 30 +++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 12121b1f38530..2403cfaf685f0 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -27,6 +27,8 @@ if _TORCH_GREATER_EQUAL_1_6: from pytorch_lightning.callbacks import StochasticWeightAveraging + from torch.optim.swa_utils import SWALR + class SwaTestModel(BoringModel): @@ -46,6 +48,17 @@ def training_step(self, batch, batch_idx): def train_dataloader(self): return DataLoader(RandomDataset(32, 64), batch_size=2) + + class SWATestModelStep(SwaTestModel): + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return { + "optimizer": optimizer, + "scheduler": torch.optim.lr_scheduler.StepLR(optimizer, step_size=1), + "interval": "step", + } + + class SwaTestCallback(StochasticWeightAveraging): update_parameters_calls: int = 0 transfer_weights_calls: int = 0 @@ -61,6 +74,10 @@ def transfer_weights(self, *args, **kwargs): def on_train_epoch_start(self, trainer, *args): super().on_train_epoch_start(trainer, *args) assert trainer.train_loop._skip_backward == (trainer.current_epoch > self.swa_end) + if self.swa_start <= trainer.current_epoch: + assert isinstance(trainer.lr_schedulers[0]["scheduler"], SWALR) + assert trainer.lr_schedulers[0]["interval"] == "epoch" + assert trainer.lr_schedulers[0]["frequency"] == 1 def on_train_epoch_end(self, trainer, *args): super().on_train_epoch_end(trainer, *args) @@ -89,8 +106,11 @@ def on_train_end(self, trainer, pl_module): @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -def train_with_swa(tmpdir, batchnorm=True, accelerator=None, gpus=None, num_processes=1): - model = SwaTestModel(batchnorm=batchnorm) +def train_with_swa(tmpdir, batchnorm=True, accelerator=None, gpus=None, num_processes=1, step=False): + if step: + model = SWATestModelStep(batchnorm=batchnorm) + else: + model = SwaTestModel(batchnorm=batchnorm) swa_start = 2 max_epochs = 5 swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1) @@ -140,6 +160,12 @@ def test_swa_callback(tmpdir, batchnorm: bool): train_with_swa(tmpdir, batchnorm=batchnorm) +@RunIf(min_torch="1.6.0") +@pytest.mark.parametrize("step", (True, False)) +def test_swa_callback_scheduler_step(tmpdir, step: bool): + train_with_swa(tmpdir, step=step) + + @RunIf(min_torch="1.6.0") def test_swa_raises(): with pytest.raises(MisconfigurationException, match=">0 integer or a float between 0 and 1"): From 2194981a5907926237884b13d1b4586e4a1cb76b Mon Sep 17 00:00:00 2001 From: Michael Baumgartner Date: Wed, 24 Mar 2021 00:05:29 +0100 Subject: [PATCH 3/5] remove blank lines --- tests/callbacks/test_stochastic_weight_avg.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 2403cfaf685f0..78c0776b75a5e 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -29,7 +29,6 @@ from pytorch_lightning.callbacks import StochasticWeightAveraging from torch.optim.swa_utils import SWALR - class SwaTestModel(BoringModel): def __init__(self, batchnorm: bool = True): @@ -48,7 +47,6 @@ def training_step(self, batch, batch_idx): def train_dataloader(self): return DataLoader(RandomDataset(32, 64), batch_size=2) - class SWATestModelStep(SwaTestModel): def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) @@ -58,7 +56,6 @@ def configure_optimizers(self): "interval": "step", } - class SwaTestCallback(StochasticWeightAveraging): update_parameters_calls: int = 0 transfer_weights_calls: int = 0 From 6ddf3aaf1fb5ab25957d867203856590bc7f9f7a Mon Sep 17 00:00:00 2001 From: Michael Baumgartner Date: Wed, 24 Mar 2021 00:26:51 +0100 Subject: [PATCH 4/5] remove class and add interval parameter --- tests/callbacks/test_stochastic_weight_avg.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 78c0776b75a5e..b856e9991dde2 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -31,13 +31,14 @@ class SwaTestModel(BoringModel): - def __init__(self, batchnorm: bool = True): + def __init__(self, batchnorm: bool = True, interval: str = "epoch"): super().__init__() layers = [nn.Linear(32, 32)] if batchnorm: layers.append(nn.BatchNorm1d(32)) layers += [nn.ReLU(), nn.Linear(32, 2)] self.layer = nn.Sequential(*layers) + self.interval = interval def training_step(self, batch, batch_idx): output = self.forward(batch) @@ -47,13 +48,12 @@ def training_step(self, batch, batch_idx): def train_dataloader(self): return DataLoader(RandomDataset(32, 64), batch_size=2) - class SWATestModelStep(SwaTestModel): def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) return { "optimizer": optimizer, "scheduler": torch.optim.lr_scheduler.StepLR(optimizer, step_size=1), - "interval": "step", + "interval": self.interval, } class SwaTestCallback(StochasticWeightAveraging): @@ -103,11 +103,8 @@ def on_train_end(self, trainer, pl_module): @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -def train_with_swa(tmpdir, batchnorm=True, accelerator=None, gpus=None, num_processes=1, step=False): - if step: - model = SWATestModelStep(batchnorm=batchnorm) - else: - model = SwaTestModel(batchnorm=batchnorm) +def train_with_swa(tmpdir, batchnorm=True, accelerator=None, gpus=None, num_processes=1, interval="epoch"): + model = SwaTestModel(batchnorm=batchnorm, interval=interval) swa_start = 2 max_epochs = 5 swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1) @@ -158,9 +155,9 @@ def test_swa_callback(tmpdir, batchnorm: bool): @RunIf(min_torch="1.6.0") -@pytest.mark.parametrize("step", (True, False)) -def test_swa_callback_scheduler_step(tmpdir, step: bool): - train_with_swa(tmpdir, step=step) +@pytest.mark.parametrize("interval", ("epoch", "step")) +def test_swa_callback_scheduler_step(tmpdir, interval: bool): + train_with_swa(tmpdir, interval=interval) @RunIf(min_torch="1.6.0") From 51f2aba35cc49ba88edbd132f8b22f93934e6c61 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 24 Mar 2021 00:38:23 +0100 Subject: [PATCH 5/5] CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a1e85d4add8d..264a95162e71f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -191,6 +191,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed comparing required versions ([#6434](https://github.com/PyTorchLightning/pytorch-lightning/pull/6434)) +- Enforce an epoch scheduler interval when using SWA ([#6588](https://github.com/PyTorchLightning/pytorch-lightning/pull/6588)) + + - Fixed a bug where gradients were disabled after calling `Trainer.predict` ([#6657](https://github.com/PyTorchLightning/pytorch-lightning/pull/6657))