Skip to content

Commit abe0582

Browse files
authored
[Flax] Add finetune Stable Diffusion (#999)
* [Flax] Add finetune Stable Diffusion * temporary fix * drop_last and seed * add dtype for mixed precision training * style * Add Flax example
1 parent 3be9fa9 commit abe0582

File tree

2 files changed

+594
-0
lines changed

2 files changed

+594
-0
lines changed

examples/text_to_image/README.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,24 @@ accelerate launch train_text_to_image.py \
6262
--output_dir="sd-pokemon-model"
6363
```
6464

65+
Or use the Flax implementation if you need a speedup
66+
67+
```bash
68+
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
69+
export dataset_name="lambdalabs/pokemon-blip-captions"
70+
71+
python train_text_to_image_flax.py \
72+
--pretrained_model_name_or_path=$MODEL_NAME \
73+
--dataset_name=$dataset_name \
74+
--resolution=512 --center_crop --random_flip \
75+
--train_batch_size=1 \
76+
--mixed_precision="fp16" \
77+
--max_train_steps=15000 \
78+
--learning_rate=1e-05 \
79+
--max_grad_norm=1 \
80+
--output_dir="sd-pokemon-model"
81+
```
82+
6583

6684
To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata).
6785
If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script.
@@ -86,6 +104,24 @@ accelerate launch train_text_to_image.py \
86104
--output_dir="sd-pokemon-model"
87105
```
88106

107+
Or use the Flax implementation if you need a speedup
108+
109+
```bash
110+
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
111+
export TRAIN_DIR="path_to_your_dataset"
112+
113+
python train_text_to_image_flax.py \
114+
--pretrained_model_name_or_path=$MODEL_NAME \
115+
--train_data_dir=$TRAIN_DIR \
116+
--resolution=512 --center_crop --random_flip \
117+
--train_batch_size=1 \
118+
--mixed_precision="fp16" \
119+
--max_train_steps=15000 \
120+
--learning_rate=1e-05 \
121+
--max_grad_norm=1 \
122+
--output_dir="sd-pokemon-model"
123+
```
124+
89125
Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline`
90126

91127

0 commit comments

Comments
 (0)