Skip to content

Commit 898bc9e

Browse files
author
Prashant Kumar
committed
Add the stable diffusion v2.1 version.
1 parent e67ea31 commit 898bc9e

File tree

3 files changed

+13
-8
lines changed

3 files changed

+13
-8
lines changed

shark/examples/shark_inference/stable_diffusion/main.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from opt_params import get_unet, get_vae, get_clip
1616
import time
1717
import sys
18-
from model_wrappers import get_vae_mlir
1918
from shark.iree_utils.compile_utils import dump_isas
2019

2120
# Helper function to profile the vulkan device.
@@ -43,7 +42,7 @@ def end_profiling(device):
4342
neg_prompt = args.negative_prompts
4443
height = 512 # default height of Stable Diffusion
4544
width = 512 # default width of Stable Diffusion
46-
if args.version == "v2":
45+
if args.version == "v2.1":
4746
height = 768
4847
width = 768
4948

@@ -75,13 +74,13 @@ def end_profiling(device):
7574
"CompVis/stable-diffusion-v1-4",
7675
subfolder="scheduler",
7776
)
78-
if args.version == "v2":
77+
if args.version == "v2.1":
7978
tokenizer = CLIPTokenizer.from_pretrained(
80-
"stabilityai/stable-diffusion-2", subfolder="tokenizer"
79+
"stabilityai/stable-diffusion-2-1", subfolder="tokenizer"
8180
)
8281

8382
scheduler = DPMSolverMultistepScheduler.from_pretrained(
84-
"stabilityai/stable-diffusion-2",
83+
"stabilityai/stable-diffusion-2-1",
8584
subfolder="scheduler",
8685
)
8786

shark/examples/shark_inference/stable_diffusion/model_wrappers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
import torch
66

77
model_config = {
8-
"v2": "stabilityai/stable-diffusion-2",
8+
"v2.1": "stabilityai/stable-diffusion-2-1",
99
"v2.1base": "stabilityai/stable-diffusion-2-1-base",
1010
"v1.4": "CompVis/stable-diffusion-v1-4",
1111
}
1212

1313
model_input = {
14-
"v2": {
14+
"v2.1": {
1515
"clip": (torch.randint(1, 2, (1, 77)),),
1616
"vae": (torch.randn(1, 4, 96, 96),),
1717
"unet": (
@@ -52,7 +52,7 @@ def get_clip_mlir(model_name="clip_text", extra_args=[]):
5252
text_encoder = CLIPTextModel.from_pretrained(
5353
"openai/clip-vit-large-patch14"
5454
)
55-
if args.version == "v2":
55+
if args.version != "v1.4":
5656
text_encoder = CLIPTextModel.from_pretrained(
5757
model_config[args.version], subfolder="text_encoder"
5858
)

shark/examples/shark_inference/stable_diffusion/opt_params.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def get_unet():
3030
model_name = "unet_8dec_fp16"
3131
if args.version == "v2.1base":
3232
model_name = "unet2base_8dec_fp16"
33+
if args.version == "v2.1":
34+
model_name = "unet2_14dec_fp16"
3335
iree_flags += [
3436
"--iree-flow-enable-padding-linalg-ops",
3537
"--iree-flow-linalg-ops-padding-size=32",
@@ -79,6 +81,8 @@ def get_vae():
7981
model_name = "vae_8dec_fp16"
8082
if args.version == "v2.1base":
8183
model_name = "vae2base_8dec_fp16"
84+
if args.version == "v2.1":
85+
model_name = "vae2_14dec_fp16"
8286
iree_flags += [
8387
"--iree-flow-enable-padding-linalg-ops",
8488
"--iree-flow-linalg-ops-padding-size=32",
@@ -144,6 +148,8 @@ def get_clip():
144148
model_name = "clip_8dec_fp32"
145149
if args.version == "v2.1base":
146150
model_name = "clip2base_8dec_fp32"
151+
if args.version == "v2.1":
152+
model_name = "clip2_14dec_fp32"
147153
iree_flags += [
148154
"--iree-flow-linalg-ops-padding-size=16",
149155
"--iree-flow-enable-padding-linalg-ops",

0 commit comments

Comments
 (0)