Skip to content

Conversation

@Ttl
Copy link
Contributor

@Ttl Ttl commented Sep 30, 2022

Cast frozen modules to fp16/bf16 when using mixed precision. Add gradient checkpoint command line option.

OOMs before on my 8 GB VRAM GPU. With these changes and using --mixed_precision=fp16 --gradient_checkpointing VRAM use is 6341 MB and the results look good.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 30, 2022

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

text_encoder.train()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(text_encoder):
with accelerator.autocast(), accelerator.accumulate(text_encoder):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be necessary in order for this to work without autocast?

There's some concern about its use: #511

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speed is about 10% better without autocast and VRAM use decreases to 5500 MB. It does need some additional casts and I don't think it's quite identical functionally since some operations that would be autocasted to fp32 are computed in fp16. Results look still fine though so it might be okay. Smaller VRAM use does enable increasing batch size or disabling gradient checkpointing for even bigger speed up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked a bit. The main problem is

accelerator.backward(loss)

If that can be done in fp16, it should work without autocast

@keturn
Copy link
Contributor

keturn commented Oct 2, 2022

@isamu-isozaki is this the approach you've been using?

@isamu-isozaki
Copy link
Contributor

@keturn Pretty much! I didn't know about the autocast functionality so I manually moved most of the parts to cpu and cuda. The code is here. One thing I remember was that for 6gb of ram, it'll omm before that accelerator.accumulate part so most models are better moved to the cpu.

@isamu-isozaki
Copy link
Contributor

Hi! Great pr! @Ttl @keturn. It doesn't fit in 6gb ram as it is now but once I did this

slice_size = unet.config.attention_head_dim // 2
unet.set_attention_slice(slice_size)

it fits. Thanks for this! Now my training will get way better.

@isamu-isozaki
Copy link
Contributor

@keturn sry on second thought this is way different from my approach but it's way better too!

@keturn
Copy link
Contributor

keturn commented Oct 3, 2022

If you add revision="fp16" to from_pretrained, do you still have to do the conversions to weight_dtype?

@patrickvonplaten
Copy link
Contributor

@patil-suraj can you take a look here?

@Ttl
Copy link
Contributor Author

Ttl commented Oct 4, 2022

I'm using locally saved weights and adding revision="fp16" doesn't seem to do anything in that case.

@isamu-isozaki
Copy link
Contributor

I tried revision fp16 and got oom for some reason. will double check later today

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR! As explained in the comments, we shouldn't cast all weights to half-precision unless specified by user and autocast should not be used as default as it won't allow full-precision training.

Also, please note that, the examples scripts are just examples, they show how to do a certain example with simple and easy way. For more customization it is recommended that the user should modify the script on their own as they need it. This will help keep the script simple so any one can understand and modify it for themselves if needed.

if you just pass mixed_precision="fp16, accelerate should enable mixed-precision without any code changes.

I'm not in favor of this, Hope you understand, thanks a lot!

Comment on lines +481 to +482
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unet is not trained in textul inversion, so gradient checkpointing here is not necessary, as no grads are computed for it.

Comment on lines +484 to +488
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be enabled by a flag, we can't always assume if user wants to cast weighst to half-precision. Also mixed precision training the weights are usually not cast to half-precision, only the forward pass runs in half precision.

text_encoder.train()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(text_encoder):
with accelerator.autocast(), accelerator.accumulate(text_encoder):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will always do the training in half-precision, what if the user wants to do fp32 training. We should not put autocast directly here.

output_states = ()

for resnet, attn in zip(self.resnets, self.attentions):
if self.training and self.gradient_checkpointing:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be removed, gradient checkpointing is only required during training.

@patil-suraj patil-suraj closed this Oct 5, 2022
@Ttl
Copy link
Contributor Author

Ttl commented Oct 5, 2022

The idea with casting the weights of non-trained nets is that without it fp32 weights are transferred to vram even when training in fp16. Since they are not trained we don't need to keep fp32 copy of them in vram.

self.training is controlled by train() or eval() call of the module. Since in this case we have set unet to be in eval() without removing the self.training gradient checkpointing is not enabled. Gradient checkpointing is useful in this case since we need to store activations in the unet for backwards pass since it's between our trainable weights and loss calculation. I checked that enabling it saves 1080 MB of memory.

I can maintain a copy of the script that casts non-trained weights to fp16 locally, but it would be nice if the gradient checkpointing changes would be merged. Would you be fine with that?

@patil-suraj
Copy link
Contributor

Gradient checkpointing is useful in this case since we need to store activations in the unet for backwards pass since it's between our trainable weights and loss calculation. I checked that enabling it saves 1080 MB of memory.

That's a really good observation! Sorry, I rushed the review a bit. In this case keeping the gradient checkpointing changes makes sense, let me try it quickly and get back to you.

Thanks a lot!

@patil-suraj patil-suraj reopened this Oct 5, 2022
@patil-suraj
Copy link
Contributor

Also pinging @patrickvonplaten and @anton-l . Are the activations stored even when the grads are disabled for the model ?

Ttl added 3 commits October 5, 2022 16:41
Cast frozen modules to fp16/bf16 when using mixed precision.
Add gradient checkpoint command line option.
@Ttl
Copy link
Contributor Author

Ttl commented Oct 5, 2022

I added a commit that removes the autocast. It should work with fp32 and bf16 too but I can't test it on my GPU. This PR does have a side effect that it saves fp16 quantized weights of unet and vae since fp32 weights from those were discarded if training in fp16. If you prefer I can remove the training changes and only keep the gradient checkpointing change.

@patrickvonplaten
Copy link
Contributor

I think this PR is currently blocked by:

  • gradient_checkpointing being a bit flaky when model is set to train mode
  • that we're not able to pass dropout down

cc @patil-suraj

@Thomas-MMJ
Copy link

Any progress on the blockers?

@patrickvonplaten
Copy link
Contributor

@patil-suraj could you have another look here?

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to only comeback to this now!

  • the output of gradient_checkpointing enabled model in train mode is deterministic when dropout is zero (which is the default), I ran a few tests to confirm this.
  • We have dropout set to 0 by default, and need to allow passing it to the model. But this could be added later.

So there is no blocker for this PR now. @Ttl could you please adapt the PR to put the unet in train mode, and not modify the gradient checkpointing parts. Then I'll open a follow-up PR to allow passing dropout to the model that we can set it 0 to always make sure we have deterministic output from unet.

@patrickvonplaten
Copy link
Contributor

cc @patil-suraj here - could you maybe go into the PR if author doesn't reply anymore to not forget it?

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Dec 10, 2022
@patrickvonplaten
Copy link
Contributor

cc @patil-suraj - could you maybe post some instructions here on how to proceed? Then someone else could pick it up

@patil-suraj
Copy link
Contributor

Sorry for being late again, I've posted instructions in this comment #687 (review)

@Ttl LMK if you are busy, then I'll make the necessary changes and merge :)

@Ttl
Copy link
Contributor Author

Ttl commented Dec 28, 2022

It's been quite long since I last looked at this code and I haven't used textual inversion much anymore. Feel free to make necessary changes to get it merged if you want to.

@Ttl Ttl closed this Dec 28, 2022
@patil-suraj
Copy link
Contributor

Thanks, will make open a PR then :)

PhaneeshB pushed a commit to nod-ai/diffusers that referenced this pull request Mar 1, 2023
This commit simplifies the code to identify the model name for a
particular set of flags. This is achieved by introducing a json file
that stores the model names information. The models are uploaded in
gcloud with these names.

Signed-Off-by: Gaurav Shukla <[email protected]>

Signed-off-by: Gaurav Shukla <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

stale Issues that haven't received updates

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants