Skip to content

Commit 8cafe56

Browse files
authored
Added flags for metadata information. (huggingface#894)
1 parent 3eceeb7 commit 8cafe56

File tree

7 files changed

+133
-9
lines changed

7 files changed

+133
-9
lines changed

apps/stable_diffusion/scripts/txt2img.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
os.environ["AMD_ENABLE_LLPC"] = "1"
44

5+
import json
56
import torch
67
import re
78
import time
89
from pathlib import Path
10+
from PIL import PngImagePlugin
911
from datetime import datetime as dt
1012
from dataclasses import dataclass
1113
from csv import DictWriter
@@ -61,7 +63,29 @@ def save_output_img(output_img):
6163
f"{prompt_slice}_{args.seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
6264
)
6365
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
64-
output_img.save(out_img_path, quality=95, subsampling=0)
66+
67+
if args.output_img_format == "jpg":
68+
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
69+
output_img.save(out_img_path, quality=95, subsampling=0)
70+
else:
71+
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
72+
pngInfo = PngImagePlugin.PngInfo()
73+
74+
if args.write_metadata_to_png:
75+
pngInfo.add_text(
76+
"parameters",
77+
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {args.seed}, Size: {args.width}x{args.height}, Model: {args.hf_model_id}",
78+
)
79+
80+
output_img.save(
81+
output_path / f"{out_img_name}.png", "PNG", pnginfo=pngInfo
82+
)
83+
84+
if args.output_img_format not in ["png", "jpg"]:
85+
print(
86+
f"[ERROR] Format {args.output_img_format} is not supported yet."
87+
"Image saved as png instead. Supported formats: png / jpg"
88+
)
6589

6690
new_entry = {
6791
"VARIANT": args.hf_model_id,
@@ -83,6 +107,11 @@ def save_output_img(output_img):
83107
dictwriter_obj.writerow(new_entry)
84108
csv_obj.close()
85109

110+
if args.save_metadata_to_json:
111+
del new_entry["OUTPUT"]
112+
with open(f"{output_path}/{out_img_name}.json", "w") as f:
113+
json.dump(new_entry, f, indent=4)
114+
86115

87116
txt2img_obj = None
88117
config_obj = None
@@ -106,6 +135,8 @@ def txt2img_inf(
106135
precision: str,
107136
device: str,
108137
max_length: int,
138+
save_metadata_to_json: bool,
139+
save_metadata_to_png: bool,
109140
):
110141
global txt2img_obj
111142
global config_obj
@@ -119,6 +150,8 @@ def txt2img_inf(
119150
args.scheduler = scheduler
120151
args.hf_model_id = custom_model_id if custom_model_id else model_id
121152
args.ckpt_loc = ckpt_file_obj.name if ckpt_file_obj else ""
153+
args.save_metadata_to_json = save_metadata_to_json
154+
args.write_metadata_to_png = save_metadata_to_png
122155
dtype = torch.float32 if precision == "fp32" else torch.half
123156
cpu_scheduling = not scheduler.startswith("Shark")
124157
new_config_obj = Config(

apps/stable_diffusion/src/utils/stable_args.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,20 @@ def path_expand(s):
270270
help="flag to clear all mlir and vmfb from common locations. Recompiling will take several minutes",
271271
)
272272

273+
p.add_argument(
274+
"--save_metadata_to_json",
275+
default=False,
276+
action=argparse.BooleanOptionalAction,
277+
help="flag for whether or not to save a generation information json file with the image.",
278+
)
279+
280+
p.add_argument(
281+
"--write_metadata_to_png",
282+
default=False,
283+
action=argparse.BooleanOptionalAction,
284+
help="flag for whether or not to save generation information in PNG chunk text to generated images.",
285+
)
286+
273287
##############################################################################
274288
### Web UI flags
275289
##############################################################################

apps/stable_diffusion/web/index.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,17 @@ def resource_path(relative_path):
148148
step=0.1,
149149
label="CFG Scale",
150150
)
151+
with gr.Row():
152+
save_metadata_to_png = gr.Checkbox(
153+
label="Save prompt information to PNG",
154+
value=False,
155+
interactive=True,
156+
)
157+
save_metadata_to_json = gr.Checkbox(
158+
label="Save prompt information to JSON file",
159+
value=False,
160+
interactive=True,
161+
)
151162
with gr.Row():
152163
seed = gr.Number(value=-1, precision=0, label="Seed")
153164
available_devices = get_available_devices()
@@ -211,6 +222,8 @@ def resource_path(relative_path):
211222
precision,
212223
device,
213224
max_length,
225+
save_metadata_to_json,
226+
save_metadata_to_png,
214227
],
215228
outputs=[gallery, std_output],
216229
show_progress=args.progress_bar,
@@ -233,6 +246,8 @@ def resource_path(relative_path):
233246
precision,
234247
device,
235248
max_length,
249+
save_metadata_to_json,
250+
save_metadata_to_png,
236251
],
237252
outputs=[gallery, std_output],
238253
show_progress=args.progress_bar,

shark/examples/shark_inference/stable_diffusion/main.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from transformers import CLIPTextModel, CLIPTokenizer
1111
import torch
12-
from PIL import Image
12+
from PIL import Image, PngImagePlugin
1313
from diffusers import (
1414
LMSDiscreteScheduler,
1515
PNDMScheduler,
@@ -329,11 +329,27 @@ def end_profiling(device):
329329
progressive=True,
330330
)
331331
else:
332-
pil_images[i].save(output_path / f"{img_name}.png", "PNG")
332+
pngInfo = PngImagePlugin.PngInfo()
333+
334+
if args.write_metadata_to_png:
335+
model_name = ""
336+
if args.ckpt_loc:
337+
model_name = Path(args.ckpt_loc).name
338+
else:
339+
model_name = json_store["hf_model_id"]
340+
pngInfo.add_text(
341+
"parameters",
342+
f"{json_store['prompt']}\nNegative prompt: {json_store['negative prompt']}\nSteps:{json_store['steps']}, Sampler: {json_store['scheduler']}, CFG scale: {json_store['guidance_scale']}, Seed: {json_store['seed']}, Size: {args.width}x{args.height}, Model: {model_name}",
343+
)
344+
345+
pil_images[i].save(
346+
output_path / f"{img_name}.png", "PNG", pnginfo=pngInfo
347+
)
333348
if args.output_img_format not in ["png", "jpg"]:
334349
print(
335350
f"[ERROR] Format {args.output_img_format} is not supported yet."
336-
"saving image as png. Supported formats png / jpg"
351+
"Image saved as png instead. Supported formats: png / jpg"
337352
)
338-
with open(output_path / f"{img_name}.json", "w") as f:
339-
f.write(json.dumps(json_store, indent=4))
353+
if args.save_metadata_to_json:
354+
with open(output_path / f"{img_name}.json", "w") as f:
355+
f.write(json.dumps(json_store, indent=4))

shark/examples/shark_inference/stable_diffusion/stable_args.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,20 @@ def is_valid_file(arg):
283283
help="flag to clear all mlir and vmfb from common locations. Recompiling will take several minutes",
284284
)
285285

286+
p.add_argument(
287+
"--save_metadata_to_json",
288+
default=True,
289+
action=argparse.BooleanOptionalAction,
290+
help="flag for whether or not to save a generation information json file with the image.",
291+
)
292+
293+
p.add_argument(
294+
"--write_metadata_to_png",
295+
default=False,
296+
action=argparse.BooleanOptionalAction,
297+
help="flag for whether or not to save generation information in PNG chunk text to generated images.",
298+
)
299+
286300
##############################################################################
287301
### Web UI flags
288302
##############################################################################

web/models/stable_diffusion/main.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import torch
22
import os
3-
from PIL import Image
3+
from PIL import Image, PngImagePlugin
44
from tqdm.auto import tqdm
55
from models.stable_diffusion.cache_objects import model_cache
66
from models.stable_diffusion.stable_args import args
77
from models.stable_diffusion.utils import disk_space_check
88
from random import randint
9+
import json
910
import numpy as np
1011
import time
1112
import sys
@@ -92,11 +93,22 @@ def save_output_img(output_img):
9293
)
9394
else:
9495
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
95-
output_img.save(out_img_path, "PNG")
96+
pngInfo = PngImagePlugin.PngInfo()
97+
98+
if args.write_metadata_to_png:
99+
pngInfo.add_text(
100+
"parameters",
101+
f"{args.prompts}\nNegative prompt: {args.negative_prompts}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {args.seed}, Size: {args.width}x{args.height}, Model: {args.variant}",
102+
)
103+
104+
output_img.save(
105+
output_path / f"{out_img_name}.png", "PNG", pnginfo=pngInfo
106+
)
107+
96108
if args.output_img_format not in ["png", "jpg"]:
97109
print(
98110
f"[ERROR] Format {args.output_img_format} is not supported yet."
99-
"saving image as png. Supported formats png / jpg"
111+
"Image saved as png instead. Supported formats: png / jpg"
100112
)
101113

102114
new_entry = {
@@ -117,6 +129,11 @@ def save_output_img(output_img):
117129
dictwriter_obj.writerow(new_entry)
118130
csv_obj.close()
119131

132+
if args.save_metadata_to_json:
133+
del new_entry["OUTPUT"]
134+
with open(f"{output_path}/{out_img_name}.json", "w") as f:
135+
json.dump(new_entry, f, indent=4)
136+
120137

121138
def stable_diff_inf(
122139
prompt: str,
@@ -209,6 +226,7 @@ def stable_diff_inf(
209226

210227
avg_ms = 0
211228
for i, t in tqdm(enumerate(scheduler.timesteps)):
229+
212230
step_start = time.time()
213231
timestep = torch.tensor([t]).to(dtype).detach().numpy()
214232
latent_model_input = scheduler.scale_model_input(latents, t)

web/models/stable_diffusion/stable_args.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,20 @@
226226
help="flag to clear all mlir and vmfb from common locations. Recompiling will take several minutes",
227227
)
228228

229+
p.add_argument(
230+
"--save_metadata_to_json",
231+
default=False,
232+
action=argparse.BooleanOptionalAction,
233+
help="flag for whether or not to save a generation information json file with the image.",
234+
)
235+
236+
p.add_argument(
237+
"--write_metadata_to_png",
238+
default=False,
239+
action=argparse.BooleanOptionalAction,
240+
help="flag for whether or not to save generation information in PNG chunk text to generated images.",
241+
)
242+
229243
##############################################################################
230244
### Web UI flags
231245
##############################################################################

0 commit comments

Comments
 (0)