diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 0dbf13e4936b8..4ff709930724b 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -216,6 +216,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with `LightningLite.setup()` not setting the `.device` attribute correctly on the returned wrapper ([#14822](https://github.com/Lightning-AI/lightning/pull/14822)) +- Fixed an attribute error when running the tuner together with the `StochasticWeightAveraging` callback ([#14836](https://github.com/Lightning-AI/lightning/pull/14836)) + + ## [1.7.6] - 2022-09-13 diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index 732c8831b26d1..5f36096fa1102 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -123,7 +123,7 @@ def __init__( self._avg_fn = avg_fn or self.avg_fn self._device = device self._model_contains_batch_norm: Optional[bool] = None - self._average_model: "pl.LightningModule" + self._average_model: Optional["pl.LightningModule"] = None self._initialized = False self._swa_scheduler: Optional[_LRScheduler] = None self._scheduler_state: Optional[Dict] = None @@ -179,6 +179,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo self._initialized = True # move average model to request device. + assert self._average_model is not None self._average_model = self._average_model.to(self._device or pl_module.device) optimizer = trainer.optimizers[0] @@ -232,12 +233,14 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo trainer.current_epoch > self._latest_update_epoch ): assert self.n_averaged is not None + assert self._average_model is not None self.update_parameters(self._average_model, pl_module, self.n_averaged, self._avg_fn) self._latest_update_epoch = trainer.current_epoch # Note: No > here in case the callback is saved with the model and training continues if trainer.current_epoch == self.swa_end + 1: # Transfer weights from average model to pl_module + assert self._average_model is not None self.transfer_weights(self._average_model, pl_module) # Reset BatchNorm for update @@ -266,6 +269,7 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - self.reset_momenta() elif trainer.current_epoch - 1 == self.swa_end: # Last SWA epoch. Transfer weights from average model to pl_module + assert self._average_model is not None self.transfer_weights(self._average_model, pl_module) @staticmethod diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index f18fce183f4cd..e3f8a979f4353 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -32,6 +32,22 @@ from tests_pytorch.helpers.runif import RunIf +def test_swa_callback_initial_state(): + swa = StochasticWeightAveraging( + swa_lrs=0.01, + swa_epoch_start=0.1, + annealing_epochs=1, + annealing_strategy="linear", + avg_fn=sum, + ) + assert swa._swa_lrs == 0.01 + assert swa._swa_epoch_start == 0.1 + assert swa._annealing_epochs == 1 + assert swa._annealing_strategy == "linear" + assert swa._avg_fn == sum + assert swa._average_model is None + + class SwaTestModel(BoringModel): def __init__( self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False, crash_on_epoch=None