-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[Flax] Add finetune Stable Diffusion #999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Very cool PR @duongna21 , amazing, will take a look soon ! |
|
The documentation is not available anymore as the PR was closed or merged. |
patil-suraj
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very good, thanks a lot for adding this example. Left some comments about dataloader and dtype.
Also, let's update the readme with an example command to show how to run these examples.
|
@patil-suraj Thanks for very helpful comments. Addressed them! |
patil-suraj
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing the comments, good for merge now!
|
@duongna21 Thank you for this contribution. I did have to make a change in order to get it working with TPUs. device_type = jax.devices()[0].device_kind
weight_dtype = torch.float32
if 'TPU' in device_type:
weight_dtype = jnp.float32
if args.mixed_precision == "fp16":
weight_dtype = jnp.float16
if args.mixed_precision == "bf16":
weight_dtype = jnp.bfloat16
else:
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16 |
What does this PR do?
How to run (76% faster than PyTorch example with same args on Tesla A100)
Prompt:
robotic cat with wingsWho can review?
cc @patrickvonplaten @patil-suraj