|
21 | 21 | rename_labels, |
22 | 22 | replace_var_names, |
23 | 23 | ) |
| 24 | +from numba.core.target_extension import target_override |
24 | 25 | from numba.core.typing import signature |
25 | 26 | from numba.parfors import parfor |
26 | 27 |
|
27 | 28 | from numba_dpex.core import config |
| 29 | +from numba_dpex.core.types.kernel_api.index_space_ids import ItemType |
28 | 30 | from numba_dpex.kernel_api_impl.spirv import spirv_generator |
29 | 31 |
|
30 | 32 | from ..descriptor import dpex_kernel_target |
@@ -66,18 +68,18 @@ def _print_body(body_dict): |
66 | 68 | def _compile_kernel_parfor( |
67 | 69 | sycl_queue, kernel_name, func_ir, argtypes, debug=False |
68 | 70 | ): |
69 | | - |
70 | | - cres = compile_numba_ir_with_dpex( |
71 | | - pyfunc=func_ir, |
72 | | - pyfunc_name=kernel_name, |
73 | | - args=argtypes, |
74 | | - return_type=None, |
75 | | - debug=debug, |
76 | | - is_kernel=True, |
77 | | - typing_context=dpex_kernel_target.typing_context, |
78 | | - target_context=dpex_kernel_target.target_context, |
79 | | - extra_compile_flags=None, |
80 | | - ) |
| 71 | + with target_override(dpex_kernel_target.target_context.target_name): |
| 72 | + cres = compile_numba_ir_with_dpex( |
| 73 | + pyfunc=func_ir, |
| 74 | + pyfunc_name=kernel_name, |
| 75 | + args=argtypes, |
| 76 | + return_type=None, |
| 77 | + debug=debug, |
| 78 | + is_kernel=True, |
| 79 | + typing_context=dpex_kernel_target.typing_context, |
| 80 | + target_context=dpex_kernel_target.target_context, |
| 81 | + extra_compile_flags=None, |
| 82 | + ) |
81 | 83 | cres.library.inline_threshold = config.INLINE_THRESHOLD |
82 | 84 | cres.library._optimize_final_module() |
83 | 85 | func = cres.library.get_function(cres.fndesc.llvm_func_name) |
@@ -420,6 +422,13 @@ def create_kernel_for_parfor( |
420 | 422 | print("kernel_ir after remove dead") |
421 | 423 | kernel_ir.dump() |
422 | 424 |
|
| 425 | + # The first argument to a range kernel is a kernel_api.Item object. The |
| 426 | + # ``Item`` object is used by the kernel_api.spirv backend to generate the |
| 427 | + # correct SPIR-V indexing instructions. Since, the argument is not something |
| 428 | + # available originally in the kernel_param_types, we add it at this point to |
| 429 | + # make sure the kernel signature matches the actual generated code. |
| 430 | + ty_item = ItemType(parfor_dim) |
| 431 | + kernel_param_types = (ty_item, *kernel_param_types) |
423 | 432 | kernel_sig = signature(types.none, *kernel_param_types) |
424 | 433 |
|
425 | 434 | if config.DEBUG_ARRAY_OPT: |
|
0 commit comments