diff --git a/numba_dpex/tests/kernel_tests/test_atomic_op.py b/numba_dpex/tests/kernel_tests/test_atomic_op.py index b34f0ed315..d4aacb9ce1 100644 --- a/numba_dpex/tests/kernel_tests/test_atomic_op.py +++ b/numba_dpex/tests/kernel_tests/test_atomic_op.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import dpctl -import numpy as np +import dpnp as np import pytest import numba_dpex as dpex @@ -38,8 +38,11 @@ def fdtype(request): @pytest.fixture(params=list_of_i_dtypes + list_of_f_dtypes) def input_arrays(request): - a = np.array([0], request.param) - return a, request.param + def _inpute_arrays(filter_str): + a = np.array([0], request.param, device=filter_str) + return a, request.param + + return _inpute_arrays list_of_op = [ @@ -72,11 +75,9 @@ def f(a): @pytest.mark.parametrize("filter_str", filter_strings) @skip_no_atomic_support def test_kernel_atomic_simple(filter_str, input_arrays, kernel_result_pair): - a, dtype = input_arrays + a, dtype = input_arrays(filter_str) kernel, expected = kernel_result_pair - device = dpctl.SyclDevice(filter_str) - with dpctl.device_context(device): - kernel[global_size, dpex.DEFAULT_LOCAL_SIZE](a) + kernel[dpex.Range(global_size)](a) assert a[0] == expected @@ -114,15 +115,11 @@ def f(a): @pytest.mark.parametrize("filter_str", filter_strings) @skip_no_atomic_support def test_kernel_atomic_local(filter_str, input_arrays, return_list_of_op): - a, dtype = input_arrays + a, dtype = input_arrays(filter_str) op_type, expected = return_list_of_op f = get_func_local(op_type, dtype) kernel = dpex.kernel(f) - device = dpctl.SyclDevice(filter_str) - with dpctl.device_context(device): - gs = (N,) - ls = (N,) - kernel[gs, ls](a) + kernel[dpex.Range(N), dpex.Range(N)](a) assert a[0] == expected @@ -161,10 +158,8 @@ def test_kernel_atomic_multi_dim( op_type, expected = return_list_of_op dim = return_list_of_dim kernel = get_kernel_multi_dim(op_type, len(dim)) - a = np.zeros(dim, return_dtype) - device = dpctl.SyclDevice(filter_str) - with dpctl.device_context(device): - kernel[global_size, dpex.DEFAULT_LOCAL_SIZE](a) + a = np.zeros(dim, dtype=return_dtype, device=filter_str) + kernel[dpex.Range(global_size)](a) assert a[0] == expected