Skip to content

Commit 1eddce8

Browse files
patil-surajPrathik Rao
authored andcommitted
[dreambooth] allow fine-tuning text encoder (huggingface#883)
* allow fine-tuning text encoder * fix a few things * update readme
1 parent ec7a08d commit 1eddce8

File tree

2 files changed

+76
-12
lines changed

2 files changed

+76
-12
lines changed

examples/dreambooth/README.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,39 @@ accelerate launch train_dreambooth.py \
160160
--mixed_precision=fp16
161161
```
162162

163+
### Fine-tune text encoder with the UNet.
164+
165+
The script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces.
166+
Pass the `--train_text_encoder` argument to the script to enable training `text_encoder`.
167+
168+
___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___
169+
170+
```bash
171+
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
172+
export INSTANCE_DIR="path-to-instance-images"
173+
export CLASS_DIR="path-to-class-images"
174+
export OUTPUT_DIR="path-to-save-model"
175+
176+
accelerate launch train_dreambooth.py \
177+
--pretrained_model_name_or_path=$MODEL_NAME \
178+
--train_text_encoder \
179+
--instance_data_dir=$INSTANCE_DIR \
180+
--class_data_dir=$CLASS_DIR \
181+
--output_dir=$OUTPUT_DIR \
182+
--with_prior_preservation --prior_loss_weight=1.0 \
183+
--instance_prompt="a photo of sks dog" \
184+
--class_prompt="a photo of dog" \
185+
--resolution=512 \
186+
--train_batch_size=1 \
187+
--use_8bit_adam
188+
--gradient_checkpointing \
189+
--learning_rate=2e-6 \
190+
--lr_scheduler="constant" \
191+
--lr_warmup_steps=0 \
192+
--num_class_images=200 \
193+
--max_train_steps=800
194+
```
195+
163196
## Inference
164197

165198
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.

examples/dreambooth/train_dreambooth.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import itertools
23
import math
34
import os
45
from pathlib import Path
@@ -100,6 +101,7 @@ def parse_args():
100101
parser.add_argument(
101102
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
102103
)
104+
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
103105
parser.add_argument(
104106
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
105107
)
@@ -320,6 +322,15 @@ def main():
320322
logging_dir=logging_dir,
321323
)
322324

325+
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
326+
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
327+
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
328+
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
329+
raise ValueError(
330+
"Gradient accumulation is not supported when training the text encoder in distributed training. "
331+
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
332+
)
333+
323334
if args.seed is not None:
324335
set_seed(args.seed)
325336

@@ -385,8 +396,14 @@ def main():
385396
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
386397
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
387398

399+
vae.requires_grad_(False)
400+
if not args.train_text_encoder:
401+
text_encoder.requires_grad_(False)
402+
388403
if args.gradient_checkpointing:
389404
unet.enable_gradient_checkpointing()
405+
if args.train_text_encoder:
406+
text_encoder.gradient_checkpointing_enable()
390407

391408
if args.scale_lr:
392409
args.learning_rate = (
@@ -406,8 +423,11 @@ def main():
406423
else:
407424
optimizer_class = torch.optim.AdamW
408425

426+
params_to_optimize = (
427+
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
428+
)
409429
optimizer = optimizer_class(
410-
unet.parameters(), # only optimize unet
430+
params_to_optimize,
411431
lr=args.learning_rate,
412432
betas=(args.adam_beta1, args.adam_beta2),
413433
weight_decay=args.adam_weight_decay,
@@ -467,9 +487,14 @@ def collate_fn(examples):
467487
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
468488
)
469489

470-
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
471-
unet, optimizer, train_dataloader, lr_scheduler
472-
)
490+
if args.train_text_encoder:
491+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
492+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
493+
)
494+
else:
495+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
496+
unet, optimizer, train_dataloader, lr_scheduler
497+
)
473498

474499
weight_dtype = torch.float32
475500
if args.mixed_precision == "fp16":
@@ -480,8 +505,9 @@ def collate_fn(examples):
480505
# Move text_encode and vae to gpu.
481506
# For mixed precision training we cast the text_encoder and vae weights to half-precision
482507
# as these models are only used for inference, keeping weights in full precision is not required.
483-
text_encoder.to(accelerator.device, dtype=weight_dtype)
484508
vae.to(accelerator.device, dtype=weight_dtype)
509+
if not args.train_text_encoder:
510+
text_encoder.to(accelerator.device, dtype=weight_dtype)
485511

486512
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
487513
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -516,9 +542,8 @@ def collate_fn(examples):
516542
for step, batch in enumerate(train_dataloader):
517543
with accelerator.accumulate(unet):
518544
# Convert images to latent space
519-
with torch.no_grad():
520-
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
521-
latents = latents * 0.18215
545+
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
546+
latents = latents * 0.18215
522547

523548
# Sample noise that we'll add to the latents
524549
noise = torch.randn_like(latents)
@@ -532,8 +557,7 @@ def collate_fn(examples):
532557
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
533558

534559
# Get the text embedding for conditioning
535-
with torch.no_grad():
536-
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
560+
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
537561

538562
# Predict the noise residual
539563
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
@@ -556,7 +580,12 @@ def collate_fn(examples):
556580

557581
accelerator.backward(loss)
558582
if accelerator.sync_gradients:
559-
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
583+
params_to_clip = (
584+
itertools.chain(unet.parameters(), text_encoder.parameters())
585+
if args.train_text_encoder
586+
else unet.parameters()
587+
)
588+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
560589
optimizer.step()
561590
lr_scheduler.step()
562591
optimizer.zero_grad()
@@ -578,7 +607,9 @@ def collate_fn(examples):
578607
# Create the pipeline using using the trained modules and save it.
579608
if accelerator.is_main_process:
580609
pipeline = StableDiffusionPipeline.from_pretrained(
581-
args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet)
610+
args.pretrained_model_name_or_path,
611+
unet=accelerator.unwrap_model(unet),
612+
text_encoder=accelerator.unwrap_model(text_encoder),
582613
)
583614
pipeline.save_pretrained(args.output_dir)
584615

0 commit comments

Comments
 (0)