Skip to content

Commit 47c3773

Browse files
patil-surajPrathik Rao
authored andcommitted
[dreambooth] fix applying clip_grad_norm_ (huggingface#686)
fix applying clip grad norm
1 parent cf83856 commit 47c3773

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,8 @@ def collate_fn(examples):
566566
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
567567

568568
accelerator.backward(loss)
569-
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
569+
if accelerator.sync_gradients:
570+
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
570571
optimizer.step()
571572
lr_scheduler.step()
572573
optimizer.zero_grad()

0 commit comments

Comments
 (0)