Skip to content

Commit 45af40f

Browse files
author
Gaurav Shukla
committed
[SD][web] Add openjourney and dreamlike in SD web UI
Signed-Off-by: Gaurav Shukla <[email protected]>
1 parent d11cf42 commit 45af40f

File tree

5 files changed

+42
-11
lines changed

5 files changed

+42
-11
lines changed

shark/examples/shark_inference/stable_diffusion/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def set_init_device_flags():
188188

189189
# use tuned models only in the case of stablediffusion/fp16 and rdna3 cards.
190190
if (
191-
args.variant == "openjourney"
191+
args.variant in ["openjourney", "dreamlike"]
192192
or args.precision != "fp16"
193193
or "vulkan" not in args.device
194194
or "rdna3" not in args.iree_vulkan_target_triple

web/index.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,11 @@
7777
label="Model Variant",
7878
value="stablediffusion",
7979
choices=[
80+
"stablediffusion",
8081
"anythingv3",
8182
"analogdiffusion",
82-
"stablediffusion",
83+
"openjourney",
84+
"dreamlike",
8385
],
8486
)
8587
scheduler_key = gr.Dropdown(

web/models/stable_diffusion/model_wrappers.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
# clip has 2 variants of max length 77 or 64.
1414
model_clip_max_length = 64 if args.max_length == 64 else 77
15-
if args.variant in ["anythingv3", "analogdiffusion"]:
15+
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
1616
model_clip_max_length = 77
1717
elif args.variant == "openjourney":
1818
model_clip_max_length = 64
@@ -64,6 +64,7 @@
6464
"anythingv3": "diffusers",
6565
"analogdiffusion": "main",
6666
"openjourney": "main",
67+
"dreamlike": "main",
6768
}
6869

6970

@@ -78,7 +79,12 @@ def get_clip_mlir(model_name="clip_text", extra_args=[]):
7879
model_config[args.version], subfolder="text_encoder"
7980
)
8081

81-
elif args.variant in ["anythingv3", "analogdiffusion", "openjourney"]:
82+
elif args.variant in [
83+
"anythingv3",
84+
"analogdiffusion",
85+
"openjourney",
86+
"dreamlike",
87+
]:
8288
text_encoder = CLIPTextModel.from_pretrained(
8389
model_variant[args.variant],
8490
subfolder="text_encoder",
@@ -133,7 +139,12 @@ def forward(self, input):
133139
)
134140
else:
135141
inputs = model_input[args.version]["vae"]
136-
elif args.variant in ["anythingv3", "analogdiffusion", "openjourney"]:
142+
elif args.variant in [
143+
"anythingv3",
144+
"analogdiffusion",
145+
"openjourney",
146+
"dreamlike",
147+
]:
137148
if args.precision == "fp16":
138149
vae = vae.half().cuda()
139150
inputs = tuple(
@@ -184,7 +195,12 @@ def forward(self, input):
184195
)
185196
else:
186197
inputs = model_input[args.version]["vae"]
187-
elif args.variant in ["anythingv3", "analogdiffusion", "openjourney"]:
198+
elif args.variant in [
199+
"anythingv3",
200+
"analogdiffusion",
201+
"openjourney",
202+
"dreamlike",
203+
]:
188204
if args.precision == "fp16":
189205
vae = vae.half().cuda()
190206
inputs = tuple(
@@ -242,7 +258,12 @@ def forward(self, latent, timestep, text_embedding, guidance_scale):
242258
)
243259
else:
244260
inputs = model_input[args.version]["unet"]
245-
elif args.variant in ["anythingv3", "analogdiffusion", "openjourney"]:
261+
elif args.variant in [
262+
"anythingv3",
263+
"analogdiffusion",
264+
"openjourney",
265+
"dreamlike",
266+
]:
246267
if args.precision == "fp16":
247268
unet = unet.half().cuda()
248269
inputs = tuple(

web/models/stable_diffusion/resources/model_db.json

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
"analogdiffusion/untuned":"gs://shark_tank/sd_analog_diffusion",
88
"analogdiffusion/tuned":"gs://shark_tank/sd_tuned",
99
"openjourney/untuned":"gs://shark_tank/sd_openjourney",
10-
"openjourney/tuned":"gs://shark_tank/sd_tuned"
10+
"openjourney/tuned":"gs://shark_tank/sd_tuned",
11+
"dreamlike/untuned":"gs://shark_tank/sd_dreamlike_diffusion"
1112
},
1213
{
1314
"stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16",
@@ -55,6 +56,13 @@
5556
"openjourney/v2_1base/vae/fp16/length_77/untuned/base":"oj_vaebase_22dec_fp16",
5657
"openjourney/v2_1base/vae/fp32/length_77/untuned":"oj_vae_22dec_fp32",
5758
"openjourney/v2_1base/vae/fp32/length_77/untuned/base":"oj_vaebase_22dec_fp32",
58-
"openjourney/v2_1base/clip/fp32/length_64/untuned":"oj_clip_22dec_fp32_64"
59+
"openjourney/v2_1base/clip/fp32/length_64/untuned":"oj_clip_22dec_fp32_64",
60+
"dreamlike/v2_1base/unet/fp16/length_77/untuned":"dl_unet_23dec_fp16_77",
61+
"dreamlike/v2_1base/unet/fp32/length_77/untuned":"dl_unet_23dec_fp32_77",
62+
"dreamlike/v2_1base/vae/fp16/length_77/untuned":"dl_vae_23dec_fp16",
63+
"dreamlike/v2_1base/vae/fp16/length_77/untuned/base":"dl_vaebase_23dec_fp16",
64+
"dreamlike/v2_1base/vae/fp32/length_77/untuned":"dl_vae_23dec_fp32",
65+
"dreamlike/v2_1base/vae/fp32/length_77/untuned/base":"dl_vaebase_23dec_fp32",
66+
"dreamlike/v2_1base/clip/fp32/length_77/untuned":"dl_clip_23dec_fp32_77"
5967
}
6068
]

web/models/stable_diffusion/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,14 @@ def set_init_device_flags():
181181
args.device = "cpu"
182182

183183
# set max_length based on availability.
184-
if args.variant in ["anythingv3", "analogdiffusion"]:
184+
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
185185
args.max_length = 77
186186
elif args.variant == "openjourney":
187187
args.max_length = 64
188188

189189
# use tuned models only in the case of stablediffusion/fp16 and rdna3 cards.
190190
if (
191-
args.variant == "openjourney"
191+
args.variant in ["openjourney", "dreamlike"]
192192
or args.precision != "fp16"
193193
or "vulkan" not in args.device
194194
or "rdna3" not in args.iree_vulkan_target_triple

0 commit comments

Comments
 (0)