Skip to content

Commit 007c914

Browse files
[Lora] Model card (#2032)
* [Lora] up lora training * finish * finish * finish model card
1 parent 3c07840 commit 007c914

File tree

1 file changed

+62
-26
lines changed

1 file changed

+62
-26
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 62 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,34 @@
5858
logger = get_logger(__name__)
5959

6060

61+
def save_model_card(repo_name, images=None, base_model=str, prompt=str, repo_folder=None):
62+
img_str = ""
63+
for i, image in enumerate(images):
64+
image.save(os.path.join(repo_folder, f"image_{i}.png"))
65+
img_str += f"![img_{i}](./image_{i}.png)\n"
66+
67+
yaml = f"""
68+
---
69+
license: creativeml-openrail-m
70+
base_model: {base_model}
71+
tags:
72+
- stable-diffusion
73+
- stable-diffusion-diffusers
74+
- text-to-image
75+
- diffusers
76+
inference: true
77+
---
78+
"""
79+
model_card = f"""
80+
# LoRA DreamBooth - {repo_name}
81+
82+
These are LoRA adaption weights for {repo_name}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
83+
{img_str}
84+
"""
85+
with open(os.path.join(repo_folder, "README.md"), "w") as f:
86+
f.write(yaml + model_card)
87+
88+
6189
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
6290
text_encoder_config = PretrainedConfig.from_pretrained(
6391
pretrained_model_name_or_path,
@@ -913,34 +941,42 @@ def main(args):
913941
unet = unet.to(torch.float32)
914942
unet.save_attn_procs(args.output_dir)
915943

916-
if args.push_to_hub:
917-
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
944+
# Final inference
945+
# Load previous pipeline
946+
pipeline = DiffusionPipeline.from_pretrained(
947+
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
948+
)
949+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
950+
pipeline = pipeline.to(accelerator.device)
951+
952+
# load attention processors
953+
pipeline.unet.load_attn_procs(args.output_dir)
954+
955+
# run inference
956+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
957+
prompt = args.num_validation_images * [args.validation_prompt]
958+
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
959+
960+
for tracker in accelerator.trackers:
961+
if tracker.name == "wandb":
962+
tracker.log(
963+
{
964+
"test": [
965+
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
966+
for i, image in enumerate(images)
967+
]
968+
}
969+
)
918970

919-
# Final inference
920-
# Load previous pipeline
921-
pipeline = DiffusionPipeline.from_pretrained(
922-
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
923-
)
924-
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
925-
pipeline = pipeline.to(accelerator.device)
926-
927-
# load attention processors
928-
pipeline.unet.load_attn_procs(args.output_dir)
929-
930-
# run inference
931-
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
932-
prompt = args.num_validation_images * [args.validation_prompt]
933-
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
934-
935-
for tracker in accelerator.trackers:
936-
if tracker.name == "wandb":
937-
tracker.log(
938-
{
939-
"test": [
940-
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
941-
]
942-
}
971+
if args.push_to_hub:
972+
save_model_card(
973+
repo_name,
974+
images=images,
975+
base_model=args.pretrained_model_name_or_path,
976+
prompt=args.instance_prompt,
977+
repo_folder=args.output_dir,
943978
)
979+
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
944980

945981
accelerator.end_training()
946982

0 commit comments

Comments
 (0)