From 4e5c448e169cb0fcebd01c3c0e7653a0aadf36dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20Forst=C3=A9n?= Date: Wed, 5 Oct 2022 13:30:52 +0300 Subject: [PATCH 1/4] Support deepspeed --- examples/dreambooth/train_dreambooth.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 73db1d6fc0d7..42bea8c22140 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -490,9 +490,15 @@ def collate_fn(examples): unet, optimizer, train_dataloader, lr_scheduler ) + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + # Move text_encode and vae to gpu - text_encoder.to(accelerator.device) - vae.to(accelerator.device) + text_encoder.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -528,11 +534,11 @@ def collate_fn(examples): with accelerator.accumulate(unet): # Convert images to latent space with torch.no_grad(): - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215 # Sample noise that we'll add to the latents - noise = torch.randn(latents.shape).to(latents.device) + noise = torch.randn(latents.shape).to(latents.device, dtype=weight_dtype) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) @@ -540,11 +546,11 @@ def collate_fn(examples): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps).to(dtype=weight_dtype) # Get the text embedding for conditioning with torch.no_grad(): - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype) # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample @@ -558,12 +564,12 @@ def collate_fn(examples): loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() # Compute prior loss - prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() + prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") accelerator.backward(loss) if accelerator.sync_gradients: From 9ea00783552616ad904f8d3a6601727d5ef80aea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20Forst=C3=A9n?= Date: Wed, 5 Oct 2022 17:17:37 +0300 Subject: [PATCH 2/4] Dreambooth DeepSpeed documentation --- examples/dreambooth/README.md | 37 +++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index e6dbf9667e44..41575e9421ef 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -119,6 +119,43 @@ accelerate launch train_dreambooth.py \ --max_train_steps=800 ``` +### Training on a 8 GB GPU: + +By using [DeepSpeed](https://www.deepspeed.ai/) it's possible to offload some +tensors from VRAM to either CPU or NVME allowing to train with less VRAM. + +DeepSpeed needs to be enabled with `accelerate config`. During configuration +answer yes to "Do you want to use DeepSpeed?". With DeepSpeed stage 2, fp16 +mixed precision and offloading both parameters and optimizer state to cpu it's +possible to train on under 8 GB VRAM with a drawback of requiring significantly +more RAM (about 25 GB). + +8-bit optimizer does not seem to be compatible with DeepSpeed at the moment. + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export INSTANCE_DIR="path-to-instance-images" +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="path-to-save-model" + +accelerate launch train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=1 --gradient_checkpointing \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --max_train_steps=800 \ + --mixed_precision=fp16 +``` ## Inference From 2bee486eba8003cbea1b881167a9140df1c8ab3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20Forst=C3=A9n?= Date: Sat, 8 Oct 2022 09:36:15 +0300 Subject: [PATCH 3/4] Remove unnecessary casts, documentation Due to recent commits some casts to half precision are not necessary anymore. Mention that DeepSpeed's version of Adam is about 2x faster. --- examples/dreambooth/README.md | 5 ++++- examples/dreambooth/train_dreambooth.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index 4eaca09b8e4f..e8b55a167c65 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -130,7 +130,10 @@ mixed precision and offloading both parameters and optimizer state to cpu it's possible to train on under 8 GB VRAM with a drawback of requiring significantly more RAM (about 25 GB). -8-bit optimizer does not seem to be compatible with DeepSpeed at the moment. +Changing the default Adam optimizer to DeepSpeed's special version of Adam +`deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but enabling +it requires CUDA toolchain with the same version as pytorch. 8-bit optimizer +does not seem to be compatible with DeepSpeed at the moment. ```bash export MODEL_NAME="CompVis/stable-diffusion-v1-4" diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 3ac29c4243c7..82675d3b446a 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -527,11 +527,11 @@ def collate_fn(examples): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps).to(dtype=weight_dtype) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning with torch.no_grad(): - encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype) + encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample From f115e8f22dd3195f62f1b38c51a2ca765284aa0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20Forst=C3=A9n?= Date: Mon, 10 Oct 2022 20:53:27 +0300 Subject: [PATCH 4/4] Review comments --- examples/dreambooth/README.md | 2 +- examples/dreambooth/train_dreambooth.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index e8b55a167c65..9ff90ea809a7 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -128,7 +128,7 @@ DeepSpeed needs to be enabled with `accelerate config`. During configuration answer yes to "Do you want to use DeepSpeed?". With DeepSpeed stage 2, fp16 mixed precision and offloading both parameters and optimizer state to cpu it's possible to train on under 8 GB VRAM with a drawback of requiring significantly -more RAM (about 25 GB). +more RAM (about 25 GB). See [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more DeepSpeed configuration options. Changing the default Adam optimizer to DeepSpeed's special version of Adam `deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but enabling diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 82675d3b446a..c3b875a5e95d 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -477,7 +477,9 @@ def collate_fn(examples): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - # Move text_encode and vae to gpu + # Move text_encode and vae to gpu. + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. text_encoder.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) @@ -519,7 +521,7 @@ def collate_fn(examples): latents = latents * 0.18215 # Sample noise that we'll add to the latents - noise = torch.randn(latents.shape).to(latents.device, dtype=weight_dtype) + noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)