|
26 | 26 | from numba.parfors import parfor |
27 | 27 |
|
28 | 28 | from numba_dpex.core import config |
| 29 | +from numba_dpex.core.decorators import kernel |
| 30 | +from numba_dpex.core.parfors.parfor_sentinel_replace_pass import ( |
| 31 | + ParforBodyArguments, |
| 32 | +) |
29 | 33 | from numba_dpex.core.types.kernel_api.index_space_ids import ItemType |
| 34 | +from numba_dpex.core.utils.call_kernel_builder import SPIRVKernelModule |
30 | 35 | from numba_dpex.kernel_api_impl.spirv import spirv_generator |
| 36 | +from numba_dpex.kernel_api_impl.spirv.dispatcher import ( |
| 37 | + SPIRVKernelDispatcher, |
| 38 | + _SPIRVKernelCompileResult, |
| 39 | +) |
31 | 40 |
|
32 | 41 | from ..descriptor import dpex_kernel_target |
33 | 42 | from ..types import DpnpNdArray |
|
38 | 47 | class ParforKernel: |
39 | 48 | def __init__( |
40 | 49 | self, |
41 | | - name, |
42 | | - kernel, |
43 | 50 | signature, |
44 | 51 | kernel_args, |
45 | 52 | kernel_arg_types, |
46 | | - queue: dpctl.SyclQueue, |
47 | 53 | local_accessors=None, |
48 | 54 | work_group_size=None, |
| 55 | + kernel_module=None, |
49 | 56 | ): |
50 | | - self.name = name |
51 | | - self.kernel = kernel |
52 | 57 | self.signature = signature |
53 | 58 | self.kernel_args = kernel_args |
54 | 59 | self.kernel_arg_types = kernel_arg_types |
55 | | - self.queue = queue |
56 | 60 | self.local_accessors = local_accessors |
57 | 61 | self.work_group_size = work_group_size |
58 | | - |
59 | | - |
60 | | -def _print_block(block): |
61 | | - for i, inst in enumerate(block.body): |
62 | | - print(" ", i, inst) |
63 | | - |
64 | | - |
65 | | -def _print_body(body_dict): |
66 | | - """Pretty-print a set of IR blocks.""" |
67 | | - for label, block in body_dict.items(): |
68 | | - print("label: ", label) |
69 | | - _print_block(block) |
70 | | - |
71 | | - |
72 | | -def _compile_kernel_parfor( |
73 | | - sycl_queue, kernel_name, func_ir, argtypes, debug=False |
74 | | -): |
75 | | - with target_override(dpex_kernel_target.target_context.target_name): |
76 | | - cres = compile_numba_ir_with_dpex( |
77 | | - pyfunc=func_ir, |
78 | | - pyfunc_name=kernel_name, |
79 | | - args=argtypes, |
80 | | - return_type=None, |
81 | | - debug=debug, |
82 | | - is_kernel=True, |
83 | | - typing_context=dpex_kernel_target.typing_context, |
84 | | - target_context=dpex_kernel_target.target_context, |
85 | | - extra_compile_flags=None, |
86 | | - ) |
87 | | - cres.library.inline_threshold = config.INLINE_THRESHOLD |
88 | | - cres.library._optimize_final_module() |
89 | | - func = cres.library.get_function(cres.fndesc.llvm_func_name) |
90 | | - kernel = dpex_kernel_target.target_context.prepare_spir_kernel( |
91 | | - func, cres.signature.args |
92 | | - ) |
93 | | - spirv_module = spirv_generator.llvm_to_spirv( |
94 | | - dpex_kernel_target.target_context, |
95 | | - kernel.module.__str__(), |
96 | | - kernel.module.as_bitcode(), |
97 | | - ) |
98 | | - |
99 | | - dpctl_create_program_from_spirv_flags = [] |
100 | | - if debug or config.DPEX_OPT == 0: |
101 | | - # if debug is ON we need to pass additional flags to igc. |
102 | | - dpctl_create_program_from_spirv_flags = ["-g", "-cl-opt-disable"] |
103 | | - |
104 | | - # create a sycl::kernel_bundle |
105 | | - kernel_bundle = dpctl_prog.create_program_from_spirv( |
106 | | - sycl_queue, |
107 | | - spirv_module, |
108 | | - " ".join(dpctl_create_program_from_spirv_flags), |
109 | | - ) |
110 | | - # create a sycl::kernel |
111 | | - sycl_kernel = kernel_bundle.get_sycl_kernel(kernel.name) |
112 | | - |
113 | | - return sycl_kernel |
| 62 | + self.kernel_module = kernel_module |
114 | 63 |
|
115 | 64 |
|
116 | 65 | def _legalize_names_with_typemap(names, typemap): |
@@ -189,76 +138,11 @@ def _replace_var_with_array(vars, loop_body, typemap, calltypes): |
189 | 138 | typemap[v] = types.npytypes.Array(el_typ, 1, "C") |
190 | 139 |
|
191 | 140 |
|
192 | | -def _find_setitems_block(setitems, block, typemap): |
193 | | - for inst in block.body: |
194 | | - if isinstance(inst, ir.StaticSetItem) or isinstance(inst, ir.SetItem): |
195 | | - setitems.add(inst.target.name) |
196 | | - elif isinstance(inst, parfor.Parfor): |
197 | | - _find_setitems_block(setitems, inst.init_block, typemap) |
198 | | - _find_setitems_body(setitems, inst.loop_body, typemap) |
199 | | - |
200 | | - |
201 | | -def _find_setitems_body(setitems, loop_body, typemap): |
202 | | - """ |
203 | | - Find the arrays that are written into (goes into setitems) |
204 | | - """ |
205 | | - for label, block in loop_body.items(): |
206 | | - _find_setitems_block(setitems, block, typemap) |
207 | | - |
208 | | - |
209 | | -def _replace_sentinel_with_parfor_body(kernel_ir, sentinel_name, loop_body): |
210 | | - # new label for splitting sentinel block |
211 | | - new_label = max(loop_body.keys()) + 1 |
212 | | - |
213 | | - # Search all the block in the kernel function for the sentinel assignment. |
214 | | - for label, block in kernel_ir.blocks.items(): |
215 | | - for i, inst in enumerate(block.body): |
216 | | - if ( |
217 | | - isinstance(inst, ir.Assign) |
218 | | - and inst.target.name == sentinel_name |
219 | | - ): |
220 | | - # We found the sentinel assignment. |
221 | | - loc = inst.loc |
222 | | - scope = block.scope |
223 | | - # split block across __sentinel__ |
224 | | - # A new block is allocated for the statements prior to the |
225 | | - # sentinel but the new block maintains the current block label. |
226 | | - prev_block = ir.Block(scope, loc) |
227 | | - prev_block.body = block.body[:i] |
228 | | - |
229 | | - # The current block is used for statements after the sentinel. |
230 | | - block.body = block.body[i + 1 :] # noqa: E203 |
231 | | - # But the current block gets a new label. |
232 | | - body_first_label = min(loop_body.keys()) |
233 | | - |
234 | | - # The previous block jumps to the minimum labelled block of the |
235 | | - # parfor body. |
236 | | - prev_block.append(ir.Jump(body_first_label, loc)) |
237 | | - # Add all the parfor loop body blocks to the kernel function's |
238 | | - # IR. |
239 | | - for loop, b in loop_body.items(): |
240 | | - kernel_ir.blocks[loop] = b |
241 | | - body_last_label = max(loop_body.keys()) |
242 | | - kernel_ir.blocks[new_label] = block |
243 | | - kernel_ir.blocks[label] = prev_block |
244 | | - # Add a jump from the last parfor body block to the block |
245 | | - # containing statements after the sentinel. |
246 | | - kernel_ir.blocks[body_last_label].append( |
247 | | - ir.Jump(new_label, loc) |
248 | | - ) |
249 | | - break |
250 | | - else: |
251 | | - continue |
252 | | - break |
253 | | - |
254 | | - |
255 | 141 | def create_kernel_for_parfor( |
256 | 142 | lowerer, |
257 | 143 | parfor_node, |
258 | 144 | typemap, |
259 | | - flags, |
260 | 145 | loop_ranges, |
261 | | - has_aliases, |
262 | 146 | races, |
263 | 147 | parfor_outputs, |
264 | 148 | ) -> ParforKernel: |
@@ -367,120 +251,38 @@ def create_kernel_for_parfor( |
367 | 251 | loop_ranges=loop_ranges, |
368 | 252 | param_dict=param_dict, |
369 | 253 | ) |
370 | | - kernel_ir = kernel_template.kernel_ir |
371 | 254 |
|
372 | | - if config.DEBUG_ARRAY_OPT: |
373 | | - print("kernel_ir dump ", type(kernel_ir)) |
374 | | - kernel_ir.dump() |
375 | | - print("loop_body dump ", type(loop_body)) |
376 | | - _print_body(loop_body) |
377 | | - |
378 | | - # rename all variables in kernel_ir afresh |
379 | | - var_table = get_name_var_table(kernel_ir.blocks) |
380 | | - new_var_dict = {} |
381 | | - reserved_names = ( |
382 | | - [sentinel_name] + list(param_dict.values()) + legal_loop_indices |
| 255 | + kernel_dispatcher: SPIRVKernelDispatcher = kernel( |
| 256 | + kernel_template.py_func, |
| 257 | + _parfor_body_args=ParforBodyArguments( |
| 258 | + loop_body=loop_body, |
| 259 | + param_dict=param_dict, |
| 260 | + legal_loop_indices=legal_loop_indices, |
| 261 | + ), |
383 | 262 | ) |
384 | | - for name, var in var_table.items(): |
385 | | - if not (name in reserved_names): |
386 | | - new_var_dict[name] = mk_unique_var(name) |
387 | | - replace_var_names(kernel_ir.blocks, new_var_dict) |
388 | | - if config.DEBUG_ARRAY_OPT: |
389 | | - print("kernel_ir dump after renaming ") |
390 | | - kernel_ir.dump() |
391 | | - |
392 | | - kernel_param_types = param_types |
393 | 263 |
|
394 | | - if config.DEBUG_ARRAY_OPT: |
395 | | - print( |
396 | | - "kernel_param_types = ", |
397 | | - type(kernel_param_types), |
398 | | - "\n", |
399 | | - kernel_param_types, |
400 | | - ) |
401 | | - |
402 | | - kernel_stub_last_label = max(kernel_ir.blocks.keys()) + 1 |
403 | | - |
404 | | - # Add kernel stub last label to each parfor.loop_body label to prevent |
405 | | - # label conflicts. |
406 | | - loop_body = add_offset_to_labels(loop_body, kernel_stub_last_label) |
407 | | - |
408 | | - _replace_sentinel_with_parfor_body(kernel_ir, sentinel_name, loop_body) |
409 | | - |
410 | | - if config.DEBUG_ARRAY_OPT: |
411 | | - print("kernel_ir last dump before renaming") |
412 | | - kernel_ir.dump() |
413 | | - |
414 | | - kernel_ir.blocks = rename_labels(kernel_ir.blocks) |
415 | | - remove_dels(kernel_ir.blocks) |
416 | | - |
417 | | - old_alias = flags.noalias |
418 | | - if not has_aliases: |
419 | | - if config.DEBUG_ARRAY_OPT: |
420 | | - print("No aliases found so adding noalias flag.") |
421 | | - flags.noalias = True |
422 | | - |
423 | | - remove_dead(kernel_ir.blocks, kernel_ir.arg_names, kernel_ir, typemap) |
424 | | - |
425 | | - if config.DEBUG_ARRAY_OPT: |
426 | | - print("kernel_ir after remove dead") |
427 | | - kernel_ir.dump() |
428 | | - |
429 | | - # The first argument to a range kernel is a kernel_api.Item object. The |
430 | | - # ``Item`` object is used by the kernel_api.spirv backend to generate the |
| 264 | + # The first argument to a range kernel is a kernel_api.NdItem object. The |
| 265 | + # ``NdItem`` object is used by the kernel_api.spirv backend to generate the |
431 | 266 | # correct SPIR-V indexing instructions. Since, the argument is not something |
432 | 267 | # available originally in the kernel_param_types, we add it at this point to |
433 | 268 | # make sure the kernel signature matches the actual generated code. |
434 | 269 | ty_item = ItemType(parfor_dim) |
435 | | - kernel_param_types = (ty_item, *kernel_param_types) |
| 270 | + kernel_param_types = (ty_item, *param_types) |
436 | 271 | kernel_sig = signature(types.none, *kernel_param_types) |
437 | 272 |
|
438 | | - if config.DEBUG_ARRAY_OPT: |
439 | | - sys.stdout.flush() |
440 | | - |
441 | | - if config.DEBUG_ARRAY_OPT: |
442 | | - print("after DUFunc inline".center(80, "-")) |
443 | | - kernel_ir.dump() |
444 | | - |
445 | | - # The ParforLegalizeCFD pass has already ensured that the LHS and RHS |
446 | | - # arrays are on same device. We can take the queue from the first input |
447 | | - # array and use that to compile the kernel. |
448 | | - |
449 | | - exec_queue: dpctl.SyclQueue = None |
450 | | - |
451 | | - for arg in parfor_args: |
452 | | - obj = typemap[arg] |
453 | | - if isinstance(obj, DpnpNdArray): |
454 | | - filter_string = obj.queue.sycl_device |
455 | | - # FIXME: A better design is required so that we do not have to |
456 | | - # create a queue every time. |
457 | | - exec_queue = dpctl.get_device_cached_queue(filter_string) |
458 | | - |
459 | | - if not exec_queue: |
460 | | - raise AssertionError( |
461 | | - "No execution found for parfor. No way to compile the kernel!" |
462 | | - ) |
463 | | - |
464 | | - sycl_kernel = _compile_kernel_parfor( |
465 | | - exec_queue, |
466 | | - kernel_name, |
467 | | - kernel_ir, |
468 | | - kernel_param_types, |
469 | | - debug=flags.debuginfo, |
| 273 | + kcres: _SPIRVKernelCompileResult = kernel_dispatcher.get_compile_result( |
| 274 | + types.void(*kernel_param_types) # kernel signature |
470 | 275 | ) |
471 | | - |
472 | | - flags.noalias = old_alias |
| 276 | + kernel_module: SPIRVKernelModule = kcres.kernel_device_ir_module |
473 | 277 |
|
474 | 278 | if config.DEBUG_ARRAY_OPT: |
475 | 279 | print("kernel_sig = ", kernel_sig) |
476 | 280 |
|
477 | 281 | return ParforKernel( |
478 | | - name=kernel_name, |
479 | | - kernel=sycl_kernel, |
480 | 282 | signature=kernel_sig, |
481 | 283 | kernel_args=parfor_args, |
482 | 284 | kernel_arg_types=func_arg_types, |
483 | | - queue=exec_queue, |
| 285 | + kernel_module=kernel_module, |
484 | 286 | ) |
485 | 287 |
|
486 | 288 |
|
|
0 commit comments