Skip to content

Commit a23ad87

Browse files
[Flax] Add Textual Inversion (#880)
* add textual inversion flax * make style * make style * replicate vae and unet params * make style * minor * save after end of training * style * Temporary fix Co-authored-by: Suraj Patil <[email protected]> * Add Flax instruction Co-authored-by: Suraj Patil <[email protected]>
1 parent d3d22ce commit a23ad87

File tree

2 files changed

+647
-0
lines changed

2 files changed

+647
-0
lines changed

examples/textual_inversion/README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,24 @@ accelerate launch textual_inversion.py \
6868

6969
A full training run takes ~1 hour on one V100 GPU.
7070

71+
If you want to speed it up even more, Flax implementation is available:
72+
73+
```bash
74+
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
75+
export DATA_DIR="path-to-dir-containing-images"
76+
77+
python textual_inversion_flax.py \
78+
--pretrained_model_name_or_path=$MODEL_NAME \
79+
--train_data_dir=$DATA_DIR \
80+
--learnable_property="object" \
81+
--placeholder_token="<cat-toy>" --initializer_token="toy" \
82+
--resolution=512 \
83+
--train_batch_size=1 \
84+
--max_train_steps=3000 \
85+
--learning_rate=5.0e-04 --scale_lr \
86+
--output_dir="textual_inversion_cat"
87+
```
88+
It should be at least 70% faster than the PyTorch script with the same configuration.
7189

7290
### Inference
7391

0 commit comments

Comments
 (0)