Skip to content

Commit 48595d4

Browse files
committed
Add fill_zeros to private array
1 parent fa4b04d commit 48595d4

File tree

3 files changed

+54
-7
lines changed

3 files changed

+54
-7
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_private_array_overloads.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import llvmlite.ir as llvmir
1111
from llvmlite.ir.builder import IRBuilder
12+
from numba.core import cgutils
1213
from numba.core.typing.npydecl import parse_dtype as _ty_parse_dtype
1314
from numba.core.typing.npydecl import parse_shape as _ty_parse_shape
1415
from numba.core.typing.templates import Signature
@@ -31,9 +32,13 @@
3132
inline="always",
3233
)
3334
def _intrinsic_private_array_ctor(
34-
ty_context, ty_shape, ty_dtype # pylint: disable=unused-argument
35+
ty_context, # pylint: disable=unused-argument
36+
ty_shape,
37+
ty_dtype,
38+
ty_fill_zeros,
3539
):
3640
require_literal(ty_shape)
41+
require_literal(ty_fill_zeros)
3742

3843
ty_array = USMNdArray(
3944
dtype=_ty_parse_dtype(ty_dtype),
@@ -42,7 +47,7 @@ def _intrinsic_private_array_ctor(
4247
addrspace=AddressSpace.PRIVATE,
4348
)
4449

45-
sig = ty_array(ty_shape, ty_dtype)
50+
sig = ty_array(ty_shape, ty_dtype, ty_fill_zeros)
4651

4752
def codegen(
4853
context: DpexExpKernelTypingContext,
@@ -52,11 +57,18 @@ def codegen(
5257
):
5358
shape = args[0]
5459
ty_shape = sig.args[0]
60+
ty_fill_zeros = sig.args[-1]
5561
ty_array = sig.return_type
5662

5763
ary = make_spirv_generic_array_on_stack(
5864
context, builder, ty_array, ty_shape, shape
5965
)
66+
67+
if ty_fill_zeros.literal_value:
68+
cgutils.memset(
69+
builder, ary.data, builder.mul(ary.itemsize, ary.nitems), 0
70+
)
71+
6072
return ary._getvalue() # pylint: disable=protected-access
6173

6274
return (
@@ -74,6 +86,7 @@ def codegen(
7486
def ol_private_array_ctor(
7587
shape,
7688
dtype,
89+
fill_zeros=False,
7790
):
7891
"""Overload of the constructor for the class
7992
class:`numba_dpex.kernel_api.PrivateArray`.
@@ -88,8 +101,9 @@ def ol_private_array_ctor(
88101
def ol_private_array_ctor_impl(
89102
shape,
90103
dtype,
104+
fill_zeros=False,
91105
):
92106
# pylint: disable=no-value-for-parameter
93-
return _intrinsic_private_array_ctor(shape, dtype)
107+
return _intrinsic_private_array_ctor(shape, dtype, fill_zeros)
94108

95109
return ol_private_array_ctor_impl

numba_dpex/kernel_api/private_array.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
kernel function.
88
"""
99

10-
from numpy import ndarray
10+
import numpy as np
1111

1212

1313
class PrivateArray:
@@ -16,10 +16,13 @@ class PrivateArray:
1616
inside kernel work item.
1717
"""
1818

19-
def __init__(self, shape, dtype) -> None:
19+
def __init__(self, shape, dtype, fill_zeros=False) -> None:
2020
"""Creates a new PrivateArray instance of the given shape and dtype."""
2121

22-
self._data = ndarray(shape=shape, dtype=dtype)
22+
if fill_zeros:
23+
self._data = np.zeros(shape=shape, dtype=dtype)
24+
else:
25+
self._data = np.empty(shape=shape, dtype=dtype)
2326

2427
def __getitem__(self, idx_obj):
2528
"""Returns the value stored at the position represented by idx_obj in

numba_dpex/tests/experimental/test_private_array.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,30 @@ def private_array_kernel(item: Item, a):
2323
a[i] += p[j]
2424

2525

26+
def private_array_kernel_fill_true(item: Item, a):
27+
i = item.get_linear_id()
28+
p = PrivateArray(10, a.dtype, fill_zeros=True)
29+
30+
for j in range(10):
31+
p[j] = j * j
32+
33+
a[i] = 0
34+
for j in range(10):
35+
a[i] += p[j]
36+
37+
38+
def private_array_kernel_fill_false(item: Item, a):
39+
i = item.get_linear_id()
40+
p = PrivateArray(10, a.dtype, fill_zeros=False)
41+
42+
for j in range(10):
43+
p[j] = j * j
44+
45+
a[i] = 0
46+
for j in range(10):
47+
a[i] += p[j]
48+
49+
2650
def private_2d_array_kernel(item: Item, a):
2751
i = item.get_linear_id()
2852
p = PrivateArray(shape=(5, 2), dtype=a.dtype)
@@ -36,7 +60,13 @@ def private_2d_array_kernel(item: Item, a):
3660

3761

3862
@pytest.mark.parametrize(
39-
"kernel", [private_array_kernel, private_2d_array_kernel]
63+
"kernel",
64+
[
65+
private_array_kernel,
66+
private_array_kernel_fill_true,
67+
private_array_kernel_fill_false,
68+
private_2d_array_kernel,
69+
],
4070
)
4171
@pytest.mark.parametrize(
4272
"call_kernel, decorator",

0 commit comments

Comments
 (0)