Skip to content

Commit 71efcab

Browse files
committed
Remove autocast
1 parent 209789e commit 71efcab

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

examples/textual_inversion/textual_inversion.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -526,14 +526,14 @@ def main():
526526
for epoch in range(args.num_train_epochs):
527527
text_encoder.train()
528528
for step, batch in enumerate(train_dataloader):
529-
with accelerator.autocast(), accelerator.accumulate(text_encoder):
529+
with accelerator.accumulate(text_encoder):
530530
# Convert images to latent space
531531
with torch.no_grad():
532-
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
532+
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
533533
latents = latents * 0.18215
534534

535535
# Sample noise that we'll add to the latents
536-
noise = torch.randn(latents.shape).to(latents.device)
536+
noise = torch.randn(latents.shape).to(latents.device, dtype=weight_dtype)
537537
bsz = latents.shape[0]
538538
# Sample a random timestep for each image
539539
timesteps = torch.randint(
@@ -542,15 +542,16 @@ def main():
542542

543543
# Add noise to the latents according to the noise magnitude at each timestep
544544
# (this is the forward diffusion process)
545-
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
545+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps).to(dtype=weight_dtype)
546546

547547
# Get the text embedding for conditioning
548-
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
548+
encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
549549

550550
# Predict the noise residual
551551
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
552552

553-
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
553+
# Calculate loss in fp32
554+
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
554555
accelerator.backward(loss)
555556

556557
# Zero out the gradients for all token embeddings except the newly added

0 commit comments

Comments
 (0)