|
58 | 58 | logger = get_logger(__name__) |
59 | 59 |
|
60 | 60 |
|
| 61 | +def save_model_card(repo_name, images=None, base_model=str, prompt=str, repo_folder=None): |
| 62 | + img_str = "" |
| 63 | + for i, image in enumerate(images): |
| 64 | + image.save(os.path.join(repo_folder, f"image_{i}.png")) |
| 65 | + img_str += f"\n" |
| 66 | + |
| 67 | + yaml = f""" |
| 68 | +--- |
| 69 | +license: creativeml-openrail-m |
| 70 | +base_model: {base_model} |
| 71 | +tags: |
| 72 | +- stable-diffusion |
| 73 | +- stable-diffusion-diffusers |
| 74 | +- text-to-image |
| 75 | +- diffusers |
| 76 | +inference: true |
| 77 | +--- |
| 78 | + """ |
| 79 | + model_card = f""" |
| 80 | +# LoRA DreamBooth - {repo_name} |
| 81 | +
|
| 82 | +These are LoRA adaption weights for {repo_name}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n |
| 83 | +{img_str} |
| 84 | +""" |
| 85 | + with open(os.path.join(repo_folder, "README.md"), "w") as f: |
| 86 | + f.write(yaml + model_card) |
| 87 | + |
| 88 | + |
61 | 89 | def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): |
62 | 90 | text_encoder_config = PretrainedConfig.from_pretrained( |
63 | 91 | pretrained_model_name_or_path, |
@@ -913,34 +941,42 @@ def main(args): |
913 | 941 | unet = unet.to(torch.float32) |
914 | 942 | unet.save_attn_procs(args.output_dir) |
915 | 943 |
|
916 | | - if args.push_to_hub: |
917 | | - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) |
| 944 | + # Final inference |
| 945 | + # Load previous pipeline |
| 946 | + pipeline = DiffusionPipeline.from_pretrained( |
| 947 | + args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype |
| 948 | + ) |
| 949 | + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) |
| 950 | + pipeline = pipeline.to(accelerator.device) |
| 951 | + |
| 952 | + # load attention processors |
| 953 | + pipeline.unet.load_attn_procs(args.output_dir) |
| 954 | + |
| 955 | + # run inference |
| 956 | + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) |
| 957 | + prompt = args.num_validation_images * [args.validation_prompt] |
| 958 | + images = pipeline(prompt, num_inference_steps=25, generator=generator).images |
| 959 | + |
| 960 | + for tracker in accelerator.trackers: |
| 961 | + if tracker.name == "wandb": |
| 962 | + tracker.log( |
| 963 | + { |
| 964 | + "test": [ |
| 965 | + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") |
| 966 | + for i, image in enumerate(images) |
| 967 | + ] |
| 968 | + } |
| 969 | + ) |
918 | 970 |
|
919 | | - # Final inference |
920 | | - # Load previous pipeline |
921 | | - pipeline = DiffusionPipeline.from_pretrained( |
922 | | - args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype |
923 | | - ) |
924 | | - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) |
925 | | - pipeline = pipeline.to(accelerator.device) |
926 | | - |
927 | | - # load attention processors |
928 | | - pipeline.unet.load_attn_procs(args.output_dir) |
929 | | - |
930 | | - # run inference |
931 | | - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) |
932 | | - prompt = args.num_validation_images * [args.validation_prompt] |
933 | | - images = pipeline(prompt, num_inference_steps=25, generator=generator).images |
934 | | - |
935 | | - for tracker in accelerator.trackers: |
936 | | - if tracker.name == "wandb": |
937 | | - tracker.log( |
938 | | - { |
939 | | - "test": [ |
940 | | - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) |
941 | | - ] |
942 | | - } |
| 971 | + if args.push_to_hub: |
| 972 | + save_model_card( |
| 973 | + repo_name, |
| 974 | + images=images, |
| 975 | + base_model=args.pretrained_model_name_or_path, |
| 976 | + prompt=args.instance_prompt, |
| 977 | + repo_folder=args.output_dir, |
943 | 978 | ) |
| 979 | + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) |
944 | 980 |
|
945 | 981 | accelerator.end_training() |
946 | 982 |
|
|
0 commit comments