Skip to content

Commit 7f31142

Browse files
Added script to save during textual inversion training. Issue 524 (#645)
* Added script to save during training * Suggested changes
1 parent 765506c commit 7f31142

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

examples/textual_inversion/textual_inversion.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,21 @@
2929
logger = get_logger(__name__)
3030

3131

32+
def save_progress(text_encoder, placeholder_token_id, accelerator, args):
33+
logger.info("Saving embeddings")
34+
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
35+
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
36+
torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin"))
37+
38+
3239
def parse_args():
3340
parser = argparse.ArgumentParser(description="Simple example of a training script.")
41+
parser.add_argument(
42+
"--save_steps",
43+
type=int,
44+
default=500,
45+
help="Save learned_embeds.bin every X updates steps.",
46+
)
3447
parser.add_argument(
3548
"--pretrained_model_name_or_path",
3649
type=str,
@@ -542,6 +555,8 @@ def main():
542555
if accelerator.sync_gradients:
543556
progress_bar.update(1)
544557
global_step += 1
558+
if global_step % args.save_steps == 0:
559+
save_progress(text_encoder, placeholder_token_id, accelerator, args)
545560

546561
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
547562
progress_bar.set_postfix(**logs)
@@ -567,9 +582,7 @@ def main():
567582
)
568583
pipeline.save_pretrained(args.output_dir)
569584
# Also save the newly trained embeddings
570-
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
571-
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
572-
torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin"))
585+
save_progress(text_encoder, placeholder_token_id, accelerator, args)
573586

574587
if args.push_to_hub:
575588
repo.push_to_hub(

0 commit comments

Comments
 (0)