Skip to content

Commit da1775c

Browse files
committed
Overload generic item's attribute 'dimensions'
1 parent aa319c3 commit da1775c

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_index_space_id_overloads.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import llvmlite.ir as llvmir
1010
from numba.core import cgutils, types
1111
from numba.core.errors import TypingError
12-
from numba.extending import intrinsic, overload_method
12+
from numba.extending import intrinsic, overload_attribute, overload_method
1313

1414
from numba_dpex.core.types.kernel_api.index_space_ids import (
1515
GroupType,
@@ -248,3 +248,24 @@ def ol_nd_item_get_group_impl(nd_item):
248248
return _intrinsic_get_group(nd_item)
249249

250250
return ol_nd_item_get_group_impl
251+
252+
253+
@overload_attribute(GroupType, "dimensions", target=DPEX_KERNEL_EXP_TARGET_NAME)
254+
@overload_attribute(ItemType, "dimensions", target=DPEX_KERNEL_EXP_TARGET_NAME)
255+
@overload_attribute(
256+
NdItemType, "dimensions", target=DPEX_KERNEL_EXP_TARGET_NAME
257+
)
258+
def ol_nd_item_dimensions(item):
259+
"""
260+
SPIR-V overload for :meth:`numba_dpex.kernel_api.<generic_item>.dimensions`.
261+
262+
Generates the same LLVM IR instruction as dpcpp for the
263+
`sycl::<generic_item>::dimensions` attribute.
264+
"""
265+
dimensions = item.ndim
266+
267+
# pylint: disable=unused-argument
268+
def ol_nd_item_get_group_impl(item):
269+
return dimensions
270+
271+
return ol_nd_item_get_group_impl

0 commit comments

Comments
 (0)