-
Notifications
You must be signed in to change notification settings - Fork 6.5k
DreamBooth DeepSpeed support for under 8 GB VRAM training #735
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -471,9 +471,17 @@ def collate_fn(examples): | |
| unet, optimizer, train_dataloader, lr_scheduler | ||
| ) | ||
|
|
||
| # Move text_encode and vae to gpu | ||
| text_encoder.to(accelerator.device) | ||
| vae.to(accelerator.device) | ||
| weight_dtype = torch.float32 | ||
| if args.mixed_precision == "fp16": | ||
| weight_dtype = torch.float16 | ||
| elif args.mixed_precision == "bf16": | ||
| weight_dtype = torch.bfloat16 | ||
|
|
||
| # Move text_encode and vae to gpu. | ||
| # For mixed precision training we cast the text_encoder and vae weights to half-precision | ||
| # as these models are only used for inference, keeping weights in full precision is not required. | ||
| text_encoder.to(accelerator.device, dtype=weight_dtype) | ||
| vae.to(accelerator.device, dtype=weight_dtype) | ||
Ttl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # We need to recalculate our total training steps as the size of the training dataloader may have changed. | ||
| num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | ||
|
|
@@ -509,11 +517,11 @@ def collate_fn(examples): | |
| with accelerator.accumulate(unet): | ||
| # Convert images to latent space | ||
| with torch.no_grad(): | ||
| latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | ||
| latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() | ||
| latents = latents * 0.18215 | ||
|
|
||
| # Sample noise that we'll add to the latents | ||
| noise = torch.randn(latents.shape).to(latents.device) | ||
| noise = torch.randn_like(latents) | ||
| bsz = latents.shape[0] | ||
| # Sample a random timestep for each image | ||
| timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) | ||
|
|
@@ -539,12 +547,12 @@ def collate_fn(examples): | |
| loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
|
|
||
| # Compute prior loss | ||
| prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() | ||
| prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") | ||
|
|
||
| # Add the prior loss to the instance loss. | ||
| loss = loss + args.prior_loss_weight * prior_loss | ||
| else: | ||
| loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
| loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why cast to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's mixed precision best practice to calculate large reduction in higher precision. This calculates mean of
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, makes sense! |
||
|
|
||
| accelerator.backward(loss) | ||
| if accelerator.sync_gradients: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.