@@ -526,14 +526,14 @@ def main():
526526 for epoch in range (args .num_train_epochs ):
527527 text_encoder .train ()
528528 for step , batch in enumerate (train_dataloader ):
529- with accelerator .autocast (), accelerator . accumulate (text_encoder ):
529+ with accelerator .accumulate (text_encoder ):
530530 # Convert images to latent space
531531 with torch .no_grad ():
532- latents = vae .encode (batch ["pixel_values" ]) .latent_dist .sample (). detach ()
532+ latents = vae .encode (batch ["pixel_values" ]. to ( dtype = weight_dtype )) .latent_dist .sample ()
533533 latents = latents * 0.18215
534534
535535 # Sample noise that we'll add to the latents
536- noise = torch .randn (latents .shape ).to (latents .device )
536+ noise = torch .randn (latents .shape ).to (latents .device , dtype = weight_dtype )
537537 bsz = latents .shape [0 ]
538538 # Sample a random timestep for each image
539539 timesteps = torch .randint (
@@ -542,15 +542,16 @@ def main():
542542
543543 # Add noise to the latents according to the noise magnitude at each timestep
544544 # (this is the forward diffusion process)
545- noisy_latents = noise_scheduler .add_noise (latents , noise , timesteps )
545+ noisy_latents = noise_scheduler .add_noise (latents , noise , timesteps ). to ( dtype = weight_dtype )
546546
547547 # Get the text embedding for conditioning
548- encoder_hidden_states = text_encoder (batch ["input_ids" ])[0 ]
548+ encoder_hidden_states = text_encoder (batch ["input_ids" ])[0 ]. to ( dtype = weight_dtype )
549549
550550 # Predict the noise residual
551551 noise_pred = unet (noisy_latents , timesteps , encoder_hidden_states ).sample
552552
553- loss = F .mse_loss (noise_pred , noise , reduction = "none" ).mean ([1 , 2 , 3 ]).mean ()
553+ # Calculate loss in fp32
554+ loss = F .mse_loss (noise_pred .float (), noise .float (), reduction = "mean" )
554555 accelerator .backward (loss )
555556
556557 # Zero out the gradients for all token embeddings except the newly added
0 commit comments