-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Call optimizer.zero_grad() before backward inside closure in AutoOpt
#6147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Codecov Report
@@ Coverage Diff @@
## master #6147 +/- ##
=======================================
- Coverage 93% 90% -3%
=======================================
Files 159 159
Lines 11380 11543 +163
=======================================
- Hits 10624 10415 -209
- Misses 756 1128 +372 |
Closed
akihironitta
commented
Feb 24, 2021
…ightning into bugfix/4083_lbfgs
justusschock
approved these changes
Feb 25, 2021
optimizer.zero_grad() inside closureoptimizer.zero_grad() before backward inside closure
optimizer.zero_grad() before backward inside closureoptimizer.zero_grad() before backward inside closure in AutoOpt
carmocca
approved these changes
Feb 27, 2021
awaelchli
reviewed
Feb 27, 2021
kaushikb11
pushed a commit
to kaushikb11/pytorch-lightning
that referenced
this pull request
Mar 2, 2021
Lightning-AI#6147) Co-authored-by: Carlos Mocholi <[email protected]>
kaushikb11
pushed a commit
to kaushikb11/pytorch-lightning
that referenced
this pull request
Mar 2, 2021
Lightning-AI#6147) Co-authored-by: Carlos Mocholi <[email protected]>
lexierule
pushed a commit
that referenced
this pull request
Mar 5, 2021
#6147) Co-authored-by: Carlos Mocholi <[email protected]>
11 tasks
11 tasks
This was referenced Mar 12, 2021
Closed
Closed
[tune](deps): Bump pytorch-lightning from 1.0.3 to 1.2.3 in /python/requirements
sumanthratna/ray#12
Closed
Closed
Closed
Closed
Closed
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Fixes #4083
Fixes #5545
To-check #6134
Description of the changes
Makes sure to call
zero_gradinside the closure function (TrainerLoop.training_step_and_backward()).Note that this positions the
zero_gradcall beforebackward, as generally suggested throughout PyTorch's docs.Reported that LBFGS doesn't work In #4083, we then found that the number of times
zero_gradis actually called is different between Lightning and pure PyTorch:closure20 times andzero_gradonly 1 time whileclosure20 times andzero_grad20 times where 20 is the value oftorch.optim.LBFGS(..., max_iter=20). (because obviouslyclosurecallszero_gradinside. See the sample scripts below)As mentioned in the PyTorch docs, the closure should call
zero_grad, but the current Lightning calls it outside the closure not inside, and thus it's not working properly when using optimizers which need re-evaluation of the loss inoptimizer.step(closure).TODO
zero_gradin closure (and removezero_gradcalls outside the closure)zero_gradpositionUpdate docs to recommend using manual optimization when using a similar optimizer totorch.optim.LBFGSwhich needs reevaluation of the loss via closure.Ensure that scheduler.step is called the same number of times as optimizer.step in manual optimizationI'll disablescheduler.stepin manual optimization in another PR. cc: @carmoccaHere are the minimal code examples using BoringModel.
Lightning code
Pure PyTorch code
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃
cc: @carmocca @tchaton