Skip to content

Commit a06f8f0

Browse files
committed
Caching kernel_bundle after create_program_from_spirv()
1 parent aa4eb5b commit a06f8f0

File tree

3 files changed

+27
-8
lines changed

3 files changed

+27
-8
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ repos:
2323
- id: blacken-docs
2424
additional_dependencies: [black==22.10]
2525
- repo: https://github.com/pycqa/isort
26-
rev: 5.10.1
26+
rev: 5.12.0
2727
hooks:
2828
- id: isort
2929
name: isort (python)

numba_dpex/core/kernel_interface/dispatcher.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ def __init__(
9595
capacity=config.CACHE_SIZE,
9696
pyfunc=self.pyfunc,
9797
)
98+
self._kernel_bundle_cache = LRUCache(
99+
name="KernelBundleCache",
100+
capacity=config.CACHE_SIZE,
101+
pyfunc=self.pyfunc,
102+
)
98103
else:
99104
self._cache = NullCache()
100105
self._cache_hits = 0
@@ -587,6 +592,7 @@ def __call__(self, *args):
587592
# redundant. We should avoid these checks for the specialized case.
588593
exec_queue = self._determine_kernel_launch_queue(args, argtypes)
589594
backend = exec_queue.backend
595+
device = exec_queue.sycl_device
590596

591597
if exec_queue.backend not in [
592598
dpctl.backend_type.opencl,
@@ -626,12 +632,25 @@ def __call__(self, *args):
626632
cache=self._cache,
627633
)
628634

629-
# create a sycl::KernelBundle
630-
kernel_bundle = dpctl_prog.create_program_from_spirv(
631-
exec_queue,
632-
device_driver_ir_module,
633-
" ".join(self._create_sycl_kernel_bundle_flags),
635+
kernel_bundle_key = build_key(
636+
tuple(argtypes),
637+
self.pyfunc,
638+
dpex_kernel_target.target_context.codegen(),
639+
backend=backend,
640+
device_type=device.device_type,
634641
)
642+
643+
kernel_bundle = self._kernel_bundle_cache.get(kernel_bundle_key)
644+
645+
if kernel_bundle is None:
646+
# create a sycl::KernelBundle
647+
kernel_bundle = dpctl_prog.create_program_from_spirv(
648+
exec_queue,
649+
device_driver_ir_module,
650+
" ".join(self._create_sycl_kernel_bundle_flags),
651+
)
652+
self._kernel_bundle_cache.put(kernel_bundle_key, kernel_bundle)
653+
635654
# get the sycl::kernel
636655
sycl_kernel = kernel_bundle.get_sycl_kernel(kernel_module_name)
637656

numba_dpex/tests/test_device_array_args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def data_parallel_sum(a, b, c):
2626

2727

2828
@skip_no_opencl_cpu
29-
class TestArrayArgsGPU:
29+
class TestArrayArgsCPU:
3030
def test_device_array_args_cpu(self):
3131
c = np.ones_like(a)
3232

@@ -37,7 +37,7 @@ def test_device_array_args_cpu(self):
3737

3838

3939
@skip_no_opencl_gpu
40-
class TestArrayArgsCPU:
40+
class TestArrayArgsGPU:
4141
def test_device_array_args_gpu(self):
4242
c = np.ones_like(a)
4343

0 commit comments

Comments
 (0)