File tree Expand file tree Collapse file tree 2 files changed +33
-2
lines changed
apps/stable_diffusion/src/utils Expand file tree Collapse file tree 2 files changed +33
-2
lines changed Original file line number Diff line number Diff line change @@ -91,7 +91,12 @@ def load_lower_configs():
9191 spec = spec .split ("-" )[0 ]
9292
9393 if args .annotation_model == "vae" :
94- config_name = f"{ args .annotation_model } _{ args .precision } _{ device } .json"
94+ if not spec or spec in ["rdna3" , "sm_80" ]:
95+ config_name = (
96+ f"{ args .annotation_model } _{ args .precision } _{ device } .json"
97+ )
98+ else :
99+ config_name = f"{ args .annotation_model } _{ args .precision } _{ device } _{ spec } .json"
95100 else :
96101 if not spec or spec in ["rdna3" , "sm_80" ]:
97102 config_name = f"{ args .annotation_model } _{ version } _{ args .precision } _{ device } .json"
Original file line number Diff line number Diff line change @@ -240,7 +240,12 @@ def set_init_device_flags():
240240
241241 # Use tuned models in the case of fp16, vulkan rdna3 or cuda sm devices.
242242 if (
243- args .ckpt_loc != ""
243+ args .hf_model_id
244+ in [
245+ "runwayml/stable-diffusion-inpainting" ,
246+ "stabilityai/stable-diffusion-2-inpainting" ,
247+ ]
248+ or args .ckpt_loc != ""
244249 or args .precision != "fp16"
245250 or args .height != 512
246251 or args .width != 512
@@ -288,6 +293,27 @@ def set_init_device_flags():
288293 elif args .height != 512 or args .width != 512 or args .batch_size != 1 :
289294 args .import_mlir = True
290295
296+ elif args .use_tuned and args .hf_model_id in [
297+ "dreamlike-art/dreamlike-diffusion-1.0" ,
298+ "prompthero/openjourney" ,
299+ "stabilityai/stable-diffusion-2-1" ,
300+ ]:
301+ args .import_mlir = True
302+
303+ elif (
304+ args .use_tuned
305+ and "vulkan" in args .device
306+ and "rdna2" in args .iree_vulkan_target_triple
307+ ):
308+ args .import_mlir = True
309+
310+ elif (
311+ args .use_tuned
312+ and "cuda" in args .device
313+ and get_cuda_sm_cc () == "sm_89"
314+ ):
315+ args .import_mlir = True
316+
291317
292318# Utility to get list of devices available.
293319def get_available_devices ():
You can’t perform that action at this time.
0 commit comments