-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[dreambooth] low precision guard #1916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[dreambooth] low precision guard #1916
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
|
@patil-suraj I added docs to the relevant cli flags. Are there any other additional locations we should update docs around loaded model precision? |
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot!
Co-authored-by: Patrick von Platen <[email protected]>
patil-suraj
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
No, I think this covers it :) |
|
Am I missing something from the readme? @patil-suraj @patrickvonplaten Or is it something related to the readme? Will try out the solution of #1817 |
re: #1246
training/fine tuning shouldn't be done with fp16 weights1, fp16 inputs are ok with amp + gradient scaling. fp16 weights throw an error when used with amp + gradient scaling. We should check the dtype of the loaded model and throw an informative error before training begins.
We add a guard checking the datatype of the unet. We also add a guard checking the datatype of the text encoder if we are training the text encoder.
Footnotes
precision issues when adding small gradient updates to fp16 weights. Reason why training with amp recommends to keep weights as fp32 for gradient updates and makes a copy in half precision for forward and backward passes. ↩