Skip to content

Conversation

@patil-suraj
Copy link
Contributor

This PR is a continuation of #687 by @Ttl . This PR

  • Adds an argument --gradient_checkpointing to enable gradient checkpointing. The gradient checkpointing is enabled for the text_encoder and unet. For this, the unet is kept in train mode and as the dropout is 0 it should not affect the results.
  • Remove freeze_params and instead use requires_grad_ to disable grads.
  • Remove the redundant call to accelerator.wait_for_everyone() after every epoch, it's only needed before saving the pipeline.

vae.to(accelerator.device, dtype=weight_dtype)

# Keep vae and unet in eval model as we don't train these
vae.eval()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vae is already in eval mode by default, when loaded using from_pretrained.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 28, 2022

The documentation is not available anymore as the PR was closed or merged.

@patil-suraj
Copy link
Contributor Author

Did a trial run, and verified that keeping the unet in train mode when gradient checkpointing is enabled does not affect the results.

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very nice!

@patil-suraj patil-suraj merged commit 9ea7052 into main Dec 29, 2022
@patil-suraj patil-suraj deleted the ti-update branch December 29, 2022 14:02
@patrickvonplaten
Copy link
Contributor

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants