Skip to content

Commit d913453

Browse files
[WEB] Update models to 8dec and also default values (huggingface#620)
1. Update the models to 8 dec. 2. precision is default to `fp16` in CLI. 3. version is default to `v2.1base` in CLI as well as web. 4. The default scheduler is set to `EulerDiscrete` now. Signed-Off-by: Gaurav Shukla <[email protected]> Signed-off-by: Gaurav Shukla <[email protected]>
1 parent 08e373a commit d913453

File tree

5 files changed

+22
-11
lines changed

5 files changed

+22
-11
lines changed

shark/examples/shark_inference/stable_diffusion/stable_args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
p.add_argument(
2424
"--version",
2525
type=str,
26-
default="v1.4",
26+
default="v2.1base",
2727
help="Specify version of stable diffusion model",
2828
)
2929

@@ -48,7 +48,7 @@
4848
)
4949

5050
p.add_argument(
51-
"--precision", type=str, default="fp32", help="precision to run the model."
51+
"--precision", type=str, default="fp16", help="precision to run the model."
5252
)
5353

5454
p.add_argument(

web/index.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,13 @@ def resource_path(relative_path):
101101
)
102102
version = gr.Radio(
103103
label="Version",
104-
value="v1.4",
104+
value="v2.1base",
105105
choices=["v1.4", "v2.1base"],
106106
)
107107
with gr.Row():
108108
scheduler_key = gr.Dropdown(
109109
label="Scheduler",
110-
value="DPMSolverMultistep",
110+
value="EulerDiscrete",
111111
choices=[
112112
"DDIM",
113113
"PNDM",
@@ -174,9 +174,9 @@ def resource_path(relative_path):
174174
outputs=[generated_img, std_output],
175175
)
176176

177+
shark_web.queue()
177178
shark_web.launch(
178179
share=False,
179180
server_name="0.0.0.0",
180181
server_port=8080,
181-
enable_queue=True,
182182
)

web/models/stable_diffusion/model_wrappers.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
model_config = {
77
"v2": "stabilityai/stable-diffusion-2",
8+
"v2.1base": "stabilityai/stable-diffusion-2-1-base",
89
"v1.4": "CompVis/stable-diffusion-v1-4",
910
}
1011

@@ -19,6 +20,16 @@
1920
torch.tensor(1).to(torch.float32), # guidance_scale
2021
),
2122
},
23+
"v2.1base": {
24+
"clip": (torch.randint(1, 2, (1, 77)),),
25+
"vae": (torch.randn(1, 4, 64, 64),),
26+
"unet": (
27+
torch.randn(1, 4, 64, 64), # latents
28+
torch.tensor([1]).to(torch.float32), # timestep
29+
torch.randn(2, 77, 1024), # embedding
30+
torch.tensor(1).to(torch.float32), # guidance_scale
31+
),
32+
},
2233
"v1.4": {
2334
"clip": (torch.randint(1, 2, (1, 77)),),
2435
"vae": (torch.randn(1, 4, 64, 64),),

web/models/stable_diffusion/opt_params.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def get_unet(args):
2222
return get_shark_model(args, bucket, model_name, iree_flags)
2323
else:
2424
bucket = "gs://shark_tank/stable_diffusion"
25-
model_name = "unet_1dec_fp16"
25+
model_name = "unet_8dec_fp16"
2626
if args.version == "v2.1base":
2727
model_name = "unet2base_8dec_fp16"
2828
iree_flags += [
@@ -56,7 +56,7 @@ def get_vae(args):
5656
)
5757
if args.precision == "fp16":
5858
bucket = "gs://shark_tank/stable_diffusion"
59-
model_name = "vae_1dec_fp16"
59+
model_name = "vae_8dec_fp16"
6060
if args.version == "v2.1base":
6161
model_name = "vae2base_8dec_fp16"
6262
iree_flags += [
@@ -119,7 +119,7 @@ def get_clip(args):
119119
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
120120
)
121121
bucket = "gs://shark_tank/stable_diffusion"
122-
model_name = "clip_1dec_fp32"
122+
model_name = "clip_8dec_fp32"
123123
if args.version == "v2.1base":
124124
model_name = "clip2base_8dec_fp32"
125125
iree_flags += [

web/models/stable_diffusion/stable_args.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
p.add_argument(
2626
"--version",
2727
type=str,
28-
default="v1.4",
28+
default="v2.1base",
2929
help="Specify version of stable diffusion model",
3030
)
3131

@@ -60,8 +60,8 @@
6060
p.add_argument(
6161
"--scheduler",
6262
type=str,
63-
default="DPMSolverMultistep",
64-
help="can be [PNDM, LMSDiscrete, DDIM, DPMSolverMultistep]",
63+
default="EulerDiscrete",
64+
help="can be [PNDM, LMSDiscrete, DDIM, DPMSolverMultistep, EulerDiscrete]",
6565
)
6666

6767
p.add_argument(

0 commit comments

Comments
 (0)