@@ -548,6 +548,9 @@ def main():
548548 progress_bar .set_description ("Steps" )
549549 global_step = 0
550550
551+ # keep original embeddings as reference
552+ orig_embeds_params = text_encoder .get_input_embeddings ().weight .data .clone ()
553+
551554 for epoch in range (args .num_train_epochs ):
552555 text_encoder .train ()
553556 for step , batch in enumerate (train_dataloader ):
@@ -585,20 +588,15 @@ def main():
585588 loss = F .mse_loss (model_pred , target , reduction = "none" ).mean ([1 , 2 , 3 ]).mean ()
586589 accelerator .backward (loss )
587590
588- # Zero out the gradients for all token embeddings except the newly added
589- # embeddings for the concept, as we only want to optimize the concept embeddings
590- if accelerator .num_processes > 1 :
591- grads = text_encoder .module .get_input_embeddings ().weight .grad
592- else :
593- grads = text_encoder .get_input_embeddings ().weight .grad
594- # Get the index for tokens that we want to zero the grads for
595- index_grads_to_zero = torch .arange (len (tokenizer )) != placeholder_token_id
596- grads .data [index_grads_to_zero , :] = grads .data [index_grads_to_zero , :].fill_ (0 )
597-
598591 optimizer .step ()
599592 lr_scheduler .step ()
600593 optimizer .zero_grad ()
601594
595+ # Let's make sure we don't update any embedding weights besides the newly added token
596+ index_no_updates = torch .arange (len (tokenizer )) != placeholder_token_id
597+ with torch .no_grad ():
598+ text_encoder .get_input_embeddings ().weight [index_no_updates ] = orig_embeds_params [index_no_updates ]
599+
602600 # Checks if the accelerator has performed an optimization step behind the scenes
603601 if accelerator .sync_gradients :
604602 progress_bar .update (1 )
0 commit comments