From b4dd8d1a060aa80c6939b9069d1a06c0d1305825 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 22 Sep 2022 01:16:53 +0200 Subject: [PATCH 1/4] Fix attribute error in SWA when running with Tuner --- .../callbacks/stochastic_weight_avg.py | 6 +++++- tests/tests_pytorch/tuner/test_lr_finder.py | 13 +++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index 732c8831b26d1..5f36096fa1102 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -123,7 +123,7 @@ def __init__( self._avg_fn = avg_fn or self.avg_fn self._device = device self._model_contains_batch_norm: Optional[bool] = None - self._average_model: "pl.LightningModule" + self._average_model: Optional["pl.LightningModule"] = None self._initialized = False self._swa_scheduler: Optional[_LRScheduler] = None self._scheduler_state: Optional[Dict] = None @@ -179,6 +179,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo self._initialized = True # move average model to request device. + assert self._average_model is not None self._average_model = self._average_model.to(self._device or pl_module.device) optimizer = trainer.optimizers[0] @@ -232,12 +233,14 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo trainer.current_epoch > self._latest_update_epoch ): assert self.n_averaged is not None + assert self._average_model is not None self.update_parameters(self._average_model, pl_module, self.n_averaged, self._avg_fn) self._latest_update_epoch = trainer.current_epoch # Note: No > here in case the callback is saved with the model and training continues if trainer.current_epoch == self.swa_end + 1: # Transfer weights from average model to pl_module + assert self._average_model is not None self.transfer_weights(self._average_model, pl_module) # Reset BatchNorm for update @@ -266,6 +269,7 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - self.reset_momenta() elif trainer.current_epoch - 1 == self.swa_end: # Last SWA epoch. Transfer weights from average model to pl_module + assert self._average_model is not None self.transfer_weights(self._average_model, pl_module) @staticmethod diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index cf0776d755083..b81c1bea27570 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -19,6 +19,7 @@ import torch from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.callbacks import StochasticWeightAveraging from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.tuner.lr_finder import _LRFinder from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -415,3 +416,15 @@ def __init__(self): lr_finder = trainer.tuner.lr_find(model=model, update_attr=True, num_training=1) # force insufficient data points assert lr_finder.suggestion() is None assert model.learning_rate == 0.123 # must remain unchanged because suggestion is not possible + + +def test_lr_finder_with_stochastic_weight_averaging(tmpdir): + """Regression test for issue https://github.com/Lightning-AI/lightning/issues/14755""" + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.learning_rate = 0.123 + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, callbacks=[StochasticWeightAveraging(swa_lrs=0.01)], auto_lr_find=True) + trainer.tune(model) From 2f6082e330439e6b0f60ce86c6af234f6063ad80 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 21 Sep 2022 23:20:20 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_pytorch/tuner/test_lr_finder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index b81c1bea27570..70cc0652024e2 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -419,7 +419,8 @@ def __init__(self): def test_lr_finder_with_stochastic_weight_averaging(tmpdir): - """Regression test for issue https://github.com/Lightning-AI/lightning/issues/14755""" + """Regression test for issue https://github.com/Lightning-AI/lightning/issues/14755.""" + class TestModel(BoringModel): def __init__(self): super().__init__() From b2e8d8cafa74d067d77f6f1c193b012bc4aba23e Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 22 Sep 2022 01:21:19 +0200 Subject: [PATCH 3/4] changelog --- src/pytorch_lightning/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 0dbf13e4936b8..4ff709930724b 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -216,6 +216,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with `LightningLite.setup()` not setting the `.device` attribute correctly on the returned wrapper ([#14822](https://github.com/Lightning-AI/lightning/pull/14822)) +- Fixed an attribute error when running the tuner together with the `StochasticWeightAveraging` callback ([#14836](https://github.com/Lightning-AI/lightning/pull/14836)) + + ## [1.7.6] - 2022-09-13 From ca71db3df7959ba973071f23aebb6faefc876824 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 22 Sep 2022 02:05:11 +0200 Subject: [PATCH 4/4] add better test --- .../callbacks/test_stochastic_weight_avg.py | 16 ++++++++++++++++ tests/tests_pytorch/tuner/test_lr_finder.py | 14 -------------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index f18fce183f4cd..e3f8a979f4353 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -32,6 +32,22 @@ from tests_pytorch.helpers.runif import RunIf +def test_swa_callback_initial_state(): + swa = StochasticWeightAveraging( + swa_lrs=0.01, + swa_epoch_start=0.1, + annealing_epochs=1, + annealing_strategy="linear", + avg_fn=sum, + ) + assert swa._swa_lrs == 0.01 + assert swa._swa_epoch_start == 0.1 + assert swa._annealing_epochs == 1 + assert swa._annealing_strategy == "linear" + assert swa._avg_fn == sum + assert swa._average_model is None + + class SwaTestModel(BoringModel): def __init__( self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False, crash_on_epoch=None diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index 70cc0652024e2..cf0776d755083 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -19,7 +19,6 @@ import torch from pytorch_lightning import seed_everything, Trainer -from pytorch_lightning.callbacks import StochasticWeightAveraging from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.tuner.lr_finder import _LRFinder from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -416,16 +415,3 @@ def __init__(self): lr_finder = trainer.tuner.lr_find(model=model, update_attr=True, num_training=1) # force insufficient data points assert lr_finder.suggestion() is None assert model.learning_rate == 0.123 # must remain unchanged because suggestion is not possible - - -def test_lr_finder_with_stochastic_weight_averaging(tmpdir): - """Regression test for issue https://github.com/Lightning-AI/lightning/issues/14755.""" - - class TestModel(BoringModel): - def __init__(self): - super().__init__() - self.learning_rate = 0.123 - - model = TestModel() - trainer = Trainer(default_root_dir=tmpdir, callbacks=[StochasticWeightAveraging(swa_lrs=0.01)], auto_lr_find=True) - trainer.tune(model)