Skip to content

Commit 2af1102

Browse files
authored
[SD] Merge configs of different max lengthes from the same variant to one config file (huggingface#1019)
1 parent c4b4728 commit 2af1102

File tree

4 files changed

+53
-65
lines changed

4 files changed

+53
-65
lines changed

apps/stable_diffusion/src/models/opt_params.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99

1010

1111
hf_model_variant_map = {
12-
"Linaqruf/anything-v3.0": ["anythingv3", "v2_1base"],
13-
"dreamlike-art/dreamlike-diffusion-1.0": ["dreamlike", "v2_1base"],
14-
"prompthero/openjourney": ["openjourney", "v2_1base"],
15-
"wavymulder/Analog-Diffusion": ["analogdiffusion", "v2_1base"],
12+
"Linaqruf/anything-v3.0": ["anythingv3", "v1_4"],
13+
"dreamlike-art/dreamlike-diffusion-1.0": ["dreamlike", "v1_4"],
14+
"prompthero/openjourney": ["openjourney", "v1_4"],
15+
"wavymulder/Analog-Diffusion": ["analogdiffusion", "v1_4"],
1616
"stabilityai/stable-diffusion-2-1": ["stablediffusion", "v2_1base"],
1717
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
1818
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],

apps/stable_diffusion/src/utils/resources/model_db.json

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -53,41 +53,41 @@
5353
"stablediffusion/inpaint_v2/vae_encode/fp16/length_77/untuned":"vae_encode_inpaint_fp16",
5454
"stablediffusion/inpaint_v2/vae/fp16/length_77/untuned":"vae_inpaint_fp16",
5555
"stablediffusion/inpaint_v2/clip/fp32/length_77/untuned":"clip_inpaint_fp32",
56-
"anythingv3/v2_1base/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16",
57-
"anythingv3/v2_1base/unet/fp16/length_77/tuned":"av3_unet_19dec_fp16_tuned",
58-
"anythingv3/v2_1base/unet/fp16/length_77/tuned/cuda":"av3_unet_19dec_fp16_cuda_tuned",
59-
"anythingv3/v2_1base/unet/fp32/length_77/untuned":"av3_unet_19dec_fp32",
60-
"anythingv3/v2_1base/vae/fp16/length_77/untuned":"av3_vae_19dec_fp16",
61-
"anythingv3/v2_1base/vae/fp16/length_77/tuned":"av3_vae_19dec_fp16_tuned",
62-
"anythingv3/v2_1base/vae/fp16/length_77/tuned/cuda":"av3_vae_19dec_fp16_cuda_tuned",
63-
"anythingv3/v2_1base/vae/fp16/length_77/untuned/base":"av3_vaebase_22dec_fp16",
64-
"anythingv3/v2_1base/vae/fp32/length_77/untuned":"av3_vae_19dec_fp32",
65-
"anythingv3/v2_1base/vae/fp32/length_77/untuned/base":"av3_vaebase_22dec_fp32",
66-
"anythingv3/v2_1base/clip/fp32/length_77/untuned":"av3_clip_19dec_fp32",
67-
"analogdiffusion/v2_1base/unet/fp16/length_77/untuned":"ad_unet_19dec_fp16",
68-
"analogdiffusion/v2_1base/unet/fp16/length_77/tuned":"ad_unet_19dec_fp16_tuned",
69-
"analogdiffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"ad_unet_19dec_fp16_cuda_tuned",
70-
"analogdiffusion/v2_1base/unet/fp32/length_77/untuned":"ad_unet_19dec_fp32",
71-
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned":"ad_vae_19dec_fp16",
72-
"analogdiffusion/v2_1base/vae/fp16/length_77/tuned":"ad_vae_19dec_fp16_tuned",
73-
"analogdiffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"ad_vae_19dec_fp16_cuda_tuned",
74-
"analogdiffusion/v2_1base/vae/fp16/length_77/untuned/base":"ad_vaebase_22dec_fp16",
75-
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned":"ad_vae_19dec_fp32",
76-
"analogdiffusion/v2_1base/vae/fp32/length_77/untuned/base":"ad_vaebase_22dec_fp32",
77-
"analogdiffusion/v2_1base/clip/fp32/length_77/untuned":"ad_clip_19dec_fp32",
78-
"openjourney/v2_1base/unet/fp16/length_64/untuned":"oj_unet_22dec_fp16_64",
79-
"openjourney/v2_1base/unet/fp32/length_64/untuned":"oj_unet_22dec_fp32_64",
80-
"openjourney/v2_1base/vae/fp16/length_77/untuned":"oj_vae_22dec_fp16",
81-
"openjourney/v2_1base/vae/fp16/length_77/untuned/base":"oj_vaebase_22dec_fp16",
82-
"openjourney/v2_1base/vae/fp32/length_77/untuned":"oj_vae_22dec_fp32",
83-
"openjourney/v2_1base/vae/fp32/length_77/untuned/base":"oj_vaebase_22dec_fp32",
84-
"openjourney/v2_1base/clip/fp32/length_64/untuned":"oj_clip_22dec_fp32_64",
85-
"dreamlike/v2_1base/unet/fp16/length_77/untuned":"dl_unet_23dec_fp16_77",
86-
"dreamlike/v2_1base/unet/fp32/length_77/untuned":"dl_unet_23dec_fp32_77",
87-
"dreamlike/v2_1base/vae/fp16/length_77/untuned":"dl_vae_23dec_fp16",
88-
"dreamlike/v2_1base/vae/fp16/length_77/untuned/base":"dl_vaebase_23dec_fp16",
89-
"dreamlike/v2_1base/vae/fp32/length_77/untuned":"dl_vae_23dec_fp32",
90-
"dreamlike/v2_1base/vae/fp32/length_77/untuned/base":"dl_vaebase_23dec_fp32",
91-
"dreamlike/v2_1base/clip/fp32/length_77/untuned":"dl_clip_23dec_fp32_77"
56+
"anythingv3/v1_4/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16",
57+
"anythingv3/v1_4/unet/fp16/length_77/tuned":"av3_unet_19dec_fp16_tuned",
58+
"anythingv3/v1_4/unet/fp16/length_77/tuned/cuda":"av3_unet_19dec_fp16_cuda_tuned",
59+
"anythingv3/v1_4/unet/fp32/length_77/untuned":"av3_unet_19dec_fp32",
60+
"anythingv3/v1_4/vae/fp16/length_77/untuned":"av3_vae_19dec_fp16",
61+
"anythingv3/v1_4/vae/fp16/length_77/tuned":"av3_vae_19dec_fp16_tuned",
62+
"anythingv3/v1_4/vae/fp16/length_77/tuned/cuda":"av3_vae_19dec_fp16_cuda_tuned",
63+
"anythingv3/v1_4/vae/fp16/length_77/untuned/base":"av3_vaebase_22dec_fp16",
64+
"anythingv3/v1_4/vae/fp32/length_77/untuned":"av3_vae_19dec_fp32",
65+
"anythingv3/v1_4/vae/fp32/length_77/untuned/base":"av3_vaebase_22dec_fp32",
66+
"anythingv3/v1_4/clip/fp32/length_77/untuned":"av3_clip_19dec_fp32",
67+
"analogdiffusion/v1_4/unet/fp16/length_77/untuned":"ad_unet_19dec_fp16",
68+
"analogdiffusion/v1_4/unet/fp16/length_77/tuned":"ad_unet_19dec_fp16_tuned",
69+
"analogdiffusion/v1_4/unet/fp16/length_77/tuned/cuda":"ad_unet_19dec_fp16_cuda_tuned",
70+
"analogdiffusion/v1_4/unet/fp32/length_77/untuned":"ad_unet_19dec_fp32",
71+
"analogdiffusion/v1_4/vae/fp16/length_77/untuned":"ad_vae_19dec_fp16",
72+
"analogdiffusion/v1_4/vae/fp16/length_77/tuned":"ad_vae_19dec_fp16_tuned",
73+
"analogdiffusion/v1_4/vae/fp16/length_77/tuned/cuda":"ad_vae_19dec_fp16_cuda_tuned",
74+
"analogdiffusion/v1_4/vae/fp16/length_77/untuned/base":"ad_vaebase_22dec_fp16",
75+
"analogdiffusion/v1_4/vae/fp32/length_77/untuned":"ad_vae_19dec_fp32",
76+
"analogdiffusion/v1_4/vae/fp32/length_77/untuned/base":"ad_vaebase_22dec_fp32",
77+
"analogdiffusion/v1_4/clip/fp32/length_77/untuned":"ad_clip_19dec_fp32",
78+
"openjourney/v1_4/unet/fp16/length_64/untuned":"oj_unet_22dec_fp16_64",
79+
"openjourney/v1_4/unet/fp32/length_64/untuned":"oj_unet_22dec_fp32_64",
80+
"openjourney/v1_4/vae/fp16/length_77/untuned":"oj_vae_22dec_fp16",
81+
"openjourney/v1_4/vae/fp16/length_77/untuned/base":"oj_vaebase_22dec_fp16",
82+
"openjourney/v1_4/vae/fp32/length_77/untuned":"oj_vae_22dec_fp32",
83+
"openjourney/v1_4/vae/fp32/length_77/untuned/base":"oj_vaebase_22dec_fp32",
84+
"openjourney/v1_4/clip/fp32/length_64/untuned":"oj_clip_22dec_fp32_64",
85+
"dreamlike/v1_4/unet/fp16/length_77/untuned":"dl_unet_23dec_fp16_77",
86+
"dreamlike/v1_4/unet/fp32/length_77/untuned":"dl_unet_23dec_fp32_77",
87+
"dreamlike/v1_4/vae/fp16/length_77/untuned":"dl_vae_23dec_fp16",
88+
"dreamlike/v1_4/vae/fp16/length_77/untuned/base":"dl_vaebase_23dec_fp16",
89+
"dreamlike/v1_4/vae/fp32/length_77/untuned":"dl_vae_23dec_fp32",
90+
"dreamlike/v1_4/vae/fp32/length_77/untuned/base":"dl_vaebase_23dec_fp32",
91+
"dreamlike/v1_4/clip/fp32/length_77/untuned":"dl_clip_23dec_fp32_77"
9292
}
9393
]

apps/stable_diffusion/src/utils/sd_annotation.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,7 @@ def load_lower_configs():
8181

8282
variant, version = get_variant_version(args.hf_model_id)
8383

84-
config_bucket = "gs://shark_tank/sd_tuned/configs/"
85-
config_version = version
86-
config_max_length = args.max_length
87-
if variant in ["anythingv3", "analogdiffusion"]:
88-
config_max_length = 77
89-
config_version = "v1_4"
90-
if args.annotation_model == "vae":
91-
config_max_length = 77
84+
config_bucket = "gs://shark_tank/sd_tuned_configs/"
9285

9386
device, device_spec_args = get_device_args()
9487
spec = ""
@@ -97,10 +90,14 @@ def load_lower_configs():
9790
if device == "vulkan":
9891
spec = spec.split("-")[0]
9992

100-
if not spec or spec in ["rdna3", "sm_80"]:
101-
config_name = f"{args.annotation_model}_{config_version}_{args.precision}_len{config_max_length}_{device}.json"
93+
if args.annotation_model == "vae":
94+
config_name = f"{args.annotation_model}_{args.precision}_{device}.json"
10295
else:
103-
config_name = f"{args.annotation_model}_{config_version}_{args.precision}_len{config_max_length}_{device}_{spec}.json"
96+
if not spec or spec in ["rdna3", "sm_80"]:
97+
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"
98+
else:
99+
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}_{spec}.json"
100+
104101
full_gs_url = config_bucket + config_name
105102
lowering_config_dir = f"{WORKDIR}configs/" + config_name
106103
print("Loading lowering config file from ", lowering_config_dir)
@@ -193,7 +190,7 @@ def annotate_with_lower_configs(
193190
return bytecode
194191

195192

196-
def sd_model_annotation(mlir_model, model_name, model_from_tank=False):
193+
def sd_model_annotation(mlir_model, model_name):
197194
device = get_device()
198195
if args.annotation_model == "unet" and device == "vulkan":
199196
use_winograd = True
@@ -222,4 +219,4 @@ def sd_model_annotation(mlir_model, model_name, model_from_tank=False):
222219

223220
if __name__ == "__main__":
224221
mlir_model, model_name = load_model_from_tank()
225-
sd_model_annotation(mlir_model, model_name, model_from_tank=True)
222+
sd_model_annotation(mlir_model, model_name)

apps/stable_diffusion/src/utils/utils.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,7 @@ def set_init_device_flags():
236236

237237
# Use tuned models in the case of fp16, vulkan rdna3 or cuda sm devices.
238238
if (
239-
args.hf_model_id == "prompthero/openjourney"
240-
or args.ckpt_loc != ""
239+
args.ckpt_loc != ""
241240
or args.precision != "fp16"
242241
or args.height != 512
243242
or args.width != 512
@@ -246,22 +245,14 @@ def set_init_device_flags():
246245
):
247246
args.use_tuned = False
248247

249-
elif (
250-
"vulkan" in args.device
251-
and "rdna3" not in args.iree_vulkan_target_triple
248+
elif "vulkan" in args.device and not any(
249+
x in args.iree_vulkan_target_triple for x in ["rdna2", "rdna3"]
252250
):
253251
args.use_tuned = False
254252

255253
elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80", "sm_89"]:
256254
args.use_tuned = False
257255

258-
elif (
259-
"cuda" in args.device
260-
and get_cuda_sm_cc() == "sm_89"
261-
and args.hf_model_id != "stabilityai/stable-diffusion-2-1-base"
262-
):
263-
args.use_tuned = False
264-
265256
elif args.use_base_vae and args.hf_model_id not in [
266257
"stabilityai/stable-diffusion-2-1-base",
267258
"CompVis/stable-diffusion-v1-4",

0 commit comments

Comments
 (0)