diff --git a/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_barriers.py b/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_barriers.py index fe993034f3..f6c93fb692 100644 --- a/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_barriers.py +++ b/numba_dpex/tests/experimental/kernel_api_overloads/spv_overloads/test_barriers.py @@ -29,3 +29,29 @@ def _kernel(nd_item: NdItem, a): dpex_exp.call_kernel(_kernel, dpex.NdRange((N,), (N,)), a) assert a[0] == N * 2 + + +def test_group_barrier_device_func(): + """A test for group_barrier function.""" + + @dpex_exp.device_func + def _increment_value(nd_item: NdItem, a): + i = nd_item.get_global_id(0) + + a[i] += 1 + group_barrier(nd_item.get_group(), MemoryScope.DEVICE) + + if i == 0: + for idx in range(1, a.size): + a[0] += a[idx] + + @dpex_exp.kernel + def _kernel(nd_item: NdItem, a): + _increment_value(nd_item, a) + + N = 16 + a = dpnp.ones(N, dtype=dpnp.int32) + + dpex_exp.call_kernel(_kernel, dpex.NdRange((N,), (N,)), a) + + assert a[0] == N * 2 diff --git a/numba_dpex/tests/experimental/test_private_array.py b/numba_dpex/tests/experimental/test_private_array.py index fa6af6f58b..44370f30da 100644 --- a/numba_dpex/tests/experimental/test_private_array.py +++ b/numba_dpex/tests/experimental/test_private_array.py @@ -82,3 +82,29 @@ def test_private_array(call_kernel, decorator, kernel): want = np.full(a.size, (9) * (9 + 1) * (2 * 9 + 1) / 6, dtype=np.float32) assert np.array_equal(want, a.asnumpy()) + + +@pytest.mark.parametrize( + "func", + [ + private_array_kernel, + private_array_kernel_fill_true, + private_array_kernel_fill_false, + private_2d_array_kernel, + ], +) +def test_private_array_in_device_func(func): + + _df = dpex_exp.device_func(func) + + @dpex_exp.kernel + def _kernel(item: Item, a): + _df(item, a) + + a = dpnp.empty(10, dtype=dpnp.float32) + dpex_exp.call_kernel(_kernel, Range(a.size), a) + + # sum of squares from 1 to n: n*(n+1)*(2*n+1)/6 + want = np.full(a.size, (9) * (9 + 1) * (2 * 9 + 1) / 6, dtype=np.float32) + + assert np.array_equal(want, a.asnumpy())