|
| 1 | +import os |
| 2 | +import sys |
| 3 | +import glob |
| 4 | +from pathlib import Path |
| 5 | +import gradio as gr |
| 6 | +from PIL import Image |
| 7 | +from apps.stable_diffusion.scripts import inpaint_inf |
| 8 | +from apps.stable_diffusion.src import args |
| 9 | +from apps.stable_diffusion.web.ui.utils import ( |
| 10 | + available_devices, |
| 11 | + nodlogo_loc, |
| 12 | +) |
| 13 | + |
| 14 | + |
| 15 | +with gr.Blocks(title="Inpainting") as inpaint_web: |
| 16 | + with gr.Row(elem_id="ui_title"): |
| 17 | + nod_logo = Image.open(nodlogo_loc) |
| 18 | + with gr.Row(): |
| 19 | + with gr.Column(scale=1, elem_id="demo_title_outer"): |
| 20 | + gr.Image( |
| 21 | + value=nod_logo, |
| 22 | + show_label=False, |
| 23 | + interactive=False, |
| 24 | + elem_id="top_logo", |
| 25 | + ).style(width=150, height=50) |
| 26 | + with gr.Row(elem_id="ui_body"): |
| 27 | + with gr.Row(): |
| 28 | + with gr.Column(scale=1, min_width=600): |
| 29 | + with gr.Row(): |
| 30 | + ckpt_path = ( |
| 31 | + Path(args.ckpt_dir) |
| 32 | + if args.ckpt_dir |
| 33 | + else Path(Path.cwd(), "models") |
| 34 | + ) |
| 35 | + ckpt_path.mkdir(parents=True, exist_ok=True) |
| 36 | + types = ( |
| 37 | + "*.ckpt", |
| 38 | + "*.safetensors", |
| 39 | + ) # the tuple of file types |
| 40 | + ckpt_files = ["None"] |
| 41 | + for extn in types: |
| 42 | + files = glob.glob(os.path.join(ckpt_path, extn)) |
| 43 | + ckpt_files.extend(files) |
| 44 | + custom_model = gr.Dropdown( |
| 45 | + label=f"Models (Custom Model path: {ckpt_path})", |
| 46 | + value=args.ckpt_loc if args.ckpt_loc else "None", |
| 47 | + choices=ckpt_files |
| 48 | + + [ |
| 49 | + "runwayml/stable-diffusion-inpainting", |
| 50 | + "stabilityai/stable-diffusion-2-inpainting", |
| 51 | + ], |
| 52 | + ) |
| 53 | + hf_model_id = gr.Textbox( |
| 54 | + placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3", |
| 55 | + value="", |
| 56 | + label="HuggingFace Model ID", |
| 57 | + lines=3, |
| 58 | + ) |
| 59 | + |
| 60 | + with gr.Group(elem_id="prompt_box_outer"): |
| 61 | + prompt = gr.Textbox( |
| 62 | + label="Prompt", |
| 63 | + value=args.prompts[0], |
| 64 | + lines=1, |
| 65 | + elem_id="prompt_box", |
| 66 | + ) |
| 67 | + negative_prompt = gr.Textbox( |
| 68 | + label="Negative Prompt", |
| 69 | + value=args.negative_prompts[0], |
| 70 | + lines=1, |
| 71 | + elem_id="negative_prompt_box", |
| 72 | + ) |
| 73 | + |
| 74 | + init_image = gr.Image( |
| 75 | + label="Masked Image", |
| 76 | + source="upload", |
| 77 | + tool="sketch", |
| 78 | + type="filepath", |
| 79 | + ) |
| 80 | + |
| 81 | + with gr.Accordion(label="Advanced Options", open=False): |
| 82 | + with gr.Row(): |
| 83 | + scheduler = gr.Dropdown( |
| 84 | + label="Scheduler", |
| 85 | + value="PNDM", |
| 86 | + choices=[ |
| 87 | + "DDIM", |
| 88 | + "PNDM", |
| 89 | + "DPMSolverMultistep", |
| 90 | + "EulerAncestralDiscrete", |
| 91 | + ], |
| 92 | + ) |
| 93 | + with gr.Group(): |
| 94 | + save_metadata_to_png = gr.Checkbox( |
| 95 | + label="Save prompt information to PNG", |
| 96 | + value=args.write_metadata_to_png, |
| 97 | + interactive=True, |
| 98 | + ) |
| 99 | + save_metadata_to_json = gr.Checkbox( |
| 100 | + label="Save prompt information to JSON file", |
| 101 | + value=args.save_metadata_to_json, |
| 102 | + interactive=True, |
| 103 | + ) |
| 104 | + with gr.Row(): |
| 105 | + height = gr.Slider( |
| 106 | + 384, 786, value=args.height, step=8, label="Height" |
| 107 | + ) |
| 108 | + width = gr.Slider( |
| 109 | + 384, 786, value=args.width, step=8, label="Width" |
| 110 | + ) |
| 111 | + precision = gr.Radio( |
| 112 | + label="Precision", |
| 113 | + value=args.precision, |
| 114 | + choices=[ |
| 115 | + "fp16", |
| 116 | + "fp32", |
| 117 | + ], |
| 118 | + visible=False, |
| 119 | + ) |
| 120 | + max_length = gr.Radio( |
| 121 | + label="Max Length", |
| 122 | + value=args.max_length, |
| 123 | + choices=[ |
| 124 | + 64, |
| 125 | + 77, |
| 126 | + ], |
| 127 | + visible=False, |
| 128 | + ) |
| 129 | + with gr.Row(): |
| 130 | + steps = gr.Slider( |
| 131 | + 1, 100, value=args.steps, step=1, label="Steps" |
| 132 | + ) |
| 133 | + with gr.Row(): |
| 134 | + guidance_scale = gr.Slider( |
| 135 | + 0, |
| 136 | + 50, |
| 137 | + value=args.guidance_scale, |
| 138 | + step=0.1, |
| 139 | + label="CFG Scale", |
| 140 | + ) |
| 141 | + batch_count = gr.Slider( |
| 142 | + 1, |
| 143 | + 100, |
| 144 | + value=args.batch_count, |
| 145 | + step=1, |
| 146 | + label="Batch Count", |
| 147 | + interactive=True, |
| 148 | + ) |
| 149 | + batch_size = gr.Slider( |
| 150 | + 1, |
| 151 | + 4, |
| 152 | + value=args.batch_size, |
| 153 | + step=1, |
| 154 | + label="Batch Size", |
| 155 | + interactive=False, |
| 156 | + visible=False, |
| 157 | + ) |
| 158 | + with gr.Row(): |
| 159 | + seed = gr.Number( |
| 160 | + value=args.seed, precision=0, label="Seed" |
| 161 | + ) |
| 162 | + device = gr.Dropdown( |
| 163 | + label="Device", |
| 164 | + value=available_devices[0], |
| 165 | + choices=available_devices, |
| 166 | + ) |
| 167 | + with gr.Row(): |
| 168 | + random_seed = gr.Button("Randomize Seed") |
| 169 | + random_seed.click( |
| 170 | + None, |
| 171 | + inputs=[], |
| 172 | + outputs=[seed], |
| 173 | + _js="() => Math.floor(Math.random() * 4294967295)", |
| 174 | + ) |
| 175 | + stable_diffusion = gr.Button("Generate Image(s)") |
| 176 | + |
| 177 | + with gr.Column(scale=1, min_width=600): |
| 178 | + with gr.Group(): |
| 179 | + gallery = gr.Gallery( |
| 180 | + label="Generated images", |
| 181 | + show_label=False, |
| 182 | + elem_id="gallery", |
| 183 | + ).style(grid=[2]) |
| 184 | + std_output = gr.Textbox( |
| 185 | + value="Nothing to show.", |
| 186 | + lines=1, |
| 187 | + show_label=False, |
| 188 | + ) |
| 189 | + output_dir = args.output_dir if args.output_dir else Path.cwd() |
| 190 | + output_dir = Path(output_dir, "generated_imgs") |
| 191 | + output_loc = gr.Textbox( |
| 192 | + label="Saving Images at", |
| 193 | + value=output_dir, |
| 194 | + interactive=False, |
| 195 | + ) |
| 196 | + kwargs = dict( |
| 197 | + fn=inpaint_inf, |
| 198 | + inputs=[ |
| 199 | + prompt, |
| 200 | + negative_prompt, |
| 201 | + init_image, |
| 202 | + height, |
| 203 | + width, |
| 204 | + steps, |
| 205 | + guidance_scale, |
| 206 | + seed, |
| 207 | + batch_count, |
| 208 | + batch_size, |
| 209 | + scheduler, |
| 210 | + custom_model, |
| 211 | + hf_model_id, |
| 212 | + precision, |
| 213 | + device, |
| 214 | + max_length, |
| 215 | + save_metadata_to_json, |
| 216 | + save_metadata_to_png, |
| 217 | + ], |
| 218 | + outputs=[gallery, std_output], |
| 219 | + show_progress=args.progress_bar, |
| 220 | + ) |
| 221 | + |
| 222 | + prompt.submit(**kwargs) |
| 223 | + negative_prompt.submit(**kwargs) |
| 224 | + stable_diffusion.click(**kwargs) |
0 commit comments