Skip to content

Commit a3acbb7

Browse files
awaelchlipre-commit-ci[bot]
authored andcommitted
Fix attribute error in SWA when running with Tuner (#14836)
* Fix attribute error in SWA when running with Tuner * changelog * add better test Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent be8ab04 commit a3acbb7

File tree

3 files changed

+24
-1
lines changed

3 files changed

+24
-1
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
219219
- Fixed an issue with `LightningLite.setup()` not setting the `.device` attribute correctly on the returned wrapper ([#14822](https://github.com/Lightning-AI/lightning/pull/14822))
220220

221221

222+
- Fixed an attribute error when running the tuner together with the `StochasticWeightAveraging` callback ([#14836](https://github.com/Lightning-AI/lightning/pull/14836))
223+
224+
222225

223226
## [1.7.6] - 2022-09-13
224227

src/pytorch_lightning/callbacks/stochastic_weight_avg.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def __init__(
123123
self._avg_fn = avg_fn or self.avg_fn
124124
self._device = device
125125
self._model_contains_batch_norm: Optional[bool] = None
126-
self._average_model: "pl.LightningModule"
126+
self._average_model: Optional["pl.LightningModule"] = None
127127
self._initialized = False
128128
self._swa_scheduler: Optional[_LRScheduler] = None
129129
self._scheduler_state: Optional[Dict] = None
@@ -179,6 +179,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
179179
self._initialized = True
180180

181181
# move average model to request device.
182+
assert self._average_model is not None
182183
self._average_model = self._average_model.to(self._device or pl_module.device)
183184

184185
optimizer = trainer.optimizers[0]
@@ -232,12 +233,14 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
232233
trainer.current_epoch > self._latest_update_epoch
233234
):
234235
assert self.n_averaged is not None
236+
assert self._average_model is not None
235237
self.update_parameters(self._average_model, pl_module, self.n_averaged, self._avg_fn)
236238
self._latest_update_epoch = trainer.current_epoch
237239

238240
# Note: No > here in case the callback is saved with the model and training continues
239241
if trainer.current_epoch == self.swa_end + 1:
240242
# Transfer weights from average model to pl_module
243+
assert self._average_model is not None
241244
self.transfer_weights(self._average_model, pl_module)
242245

243246
# Reset BatchNorm for update
@@ -266,6 +269,7 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
266269
self.reset_momenta()
267270
elif trainer.current_epoch - 1 == self.swa_end:
268271
# Last SWA epoch. Transfer weights from average model to pl_module
272+
assert self._average_model is not None
269273
self.transfer_weights(self._average_model, pl_module)
270274

271275
@staticmethod

tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,22 @@
3232
from tests_pytorch.helpers.runif import RunIf
3333

3434

35+
def test_swa_callback_initial_state():
36+
swa = StochasticWeightAveraging(
37+
swa_lrs=0.01,
38+
swa_epoch_start=0.1,
39+
annealing_epochs=1,
40+
annealing_strategy="linear",
41+
avg_fn=sum,
42+
)
43+
assert swa._swa_lrs == 0.01
44+
assert swa._swa_epoch_start == 0.1
45+
assert swa._annealing_epochs == 1
46+
assert swa._annealing_strategy == "linear"
47+
assert swa._avg_fn == sum
48+
assert swa._average_model is None
49+
50+
3551
class SwaTestModel(BoringModel):
3652
def __init__(
3753
self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False, crash_on_epoch=None

0 commit comments

Comments
 (0)