Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with `LightningLite.setup()` not setting the `.device` attribute correctly on the returned wrapper ([#14822](https://github.com/Lightning-AI/lightning/pull/14822))


- Fixed an attribute error when running the tuner together with the `StochasticWeightAveraging` callback ([#14836](https://github.com/Lightning-AI/lightning/pull/14836))



## [1.7.6] - 2022-09-13

Expand Down
6 changes: 5 additions & 1 deletion src/pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
self._avg_fn = avg_fn or self.avg_fn
self._device = device
self._model_contains_batch_norm: Optional[bool] = None
self._average_model: "pl.LightningModule"
self._average_model: Optional["pl.LightningModule"] = None
self._initialized = False
self._swa_scheduler: Optional[_LRScheduler] = None
self._scheduler_state: Optional[Dict] = None
Expand Down Expand Up @@ -179,6 +179,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
self._initialized = True

# move average model to request device.
assert self._average_model is not None
self._average_model = self._average_model.to(self._device or pl_module.device)

optimizer = trainer.optimizers[0]
Expand Down Expand Up @@ -232,12 +233,14 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
trainer.current_epoch > self._latest_update_epoch
):
assert self.n_averaged is not None
assert self._average_model is not None
self.update_parameters(self._average_model, pl_module, self.n_averaged, self._avg_fn)
self._latest_update_epoch = trainer.current_epoch

# Note: No > here in case the callback is saved with the model and training continues
if trainer.current_epoch == self.swa_end + 1:
# Transfer weights from average model to pl_module
assert self._average_model is not None
self.transfer_weights(self._average_model, pl_module)

# Reset BatchNorm for update
Expand Down Expand Up @@ -266,6 +269,7 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
self.reset_momenta()
elif trainer.current_epoch - 1 == self.swa_end:
# Last SWA epoch. Transfer weights from average model to pl_module
assert self._average_model is not None
self.transfer_weights(self._average_model, pl_module)

@staticmethod
Expand Down
16 changes: 16 additions & 0 deletions tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@
from tests_pytorch.helpers.runif import RunIf


def test_swa_callback_initial_state():
swa = StochasticWeightAveraging(
swa_lrs=0.01,
swa_epoch_start=0.1,
annealing_epochs=1,
annealing_strategy="linear",
avg_fn=sum,
)
assert swa._swa_lrs == 0.01
assert swa._swa_epoch_start == 0.1
assert swa._annealing_epochs == 1
assert swa._annealing_strategy == "linear"
assert swa._avg_fn == sum
assert swa._average_model is None


class SwaTestModel(BoringModel):
def __init__(
self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False, crash_on_epoch=None
Expand Down