From 78da8b12b9074686a4a6dfe44e7494cbabcede87 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Mon, 26 Sep 2022 17:07:47 -0400 Subject: [PATCH 1/2] Added script to save during training --- .../textual_inversion/textual_inversion.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index de5761646a00..4f4607674840 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -29,8 +29,21 @@ logger = get_logger(__name__) +def save_progress(text_encoder, placeholder_token_id, accelerator, args): + print("Saving embeddings") + learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id] + learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()} + torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin")) + + def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--save_interval", + type=int, + default=500, + help="Interval to save learned_embeds.bin", + ) parser.add_argument( "--pretrained_model_name_or_path", type=str, @@ -539,6 +552,8 @@ def main(): if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 + if global_step % args.save_interval == 0: + save_progress(text_encoder, placeholder_token_id, accelerator, args) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -564,9 +579,7 @@ def main(): ) pipeline.save_pretrained(args.output_dir) # Also save the newly trained embeddings - learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id] - learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()} - torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin")) + save_progress(text_encoder, placeholder_token_id, accelerator, args) if args.push_to_hub: repo.push_to_hub( From 0a1e379d65878449999272491c6b47d02283849b Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Wed, 28 Sep 2022 09:41:18 -0400 Subject: [PATCH 2/2] Suggested changes --- examples/textual_inversion/textual_inversion.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 4f4607674840..936245bd5849 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -30,7 +30,7 @@ def save_progress(text_encoder, placeholder_token_id, accelerator, args): - print("Saving embeddings") + logger.info("Saving embeddings") learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id] learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()} torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin")) @@ -39,10 +39,10 @@ def save_progress(text_encoder, placeholder_token_id, accelerator, args): def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( - "--save_interval", + "--save_steps", type=int, default=500, - help="Interval to save learned_embeds.bin", + help="Save learned_embeds.bin every X updates steps.", ) parser.add_argument( "--pretrained_model_name_or_path", @@ -552,7 +552,7 @@ def main(): if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 - if global_step % args.save_interval == 0: + if global_step % args.save_steps == 0: save_progress(text_encoder, placeholder_token_id, accelerator, args) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}