-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Optimize VRAM use in textual inversion training #687
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
| text_encoder.train() | ||
| for step, batch in enumerate(train_dataloader): | ||
| with accelerator.accumulate(text_encoder): | ||
| with accelerator.autocast(), accelerator.accumulate(text_encoder): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would be necessary in order for this to work without autocast?
There's some concern about its use: #511
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Speed is about 10% better without autocast and VRAM use decreases to 5500 MB. It does need some additional casts and I don't think it's quite identical functionally since some operations that would be autocasted to fp32 are computed in fp16. Results look still fine though so it might be okay. Smaller VRAM use does enable increasing batch size or disabling gradient checkpointing for even bigger speed up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked a bit. The main problem is
accelerator.backward(loss)
If that can be done in fp16, it should work without autocast
|
@isamu-isozaki is this the approach you've been using? |
|
@keturn sry on second thought this is way different from my approach but it's way better too! |
|
If you add |
|
@patil-suraj can you take a look here? |
|
I'm using locally saved weights and adding |
|
I tried revision fp16 and got oom for some reason. will double check later today |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the PR! As explained in the comments, we shouldn't cast all weights to half-precision unless specified by user and autocast should not be used as default as it won't allow full-precision training.
Also, please note that, the examples scripts are just examples, they show how to do a certain example with simple and easy way. For more customization it is recommended that the user should modify the script on their own as they need it. This will help keep the script simple so any one can understand and modify it for themselves if needed.
if you just pass mixed_precision="fp16, accelerate should enable mixed-precision without any code changes.
I'm not in favor of this, Hope you understand, thanks a lot!
| if args.gradient_checkpointing: | ||
| unet.enable_gradient_checkpointing() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unet is not trained in textul inversion, so gradient checkpointing here is not necessary, as no grads are computed for it.
| weight_dtype = torch.float32 | ||
| if args.mixed_precision == "fp16": | ||
| weight_dtype = torch.float16 | ||
| elif args.mixed_precision == "bf16": | ||
| weight_dtype = torch.bfloat16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be enabled by a flag, we can't always assume if user wants to cast weighst to half-precision. Also mixed precision training the weights are usually not cast to half-precision, only the forward pass runs in half precision.
| text_encoder.train() | ||
| for step, batch in enumerate(train_dataloader): | ||
| with accelerator.accumulate(text_encoder): | ||
| with accelerator.autocast(), accelerator.accumulate(text_encoder): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will always do the training in half-precision, what if the user wants to do fp32 training. We should not put autocast directly here.
| output_states = () | ||
|
|
||
| for resnet, attn in zip(self.resnets, self.attentions): | ||
| if self.training and self.gradient_checkpointing: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should not be removed, gradient checkpointing is only required during training.
|
The idea with casting the weights of non-trained nets is that without it fp32 weights are transferred to vram even when training in fp16. Since they are not trained we don't need to keep fp32 copy of them in vram.
I can maintain a copy of the script that casts non-trained weights to fp16 locally, but it would be nice if the gradient checkpointing changes would be merged. Would you be fine with that? |
That's a really good observation! Sorry, I rushed the review a bit. In this case keeping the gradient checkpointing changes makes sense, let me try it quickly and get back to you. Thanks a lot! |
|
Also pinging @patrickvonplaten and @anton-l . Are the activations stored even when the grads are disabled for the model ? |
Cast frozen modules to fp16/bf16 when using mixed precision. Add gradient checkpoint command line option.
|
I added a commit that removes the autocast. It should work with fp32 and bf16 too but I can't test it on my GPU. This PR does have a side effect that it saves fp16 quantized weights of unet and vae since fp32 weights from those were discarded if training in fp16. If you prefer I can remove the training changes and only keep the gradient checkpointing change. |
|
I think this PR is currently blocked by:
cc @patil-suraj |
|
Any progress on the blockers? |
|
@patil-suraj could you have another look here? |
patil-suraj
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry to only comeback to this now!
- the output of
gradient_checkpointingenabled model intrainmode is deterministic when dropout is zero (which is the default), I ran a few tests to confirm this. - We have dropout set to 0 by default, and need to allow passing it to the model. But this could be added later.
So there is no blocker for this PR now. @Ttl could you please adapt the PR to put the unet in train mode, and not modify the gradient checkpointing parts. Then I'll open a follow-up PR to allow passing dropout to the model that we can set it 0 to always make sure we have deterministic output from unet.
|
cc @patil-suraj here - could you maybe go into the PR if author doesn't reply anymore to not forget it? |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
cc @patil-suraj - could you maybe post some instructions here on how to proceed? Then someone else could pick it up |
|
Sorry for being late again, I've posted instructions in this comment #687 (review) @Ttl LMK if you are busy, then I'll make the necessary changes and merge :) |
|
It's been quite long since I last looked at this code and I haven't used textual inversion much anymore. Feel free to make necessary changes to get it merged if you want to. |
|
Thanks, will make open a PR then :) |
This commit simplifies the code to identify the model name for a particular set of flags. This is achieved by introducing a json file that stores the model names information. The models are uploaded in gcloud with these names. Signed-Off-by: Gaurav Shukla <[email protected]> Signed-off-by: Gaurav Shukla <[email protected]>
Cast frozen modules to fp16/bf16 when using mixed precision. Add gradient checkpoint command line option.
OOMs before on my 8 GB VRAM GPU. With these changes and using
--mixed_precision=fp16 --gradient_checkpointingVRAM use is 6341 MB and the results look good.