File tree Expand file tree Collapse file tree 1 file changed +12
-1
lines changed Expand file tree Collapse file tree 1 file changed +12
-1
lines changed Original file line number Diff line number Diff line change 88from numba import errors , types
99from numba .core import cgutils
1010from numba .core .typing import signature
11+ from numba .core .typing .npydecl import parse_dtype as ty_parse_dtype
1112from numba .core .typing .npydecl import parse_shape
1213from numba .extending import intrinsic , overload , overload_classmethod
1314from numba .np .arrayobj import (
1617 make_array ,
1718 populate_array ,
1819)
20+ from numba .np .numpy_support import is_nonelike
1921
2022from numba_dpex .core .runtime import context as dpexrt
2123from numba_dpex .core .types import DpnpNdArray
@@ -263,6 +265,15 @@ def ol_dpnp_empty(
263265 if not ndim :
264266 raise errors .TypingError ("Could not infer the rank of the ndarray" )
265267
268+ # If a dtype value was passed in, then try to convert it to the
269+ # coresponding Numba type. If None was passed, the default, then pass None
270+ # to the DpnpNdArray constructor. The default dtype will be derived based
271+ # on the behavior defined in dpctl.tensor.usm_ndarray.
272+ if not is_nonelike (dtype ):
273+ nb_dtype = ty_parse_dtype (dtype )
274+ else :
275+ nb_dtype = None
276+
266277 if usm_type is not None :
267278 usm_type = _parse_usm_type (usm_type )
268279 else :
@@ -275,7 +286,7 @@ def ol_dpnp_empty(
275286
276287 if ndim is not None :
277288 retty = DpnpNdArray (
278- dtype = dtype ,
289+ dtype = nb_dtype ,
279290 ndim = ndim ,
280291 usm_type = usm_type ,
281292 device = device ,
You can’t perform that action at this time.
0 commit comments