2929logger = get_logger (__name__ )
3030
3131
32+ def save_progress (text_encoder , placeholder_token_id , accelerator , args ):
33+ logger .info ("Saving embeddings" )
34+ learned_embeds = accelerator .unwrap_model (text_encoder ).get_input_embeddings ().weight [placeholder_token_id ]
35+ learned_embeds_dict = {args .placeholder_token : learned_embeds .detach ().cpu ()}
36+ torch .save (learned_embeds_dict , os .path .join (args .output_dir , "learned_embeds.bin" ))
37+
38+
3239def parse_args ():
3340 parser = argparse .ArgumentParser (description = "Simple example of a training script." )
41+ parser .add_argument (
42+ "--save_steps" ,
43+ type = int ,
44+ default = 500 ,
45+ help = "Save learned_embeds.bin every X updates steps." ,
46+ )
3447 parser .add_argument (
3548 "--pretrained_model_name_or_path" ,
3649 type = str ,
@@ -542,6 +555,8 @@ def main():
542555 if accelerator .sync_gradients :
543556 progress_bar .update (1 )
544557 global_step += 1
558+ if global_step % args .save_steps == 0 :
559+ save_progress (text_encoder , placeholder_token_id , accelerator , args )
545560
546561 logs = {"loss" : loss .detach ().item (), "lr" : lr_scheduler .get_last_lr ()[0 ]}
547562 progress_bar .set_postfix (** logs )
@@ -567,9 +582,7 @@ def main():
567582 )
568583 pipeline .save_pretrained (args .output_dir )
569584 # Also save the newly trained embeddings
570- learned_embeds = accelerator .unwrap_model (text_encoder ).get_input_embeddings ().weight [placeholder_token_id ]
571- learned_embeds_dict = {args .placeholder_token : learned_embeds .detach ().cpu ()}
572- torch .save (learned_embeds_dict , os .path .join (args .output_dir , "learned_embeds.bin" ))
585+ save_progress (text_encoder , placeholder_token_id , accelerator , args )
573586
574587 if args .push_to_hub :
575588 repo .push_to_hub (
0 commit comments