Skip to content

Commit 3a7378a

Browse files
author
dotieuthien
committed
Fix code quality
1 parent 09c06e0 commit 3a7378a

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

scripts/convert_stable_diffusion_controlnet_to_onnx.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,6 @@ def onnx_export(
169169
model,
170170
model_args,
171171
f=output_path.as_posix(),
172-
# export_params=True,
173172
input_names=ordered_input_names,
174173
output_names=output_names,
175174
dynamic_axes=dynamic_axes,
@@ -181,7 +180,6 @@ def onnx_export(
181180
@torch.no_grad()
182181
def convert_models(model_path: str, controlnet_path: list, output_path: str, opset: int, fp16: bool = False):
183182
dtype = torch.float16 if fp16 else torch.float32
184-
dtype = torch.float16
185183
if fp16 and torch.cuda.is_available():
186184
device = "cuda"
187185
elif fp16 and not torch.cuda.is_available():
@@ -231,8 +229,8 @@ def convert_models(model_path: str, controlnet_path: list, output_path: str, ops
231229
controlnets = torch.nn.ModuleList(controlnets)
232230
unet_controlnet = UNet2DConditionControlNetModel(pipeline.unet, controlnets)
233231
unet_in_channels = pipeline.unet.config.in_channels
234-
img_size = 512
235-
unet_sample_size = img_size // 8
232+
unet_sample_size = pipeline.unet.config.sample_size
233+
img_size = 8 * unet_sample_size
236234
unet_path = output_path / "unet" / "model.onnx"
237235
onnx_export(
238236
unet_controlnet,
@@ -256,7 +254,6 @@ def convert_models(model_path: str, controlnet_path: list, output_path: str, ops
256254
"sample": {0: "2B", 2: "H", 3: "W"},
257255
"encoder_hidden_states": {0: "2B"},
258256
"controlnet_conds": {1: "2B", 3: "8H", 4: "8W"},
259-
# "noise_pred": {0: "2B", 2: "H", 3: "W"}
260257
},
261258
opset=opset,
262259
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split

0 commit comments

Comments
 (0)