diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 66986e6c8f00..f516fd1b638b 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -58,6 +58,34 @@ logger = get_logger(__name__) +def save_model_card(repo_name, images=None, base_model=str, prompt=str, repo_folder=None): + img_str = "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +tags: +- stable-diffusion +- stable-diffusion-diffusers +- text-to-image +- diffusers +inference: true +--- + """ + model_card = f""" +# LoRA DreamBooth - {repo_name} + +These are LoRA adaption weights for {repo_name}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n +{img_str} +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, @@ -913,34 +941,42 @@ def main(args): unet = unet.to(torch.float32) unet.save_attn_procs(args.output_dir) - if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + # Final inference + # Load previous pipeline + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype + ) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + + # load attention processors + pipeline.unet.load_attn_procs(args.output_dir) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + prompt = args.num_validation_images * [args.validation_prompt] + images = pipeline(prompt, num_inference_steps=25, generator=generator).images + + for tracker in accelerator.trackers: + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) - # Final inference - # Load previous pipeline - pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype - ) - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) - pipeline = pipeline.to(accelerator.device) - - # load attention processors - pipeline.unet.load_attn_procs(args.output_dir) - - # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) - prompt = args.num_validation_images * [args.validation_prompt] - images = pipeline(prompt, num_inference_steps=25, generator=generator).images - - for tracker in accelerator.trackers: - if tracker.name == "wandb": - tracker.log( - { - "test": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) - ] - } + if args.push_to_hub: + save_model_card( + repo_name, + images=images, + base_model=args.pretrained_model_name_or_path, + prompt=args.instance_prompt, + repo_folder=args.output_dir, ) + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) accelerator.end_training()