Skip to content

Commit 1d38d49

Browse files
authored
Fix iree flags due to the change in shark-runtime (huggingface#944)
1 parent a783c08 commit 1d38d49

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

shark/examples/shark_inference/upscaler/opt_params.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,15 @@
1313

1414

1515
unet_flag = [
16-
"--iree-flow-enable-padding-linalg-ops",
17-
"--iree-flow-linalg-ops-padding-size=32",
18-
"--iree-flow-enable-conv-img2col-transform",
16+
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
1917
]
2018

2119
vae_flag = [
22-
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
23-
"--iree-flow-enable-padding-linalg-ops",
24-
"--iree-flow-linalg-ops-padding-size=16",
20+
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-convert-conv-nchw-to-nhwc,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
2521
]
2622

2723
clip_flag = [
28-
"--iree-flow-linalg-ops-padding-size=16",
29-
"--iree-flow-enable-padding-linalg-ops",
24+
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
3025
]
3126

3227
bucket = "gs://shark_tank/stable_diffusion/"

shark/iree_utils/compile_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,17 @@ def get_iree_common_args():
8080
def get_model_specific_args():
8181
ms_args = []
8282
if shark_args.enable_conv_transform == True:
83-
ms_args += ["--iree-flow-enable-conv-nchw-to-nhwc-transform"]
83+
ms_args += [
84+
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-convert-conv-nchw-to-nhwc))"
85+
]
8486
if shark_args.enable_img2col_transform == True:
85-
ms_args += ["--iree-flow-enable-conv-img2col-transform"]
87+
ms_args += [
88+
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-convert-conv2d-to-img2col))"
89+
]
8690
if shark_args.use_winograd == True:
87-
ms_args += ["--iree-flow-enable-conv-winograd-transform"]
91+
ms_args += [
92+
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-linalg-ext-convert-conv2d-to-winograd))"
93+
]
8894
return ms_args
8995

9096

0 commit comments

Comments
 (0)