22
33os .environ ["AMD_ENABLE_LLPC" ] = "1"
44
5+ import json
56import torch
67import re
78import time
89from pathlib import Path
10+ from PIL import PngImagePlugin
911from datetime import datetime as dt
1012from dataclasses import dataclass
1113from csv import DictWriter
@@ -61,7 +63,29 @@ def save_output_img(output_img):
6163 f"{ prompt_slice } _{ args .seed } _{ dt .now ().strftime ('%y%m%d_%H%M%S' )} "
6264 )
6365 out_img_path = Path (generated_imgs_path , f"{ out_img_name } .jpg" )
64- output_img .save (out_img_path , quality = 95 , subsampling = 0 )
66+
67+ if args .output_img_format == "jpg" :
68+ out_img_path = Path (generated_imgs_path , f"{ out_img_name } .jpg" )
69+ output_img .save (out_img_path , quality = 95 , subsampling = 0 )
70+ else :
71+ out_img_path = Path (generated_imgs_path , f"{ out_img_name } .png" )
72+ pngInfo = PngImagePlugin .PngInfo ()
73+
74+ if args .write_metadata_to_png :
75+ pngInfo .add_text (
76+ "parameters" ,
77+ f"{ args .prompts [0 ]} \n Negative prompt: { args .negative_prompts [0 ]} \n Steps:{ args .steps } , Sampler: { args .scheduler } , CFG scale: { args .guidance_scale } , Seed: { args .seed } , Size: { args .width } x{ args .height } , Model: { args .hf_model_id } " ,
78+ )
79+
80+ output_img .save (
81+ output_path / f"{ out_img_name } .png" , "PNG" , pnginfo = pngInfo
82+ )
83+
84+ if args .output_img_format not in ["png" , "jpg" ]:
85+ print (
86+ f"[ERROR] Format { args .output_img_format } is not supported yet."
87+ "Image saved as png instead. Supported formats: png / jpg"
88+ )
6589
6690 new_entry = {
6791 "VARIANT" : args .hf_model_id ,
@@ -83,6 +107,11 @@ def save_output_img(output_img):
83107 dictwriter_obj .writerow (new_entry )
84108 csv_obj .close ()
85109
110+ if args .save_metadata_to_json :
111+ del new_entry ["OUTPUT" ]
112+ with open (f"{ output_path } /{ out_img_name } .json" , "w" ) as f :
113+ json .dump (new_entry , f , indent = 4 )
114+
86115
87116txt2img_obj = None
88117config_obj = None
@@ -106,6 +135,8 @@ def txt2img_inf(
106135 precision : str ,
107136 device : str ,
108137 max_length : int ,
138+ save_metadata_to_json : bool ,
139+ save_metadata_to_png : bool ,
109140):
110141 global txt2img_obj
111142 global config_obj
@@ -119,6 +150,8 @@ def txt2img_inf(
119150 args .scheduler = scheduler
120151 args .hf_model_id = custom_model_id if custom_model_id else model_id
121152 args .ckpt_loc = ckpt_file_obj .name if ckpt_file_obj else ""
153+ args .save_metadata_to_json = save_metadata_to_json
154+ args .write_metadata_to_png = save_metadata_to_png
122155 dtype = torch .float32 if precision == "fp32" else torch .half
123156 cpu_scheduling = not scheduler .startswith ("Shark" )
124157 new_config_obj = Config (
0 commit comments