-
Notifications
You must be signed in to change notification settings - Fork 6.5k
dreambooth: fix #1566: maintain fp32 wrapper when saving a checkpoint to avoid crash when running fp16 #1618
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
… checkpoint to avoid crash when running fp16
…ions of accelerate. part of fix for huggingface#1566
053e6b7 to
ebfd8bf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was already reviewed by @patil-suraj in #1567, and it looks good to me! I just suggested a minor rewording of the comment.
Co-authored-by: Pedro Cuenca <[email protected]>
|
Thanks @timh! @patil-suraj @pcuenca @williamberman let's not forgot to remove this when accelerate is forced to be a newer version |
… checkpoint to avoid crash when running fp16 (huggingface#1618) * dreambooth: fix huggingface#1566: maintain fp32 wrapper when saving a checkpoint to avoid crash when running fp16 * dreambooth: guard against passing keep_fp32_wrapper arg to older versions of accelerate. part of fix for huggingface#1566 * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Update examples/dreambooth/train_dreambooth.py Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
… checkpoint to avoid crash when running fp16 (huggingface#1618) * dreambooth: fix huggingface#1566: maintain fp32 wrapper when saving a checkpoint to avoid crash when running fp16 * dreambooth: guard against passing keep_fp32_wrapper arg to older versions of accelerate. part of fix for huggingface#1566 * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Update examples/dreambooth/train_dreambooth.py Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
When using mixed precision and trying to save weights every N steps I was getting this error after the first save step: RuntimeError: Input type (struct c10::Half) and bias type (float) should be the same Adding keep_fp32_wrapper=True to the two unwrap_model calls on save_weights seems to fix the issue.
same code changes as PR #1567, but with a proper branch name now, so the merge commit is nicer :)