Skip to content

Conversation

@akihironitta
Copy link
Contributor

@akihironitta akihironitta commented Feb 23, 2021

What does this PR do?

Fixes #4083
Fixes #5545
To-check #6134

Description of the changes

Makes sure to call zero_grad inside the closure function (TrainerLoop.training_step_and_backward()).

Note that this positions the zero_grad call before backward, as generally suggested throughout PyTorch's docs.

Reported that LBFGS doesn't work In #4083, we then found that the number of times zero_grad is actually called is different between Lightning and pure PyTorch:

  • Lightning calls closure 20 times and zero_grad only 1 time while
  • PyTorch calls closure 20 times and zero_grad 20 times where 20 is the value of torch.optim.LBFGS(..., max_iter=20). (because obviously closure calls zero_grad inside. See the sample scripts below)

As mentioned in the PyTorch docs, the closure should call zero_grad, but the current Lightning calls it outside the closure not inside, and thus it's not working properly when using optimizers which need re-evaluation of the loss in optimizer.step(closure).

The closure should clear the gradients, compute the loss, and return it.

TODO

  • Call zero_grad in closure (and remove zero_grad calls outside the closure)
  • Update docs to reflect the new zero_grad position
  • Update docs to recommend using manual optimization when using a similar optimizer to torch.optim.LBFGS which needs reevaluation of the loss via closure.
  • Ensure that scheduler.step is called the same number of times as optimizer.step in manual optimization I'll disable scheduler.step in manual optimization in another PR. cc: @carmocca

Here are the minimal code examples using BoringModel.

Lightning code
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader, Dataset
pl.seed_everything(42)

class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)
    def __getitem__(self, index):
        return self.data[index]
    def __len__(self):
        return self.len

class BoringModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
    def forward(self, x):
        return self.layer(x)
    def loss(self, batch, prediction):
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}
    def training_step_end(self, training_step_outputs):
        loss = training_step_outputs["loss"]
        print("loss:", loss.item())
        return training_step_outputs
    def configure_optimizers(self):
        # optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        optimizer = torch.optim.LBFGS(self.parameters(), lr=0.01, max_iter=20)
        return optimizer

def main():
    ds = RandomDataset(32, 100000)
    dl = DataLoader(ds, batch_size=1024)
    model = BoringModel()
    trainer = pl.Trainer(
        progress_bar_refresh_rate=0,
        fast_dev_run=1,
    )
    trainer.fit(model, dl)

if __name__ == "__main__":
    main()
Pure PyTorch code
import torch
import torch.nn as nn
from pytorch_lightning import seed_everything
from torch.utils.data import DataLoader, Dataset
seed_everything(42)

class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)
    def __getitem__(self, index):
        return self.data[index]
    def __len__(self):
        return self.len

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.layer = torch.nn.Linear(32, 2)
    def forward(self, x):
        return self.layer(x)

def main():
    ds = RandomDataset(32, 100000)
    dl = DataLoader(ds, batch_size=1024)
    model = Model()
    # optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    optimizer = torch.optim.LBFGS(model.parameters(), lr=0.01, max_iter=20)
    for epoch in range(3):
        for i, x in enumerate(dl):
            def closure():
                prediction = model(x)
                loss = torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
                optimizer.zero_grad()  # removing this line causes the same bug as in Lightning script
                loss.backward()
                print("loss:", loss.item())
                return loss
            loss_out = optimizer.step(closure=closure)

if __name__ == '__main__':
    main()

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

cc: @carmocca @tchaton

@codecov
Copy link

codecov bot commented Feb 23, 2021

Codecov Report

Merging #6147 (9836bb7) into master (40d5a9d) will decrease coverage by 3%.
The diff coverage is 100%.

@@           Coverage Diff           @@
##           master   #6147    +/-   ##
=======================================
- Coverage      93%     90%    -3%     
=======================================
  Files         159     159            
  Lines       11380   11543   +163     
=======================================
- Hits        10624   10415   -209     
- Misses        756    1128   +372     

@akihironitta akihironitta changed the title Call optimizer.zero_grad() inside closure Call optimizer.zero_grad() before backward inside closure Feb 25, 2021
@akihironitta akihironitta changed the title Call optimizer.zero_grad() before backward inside closure Call optimizer.zero_grad() before backward inside closure in AutoOpt Feb 25, 2021
@akihironitta akihironitta marked this pull request as ready for review February 27, 2021 13:20
@carmocca carmocca merged commit 925f082 into Lightning-AI:master Mar 1, 2021
@akihironitta akihironitta deleted the bugfix/4083_lbfgs branch March 1, 2021 14:16
@tchaton tchaton added the bug Something isn't working label Mar 2, 2021
@tchaton tchaton added this to the 1.2.x milestone Mar 2, 2021
kaushikb11 pushed a commit to kaushikb11/pytorch-lightning that referenced this pull request Mar 2, 2021
kaushikb11 pushed a commit to kaushikb11/pytorch-lightning that referenced this pull request Mar 2, 2021
lexierule pushed a commit that referenced this pull request Mar 5, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

6 participants