Skip to content

Commit d11cf42

Browse files
authored
Add support for dreamlike diffusion (huggingface#725)
* Add support for dreamlike diffusion * model wrapper to support 77 dreamlike * lint fix
1 parent c3c1e3b commit d11cf42

File tree

3 files changed

+37
-8
lines changed

3 files changed

+37
-8
lines changed

shark/examples/shark_inference/stable_diffusion/model_wrappers.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
# clip has 2 variants of max length 77 or 64.
1414
model_clip_max_length = 64 if args.max_length == 64 else 77
15-
if args.variant in ["anythingv3", "analogdiffusion"]:
15+
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
1616
model_clip_max_length = 77
1717
elif args.variant == "openjourney":
1818
model_clip_max_length = 64
@@ -64,6 +64,7 @@
6464
"anythingv3": "diffusers",
6565
"analogdiffusion": "main",
6666
"openjourney": "main",
67+
"dreamlike": "main",
6768
}
6869

6970

@@ -78,7 +79,12 @@ def get_clip_mlir(model_name="clip_text", extra_args=[]):
7879
model_config[args.version], subfolder="text_encoder"
7980
)
8081

81-
elif args.variant in ["anythingv3", "analogdiffusion", "openjourney"]:
82+
elif args.variant in [
83+
"anythingv3",
84+
"analogdiffusion",
85+
"openjourney",
86+
"dreamlike",
87+
]:
8288
text_encoder = CLIPTextModel.from_pretrained(
8389
model_variant[args.variant],
8490
subfolder="text_encoder",
@@ -133,7 +139,12 @@ def forward(self, input):
133139
)
134140
else:
135141
inputs = model_input[args.version]["vae"]
136-
elif args.variant in ["anythingv3", "analogdiffusion", "openjourney"]:
142+
elif args.variant in [
143+
"anythingv3",
144+
"analogdiffusion",
145+
"openjourney",
146+
"dreamlike",
147+
]:
137148
if args.precision == "fp16":
138149
vae = vae.half().cuda()
139150
inputs = tuple(
@@ -184,7 +195,12 @@ def forward(self, input):
184195
)
185196
else:
186197
inputs = model_input[args.version]["vae"]
187-
elif args.variant in ["anythingv3", "analogdiffusion", "openjourney"]:
198+
elif args.variant in [
199+
"anythingv3",
200+
"analogdiffusion",
201+
"openjourney",
202+
"dreamlike",
203+
]:
188204
if args.precision == "fp16":
189205
vae = vae.half().cuda()
190206
inputs = tuple(
@@ -242,7 +258,12 @@ def forward(self, latent, timestep, text_embedding, guidance_scale):
242258
)
243259
else:
244260
inputs = model_input[args.version]["unet"]
245-
elif args.variant in ["anythingv3", "analogdiffusion", "openjourney"]:
261+
elif args.variant in [
262+
"anythingv3",
263+
"analogdiffusion",
264+
"openjourney",
265+
"dreamlike",
266+
]:
246267
if args.precision == "fp16":
247268
unet = unet.half().cuda()
248269
inputs = tuple(

shark/examples/shark_inference/stable_diffusion/resources/model_db.json

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
"analogdiffusion/untuned":"gs://shark_tank/sd_analog_diffusion",
88
"analogdiffusion/tuned":"gs://shark_tank/sd_tuned",
99
"openjourney/untuned":"gs://shark_tank/sd_openjourney",
10-
"openjourney/tuned":"gs://shark_tank/sd_tuned"
10+
"openjourney/tuned":"gs://shark_tank/sd_tuned",
11+
"dreamlike/untuned":"gs://shark_tank/sd_dreamlike_diffusion"
1112
},
1213
{
1314
"stablediffusion/v1_4/unet/fp16/length_77/untuned":"unet_8dec_fp16",
@@ -55,6 +56,13 @@
5556
"openjourney/v2_1base/vae/fp16/length_77/untuned/base":"oj_vaebase_22dec_fp16",
5657
"openjourney/v2_1base/vae/fp32/length_77/untuned":"oj_vae_22dec_fp32",
5758
"openjourney/v2_1base/vae/fp32/length_77/untuned/base":"oj_vaebase_22dec_fp32",
58-
"openjourney/v2_1base/clip/fp32/length_64/untuned":"oj_clip_22dec_fp32_64"
59+
"openjourney/v2_1base/clip/fp32/length_64/untuned":"oj_clip_22dec_fp32_64",
60+
"dreamlike/v2_1base/unet/fp16/length_77/untuned":"dl_unet_23dec_fp16_77",
61+
"dreamlike/v2_1base/unet/fp32/length_77/untuned":"dl_unet_23dec_fp32_77",
62+
"dreamlike/v2_1base/vae/fp16/length_77/untuned":"dl_vae_23dec_fp16",
63+
"dreamlike/v2_1base/vae/fp16/length_77/untuned/base":"dl_vaebase_23dec_fp16",
64+
"dreamlike/v2_1base/vae/fp32/length_77/untuned":"dl_vae_23dec_fp32",
65+
"dreamlike/v2_1base/vae/fp32/length_77/untuned/base":"dl_vaebase_23dec_fp32",
66+
"dreamlike/v2_1base/clip/fp32/length_77/untuned":"dl_clip_23dec_fp32_77"
5967
}
6068
]

shark/examples/shark_inference/stable_diffusion/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def set_init_device_flags():
181181
args.device = "cpu"
182182

183183
# set max_length based on availability.
184-
if args.variant in ["anythingv3", "analogdiffusion"]:
184+
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
185185
args.max_length = 77
186186
elif args.variant == "openjourney":
187187
args.max_length = 64

0 commit comments

Comments
 (0)