Skip to content

Commit 69de9b2

Browse files
[Textual Inversion] Do not update other embeddings (#1665)
1 parent 3ce6380 commit 69de9b2

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

examples/textual_inversion/textual_inversion.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,9 @@ def main():
548548
progress_bar.set_description("Steps")
549549
global_step = 0
550550

551+
# keep original embeddings as reference
552+
orig_embeds_params = text_encoder.get_input_embeddings().weight.data.clone()
553+
551554
for epoch in range(args.num_train_epochs):
552555
text_encoder.train()
553556
for step, batch in enumerate(train_dataloader):
@@ -585,20 +588,15 @@ def main():
585588
loss = F.mse_loss(model_pred, target, reduction="none").mean([1, 2, 3]).mean()
586589
accelerator.backward(loss)
587590

588-
# Zero out the gradients for all token embeddings except the newly added
589-
# embeddings for the concept, as we only want to optimize the concept embeddings
590-
if accelerator.num_processes > 1:
591-
grads = text_encoder.module.get_input_embeddings().weight.grad
592-
else:
593-
grads = text_encoder.get_input_embeddings().weight.grad
594-
# Get the index for tokens that we want to zero the grads for
595-
index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
596-
grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
597-
598591
optimizer.step()
599592
lr_scheduler.step()
600593
optimizer.zero_grad()
601594

595+
# Let's make sure we don't update any embedding weights besides the newly added token
596+
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
597+
with torch.no_grad():
598+
text_encoder.get_input_embeddings().weight[index_no_updates] = orig_embeds_params[index_no_updates]
599+
602600
# Checks if the accelerator has performed an optimization step behind the scenes
603601
if accelerator.sync_gradients:
604602
progress_bar.update(1)

0 commit comments

Comments
 (0)