Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions examples/dreambooth/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,46 @@ accelerate launch train_dreambooth.py \
--max_train_steps=800
```

### Training on a 8 GB GPU:

By using [DeepSpeed](https://www.deepspeed.ai/) it's possible to offload some
tensors from VRAM to either CPU or NVME allowing to train with less VRAM.

DeepSpeed needs to be enabled with `accelerate config`. During configuration
answer yes to "Do you want to use DeepSpeed?". With DeepSpeed stage 2, fp16
mixed precision and offloading both parameters and optimizer state to cpu it's
possible to train on under 8 GB VRAM with a drawback of requiring significantly
more RAM (about 25 GB). See [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more DeepSpeed configuration options.

Changing the default Adam optimizer to DeepSpeed's special version of Adam
`deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but enabling
it requires CUDA toolchain with the same version as pytorch. 8-bit optimizer
does not seem to be compatible with DeepSpeed at the moment.

```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"

accelerate launch train_dreambooth.py \
--pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \
--instance_data_dir=$INSTANCE_DIR \
--class_data_dir=$CLASS_DIR \
--output_dir=$OUTPUT_DIR \
--with_prior_preservation --prior_loss_weight=1.0 \
--instance_prompt="a photo of sks dog" \
--class_prompt="a photo of dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 --gradient_checkpointing \
--learning_rate=5e-6 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=800 \
--mixed_precision=fp16
```

## Inference

Expand Down
22 changes: 15 additions & 7 deletions examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,9 +471,17 @@ def collate_fn(examples):
unet, optimizer, train_dataloader, lr_scheduler
)

# Move text_encode and vae to gpu
text_encoder.to(accelerator.device)
vae.to(accelerator.device)
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16

# Move text_encode and vae to gpu.
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
text_encoder.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
Expand Down Expand Up @@ -509,11 +517,11 @@ def collate_fn(examples):
with accelerator.accumulate(unet):
# Convert images to latent space
with torch.no_grad():
latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215

# Sample noise that we'll add to the latents
noise = torch.randn(latents.shape).to(latents.device)
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
Expand All @@ -539,12 +547,12 @@ def collate_fn(examples):
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()

# Compute prior loss
prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean()
prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")

# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why cast to float32 here ? Do we always want to compute loss in full precision ?
cc @patrickvonplaten

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's mixed precision best practice to calculate large reduction in higher precision. This calculates mean of batch_size * 4 * 64 * 64 halfs and mse_loss is one of the operations that would be automatically casted to fp32 in fp16 with autocast. I'm not sure it it's necessary at low batch size as it does seem to work without it, but it doesn't really affect memory consumption since it's only one operation and should give some safety.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, makes sense!


accelerator.backward(loss)
if accelerator.sync_gradients:
Expand Down