@@ -169,7 +169,6 @@ def onnx_export(
169169 model ,
170170 model_args ,
171171 f = output_path .as_posix (),
172- # export_params=True,
173172 input_names = ordered_input_names ,
174173 output_names = output_names ,
175174 dynamic_axes = dynamic_axes ,
@@ -181,7 +180,6 @@ def onnx_export(
181180@torch .no_grad ()
182181def convert_models (model_path : str , controlnet_path : list , output_path : str , opset : int , fp16 : bool = False ):
183182 dtype = torch .float16 if fp16 else torch .float32
184- dtype = torch .float16
185183 if fp16 and torch .cuda .is_available ():
186184 device = "cuda"
187185 elif fp16 and not torch .cuda .is_available ():
@@ -231,8 +229,8 @@ def convert_models(model_path: str, controlnet_path: list, output_path: str, ops
231229 controlnets = torch .nn .ModuleList (controlnets )
232230 unet_controlnet = UNet2DConditionControlNetModel (pipeline .unet , controlnets )
233231 unet_in_channels = pipeline .unet .config .in_channels
234- img_size = 512
235- unet_sample_size = img_size // 8
232+ unet_sample_size = pipeline . unet . config . sample_size
233+ img_size = 8 * unet_sample_size
236234 unet_path = output_path / "unet" / "model.onnx"
237235 onnx_export (
238236 unet_controlnet ,
@@ -256,7 +254,6 @@ def convert_models(model_path: str, controlnet_path: list, output_path: str, ops
256254 "sample" : {0 : "2B" , 2 : "H" , 3 : "W" },
257255 "encoder_hidden_states" : {0 : "2B" },
258256 "controlnet_conds" : {1 : "2B" , 3 : "8H" , 4 : "8W" },
259- # "noise_pred": {0: "2B", 2: "H", 3: "W"}
260257 },
261258 opset = opset ,
262259 use_external_data_format = True , # UNet is > 2GB, so the weights need to be split
0 commit comments