Skip to content

Commit 8d0c8ea

Browse files
author
Diptorup Deb
authored
Merge pull request #896 from chudur-budur/github-886
Solving perferomance regression issue by caching the kernel_bundle
2 parents 8a891bb + 3958d13 commit 8d0c8ea

File tree

4 files changed

+93
-28
lines changed

4 files changed

+93
-28
lines changed

numba_dpex/core/caching.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
from numba_dpex.core.types import USMNdArray
1414

1515

16-
def build_key(argtypes, pyfunc, codegen, backend=None, device_type=None):
17-
"""Constructs a key from python function, context, backend and the device
18-
type.
16+
def build_key(
17+
argtypes, pyfunc, codegen, backend=None, device_type=None, exec_queue=None
18+
):
19+
"""Constructs a key from python function, context, backend, the device
20+
type and execution queue.
1921
2022
Compute index key for the given argument types and codegen. It includes a
2123
description of the OS, target architecture and hashes of the bytecode for
@@ -32,6 +34,7 @@ def build_key(argtypes, pyfunc, codegen, backend=None, device_type=None):
3234
Defaults to None.
3335
device_type (enum, optional): A 'device_type' enum.
3436
Defaults to None.
37+
exec_queue (dpctl._sycl_queue.SyclQueue', optional): A SYCL queue object.
3538
3639
Returns:
3740
tuple: A tuple of return type, argtpes, magic_tuple of codegen
@@ -64,6 +67,7 @@ def build_key(argtypes, pyfunc, codegen, backend=None, device_type=None):
6467
codegen.magic_tuple(),
6568
backend,
6669
device_type,
70+
exec_queue,
6771
(
6872
hashlib.sha256(codebytes).hexdigest(),
6973
hashlib.sha256(cvarbytes).hexdigest(),

numba_dpex/core/kernel_interface/dispatcher.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,21 @@ def __init__(
8888
# caching related attributes
8989
if not config.ENABLE_CACHE:
9090
self._cache = NullCache()
91+
self._kernel_bundle_cache = NullCache()
9192
elif enable_cache:
9293
self._cache = LRUCache(
9394
name="SPIRVKernelCache",
9495
capacity=config.CACHE_SIZE,
9596
pyfunc=self.pyfunc,
9697
)
98+
self._kernel_bundle_cache = LRUCache(
99+
name="KernelBundleCache",
100+
capacity=config.CACHE_SIZE,
101+
pyfunc=self.pyfunc,
102+
)
97103
else:
98104
self._cache = NullCache()
105+
self._kernel_bundle_cache = NullCache()
99106
self._cache_hits = 0
100107

101108
if array_access_specifiers:
@@ -627,12 +634,26 @@ def __call__(self, *args):
627634
cache=self._cache,
628635
)
629636

630-
# create a sycl::KernelBundle
631-
kernel_bundle = dpctl_prog.create_program_from_spirv(
632-
exec_queue,
633-
device_driver_ir_module,
634-
" ".join(self._create_sycl_kernel_bundle_flags),
637+
kernel_bundle_key = build_key(
638+
tuple(argtypes),
639+
self.pyfunc,
640+
dpex_kernel_target.target_context.codegen(),
641+
exec_queue=exec_queue,
635642
)
643+
644+
artifact = self._kernel_bundle_cache.get(kernel_bundle_key)
645+
646+
if artifact is None:
647+
# create a sycl::KernelBundle
648+
kernel_bundle = dpctl_prog.create_program_from_spirv(
649+
exec_queue,
650+
device_driver_ir_module,
651+
" ".join(self._create_sycl_kernel_bundle_flags),
652+
)
653+
self._kernel_bundle_cache.put(kernel_bundle_key, kernel_bundle)
654+
else:
655+
kernel_bundle = artifact
656+
636657
# get the sycl::kernel
637658
sycl_kernel = kernel_bundle.get_sycl_kernel(kernel_module_name)
638659

numba_dpex/tests/kernel_tests/test_arg_types.py

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import sys
6+
57
import dpctl
8+
import dpctl.tensor as dpt
69
import numpy as np
710
import pytest
811

912
import numba_dpex as dpex
10-
from numba_dpex.tests._helper import filter_strings
1113

1214
global_size = 1054
1315
local_size = 1
@@ -35,15 +37,39 @@ def input_arrays(request):
3537
return a, b, c[0]
3638

3739

38-
@pytest.mark.parametrize("filter_str", filter_strings)
39-
def test_kernel_arg_types(filter_str, input_arrays):
40-
kernel = dpex.kernel(mul_kernel)
41-
a, actual, c = input_arrays
40+
def test_kernel_arg_types(input_arrays):
41+
usm_type = "device"
42+
43+
a, b, c = input_arrays
4244
expected = a * c
43-
device = dpctl.SyclDevice(filter_str)
44-
with dpctl.device_context(device):
45-
kernel[global_size, local_size](a, actual, c)
46-
np.testing.assert_allclose(actual, expected, rtol=1e-5, atol=0)
45+
46+
queue = dpctl.SyclQueue(dpctl.select_default_device())
47+
48+
da = dpt.usm_ndarray(
49+
a.shape,
50+
dtype=a.dtype,
51+
buffer=usm_type,
52+
buffer_ctor_kwargs={"queue": queue},
53+
)
54+
da.usm_data.copy_from_host(a.reshape((-1)).view("|u1"))
55+
56+
db = dpt.usm_ndarray(
57+
b.shape,
58+
dtype=b.dtype,
59+
buffer=usm_type,
60+
buffer_ctor_kwargs={"queue": queue},
61+
)
62+
db.usm_data.copy_from_host(b.reshape((-1)).view("|u1"))
63+
64+
kernel = dpex.kernel(mul_kernel)
65+
kernel[dpex.NdRange(dpex.Range(global_size), dpex.Range(local_size))](
66+
da, db, c
67+
)
68+
69+
result = np.zeros_like(b)
70+
db.usm_data.copy_to_host(result.reshape((-1)).view("|u1"))
71+
72+
np.testing.assert_allclose(result, expected, rtol=1e-5, atol=0)
4773

4874

4975
def check_bool_kernel(A, test):
@@ -53,14 +79,28 @@ def check_bool_kernel(A, test):
5379
A[0] = 222
5480

5581

56-
@pytest.mark.parametrize("filter_str", filter_strings)
57-
def test_bool_type(filter_str):
58-
kernel = dpex.kernel(check_bool_kernel)
82+
def test_bool_type():
83+
usm_type = "device"
5984
a = np.array([2], np.int64)
6085

61-
device = dpctl.SyclDevice(filter_str)
62-
with dpctl.device_context(device):
63-
kernel[a.size, dpex.DEFAULT_LOCAL_SIZE](a, True)
64-
assert a[0] == 111
65-
kernel[a.size, dpex.DEFAULT_LOCAL_SIZE](a, False)
66-
assert a[0] == 222
86+
queue = dpctl.SyclQueue(dpctl.select_default_device())
87+
88+
da = dpt.usm_ndarray(
89+
a.shape,
90+
dtype=a.dtype,
91+
buffer=usm_type,
92+
buffer_ctor_kwargs={"queue": queue},
93+
)
94+
da.usm_data.copy_from_host(a.reshape((-1)).view("|u1"))
95+
96+
kernel = dpex.kernel(check_bool_kernel)
97+
98+
kernel[dpex.Range(a.size)](da, True)
99+
result = np.zeros_like(a)
100+
da.usm_data.copy_to_host(result.reshape((-1)).view("|u1"))
101+
assert result[0] == 111
102+
103+
kernel[dpex.Range(a.size)](da, False)
104+
result = np.zeros_like(a)
105+
da.usm_data.copy_to_host(result.reshape((-1)).view("|u1"))
106+
assert result[0] == 222

numba_dpex/tests/test_device_array_args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def data_parallel_sum(a, b, c):
2626

2727

2828
@skip_no_opencl_cpu
29-
class TestArrayArgsGPU:
29+
class TestArrayArgsCPU:
3030
def test_device_array_args_cpu(self):
3131
c = np.ones_like(a)
3232

@@ -37,7 +37,7 @@ def test_device_array_args_cpu(self):
3737

3838

3939
@skip_no_opencl_gpu
40-
class TestArrayArgsCPU:
40+
class TestArrayArgsGPU:
4141
def test_device_array_args_gpu(self):
4242
c = np.ones_like(a)
4343

0 commit comments

Comments
 (0)