@@ -37,7 +37,64 @@ def run_cmd(cmd):
3737 sys .exit ("Exiting program due to error running:" , cmd )
3838
3939
40- IREE_DEVICE_MAP = {
40+ def iree_device_map (device ):
41+
42+ from iree .runtime import get_driver , get_device
43+
44+ def get_all_devices (driver_name ):
45+ driver = get_driver (driver_name )
46+ device_list_src = driver .query_available_devices ()
47+ device_list = []
48+ for device_dict in device_list_src :
49+ device_list .append (f"{ driver_name } ://{ device_dict ['path' ]} " )
50+ device_list .sort ()
51+ return device_list
52+
53+ # only supported for vulkan as of now
54+ if "vulkan://" in device :
55+ device_list = get_all_devices ("vulkan" )
56+ _ , d_index = device .split ("://" )
57+ matched_index = None
58+ match_with_index = False
59+ if 0 <= len (d_index ) <= 2 :
60+ try :
61+ d_index = int (d_index )
62+ except :
63+ print (
64+ f"{ d_index } is not valid index or uri. Will choose device 0"
65+ )
66+ d_index = 0
67+ match_with_index = True
68+
69+ if len (device_list ) > 1 :
70+ print ("List of available vulkan devices:" )
71+ for i , d in enumerate (device_list ):
72+ print (f"vulkan://{ i } => { d } " )
73+ if (match_with_index and d_index == i ) or (
74+ not match_with_index and d == device
75+ ):
76+ matched_index = i
77+ print (
78+ f"Choosing device vulkan://{ matched_index } \n To choose another device please specify device index or uri accordingly."
79+ )
80+ return get_device (device_list [matched_index ])
81+ elif len (device_list ) == 1 :
82+ print (f"Found one vulkan device: { device_list [0 ]} . Using this." )
83+ return get_device (device_list [0 ])
84+ else :
85+ print (
86+ f"No device found! returning device corresponding to driver name: vulkan"
87+ )
88+ return _IREE_DEVICE_MAP ["vulkan" ]
89+ else :
90+ return _IREE_DEVICE_MAP [device ]
91+
92+
93+ def get_supported_device_list ():
94+ return list (_IREE_DEVICE_MAP .keys ())
95+
96+
97+ _IREE_DEVICE_MAP = {
4198 "cpu" : "local-task" ,
4299 "cuda" : "cuda" ,
43100 "vulkan" : "vulkan" ,
@@ -46,7 +103,14 @@ def run_cmd(cmd):
46103 "intel-gpu" : "level_zero" ,
47104}
48105
49- IREE_TARGET_MAP = {
106+
107+ def iree_target_map (device ):
108+ if "://" in device :
109+ device = device .split ("://" )[0 ]
110+ return _IREE_TARGET_MAP [device ]
111+
112+
113+ _IREE_TARGET_MAP = {
50114 "cpu" : "llvm-cpu" ,
51115 "cuda" : "cuda" ,
52116 "vulkan" : "vulkan" ,
@@ -58,6 +122,9 @@ def run_cmd(cmd):
58122# Finds whether the required drivers are installed for the given device.
59123def check_device_drivers (device ):
60124 """Checks necessary drivers present for gpu and vulkan devices"""
125+ if "://" in device :
126+ device = device .split ("://" )[0 ]
127+
61128 if device == "cuda" :
62129 try :
63130 subprocess .check_output ("nvidia-smi" )
0 commit comments