Skip to content

Commit d973ba1

Browse files
authored
Add conditions to force use --import_mlir (huggingface#1028)
1 parent 0198b18 commit d973ba1

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

apps/stable_diffusion/src/utils/sd_annotation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,12 @@ def load_lower_configs():
9191
spec = spec.split("-")[0]
9292

9393
if args.annotation_model == "vae":
94-
config_name = f"{args.annotation_model}_{args.precision}_{device}.json"
94+
if not spec or spec in ["rdna3", "sm_80"]:
95+
config_name = (
96+
f"{args.annotation_model}_{args.precision}_{device}.json"
97+
)
98+
else:
99+
config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json"
95100
else:
96101
if not spec or spec in ["rdna3", "sm_80"]:
97102
config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json"

apps/stable_diffusion/src/utils/utils.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,12 @@ def set_init_device_flags():
240240

241241
# Use tuned models in the case of fp16, vulkan rdna3 or cuda sm devices.
242242
if (
243-
args.ckpt_loc != ""
243+
args.hf_model_id
244+
in [
245+
"runwayml/stable-diffusion-inpainting",
246+
"stabilityai/stable-diffusion-2-inpainting",
247+
]
248+
or args.ckpt_loc != ""
244249
or args.precision != "fp16"
245250
or args.height != 512
246251
or args.width != 512
@@ -288,6 +293,27 @@ def set_init_device_flags():
288293
elif args.height != 512 or args.width != 512 or args.batch_size != 1:
289294
args.import_mlir = True
290295

296+
elif args.use_tuned and args.hf_model_id in [
297+
"dreamlike-art/dreamlike-diffusion-1.0",
298+
"prompthero/openjourney",
299+
"stabilityai/stable-diffusion-2-1",
300+
]:
301+
args.import_mlir = True
302+
303+
elif (
304+
args.use_tuned
305+
and "vulkan" in args.device
306+
and "rdna2" in args.iree_vulkan_target_triple
307+
):
308+
args.import_mlir = True
309+
310+
elif (
311+
args.use_tuned
312+
and "cuda" in args.device
313+
and get_cuda_sm_cc() == "sm_89"
314+
):
315+
args.import_mlir = True
316+
291317

292318
# Utility to get list of devices available.
293319
def get_available_devices():

0 commit comments

Comments
 (0)