@@ -471,9 +471,17 @@ def collate_fn(examples):
471471 unet , optimizer , train_dataloader , lr_scheduler
472472 )
473473
474- # Move text_encode and vae to gpu
475- text_encoder .to (accelerator .device )
476- vae .to (accelerator .device )
474+ weight_dtype = torch .float32
475+ if args .mixed_precision == "fp16" :
476+ weight_dtype = torch .float16
477+ elif args .mixed_precision == "bf16" :
478+ weight_dtype = torch .bfloat16
479+
480+ # Move text_encode and vae to gpu.
481+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
482+ # as these models are only used for inference, keeping weights in full precision is not required.
483+ text_encoder .to (accelerator .device , dtype = weight_dtype )
484+ vae .to (accelerator .device , dtype = weight_dtype )
477485
478486 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
479487 num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
@@ -509,11 +517,11 @@ def collate_fn(examples):
509517 with accelerator .accumulate (unet ):
510518 # Convert images to latent space
511519 with torch .no_grad ():
512- latents = vae .encode (batch ["pixel_values" ]).latent_dist .sample ()
520+ latents = vae .encode (batch ["pixel_values" ]. to ( dtype = weight_dtype ) ).latent_dist .sample ()
513521 latents = latents * 0.18215
514522
515523 # Sample noise that we'll add to the latents
516- noise = torch .randn (latents . shape ). to ( latents . device )
524+ noise = torch .randn_like (latents )
517525 bsz = latents .shape [0 ]
518526 # Sample a random timestep for each image
519527 timesteps = torch .randint (0 , noise_scheduler .config .num_train_timesteps , (bsz ,), device = latents .device )
@@ -539,12 +547,12 @@ def collate_fn(examples):
539547 loss = F .mse_loss (noise_pred , noise , reduction = "none" ).mean ([1 , 2 , 3 ]).mean ()
540548
541549 # Compute prior loss
542- prior_loss = F .mse_loss (noise_pred_prior , noise_prior , reduction = "none" ). mean ([ 1 , 2 , 3 ]). mean ( )
550+ prior_loss = F .mse_loss (noise_pred_prior . float () , noise_prior . float () , reduction = "mean" )
543551
544552 # Add the prior loss to the instance loss.
545553 loss = loss + args .prior_loss_weight * prior_loss
546554 else :
547- loss = F .mse_loss (noise_pred , noise , reduction = "none" ). mean ([ 1 , 2 , 3 ]). mean ( )
555+ loss = F .mse_loss (noise_pred . float () , noise . float () , reduction = "mean" )
548556
549557 accelerator .backward (loss )
550558 if accelerator .sync_gradients :
0 commit comments