Skip to content

Commit 70e93bc

Browse files
author
Diptorup Deb
committed
Convert non-None dtypes to Numba dtype in side dpnp.empty.
1 parent 0c07bf7 commit 70e93bc

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

numba_dpex/dpnp_iface/arrayobj.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from numba import errors, types
99
from numba.core import cgutils
1010
from numba.core.typing import signature
11+
from numba.core.typing.npydecl import parse_dtype as ty_parse_dtype
1112
from numba.core.typing.npydecl import parse_shape
1213
from numba.extending import intrinsic, overload, overload_classmethod
1314
from numba.np.arrayobj import (
@@ -16,6 +17,7 @@
1617
make_array,
1718
populate_array,
1819
)
20+
from numba.np.numpy_support import is_nonelike
1921

2022
from numba_dpex.core.runtime import context as dpexrt
2123
from numba_dpex.core.types import DpnpNdArray
@@ -263,6 +265,15 @@ def ol_dpnp_empty(
263265
if not ndim:
264266
raise errors.TypingError("Could not infer the rank of the ndarray")
265267

268+
# If a dtype value was passed in, then try to convert it to the
269+
# coresponding Numba type. If None was passed, the default, then pass None
270+
# to the DpnpNdArray constructor. The default dtype will be derived based
271+
# on the behavior defined in dpctl.tensor.usm_ndarray.
272+
if not is_nonelike(dtype):
273+
nb_dtype = ty_parse_dtype(dtype)
274+
else:
275+
nb_dtype = None
276+
266277
if usm_type is not None:
267278
usm_type = _parse_usm_type(usm_type)
268279
else:
@@ -275,7 +286,7 @@ def ol_dpnp_empty(
275286

276287
if ndim is not None:
277288
retty = DpnpNdArray(
278-
dtype=dtype,
289+
dtype=nb_dtype,
279290
ndim=ndim,
280291
usm_type=usm_type,
281292
device=device,

0 commit comments

Comments
 (0)