Skip to content

Commit 749a2c2

Browse files
authored
add support for choosing vulkan device (huggingface#439)
1 parent 29a317d commit 749a2c2

File tree

6 files changed

+91
-14
lines changed

6 files changed

+91
-14
lines changed

shark/examples/shark_inference/stable_diffusion/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313

1414
def _compile_module(shark_module, model_name, extra_args=[]):
1515
if args.load_vmfb or args.save_vmfb:
16-
extended_name = "{}_{}".format(model_name, args.device)
16+
device = (
17+
args.device
18+
if "://" not in args.device
19+
else "-".join(args.device.split("://"))
20+
)
21+
extended_name = "{}_{}".format(model_name, device)
1722
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
1823
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
1924
print("Loading flatbuffer from {}".format(vmfb_path))

shark/iree_utils/_common.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,64 @@ def run_cmd(cmd):
3737
sys.exit("Exiting program due to error running:", cmd)
3838

3939

40-
IREE_DEVICE_MAP = {
40+
def iree_device_map(device):
41+
42+
from iree.runtime import get_driver, get_device
43+
44+
def get_all_devices(driver_name):
45+
driver = get_driver(driver_name)
46+
device_list_src = driver.query_available_devices()
47+
device_list = []
48+
for device_dict in device_list_src:
49+
device_list.append(f"{driver_name}://{device_dict['path']}")
50+
device_list.sort()
51+
return device_list
52+
53+
# only supported for vulkan as of now
54+
if "vulkan://" in device:
55+
device_list = get_all_devices("vulkan")
56+
_, d_index = device.split("://")
57+
matched_index = None
58+
match_with_index = False
59+
if 0 <= len(d_index) <= 2:
60+
try:
61+
d_index = int(d_index)
62+
except:
63+
print(
64+
f"{d_index} is not valid index or uri. Will choose device 0"
65+
)
66+
d_index = 0
67+
match_with_index = True
68+
69+
if len(device_list) > 1:
70+
print("List of available vulkan devices:")
71+
for i, d in enumerate(device_list):
72+
print(f"vulkan://{i} => {d}")
73+
if (match_with_index and d_index == i) or (
74+
not match_with_index and d == device
75+
):
76+
matched_index = i
77+
print(
78+
f"Choosing device vulkan://{matched_index}\nTo choose another device please specify device index or uri accordingly."
79+
)
80+
return get_device(device_list[matched_index])
81+
elif len(device_list) == 1:
82+
print(f"Found one vulkan device: {device_list[0]}. Using this.")
83+
return get_device(device_list[0])
84+
else:
85+
print(
86+
f"No device found! returning device corresponding to driver name: vulkan"
87+
)
88+
return _IREE_DEVICE_MAP["vulkan"]
89+
else:
90+
return _IREE_DEVICE_MAP[device]
91+
92+
93+
def get_supported_device_list():
94+
return list(_IREE_DEVICE_MAP.keys())
95+
96+
97+
_IREE_DEVICE_MAP = {
4198
"cpu": "local-task",
4299
"cuda": "cuda",
43100
"vulkan": "vulkan",
@@ -46,7 +103,14 @@ def run_cmd(cmd):
46103
"intel-gpu": "level_zero",
47104
}
48105

49-
IREE_TARGET_MAP = {
106+
107+
def iree_target_map(device):
108+
if "://" in device:
109+
device = device.split("://")[0]
110+
return _IREE_TARGET_MAP[device]
111+
112+
113+
_IREE_TARGET_MAP = {
50114
"cpu": "llvm-cpu",
51115
"cuda": "cuda",
52116
"vulkan": "vulkan",
@@ -58,6 +122,9 @@ def run_cmd(cmd):
58122
# Finds whether the required drivers are installed for the given device.
59123
def check_device_drivers(device):
60124
"""Checks necessary drivers present for gpu and vulkan devices"""
125+
if "://" in device:
126+
device = device.split("://")[0]
127+
61128
if device == "cuda":
62129
try:
63130
subprocess.check_output("nvidia-smi")

shark/iree_utils/benchmark_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import iree.runtime.scripts.iree_benchmark_module as benchmark_module
16-
from shark.iree_utils._common import run_cmd, IREE_DEVICE_MAP
16+
from shark.iree_utils._common import run_cmd, iree_device_map
1717
import numpy as np
1818
import os
1919
import re
@@ -69,7 +69,7 @@ def build_benchmark_args(
6969
# TODO: Replace name of train with actual train fn name.
7070
fn_name = "train"
7171
benchmark_cl.append(f"--entry_function={fn_name}")
72-
benchmark_cl.append(f"--device={IREE_DEVICE_MAP[device]}")
72+
benchmark_cl.append(f"--device={iree_device_map(device)}")
7373
mlir_input_types = tensor_to_type_str(input_tensors, mlir_dialect)
7474
for mlir_input in mlir_input_types:
7575
benchmark_cl.append(f"--function_input={mlir_input}")
@@ -96,7 +96,7 @@ def build_benchmark_args_non_tensor_input(
9696
# TODO: The function named can be passed as one of the args.
9797
if function_name:
9898
benchmark_cl.append(f"--entry_function={function_name}")
99-
benchmark_cl.append(f"--device={IREE_DEVICE_MAP[device]}")
99+
benchmark_cl.append(f"--device={iree_device_map(device)}")
100100
for input in inputs:
101101
benchmark_cl.append(f"--function_input={input}")
102102
time_extractor = "| awk 'END{{print $2 $3}}'"

shark/iree_utils/compile_utils.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import iree.runtime as ireert
1515
import iree.compiler as ireec
16-
from shark.iree_utils._common import IREE_DEVICE_MAP, IREE_TARGET_MAP
16+
from shark.iree_utils._common import iree_device_map, iree_target_map
1717
from shark.iree_utils.benchmark_utils import *
1818
import numpy as np
1919
import os
@@ -224,15 +224,15 @@ def compile_module_to_flatbuffer(
224224
# Currently for MHLO/TOSA.
225225
flatbuffer_blob = ireec.compile_str(
226226
module,
227-
target_backends=[IREE_TARGET_MAP[device]],
227+
target_backends=[iree_target_map(device)],
228228
extra_args=args,
229229
input_type=input_type,
230230
)
231231
else:
232232
# Currently for Torch.
233233
flatbuffer_blob = ireec.compile_str(
234234
module,
235-
target_backends=[IREE_TARGET_MAP[device]],
235+
target_backends=[iree_target_map(device)],
236236
extra_args=args,
237237
)
238238

@@ -241,7 +241,12 @@ def compile_module_to_flatbuffer(
241241

242242
def get_iree_module(flatbuffer_blob, device, func_name):
243243
# Returns the compiled module and the configs.
244-
config = ireert.Config(IREE_DEVICE_MAP[device])
244+
device = iree_device_map(device)
245+
if type(device) == ireert.HalDevice:
246+
config = ireert.Config(device=device)
247+
else:
248+
driver_name = device.split("://")[0] if "://" in device else device
249+
config = ireert.Config(driver_name=driver_name)
245250
vm_module = ireert.VmModule.from_flatbuffer(
246251
config.vm_instance, flatbuffer_blob
247252
)
@@ -291,7 +296,8 @@ def export_iree_module_to_vmfb(
291296
module, device, mlir_dialect, func_name, model_config_path, extra_args
292297
)
293298
if module_name is None:
294-
module_name = f"{mlir_dialect}_{func_name}_{device}"
299+
device_name = device.split("://")[0]
300+
module_name = f"{mlir_dialect}_{func_name}_{device_name}"
295301
filename = os.path.join(directory, module_name + ".vmfb")
296302
print(f"Saved vmfb in {filename}.")
297303
with open(filename, "wb") as f:

shark/shark_importer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ def _tf_mlir(self, func_name, save_dir="./shark_tmp/"):
8787

8888
def _tflite_mlir(self, func_name, save_dir="./shark_tmp/"):
8989
from iree.compiler import tflite as tflitec
90-
from shark.iree_utils._common import IREE_TARGET_MAP
9190

9291
self.mlir_model = tflitec.compile_file(
9392
self.raw_model_file, # in tflite, it is a path to .tflite file, not a tflite interpreter

tank/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from shark.iree_utils._common import (
22
check_device_drivers,
33
device_driver_info,
4-
IREE_DEVICE_MAP,
4+
get_supported_device_list,
55
)
66
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
77
from parameterized import parameterized
@@ -59,7 +59,7 @@ def get_valid_test_params():
5959
"""
6060
device_list = [
6161
device
62-
for device in IREE_DEVICE_MAP.keys()
62+
for device in get_supported_device_list()
6363
if not check_device_drivers(device)
6464
]
6565
dynamic_list = (True, False)

0 commit comments

Comments
 (0)