Skip to content

Commit 200bf8d

Browse files
authored
Merge pull request #1377 from IntelPython/fix/private_array
Fix/private array
2 parents 2831862 + 0131838 commit 200bf8d

File tree

9 files changed

+136
-70
lines changed

9 files changed

+136
-70
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _intrinsic_spirv_global_index_const(
5454
sig = types.int64(types.int32)
5555

5656
def _intrinsic_spirv_global_index_const_gen(
57-
context: SPIRVTargetContext,
57+
context: SPIRVTargetContext, # pylint: disable=unused-argument
5858
builder: llvmir.IRBuilder,
5959
sig, # pylint: disable=unused-argument
6060
args,
@@ -79,7 +79,16 @@ def _intrinsic_spirv_global_index_const_gen(
7979
dim,
8080
)
8181

82-
return context.cast(builder, res, types.uintp, types.intp)
82+
# Generating same check as sycl does. Did they add it to avoid pointer
83+
# bitcast on special constant?
84+
max_int32 = llvmir.Constant(res.type, 2147483648)
85+
cmp = builder.icmp_unsigned("<", res, max_int32)
86+
87+
inst = builder.assume(cmp)
88+
# TODO: tail does not always work
89+
inst.tail = "tail"
90+
91+
return res
8392

8493
return sig, _intrinsic_spirv_global_index_const_gen
8594

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_private_array_overloads.py

Lines changed: 50 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99

1010
import llvmlite.ir as llvmir
1111
from llvmlite.ir.builder import IRBuilder
12+
from numba.core import cgutils, types
1213
from numba.core.typing.npydecl import parse_dtype as _ty_parse_dtype
1314
from numba.core.typing.npydecl import parse_shape as _ty_parse_shape
1415
from numba.core.typing.templates import Signature
15-
from numba.extending import intrinsic, overload
16+
from numba.extending import type_callable
1617

1718
from numba_dpex.core.types import USMNdArray
1819
from numba_dpex.experimental.target import DpexExpKernelTypingContext
@@ -23,55 +24,12 @@
2324
)
2425
from numba_dpex.utils import address_space as AddressSpace
2526

26-
from ..target import DPEX_KERNEL_EXP_TARGET_NAME
27+
from ._registry import lower
2728

2829

29-
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
30-
def _intrinsic_private_array_ctor(
31-
ty_context, ty_shape, ty_dtype # pylint: disable=unused-argument
32-
):
33-
require_literal(ty_shape)
34-
35-
ty_array = USMNdArray(
36-
dtype=_ty_parse_dtype(ty_dtype),
37-
ndim=_ty_parse_shape(ty_shape),
38-
layout="C",
39-
addrspace=AddressSpace.PRIVATE,
40-
)
41-
42-
sig = ty_array(ty_shape, ty_dtype)
43-
44-
def codegen(
45-
context: DpexExpKernelTypingContext,
46-
builder: IRBuilder,
47-
sig: Signature,
48-
args: list[llvmir.Value],
49-
):
50-
shape = args[0]
51-
ty_shape = sig.args[0]
52-
ty_array = sig.return_type
53-
54-
ary = make_spirv_generic_array_on_stack(
55-
context, builder, ty_array, ty_shape, shape
56-
)
57-
return ary._getvalue() # pylint: disable=protected-access
58-
59-
return (
60-
sig,
61-
codegen,
62-
)
63-
64-
65-
@overload(
66-
PrivateArray,
67-
prefer_literal=True,
68-
target=DPEX_KERNEL_EXP_TARGET_NAME,
69-
)
70-
def ol_private_array_ctor(
71-
shape,
72-
dtype,
73-
):
74-
"""Overload of the constructor for the class
30+
@type_callable(PrivateArray)
31+
def type_interval(context): # pylint: disable=unused-argument
32+
"""Sets type of the constructor for the class
7533
class:`numba_dpex.kernel_api.PrivateArray`.
7634
7735
Raises:
@@ -81,11 +39,48 @@ def ol_private_array_ctor(
8139
type.
8240
"""
8341

84-
def ol_private_array_ctor_impl(
85-
shape,
86-
dtype,
87-
):
88-
# pylint: disable=no-value-for-parameter
89-
return _intrinsic_private_array_ctor(shape, dtype)
42+
def typer(shape, dtype, fill_zeros=types.BooleanLiteral(False)):
43+
require_literal(shape)
44+
require_literal(fill_zeros)
45+
46+
return USMNdArray(
47+
dtype=_ty_parse_dtype(dtype),
48+
ndim=_ty_parse_shape(shape),
49+
layout="C",
50+
addrspace=AddressSpace.PRIVATE,
51+
)
52+
53+
return typer
54+
55+
56+
@lower(PrivateArray, types.IntegerLiteral, types.Any, types.BooleanLiteral)
57+
@lower(PrivateArray, types.Tuple, types.Any, types.BooleanLiteral)
58+
@lower(PrivateArray, types.UniTuple, types.Any, types.BooleanLiteral)
59+
@lower(PrivateArray, types.IntegerLiteral, types.Any)
60+
@lower(PrivateArray, types.Tuple, types.Any)
61+
@lower(PrivateArray, types.UniTuple, types.Any)
62+
def dpex_private_array_lower(
63+
context: DpexExpKernelTypingContext,
64+
builder: IRBuilder,
65+
sig: Signature,
66+
args: list[llvmir.Value],
67+
):
68+
"""Implements lower for the class:`numba_dpex.kernel_api.PrivateArray`"""
69+
shape = args[0]
70+
ty_shape = sig.args[0]
71+
if len(sig.args) == 3:
72+
fill_zeros = sig.args[-1].literal_value
73+
else:
74+
fill_zeros = False
75+
ty_array = sig.return_type
76+
77+
ary = make_spirv_generic_array_on_stack(
78+
context, builder, ty_array, ty_shape, shape
79+
)
80+
81+
if fill_zeros:
82+
cgutils.memset(
83+
builder, ary.data, builder.mul(ary.itemsize, ary.nitems), 0
84+
)
9085

91-
return ol_private_array_ctor_impl
86+
return ary._getvalue() # pylint: disable=protected-access
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# SPDX-FileCopyrightText: 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""
6+
Implements the SPIR-V overloads for the kernel_api.PrivateArray class.
7+
"""
8+
9+
from numba.core.imputils import Registry
10+
11+
registry = Registry()
12+
lower = registry.lower

numba_dpex/kernel_api/private_array.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
kernel function.
88
"""
99

10-
from numpy import ndarray
10+
import numpy as np
1111

1212

1313
class PrivateArray:
@@ -16,10 +16,13 @@ class PrivateArray:
1616
inside kernel work item.
1717
"""
1818

19-
def __init__(self, shape, dtype) -> None:
19+
def __init__(self, shape, dtype, fill_zeros=False) -> None:
2020
"""Creates a new PrivateArray instance of the given shape and dtype."""
2121

22-
self._data = ndarray(shape=shape, dtype=dtype)
22+
if fill_zeros:
23+
self._data = np.zeros(shape=shape, dtype=dtype)
24+
else:
25+
self._data = np.empty(shape=shape, dtype=dtype)
2326

2427
def __getitem__(self, idx_obj):
2528
"""Returns the value stored at the position represented by idx_obj in

numba_dpex/kernel_api_impl/spirv/arrayobj.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def require_literal(literal_type: types.Type):
4141

4242
for i, _ in enumerate(literal_type):
4343
if not isinstance(literal_type[i], types.Literal):
44-
raise errors.TypingError("requires literal type")
44+
raise errors.TypingError(
45+
"requires each element of tuple literal type"
46+
)
4547

4648

4749
def make_spirv_array( # pylint: disable=too-many-arguments

numba_dpex/kernel_api_impl/spirv/dispatcher.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""Implements a new numba dispatcher class and a compiler class to compile and
66
call numba_dpex.kernel decorated function.
77
"""
8+
import hashlib
89
from collections import namedtuple
910
from contextlib import ExitStack
1011
from typing import Tuple
@@ -181,6 +182,9 @@ def _compile_to_spirv(
181182
# all linking libraries getting linked together and final optimization
182183
# including inlining of functions if an inlining level is specified.
183184
kernel_library.finalize()
185+
186+
if config.DUMP_KERNEL_LLVM:
187+
self._dump_kernel(kernel_fndesc, kernel_library)
184188
# Compiled the LLVM IR to SPIR-V
185189
kernel_spirv_module = spirv_generator.llvm_to_spirv(
186190
kernel_targetctx,
@@ -268,20 +272,26 @@ def _compile_cached(
268272

269273
kcres_attrs.append(kernel_device_ir_module)
270274

271-
if config.DUMP_KERNEL_LLVM:
272-
with open(
273-
cres.fndesc.llvm_func_name + ".ll",
274-
"w",
275-
encoding="UTF-8",
276-
) as fptr:
277-
fptr.write(str(cres.library.final_module))
278-
279275
except errors.TypingError as err:
280276
self._failed_cache[key] = err
281277
return False, err
282278

283279
return True, _SPIRVKernelCompileResult(*kcres_attrs)
284280

281+
def _dump_kernel(self, fndesc, library):
282+
"""Dump kernel into file."""
283+
name = fndesc.llvm_func_name
284+
if len(name) > 200:
285+
sha256 = hashlib.sha256(name.encode("utf-8")).hexdigest()
286+
name = name[:150] + "_" + sha256
287+
288+
with open(
289+
name + ".ll",
290+
"w",
291+
encoding="UTF-8",
292+
) as fptr:
293+
fptr.write(str(library.final_module))
294+
285295

286296
class SPIRVKernelDispatcher(Dispatcher):
287297
"""Dispatcher class designed to compile kernel decorated functions. The

numba_dpex/kernel_api_impl/spirv/spirv_generator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def finalize(self):
123123
llvm_spirv_args = [
124124
"--spirv-ext=+SPV_EXT_shader_atomic_float_add",
125125
"--spirv-ext=+SPV_EXT_shader_atomic_float_min_max",
126+
"--spirv-ext=+SPV_INTEL_arbitrary_precision_integers",
126127
]
127128
for key in list(self.context.extra_compile_options.keys()):
128129
if key == LLVM_SPIRV_ARGS:

numba_dpex/kernel_api_impl/spirv/target.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,12 +383,16 @@ def load_additional_registries(self):
383383
# pylint: disable=import-outside-toplevel
384384
from numba_dpex import printimpl
385385
from numba_dpex.dpnp_iface import dpnpimpl
386+
from numba_dpex.experimental._kernel_dpcpp_spirv_overloads._registry import (
387+
registry as spirv_registry,
388+
)
386389
from numba_dpex.ocl import mathimpl, oclimpl
387390

388391
self.insert_func_defn(oclimpl.registry.functions)
389392
self.insert_func_defn(mathimpl.registry.functions)
390393
self.insert_func_defn(dpnpimpl.registry.functions)
391394
self.install_registry(printimpl.registry)
395+
self.install_registry(spirv_registry)
392396
# Replace dpnp math functions with their OpenCL versions.
393397
self.replace_dpnp_ufunc_with_ocl_intrinsics()
394398

numba_dpex/tests/experimental/test_private_array.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,30 @@ def private_array_kernel(item: Item, a):
2323
a[i] += p[j]
2424

2525

26+
def private_array_kernel_fill_true(item: Item, a):
27+
i = item.get_linear_id()
28+
p = PrivateArray(10, a.dtype, fill_zeros=True)
29+
30+
for j in range(10):
31+
p[j] = j * j
32+
33+
a[i] = 0
34+
for j in range(10):
35+
a[i] += p[j]
36+
37+
38+
def private_array_kernel_fill_false(item: Item, a):
39+
i = item.get_linear_id()
40+
p = PrivateArray(10, a.dtype, fill_zeros=False)
41+
42+
for j in range(10):
43+
p[j] = j * j
44+
45+
a[i] = 0
46+
for j in range(10):
47+
a[i] += p[j]
48+
49+
2650
def private_2d_array_kernel(item: Item, a):
2751
i = item.get_linear_id()
2852
p = PrivateArray(shape=(5, 2), dtype=a.dtype)
@@ -36,7 +60,13 @@ def private_2d_array_kernel(item: Item, a):
3660

3761

3862
@pytest.mark.parametrize(
39-
"kernel", [private_array_kernel, private_2d_array_kernel]
63+
"kernel",
64+
[
65+
private_array_kernel,
66+
private_array_kernel_fill_true,
67+
private_array_kernel_fill_false,
68+
private_2d_array_kernel,
69+
],
4070
)
4171
@pytest.mark.parametrize(
4272
"call_kernel, decorator",

0 commit comments

Comments
 (0)