Skip to content

Commit a463463

Browse files
add generate_sharktank for stable_diffusion model defaults (huggingface#742)
Co-authored-by: dan <[email protected]> Co-authored-by: powderluv <[email protected]>
1 parent d17e8dc commit a463463

File tree

9 files changed

+400
-44
lines changed

9 files changed

+400
-44
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
rm -rf ./test_images
2+
mkdir test_images
3+
python shark/examples/shark_inference/stable_diffusion/main.py --device=vulkan --output_dir=./test_images --no-load_vmfb --no-use_tuned
4+
python shark/examples/shark_inference/stable_diffusion/main.py --device=vulkan --output_dir=./test_images --no-load_vmfb --no-use_tuned --beta_models=True
5+
6+
python build_tools/image_comparison.py -n ./test_images/*.png
7+
exit $?

generate_sharktank.py

Lines changed: 61 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
import hashlib
1919
import numpy as np
2020
from pathlib import Path
21+
from shark.examples.shark_inference.stable_diffusion import (
22+
model_wrappers as mw,
23+
)
24+
from shark.examples.shark_inference.stable_diffusion.stable_args import (
25+
args,
26+
)
2127

2228

2329
def create_hash(file_name):
@@ -51,6 +57,32 @@ def save_torch_model(torch_model_list):
5157

5258
model = None
5359
input = None
60+
if model_type == "stable_diffusion":
61+
62+
args.use_tuned = False
63+
args.import_mlir = True
64+
args.use_tuned = False
65+
args.local_tank_cache = WORKDIR
66+
67+
precision_values = ["fp16"]
68+
seq_lengths = [64, 77]
69+
for precision_value in precision_values:
70+
args.precision = precision_value
71+
for length in seq_lengths:
72+
model = mw.SharkifyStableDiffusionModel(
73+
model_id=torch_model_name,
74+
custom_weights="",
75+
precision=precision_value,
76+
max_len=length,
77+
width=512,
78+
height=512,
79+
use_base_vae=False,
80+
debug=True,
81+
sharktank_dir=WORKDIR,
82+
generate_vmfb=False,
83+
)
84+
model()
85+
continue
5486
if model_type == "vision":
5587
model, input, _ = get_vision_model(torch_model_name)
5688
elif model_type == "hf":
@@ -205,34 +237,35 @@ def is_valid_file(arg):
205237

206238

207239
if __name__ == "__main__":
208-
parser = argparse.ArgumentParser()
209-
parser.add_argument(
210-
"--torch_model_csv",
211-
type=lambda x: is_valid_file(x),
212-
default="./tank/torch_model_list.csv",
213-
help="""Contains the file with torch_model name and args.
214-
Please see: https://github.com/nod-ai/SHARK/blob/main/tank/torch_model_list.csv""",
215-
)
216-
parser.add_argument(
217-
"--tf_model_csv",
218-
type=lambda x: is_valid_file(x),
219-
default="./tank/tf_model_list.csv",
220-
help="Contains the file with tf model name and args.",
221-
)
222-
parser.add_argument(
223-
"--tflite_model_csv",
224-
type=lambda x: is_valid_file(x),
225-
default="./tank/tflite/tflite_model_list.csv",
226-
help="Contains the file with tf model name and args.",
227-
)
228-
parser.add_argument(
229-
"--ci_tank_dir",
230-
type=bool,
231-
default=False,
232-
)
233-
parser.add_argument("--upload", type=bool, default=False)
234-
235-
args = parser.parse_args()
240+
# Note, all of these flags are overridden by the import of args from stable_args.py, flags are duplicated temporarily to preserve functionality
241+
# parser = argparse.ArgumentParser()
242+
# parser.add_argument(
243+
# "--torch_model_csv",
244+
# type=lambda x: is_valid_file(x),
245+
# default="./tank/torch_model_list.csv",
246+
# help="""Contains the file with torch_model name and args.
247+
# Please see: https://github.com/nod-ai/SHARK/blob/main/tank/torch_model_list.csv""",
248+
# )
249+
# parser.add_argument(
250+
# "--tf_model_csv",
251+
# type=lambda x: is_valid_file(x),
252+
# default="./tank/tf_model_list.csv",
253+
# help="Contains the file with tf model name and args.",
254+
# )
255+
# parser.add_argument(
256+
# "--tflite_model_csv",
257+
# type=lambda x: is_valid_file(x),
258+
# default="./tank/tflite/tflite_model_list.csv",
259+
# help="Contains the file with tf model name and args.",
260+
# )
261+
# parser.add_argument(
262+
# "--ci_tank_dir",
263+
# type=bool,
264+
# default=False,
265+
# )
266+
# parser.add_argument("--upload", type=bool, default=False)
267+
268+
# old_args = parser.parse_args()
236269

237270
home = str(Path.home())
238271
if args.ci_tank_dir == True:

shark/examples/shark_inference/stable_diffusion/model_wrappers.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import sys
2+
import os
3+
4+
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
15
from diffusers import AutoencoderKL, UNet2DConditionModel
26
from transformers import CLIPTextModel
37
from utils import compile_through_fx, get_opt_flags
48
from resources import base_models
59
from collections import defaultdict
610
import torch
7-
import sys
811

912

1013
# These shapes are parameter dependent.
@@ -63,6 +66,9 @@ def __init__(
6366
batch_size: int = 1,
6467
use_base_vae: bool = False,
6568
use_tuned: bool = False,
69+
debug: bool = False,
70+
sharktank_dir: str = "",
71+
generate_vmfb: bool = True,
6672
):
6773
self.check_params(max_len, width, height)
6874
self.max_len = max_len
@@ -73,7 +79,8 @@ def __init__(
7379
self.precision = precision
7480
self.base_vae = use_base_vae
7581
self.model_name = (
76-
str(batch_size)
82+
"_"
83+
+ str(batch_size)
7784
+ "_"
7885
+ str(max_len)
7986
+ "_"
@@ -84,6 +91,9 @@ def __init__(
8491
+ precision
8592
)
8693
self.use_tuned = use_tuned
94+
self.debug = debug
95+
self.sharktank_dir = sharktank_dir
96+
self.generate_vmfb = generate_vmfb
8797
# We need a better naming convention for the .vmfbs because despite
8898
# using the custom model variant the .vmfb names remain the same and
8999
# it'll always pick up the compiled .vmfb instead of compiling the
@@ -130,13 +140,20 @@ def forward(self, input):
130140
inputs = tuple(self.inputs["vae"])
131141
is_f16 = True if self.precision == "fp16" else False
132142
vae_name = "base_vae" if self.base_vae else "vae"
143+
vae_model_name = vae_name + self.model_name
144+
if self.debug:
145+
os.makedirs(
146+
os.path.join(self.sharktank_dir, vae_model_name), exist_ok=True
147+
)
133148
shark_vae = compile_through_fx(
134149
vae,
135150
inputs,
136151
is_f16=is_f16,
137-
model_name=vae_name + self.model_name,
138152
use_tuned=self.use_tuned,
153+
model_name=vae_model_name,
139154
extra_args=get_opt_flags("vae", precision=self.precision),
155+
debug=self.debug,
156+
generate_vmfb=self.generate_vmfb,
140157
)
141158
return shark_vae
142159

@@ -169,14 +186,22 @@ def forward(
169186
is_f16 = True if self.precision == "fp16" else False
170187
inputs = tuple(self.inputs["unet"])
171188
input_mask = [True, True, True, False]
189+
unet_model_name = "unet" + self.model_name
190+
if self.debug:
191+
os.makedirs(
192+
os.path.join(self.sharktank_dir, unet_model_name),
193+
exist_ok=True,
194+
)
172195
shark_unet = compile_through_fx(
173196
unet,
174197
inputs,
175-
model_name="unet" + self.model_name,
198+
model_name=unet_model_name,
176199
is_f16=is_f16,
177200
f16_input_mask=input_mask,
178201
use_tuned=self.use_tuned,
179202
extra_args=get_opt_flags("unet", precision=self.precision),
203+
debug=self.debug,
204+
generate_vmfb=self.generate_vmfb,
180205
)
181206
return shark_unet
182207

@@ -193,12 +218,20 @@ def forward(self, input):
193218
return self.text_encoder(input)[0]
194219

195220
clip_model = CLIPText()
221+
clip_model_name = "clip" + self.model_name
222+
if self.debug:
223+
os.makedirs(
224+
os.path.join(self.sharktank_dir, clip_model_name),
225+
exist_ok=True,
226+
)
196227

197228
shark_clip = compile_through_fx(
198229
clip_model,
199230
tuple(self.inputs["clip"]),
200-
model_name="clip" + self.model_name,
231+
model_name=clip_model_name,
201232
extra_args=get_opt_flags("clip", precision="fp32"),
233+
debug=self.debug,
234+
generate_vmfb=self.generate_vmfb,
202235
)
203236
return shark_clip
204237

shark/examples/shark_inference/stable_diffusion/opt_params.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import sys
2-
from resources import models_db
2+
import resources
33
from stable_args import args
44
from utils import get_shark_model
55

6+
models_db = (
7+
resources.beta_models_db if args.beta_models else resources.models_db
8+
)
69
BATCH_SIZE = len(args.prompts)
710
if BATCH_SIZE != 1:
811
sys.exit("Only batch size 1 is supported.")

shark/examples/shark_inference/stable_diffusion/resources.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def get_json_file(path):
2828
# it will run all the global vars.
2929
prompts_examples = get_json_file("resources/prompts.json")
3030
models_db = get_json_file("resources/model_db.json")
31+
beta_models_db = get_json_file("resources/beta_model_db.json")
3132

3233
# The base_model contains the input configuration for the different
3334
# models and also helps in providing information for the variants.

0 commit comments

Comments
 (0)