Skip to content

Conversation

@duongna21
Copy link
Contributor

What does this PR do?

  • Add Flax example for finetuning Stable Diffusion.
  • Ran well on Tesla A100 (40GB) with max batch size = 3.
  • EMA is not added because my 1-line tree_map implementation for EMA make training much slower, and it did not show a visible improvement on the result.

How to run (76% faster than PyTorch example with same args on Tesla A100)

export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export dataset_name="lambdalabs/pokemon-blip-captions"

python train_text_to_image_flax.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --dataset_name=$dataset_name \
  --resolution=512 --center_crop --random_flip \
  --train_batch_size=1 \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --output_dir="sd-pokemon-model" 

Prompt: robotic cat with wings

ảnh

Who can review?

cc @patrickvonplaten @patil-suraj

@patil-suraj patil-suraj self-assigned this Oct 26, 2022
@patil-suraj
Copy link
Contributor

Very cool PR @duongna21 , amazing, will take a look soon !

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 26, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very good, thanks a lot for adding this example. Left some comments about dataloader and dtype.

Also, let's update the readme with an example command to show how to run these examples.

@duongna21
Copy link
Contributor Author

@patil-suraj Thanks for very helpful comments. Addressed them!

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the comments, good for merge now!

@patil-suraj patil-suraj merged commit abe0582 into huggingface:main Oct 27, 2022
@duongna21 duongna21 deleted the add-finetune-sd-flax branch October 27, 2022 14:19
@entrpn
Copy link
Contributor

entrpn commented Oct 27, 2022

@duongna21 Thank you for this contribution. I did have to make a change in order to get it working with TPUs.

device_type = jax.devices()[0].device_kind

weight_dtype = torch.float32
if 'TPU' in device_type:
    weight_dtype = jnp.float32
    if args.mixed_precision == "fp16":
        weight_dtype = jnp.float16
    if args.mixed_precision == "bf16":
        weight_dtype = jnp.bfloat16
else:
    if args.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif args.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

@duongna21
Copy link
Contributor Author

@entrpn Thanks a lot! Will be fixed at #1038.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants