Skip to content

Commit 097d0f2

Browse files
committed
[SD][web] Add 64 max_length support in SD web
Signed-Off-by: Gaurav Shukla <[email protected]>
1 parent 2257f87 commit 097d0f2

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

web/models/stable_diffusion/model_wrappers.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,34 +10,37 @@
1010
"v1.4": "CompVis/stable-diffusion-v1-4",
1111
}
1212

13+
# clip has 2 variants of max length 77 or 64.
14+
model_clip_max_length = 64 if args.max_length == 64 else 77
15+
1316
model_input = {
1417
"v2.1": {
15-
"clip": (torch.randint(1, 2, (2, 77)),),
18+
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
1619
"vae": (torch.randn(1, 4, 96, 96),),
1720
"unet": (
1821
torch.randn(1, 4, 96, 96), # latents
1922
torch.tensor([1]).to(torch.float32), # timestep
20-
torch.randn(2, 77, 1024), # embedding
23+
torch.randn(2, model_clip_max_length, 1024), # embedding
2124
torch.tensor(1).to(torch.float32), # guidance_scale
2225
),
2326
},
2427
"v2.1base": {
25-
"clip": (torch.randint(1, 2, (2, 77)),),
28+
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
2629
"vae": (torch.randn(1, 4, 64, 64),),
2730
"unet": (
2831
torch.randn(1, 4, 64, 64), # latents
2932
torch.tensor([1]).to(torch.float32), # timestep
30-
torch.randn(2, 77, 1024), # embedding
33+
torch.randn(2, model_clip_max_length, 1024), # embedding
3134
torch.tensor(1).to(torch.float32), # guidance_scale
3235
),
3336
},
3437
"v1.4": {
35-
"clip": (torch.randint(1, 2, (2, 77)),),
38+
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
3639
"vae": (torch.randn(1, 4, 64, 64),),
3740
"unet": (
3841
torch.randn(1, 4, 64, 64),
3942
torch.tensor([1]).to(torch.float32), # timestep
40-
torch.randn(2, 77, 768),
43+
torch.randn(2, model_clip_max_length, 768),
4144
torch.tensor(1).to(torch.float32),
4245
),
4346
},

web/models/stable_diffusion/opt_params.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ def get_unet():
4646
bucket = "gs://shark_tank/stable_diffusion"
4747
model_name = "unet_8dec_fp16"
4848
if args.version == "v2.1base":
49-
model_name = "unet2base_8dec_fp16"
49+
if args.max_length == 64:
50+
model_name = "unet_19dec_v2p1base_fp16_64"
51+
else:
52+
model_name = "unet2base_8dec_fp16"
5053
if args.version == "v2.1":
5154
model_name = "unet2_14dec_fp16"
5255
iree_flags += [
@@ -149,7 +152,10 @@ def get_clip():
149152
bucket = "gs://shark_tank/stable_diffusion"
150153
model_name = "clip_18dec_fp32"
151154
if args.version == "v2.1base":
152-
model_name = "clip2base_18dec_fp32"
155+
if args.max_length == 64:
156+
model_name = "clip_19dec_v2p1base_fp32_64"
157+
else:
158+
model_name = "clip2base_18dec_fp32"
153159
if args.version == "v2.1":
154160
model_name = "clip2_18dec_fp32"
155161
iree_flags += [

0 commit comments

Comments
 (0)