Skip to content

Commit e9864cb

Browse files
authored
Modify the annotation OTF to return bytecode module (huggingface#980)
1 parent 83c69ec commit e9864cb

File tree

2 files changed

+25
-25
lines changed

2 files changed

+25
-25
lines changed

apps/stable_diffusion/src/utils/sd_annotation.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import io
23
from shark.model_annotation import model_annotation, create_context
34
from shark.iree_utils._common import iree_target_map, run_cmd
45
from shark.shark_downloader import (
@@ -97,10 +98,15 @@ def annotate_with_winograd(input_mlir, winograd_config_dir, model_name):
9798
search_op="conv",
9899
winograd=True,
99100
)
100-
with open(out_file_path, "w") as f:
101-
f.write(str(winograd_model))
102-
f.close()
103-
return winograd_model, out_file_path
101+
102+
bytecode_stream = io.BytesIO()
103+
winograd_model.operation.write_bytecode(bytecode_stream)
104+
bytecode = bytecode_stream.getvalue()
105+
106+
with open(out_file_path, "w") as f:
107+
f.write(str(winograd_model))
108+
f.close()
109+
return bytecode, out_file_path
104110

105111

106112
def dump_after_mlir(input_mlir, model_name, use_winograd):
@@ -176,10 +182,15 @@ def annotate_with_lower_configs(
176182
)
177183
else:
178184
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
185+
186+
bytecode_stream = io.BytesIO()
187+
tuned_model.operation.write_bytecode(bytecode_stream)
188+
bytecode = bytecode_stream.getvalue()
189+
179190
with open(out_file_path, "w") as f:
180191
f.write(str(tuned_model))
181192
f.close()
182-
return tuned_model, out_file_path
193+
return bytecode, out_file_path
183194

184195

185196
def sd_model_annotation(mlir_model, model_name, model_from_tank=False):
@@ -215,7 +226,7 @@ def sd_model_annotation(mlir_model, model_name, model_from_tank=False):
215226
mlir_model, lowering_config_dir, model_name, use_winograd
216227
)
217228
print(f"Saved the annotated mlir in {output_path}.")
218-
return tuned_model, output_path
229+
return tuned_model
219230

220231

221232
if __name__ == "__main__":

apps/stable_diffusion/src/utils/utils.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -96,26 +96,19 @@ def compile_through_fx(
9696
)
9797

9898
if use_tuned:
99-
tuned_model_path = f"{args.annotation_output}/{model_name}_torch.mlir"
100-
if not os.path.exists(tuned_model_path):
101-
if "vae" in model_name.split("_")[0]:
102-
args.annotation_model = "vae"
103-
104-
tuned_model, tuned_model_path = sd_model_annotation(
105-
mlir_module, model_name
106-
)
107-
del mlir_module, tuned_model
108-
gc.collect()
109-
110-
with open(tuned_model_path, "rb") as f:
111-
mlir_module = f.read()
112-
f.close()
99+
if "vae" in model_name.split("_")[0]:
100+
args.annotation_model = "vae"
101+
mlir_module = sd_model_annotation(mlir_module, model_name)
113102

114103
shark_module = SharkInference(
115104
mlir_module,
116105
device=args.device,
117106
mlir_dialect="linalg",
118107
)
108+
109+
del mlir_module
110+
gc.collect()
111+
119112
return _compile_module(shark_module, model_name, extra_args)
120113

121114

@@ -253,11 +246,7 @@ def set_init_device_flags():
253246
):
254247
args.use_tuned = False
255248

256-
elif "cuda" in args.device and get_cuda_sm_cc() not in [
257-
"sm_80",
258-
"sm_84",
259-
"sm_86",
260-
]:
249+
elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80"]:
261250
args.use_tuned = False
262251

263252
elif args.use_base_vae and args.hf_model_id not in [

0 commit comments

Comments
 (0)