From c98a6368dc0a7470030ddf341c9fa5309d34fcd2 Mon Sep 17 00:00:00 2001 From: Yevhenii Havrylko Date: Wed, 21 Feb 2024 19:35:15 -0500 Subject: [PATCH 1/3] Rename (Nd)Item .ndim to .dimensions --- numba_dpex/experimental/typeof.py | 6 +++--- numba_dpex/kernel_api/index_space_ids.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/numba_dpex/experimental/typeof.py b/numba_dpex/experimental/typeof.py index 108e9d3b09..e72c951a0f 100644 --- a/numba_dpex/experimental/typeof.py +++ b/numba_dpex/experimental/typeof.py @@ -68,11 +68,11 @@ def typeof_item(val: Item, c): Returns: A numba_dpex.experimental.core.types.kernel_api.items.ItemType instance. """ - return ItemType(val.ndim) + return ItemType(val.dimensions) @typeof_impl.register(NdItem) -def typeof_nditem(val, c): +def typeof_nditem(val: NdItem, c): """Registers the type inference implementation function for a numba_dpex.kernel_api.NdItem PyObject. @@ -83,4 +83,4 @@ def typeof_nditem(val, c): Returns: A numba_dpex.experimental.core.types.kernel_api.items.NdItemType instance. """ - return NdItemType(val.ndim) + return NdItemType(val.dimensions) diff --git a/numba_dpex/kernel_api/index_space_ids.py b/numba_dpex/kernel_api/index_space_ids.py index 4e10bfc688..f1ccaa4cb4 100644 --- a/numba_dpex/kernel_api/index_space_ids.py +++ b/numba_dpex/kernel_api/index_space_ids.py @@ -147,7 +147,7 @@ def get_range(self, idx): return self._extent[idx] @property - def ndim(self) -> int: + def dimensions(self) -> int: """Returns the rank of a Item object. Returns: @@ -228,10 +228,10 @@ def get_group(self): return self._group @property - def ndim(self) -> int: + def dimensions(self) -> int: """Returns the rank of a NdItem object. Returns: int: Number of dimensions in the NdItem object """ - return self._global_item.ndim + return self._global_item.dimensions From bff5c5213a81364bb0d31cbddec2b6d6f7888293 Mon Sep 17 00:00:00 2001 From: Yevhenii Havrylko Date: Fri, 23 Feb 2024 14:15:52 -0500 Subject: [PATCH 2/3] Add dimensions to group --- numba_dpex/kernel_api/index_space_ids.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/numba_dpex/kernel_api/index_space_ids.py b/numba_dpex/kernel_api/index_space_ids.py index f1ccaa4cb4..faa78ab4cd 100644 --- a/numba_dpex/kernel_api/index_space_ids.py +++ b/numba_dpex/kernel_api/index_space_ids.py @@ -98,6 +98,14 @@ def leader(self): """ return self._leader + @property + def dimensions(self) -> int: + """Returns the rank of a Group object. + Returns: + int: Number of dimensions in the Group object + """ + return self._global_range.ndim + @leader.setter def leader(self, work_item_id): """Sets the leader attribute for the group.""" From 554e52d52f0c5d804ebf33c946e10c6d87334724 Mon Sep 17 00:00:00 2001 From: Yevhenii Havrylko Date: Wed, 21 Feb 2024 19:44:45 -0500 Subject: [PATCH 3/3] Overload generic item's attribute 'dimensions' --- .../_index_space_id_overloads.py | 23 ++++++++++- .../experimental/test_index_space_ids.py | 41 +++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py b/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py index d574f6aced..a99781e53f 100644 --- a/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py +++ b/numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py @@ -9,7 +9,7 @@ import llvmlite.ir as llvmir from numba.core import cgutils, types from numba.core.errors import TypingError -from numba.extending import intrinsic, overload_method +from numba.extending import intrinsic, overload_attribute, overload_method from numba_dpex.core.types.kernel_api.index_space_ids import ( GroupType, @@ -248,3 +248,24 @@ def ol_nd_item_get_group_impl(nd_item): return _intrinsic_get_group(nd_item) return ol_nd_item_get_group_impl + + +@overload_attribute(GroupType, "dimensions", target=DPEX_KERNEL_EXP_TARGET_NAME) +@overload_attribute(ItemType, "dimensions", target=DPEX_KERNEL_EXP_TARGET_NAME) +@overload_attribute( + NdItemType, "dimensions", target=DPEX_KERNEL_EXP_TARGET_NAME +) +def ol_nd_item_dimensions(item): + """ + SPIR-V overload for :meth:`numba_dpex.kernel_api..dimensions`. + + Generates the same LLVM IR instruction as dpcpp for the + `sycl::::dimensions` attribute. + """ + dimensions = item.ndim + + # pylint: disable=unused-argument + def ol_nd_item_get_group_impl(item): + return dimensions + + return ol_nd_item_get_group_impl diff --git a/numba_dpex/tests/experimental/test_index_space_ids.py b/numba_dpex/tests/experimental/test_index_space_ids.py index 2d1edb54f2..887ce6584e 100644 --- a/numba_dpex/tests/experimental/test_index_space_ids.py +++ b/numba_dpex/tests/experimental/test_index_space_ids.py @@ -63,6 +63,24 @@ def set_local_ones_nd_item(nd_item: NdItem, a): a[i] = 1 +@dpex_exp.kernel +def set_dimensions_item(item: Item, a): + i = item.get_id(0) + a[i] = item.dimensions + + +@dpex_exp.kernel +def set_dimensions_nd_item(nd_item: NdItem, a): + i = nd_item.get_global_id(0) + a[i] = nd_item.dimensions + + +@dpex_exp.kernel +def set_dimensions_group(nd_item: NdItem, a): + i = nd_item.get_global_id(0) + a[i] = nd_item.get_group().dimensions + + def _get_group_id_driver(nditem: NdItem, a): i = nditem.get_global_id(0) g = nditem.get_group() @@ -149,6 +167,29 @@ def test_nd_item_get_local_id(): ) +@pytest.mark.parametrize("dims", [1, 2, 3]) +def test_item_dimensions(dims): + a = dpnp.zeros(_SIZE, dtype=dpnp.float32) + rng = [1] * dims + rng[0] = a.size + dpex_exp.call_kernel(set_dimensions_item, dpex.Range(*rng), a) + + assert np.array_equal(a.asnumpy(), dims * np.ones(a.size, dtype=np.float32)) + + +@pytest.mark.parametrize("dims", [1, 2, 3]) +@pytest.mark.parametrize( + "kernel", [set_dimensions_nd_item, set_dimensions_group] +) +def test_nd_item_dimensions(dims, kernel): + a = dpnp.zeros(_SIZE, dtype=dpnp.float32) + rng, grp = [1] * dims, [1] * dims + rng[0], grp[0] = a.size, _GROUP_SIZE + dpex_exp.call_kernel(kernel, dpex.NdRange(rng, grp), a) + + assert np.array_equal(a.asnumpy(), dims * np.ones(a.size, dtype=np.float32)) + + def test_error_item_get_global_id(): a = dpnp.zeros(_SIZE, dtype=dpnp.float32)