-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
This PR #7357 which was recently introduced in v1.4 breaks cases where one needs to somehow update the model before running validation/testing, such as with the increasingly popular SWA (stochastic weight averaging) method.
From https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR
loader, optimizer, model, loss_fn = ...
swa_model = AveragedModel(model)
scheduler = CosineAnnealingLR(optimizer, T_max=100)
swa_start = 5
swa_scheduler = SWALR(optimizer, swa_lr=0.05)
for epoch in range(100):
for input, target in loader:
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()
if epoch > swa_start:
swa_model.update_parameters(model)
swa_scheduler.step()
else:
scheduler.step()
# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
# Use swa_model to make predictions on test data
preds = swa_model(test_input)The code snippet above lacks a validation step but shows a test step. Basically, the swa_model bn statistics have to be updated by running one forward loop on the dataloader before testing/validating.
To employ SWA in my lightning code, I have something as follows:
def training_epoch_end(self, training_step_outputs=None):
if self.current_epoch >= self.swa_start:
self.swa_model.update_parameters(self.model)
torch.optim.swa_utils.update_bn(loader, self.swa_model) # update bn statsThis used to run before the validation step, but in v1.4 it now runs after the validation step. In practice, this means that the validation metrics are delayed by one epoch.
I'm aware there's work on incorporating SWA into Lightning https://pytorch-lightning.readthedocs.io/en/latest/extensions/generated/pytorch_lightning.callbacks.StochasticWeightAveraging.html but it's currently in beta and lacks documentation.
In general, is there now (v1.4) any way to insert code to run after a training epoch end and before the validation step?