Skip to content

Commit b99856e

Browse files
committed
Migrate parfor to SPIRVKernelDispatcher
1 parent 1336f76 commit b99856e

File tree

6 files changed

+54
-20
lines changed

6 files changed

+54
-20
lines changed

numba_dpex/core/parfors/kernel_builder.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,14 @@
2626
from numba.parfors import parfor
2727

2828
from numba_dpex.core import config
29+
from numba_dpex.core.decorators import kernel
2930
from numba_dpex.core.types.kernel_api.index_space_ids import ItemType
31+
from numba_dpex.core.utils.call_kernel_builder import SPIRVKernelModule
3032
from numba_dpex.kernel_api_impl.spirv import spirv_generator
33+
from numba_dpex.kernel_api_impl.spirv.dispatcher import (
34+
SPIRVKernelDispatcher,
35+
_SPIRVKernelCompileResult,
36+
)
3137

3238
from ..descriptor import dpex_kernel_target
3339
from ..types import DpnpNdArray
@@ -46,6 +52,7 @@ def __init__(
4652
queue: dpctl.SyclQueue,
4753
local_accessors=None,
4854
work_group_size=None,
55+
kernel_module=None,
4956
):
5057
self.name = name
5158
self.kernel = kernel
@@ -55,6 +62,7 @@ def __init__(
5562
self.queue = queue
5663
self.local_accessors = local_accessors
5764
self.work_group_size = work_group_size
65+
self.kernel_module = kernel_module
5866

5967

6068
def _print_block(block):
@@ -369,6 +377,8 @@ def create_kernel_for_parfor(
369377
)
370378
kernel_ir = kernel_template.kernel_ir
371379

380+
kernel_dispatcher: SPIRVKernelDispatcher = kernel(kernel_template._py_func)
381+
372382
if config.DEBUG_ARRAY_OPT:
373383
print("kernel_ir dump ", type(kernel_ir))
374384
kernel_ir.dump()
@@ -469,6 +479,11 @@ def create_kernel_for_parfor(
469479
debug=flags.debuginfo,
470480
)
471481

482+
kcres: _SPIRVKernelCompileResult = kernel_dispatcher.get_compile_result(
483+
types.void(*kernel_param_types) # kernel signature
484+
)
485+
kernel_module: SPIRVKernelModule = kcres.kernel_device_ir_module
486+
472487
flags.noalias = old_alias
473488

474489
if config.DEBUG_ARRAY_OPT:
@@ -481,6 +496,7 @@ def create_kernel_for_parfor(
481496
kernel_args=parfor_args,
482497
kernel_arg_types=func_arg_types,
483498
queue=exec_queue,
499+
kernel_module=kernel_module,
484500
)
485501

486502

numba_dpex/core/parfors/kernel_templates/range_kernel_template.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(
5151
self._param_dict = param_dict
5252

5353
self._kernel_txt = self._generate_kernel_stub_as_string()
54-
self._kernel_ir = self._generate_kernel_ir()
54+
self._py_func, self._kernel_ir = self._generate_kernel_ir()
5555

5656
def _generate_kernel_stub_as_string(self):
5757
"""Generates a stub dpex kernel for the parfor as a string.
@@ -111,7 +111,7 @@ def _generate_kernel_ir(self):
111111
exec(self._kernel_txt, globls, locls)
112112
kernel_fn = locls[self._kernel_name]
113113

114-
return compiler.run_frontend(kernel_fn)
114+
return kernel_fn, compiler.run_frontend(kernel_fn)
115115

116116
@property
117117
def kernel_ir(self):

numba_dpex/core/parfors/kernel_templates/reduction_template.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(
4848
self._typemap = typemap
4949

5050
self._kernel_txt = self._generate_kernel_stub_as_string()
51-
self._kernel_ir = self._generate_kernel_ir()
51+
self._py_func, self._kernel_ir = self._generate_kernel_ir()
5252

5353
def _generate_kernel_stub_as_string(self):
5454
"""Generate reduction main kernel template"""
@@ -163,7 +163,7 @@ def _generate_kernel_ir(self):
163163
exec(self._kernel_txt, globls, locls)
164164
kernel_fn = locls[self._kernel_name]
165165

166-
return compiler.run_frontend(kernel_fn)
166+
return kernel_fn, compiler.run_frontend(kernel_fn)
167167

168168
@property
169169
def kernel_ir(self):
@@ -234,7 +234,7 @@ def __init__(
234234
self._reductionKernelVar = reductionKernelVar
235235

236236
self._kernel_txt = self._generate_kernel_stub_as_string()
237-
self._kernel_ir = self._generate_kernel_ir()
237+
self._py_func, self._kernel_ir = self._generate_kernel_ir()
238238

239239
def _generate_kernel_stub_as_string(self):
240240
"""Generate reduction remainder kernel template"""
@@ -322,7 +322,7 @@ def _generate_kernel_ir(self):
322322
exec(self._kernel_txt, globls, locls)
323323
kernel_fn = locls[self._kernel_name]
324324

325-
return compiler.run_frontend(kernel_fn)
325+
return kernel_fn, compiler.run_frontend(kernel_fn)
326326

327327
@property
328328
def kernel_ir(self):

numba_dpex/core/parfors/parfor_lowerer.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,6 @@
3131
create_reduction_remainder_kernel_for_parfor,
3232
)
3333

34-
# A global list of kernels to keep the objects alive indefinitely.
35-
keep_alive_kernels = []
36-
3734

3835
def _getvar(lowerer, x):
3936
"""Returns the LLVM Value corresponding to a Numba IR variable.
@@ -154,14 +151,12 @@ def _submit_parfor_kernel(
154151
kernel_fn: ParforKernel,
155152
global_range,
156153
local_range,
154+
debug=False,
157155
):
158156
"""
159157
Adds a call to submit a kernel function into the function body of the
160158
current Numba JIT compiled function.
161159
"""
162-
# Ensure that the Python arguments are kept alive for the duration of
163-
# the kernel execution
164-
keep_alive_kernels.append(kernel_fn.kernel)
165160
kl_builder = KernelLaunchIRBuilder(
166161
lowerer.context, lowerer.builder, kernel_dmm
167162
)
@@ -188,19 +183,17 @@ def _submit_parfor_kernel(
188183
else:
189184
kernel_args.append(_getvar(lowerer, arg))
190185

191-
kernel_ref_addr = kernel_fn.kernel.addressof_ref()
192-
kernel_ref = lowerer.builder.inttoptr(
193-
lowerer.context.get_constant(types.uintp, kernel_ref_addr),
194-
cgutils.voidptr_t,
195-
)
196-
197-
kl_builder.set_kernel(kernel_ref)
198186
kl_builder.set_queue(queue_ref)
199187
kl_builder.set_range(global_range, local_range)
200188
kl_builder.set_arguments(
201189
kernel_fn.kernel_arg_types, kernel_args=kernel_args
202190
)
203191
kl_builder.set_dependent_events([])
192+
kl_builder.set_kernel_from_spirv(
193+
kernel_fn.kernel_module,
194+
debug=debug,
195+
)
196+
204197
event_ref = kl_builder.submit()
205198

206199
sycl.dpctl_event_wait(lowerer.builder, event_ref)
@@ -278,6 +271,7 @@ def _reduction_codegen(
278271
parfor_kernel,
279272
global_range,
280273
local_range,
274+
debug=flags.debuginfo,
281275
)
282276

283277
parfor_kernel = create_reduction_remainder_kernel_for_parfor(
@@ -297,6 +291,7 @@ def _reduction_codegen(
297291
parfor_kernel,
298292
global_range,
299293
local_range,
294+
debug=flags.debuginfo,
300295
)
301296

302297
reductionKernelVar.copy_final_sum_to_host(parfor_kernel)
@@ -418,6 +413,7 @@ def _lower_parfor_as_kernel(self, lowerer, parfor):
418413
parfor_kernel,
419414
global_range,
420415
local_range,
416+
debug=flags.debuginfo,
421417
)
422418

423419
# TODO: free the kernel at this point

numba_dpex/core/parfors/reduction_kernel_builder.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,16 @@
1919
)
2020
from numba.core.typing import signature
2121

22+
from numba_dpex.core.decorators import kernel
2223
from numba_dpex.core.parfors.reduction_helper import ReductionKernelVariables
2324
from numba_dpex.core.types import DpctlSyclQueue
2425
from numba_dpex.core.types.kernel_api.index_space_ids import NdItemType
2526
from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType
27+
from numba_dpex.core.utils.call_kernel_builder import SPIRVKernelModule
28+
from numba_dpex.kernel_api_impl.spirv.dispatcher import (
29+
SPIRVKernelDispatcher,
30+
_SPIRVKernelCompileResult,
31+
)
2632

2733
from .kernel_builder import _print_body # saved for debug
2834
from .kernel_builder import (
@@ -113,6 +119,8 @@ def create_reduction_main_kernel_for_parfor(
113119
)
114120
kernel_ir = kernel_template.kernel_ir
115121

122+
kernel_dispatcher: SPIRVKernelDispatcher = kernel(kernel_template._py_func)
123+
116124
for i, name in enumerate(reductionKernelVar.parfor_params):
117125
try:
118126
tmp = reductionKernelVar.parfor_redvars_to_redarrs[name][0]
@@ -171,6 +179,11 @@ def create_reduction_main_kernel_for_parfor(
171179
].queue
172180
exec_queue = dpctl.get_device_cached_queue(ty_queue.sycl_device)
173181

182+
kcres: _SPIRVKernelCompileResult = kernel_dispatcher.get_compile_result(
183+
types.void(*kernel_param_types) # kernel signature
184+
)
185+
kernel_module: SPIRVKernelModule = kcres.kernel_device_ir_module
186+
174187
sycl_kernel = _compile_kernel_parfor(
175188
exec_queue,
176189
kernel_name,
@@ -195,6 +208,7 @@ def create_reduction_main_kernel_for_parfor(
195208
queue=exec_queue,
196209
local_accessors=set(local_accessors_dict.values()),
197210
work_group_size=reductionKernelVar.work_group_size,
211+
kernel_module=kernel_module,
198212
)
199213

200214

@@ -290,6 +304,8 @@ def create_reduction_remainder_kernel_for_parfor(
290304
)
291305
kernel_ir = kernel_template.kernel_ir
292306

307+
kernel_dispatcher: SPIRVKernelDispatcher = kernel(kernel_template._py_func)
308+
293309
var_table = get_name_var_table(kernel_ir.blocks)
294310
new_var_dict = {}
295311
reserved_names = (
@@ -388,6 +404,11 @@ def create_reduction_remainder_kernel_for_parfor(
388404
debug=flags.debuginfo,
389405
)
390406

407+
kcres: _SPIRVKernelCompileResult = kernel_dispatcher.get_compile_result(
408+
types.void(*kernel_param_types) # kernel signature
409+
)
410+
kernel_module: SPIRVKernelModule = kcres.kernel_device_ir_module
411+
391412
flags.noalias = old_alias
392413

393414
return ParforKernel(
@@ -397,4 +418,5 @@ def create_reduction_remainder_kernel_for_parfor(
397418
kernel_args=reductionKernelVar.parfor_params,
398419
kernel_arg_types=reductionKernelVar.func_arg_types,
399420
queue=exec_queue,
421+
kernel_module=kernel_module,
400422
)

numba_dpex/kernel_api_impl/spirv/dispatcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from numba.core.types import void
2727
from numba.core.typing.typeof import Purpose, typeof
2828

29-
from numba_dpex import config
29+
from numba_dpex.core import config
3030
from numba_dpex.core.descriptor import dpex_kernel_target
3131
from numba_dpex.core.exceptions import (
3232
ExecutionQueueInferenceError,

0 commit comments

Comments
 (0)