Skip to content

Commit 47a119a

Browse files
authored
[SD] Add CUDA A100 tuned model (huggingface#773)
1 parent ee56559 commit 47a119a

File tree

5 files changed

+60
-17
lines changed

5 files changed

+60
-17
lines changed

shark/examples/shark_inference/stable_diffusion/opt_params.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,13 @@ def get_params(bucket_key, model_key, model, is_tuned, precision):
6262
def get_unet():
6363
# Tuned model is present only for `fp16` precision.
6464
is_tuned = "tuned" if args.use_tuned else "untuned"
65-
bucket_key = f"{args.variant}/{is_tuned}"
66-
model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}"
65+
if "vulkan" not in args.device and is_tuned:
66+
bucket_key = f"{args.variant}/{is_tuned}/{args.device}"
67+
model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"
68+
else:
69+
bucket_key = f"{args.variant}/{is_tuned}"
70+
model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}"
71+
6772
bucket, model_name, iree_flags = get_params(
6873
bucket_key, model_key, "unet", is_tuned, args.precision
6974
)
@@ -74,7 +79,9 @@ def get_unet():
7479

7580
def get_vae():
7681
# Tuned model is present only for `fp16` precision.
77-
is_tuned = "tuned" if args.use_tuned else "untuned"
82+
is_tuned = (
83+
"tuned" if (args.use_tuned and "vulkan" in args.device) else "untuned"
84+
)
7885
is_base = "/base" if args.use_base_vae else ""
7986
bucket_key = f"{args.variant}/{is_tuned}"
8087
model_key = f"{args.variant}/{args.version}/vae/{args.precision}/length_77/{is_tuned}{is_base}"

shark/examples/shark_inference/stable_diffusion/resources/model_db.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
{
33
"stablediffusion/untuned":"gs://shark_tank/stable_diffusion",
44
"stablediffusion/tuned":"gs://shark_tank/sd_tuned",
5+
"stablediffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
56
"anythingv3/untuned":"gs://shark_tank/sd_anythingv3",
67
"anythingv3/tuned":"gs://shark_tank/sd_tuned",
78
"analogdiffusion/untuned":"gs://shark_tank/sd_analog_diffusion",
@@ -23,6 +24,7 @@
2324
"stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2",
2425
"stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet_19dec_v2p1base_fp16_64",
2526
"stablediffusion/v2_1base/unet/fp16/length_64/tuned":"unet_19dec_v2p1base_fp16_64_tuned",
27+
"stablediffusion/v2_1base/unet/fp16/length_64/tuned/cuda":"unet_19dec_v2p1base_fp16_64_cuda_tuned",
2628
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae2base_19dec_fp16",
2729
"stablediffusion/v2_1base/vae/fp16/length_77/tuned":"vae2base_19dec_fp16_tuned",
2830
"stablediffusion/v2_1base/vae/fp16/length_77/untuned/base":"vae2base_8dec_fp16",

shark/examples/shark_inference/stable_diffusion/sd_annotation.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@
1212
from utils import set_init_device_flags
1313

1414

15-
# Downloads the model (Unet or VAE fp16) from shark_tank
1615
set_init_device_flags()
16+
device = (
17+
args.device if "://" not in args.device else args.device.split("://")[0]
18+
)
19+
20+
# Downloads the model (Unet or VAE fp16) from shark_tank
1721
shark_args.local_tank_cache = args.local_tank_cache
1822
bucket_key = f"{args.variant}/untuned"
19-
use_winograd = True
2023
if args.annotation_model == "unet":
2124
model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}/untuned"
2225
elif args.annotation_model == "vae":
@@ -34,29 +37,29 @@
3437

3538
# Downloads the tuned config files from shark_tank
3639
config_bucket = "gs://shark_tank/sd_tuned/configs/"
37-
if use_winograd:
38-
config_name = f"{args.annotation_model}_winograd.json"
40+
if args.use_winograd:
41+
config_name = f"{args.annotation_model}_winograd_{device}.json"
3942
full_gs_url = config_bucket + config_name
4043
winograd_config_dir = f"{WORKDIR}configs/" + config_name
4144
download_public_file(full_gs_url, winograd_config_dir, True)
4245

4346
if args.annotation_model == "unet":
4447
if args.variant in ["anythingv3", "analogdiffusion"]:
4548
args.max_length = 77
46-
config_name = f"{args.annotation_model}_{args.version}_{args.precision}_len{args.max_length}.json"
49+
config_name = f"{args.annotation_model}_{args.version}_{args.precision}_len{args.max_length}_{device}.json"
4750
full_gs_url = config_bucket + config_name
4851
lowering_config_dir = f"{WORKDIR}configs/" + config_name
4952
download_public_file(full_gs_url, lowering_config_dir, True)
5053

5154
# Annotate the model with Winograd attribute on selected conv ops
52-
if use_winograd:
55+
if args.use_winograd:
5356
with create_context() as ctx:
5457
winograd_model = model_annotation(
5558
ctx,
5659
input_contents=mlir_model,
5760
config_path=winograd_config_dir,
5861
search_op="conv",
59-
winograd=use_winograd,
62+
winograd=args.use_winograd,
6063
)
6164
with open(
6265
f"{args.annotation_output}/{model_name}_tuned_torch.mlir", "w"
@@ -65,19 +68,30 @@
6568

6669
# For Unet annotate the model with tuned lowering configs
6770
if args.annotation_model == "unet":
68-
if use_winograd:
71+
if args.use_winograd:
6972
input_mlir = f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
7073
dump_after = "iree-linalg-ext-convert-conv2d-to-winograd"
7174
else:
7275
input_mlir = f"{WORKDIR}{model_name}_torch/{model_name}_torch.mlir"
7376
dump_after = "iree-flow-pad-linalg-ops"
7477

7578
# Dump IR after padding/img2col/winograd passes
79+
device_spec_args = ""
80+
if device == "cuda":
81+
from shark.iree_utils.gpu_utils import get_iree_gpu_args
82+
83+
gpu_flags = get_iree_gpu_args()
84+
for flag in gpu_flags:
85+
device_spec_args += flag + " "
86+
elif device == "vulkan":
87+
device_spec_args = (
88+
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
89+
)
7690
run_cmd(
7791
f"iree-compile {input_mlir} "
7892
"--iree-input-type=tm_tensor "
79-
f"--iree-hal-target-backends={iree_target_map(args.device)} "
80-
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
93+
f"--iree-hal-target-backends={iree_target_map(device)} "
94+
f"{device_spec_args}"
8195
"--iree-stream-resource-index-bits=64 "
8296
"--iree-vm-target-index-bits=64 "
8397
"--iree-flow-enable-padding-linalg-ops "

shark/examples/shark_inference/stable_diffusion/stable_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,4 +247,11 @@ def path_expand(s):
247247
help="Options are unet and vae.",
248248
)
249249

250+
p.add_argument(
251+
"--use_winograd",
252+
default=False,
253+
action=argparse.BooleanOptionalAction,
254+
help="Apply Winograd on selected conv ops.",
255+
)
256+
250257
args = p.parse_args()

shark/examples/shark_inference/stable_diffusion/utils.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
set_iree_vulkan_runtime_flags,
88
get_vulkan_target_triple,
99
)
10+
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
1011

1112

1213
def _compile_module(shark_module, model_name, extra_args=[]):
@@ -46,6 +47,8 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
4647

4748
# Set local shark_tank cache directory.
4849
shark_args.local_tank_cache = args.local_tank_cache
50+
if "cuda" in args.device:
51+
shark_args.enable_tf32 = True
4952

5053
mlir_model, func_name, inputs, golden_out = download_model(
5154
model_name,
@@ -185,22 +188,32 @@ def set_init_device_flags():
185188
elif args.variant == "openjourney":
186189
args.max_length = 64
187190

188-
# use tuned models only in the case of stablediffusion/fp16 and rdna3 cards.
191+
# Use tuned models in the case of stablediffusion/fp16 and rdna3 cards.
189192
if (
190193
args.variant in ["openjourney", "dreamlike"]
191194
or args.precision != "fp16"
192195
or "vulkan" not in args.device
193196
or "rdna3" not in args.iree_vulkan_target_triple
194197
):
195198
args.use_tuned = False
196-
print("Tuned models are currently not supported for this setting.")
197199

198200
elif args.use_base_vae and args.variant != "stablediffusion":
199201
args.use_tuned = False
200-
print("Tuned models are currently not supported for this setting.")
202+
203+
# Use tuned model in the case of stablediffusion/fp16 and cuda device sm_80
204+
if (
205+
args.variant == "stablediffusion"
206+
and args.precision == "fp16"
207+
and "cuda" in args.device
208+
and get_cuda_sm_cc() == "sm_80"
209+
and args.version == "v2_1base"
210+
):
211+
args.use_tuned = True
201212

202213
if args.use_tuned:
203-
print("Using tuned models for stablediffusion/fp16 and rdna3 card.")
214+
print(f"Using {args.device} tuned models for stablediffusion/fp16.")
215+
else:
216+
print("Tuned models are currently not supported for this setting.")
204217

205218

206219
# Utility to get list of devices available.

0 commit comments

Comments
 (0)