diff --git a/numba_dpex/dpnp_iface/_intrinsic.py b/numba_dpex/dpnp_iface/_intrinsic.py new file mode 100644 index 0000000000..5cb70ecb95 --- /dev/null +++ b/numba_dpex/dpnp_iface/_intrinsic.py @@ -0,0 +1,566 @@ +# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +from llvmlite import ir +from llvmlite.ir import Constant +from numba import types +from numba.core import cgutils +from numba.core.typing import signature +from numba.extending import intrinsic +from numba.np.arrayobj import ( + _parse_empty_args, + _parse_empty_like_args, + get_itemsize, + make_array, + populate_array, +) + +from numba_dpex.core.runtime import context as dpexrt + +from ..decorators import dpjit + + +@dpjit +# TODO: rename this to _call_allocator and see below +def _call_usm_allocator(arrtype, size, usm_type, device): + """Trampoline to call the intrinsic used for allocation""" + + return arrtype._usm_allocate(size, usm_type, device) + + +def _empty_nd_impl(context, builder, arrtype, shapes): + """Utility function used for allocating a new array during LLVM + code generation (lowering). + + Given a target context, builder, array type, and a tuple or list + of lowered dimension sizes, returns a LLVM value pointing at a + Numba runtime allocated array. + + Args: + context (numba.core.base.BaseContext): One of the class derived + from numba's BaseContext, e.g. CPUContext + builder (llvmlite.ir.builder.IRBuilder): IR builder object from + llvmlite. + arrtype (numba_dpex.core.types.dpnp_ndarray_type.DpnpNdArray): + An array type info to construct the actual array. + shapes (list): The dimension of the array. + + Raises: + NotImplementedError: If the layout of the array is not known. + + Returns: + numba.np.arrayobj.make_array..ArrayStruct: The constructed + array. + """ + + arycls = make_array(arrtype) + ary = arycls(context, builder) + + datatype = context.get_data_type(arrtype.dtype) + itemsize = context.get_constant(types.intp, get_itemsize(context, arrtype)) + + # compute array length + arrlen = context.get_constant(types.intp, 1) + overflow = Constant(ir.IntType(1), 0) + for s in shapes: + arrlen_mult = builder.smul_with_overflow(arrlen, s) + arrlen = builder.extract_value(arrlen_mult, 0) + overflow = builder.or_(overflow, builder.extract_value(arrlen_mult, 1)) + + if arrtype.ndim == 0: + strides = () + elif arrtype.layout == "C": + strides = [itemsize] + for dimension_size in reversed(shapes[1:]): + strides.append(builder.mul(strides[-1], dimension_size)) + strides = tuple(reversed(strides)) + elif arrtype.layout == "F": + strides = [itemsize] + for dimension_size in shapes[:-1]: + strides.append(builder.mul(strides[-1], dimension_size)) + strides = tuple(strides) + else: + raise NotImplementedError( + "Don't know how to allocate array with layout '{0}'.".format( + arrtype.layout + ) + ) + + # Check overflow, numpy also does this after checking order + allocsize_mult = builder.smul_with_overflow(arrlen, itemsize) + allocsize = builder.extract_value(allocsize_mult, 0) + overflow = builder.or_(overflow, builder.extract_value(allocsize_mult, 1)) + + with builder.if_then(overflow, likely=False): + # Raise same error as numpy, see: + # https://github.com/numpy/numpy/blob/2a488fe76a0f732dc418d03b452caace161673da/numpy/core/src/multiarray/ctors.c#L1095-L1101 # noqa: E501 + context.call_conv.return_user_exc( + builder, + ValueError, + ( + "array is too big; `arr.size * arr.dtype.itemsize` is larger " + "than the maximum possible size.", + ), + ) + + usm_ty = arrtype.usm_type + usm_ty_val = 0 + if usm_ty == "device": + usm_ty_val = 1 + elif usm_ty == "shared": + usm_ty_val = 2 + elif usm_ty == "host": + usm_ty_val = 3 + usm_type = context.get_constant(types.uint64, usm_ty_val) + device = context.insert_const_string(builder.module, arrtype.device) + + args = ( + context.get_dummy_value(), + allocsize, + usm_type, + device, + ) + mip = types.MemInfoPointer(types.voidptr) + arytypeclass = types.TypeRef(type(arrtype)) + sig = signature( + mip, + arytypeclass, + types.intp, + types.uint64, + types.voidptr, + ) + + op = _call_usm_allocator + fnop = context.typing_context.resolve_value_type(op) + # The _call_usm_allocator function will be compiled and added to registry + # when the get_call_type function is invoked. + fnop.get_call_type(context.typing_context, sig.args, {}) + eqfn = context.get_function(fnop, sig) + meminfo = eqfn(builder, args) + data = context.nrt.meminfo_data(builder, meminfo) + intp_t = context.get_value_type(types.intp) + shape_array = cgutils.pack_array(builder, shapes, ty=intp_t) + strides_array = cgutils.pack_array(builder, strides, ty=intp_t) + + populate_array( + ary, + data=builder.bitcast(data, datatype.as_pointer()), + shape=shape_array, + strides=strides_array, + itemsize=itemsize, + meminfo=meminfo, + ) + + return ary + + +def alloc_empty_arrayobj(context, builder, sig, llargs, is_like=False): + """Construct an empty numba.np.arrayobj.make_array..ArrayStruct + + Args: + context (numba.core.base.BaseContext): One of the class derived + from numba's BaseContext, e.g. CPUContext + builder (llvmlite.ir.builder.IRBuilder): IR builder object from + llvmlite. + sig (numba.core.typing.templates.Signature): A numba's function + signature object. + llargs (tuple): A tuple of args to be parsed as the arguments of + an np.empty(), np.zeros() or np.ones() call. + is_like (bool, optional): Decides on how to parse the args. + Defaults to False. + + Returns: + tuple(numba.np.arrayobj.make_array..ArrayStruct, + numba_dpex.core.types.dpnp_ndarray_type.DpnpNdArray): + A tuple of allocated array and constructed array type info + in DpnpNdArray. + """ + + arrtype = ( + _parse_empty_like_args(context, builder, sig, llargs) + if is_like + else _parse_empty_args(context, builder, sig, llargs) + ) + ary = _empty_nd_impl(context, builder, *arrtype) + + return ary, arrtype + + +def fill_arrayobj(context, builder, sig, llargs, value, is_like=False): + """Fill a numba.np.arrayobj.make_array..ArrayStruct + with a specified value. + + Args: + context (numba.core.base.BaseContext): One of the class derived + from numba's BaseContext, e.g. CPUContext + builder (llvmlite.ir.builder.IRBuilder): IR builder object from + llvmlite. + sig (numba.core.typing.templates.Signature): A numba's function + signature object. + llargs (tuple): A tuple of args to be parsed as the arguments of + an np.empty(), np.zeros() or np.ones() call. + value (int): The value to be set. + is_like (bool, optional): Decides on how to parse the args. + Defaults to False. + + Returns: + tuple(numba.np.arrayobj.make_array..ArrayStruct, + numba_dpex.core.types.dpnp_ndarray_type.DpnpNdArray): + A tuple of allocated array and constructed array type info + in DpnpNdArray. + """ + + ary, arrtype = alloc_empty_arrayobj(context, builder, sig, llargs, is_like) + itemsize = context.get_constant( + types.intp, get_itemsize(context, arrtype[0]) + ) + device = context.insert_const_string(builder.module, arrtype[0].device) + value = context.get_constant(types.int8, value) + if isinstance(arrtype[0].dtype, types.scalars.Float): + is_float = context.get_constant(types.boolean, 1) + else: + is_float = context.get_constant(types.boolean, 0) + dpexrtCtx = dpexrt.DpexRTContext(context) + dpexrtCtx.meminfo_fill( + builder, ary.meminfo, itemsize, is_float, value, device + ) + return ary, arrtype + + +@intrinsic +def intrin_usm_alloc(typingctx, allocsize, usm_type, device): + """Intrinsic to call into the allocator for Array""" + + def codegen(context, builder, signature, args): + [allocsize, usm_type, device] = args + dpexrtCtx = dpexrt.DpexRTContext(context) + meminfo = dpexrtCtx.meminfo_alloc(builder, allocsize, usm_type, device) + return meminfo + + mip = types.MemInfoPointer(types.voidptr) # return untyped pointer + sig = signature(mip, allocsize, usm_type, device) + return sig, codegen + + +@intrinsic +def impl_dpnp_empty( + ty_context, + ty_shape, + ty_dtype, + ty_order, + ty_device, + ty_usm_type, + ty_sycl_queue, + ty_retty_ref, +): + """A numba "intrinsic" function to inject code for dpnp.empty(). + + Args: + ty_context (numba.core.typing.context.Context): The typing context + for the codegen. + ty_shape (numba.core.types.abstract): One of the numba defined + abstract types. + ty_dtype (numba.core.types.functions.NumberClass): Type class for + number classes (e.g. "np.float64"). + ty_order (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_device (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_usm_type (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_sycl_queue (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_retty_ref (numba.core.types.abstract.TypeRef): Reference to + a type from numba, used when a type is passed as a value. + + Returns: + tuple(numba.core.typing.templates.Signature, function): A tuple of + numba function signature type and a function object. + """ + + ty_retty = ty_retty_ref.instance_type + sig = ty_retty( + ty_shape, + ty_dtype, + ty_order, + ty_device, + ty_usm_type, + ty_sycl_queue, + ty_retty_ref, + ) + + def codegen(context, builder, sig, llargs): + ary, _ = alloc_empty_arrayobj(context, builder, sig, llargs) + return ary._getvalue() + + return sig, codegen + + +@intrinsic +def impl_dpnp_zeros( + ty_context, + ty_shape, + ty_dtype, + ty_order, + ty_device, + ty_usm_type, + ty_sycl_queue, + ty_retty_ref, +): + """A numba "intrinsic" function to inject code for dpnp.zeros(). + + Args: + ty_context (numba.core.typing.context.Context): The typing context + for the codegen. + ty_shape (numba.core.types.abstract): One of the numba defined + abstract types. + ty_dtype (numba.core.types.functions.NumberClass): Type class for + number classes (e.g. "np.float64"). + ty_order (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_device (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_usm_type (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_sycl_queue (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_retty_ref (numba.core.types.abstract.TypeRef): Reference to + a type from numba, used when a type is passed as a value. + + Returns: + tuple(numba.core.typing.templates.Signature, function): A tuple of + numba function signature type and a function object. + """ + + ty_retty = ty_retty_ref.instance_type + sig = ty_retty( + ty_shape, + ty_dtype, + ty_order, + ty_device, + ty_usm_type, + ty_sycl_queue, + ty_retty_ref, + ) + + def codegen(context, builder, sig, llargs): + ary, _ = fill_arrayobj(context, builder, sig, llargs, 0) + return ary._getvalue() + + return sig, codegen + + +@intrinsic +def impl_dpnp_ones( + ty_context, + ty_shape, + ty_dtype, + ty_order, + ty_device, + ty_usm_type, + ty_sycl_queue, + ty_retty_ref, +): + """A numba "intrinsic" function to inject code for dpnp.ones(). + + Args: + ty_context (numba.core.typing.context.Context): The typing context + for the codegen. + ty_shape (numba.core.types.abstract): One of the numba defined + abstract types. + ty_dtype (numba.core.types.functions.NumberClass): Type class for + number classes (e.g. "np.float64"). + ty_order (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_device (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_usm_type (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_sycl_queue (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_retty_ref (numba.core.types.abstract.TypeRef): Reference to + a type from numba, used when a type is passed as a value. + + Returns: + tuple(numba.core.typing.templates.Signature, function): A tuple of + numba function signature type and a function object. + """ + + ty_retty = ty_retty_ref.instance_type + sig = ty_retty( + ty_shape, + ty_dtype, + ty_order, + ty_device, + ty_usm_type, + ty_sycl_queue, + ty_retty_ref, + ) + + def codegen(context, builder, sig, llargs): + ary, _ = fill_arrayobj(context, builder, sig, llargs, 1) + return ary._getvalue() + + return sig, codegen + + +@intrinsic +def impl_dpnp_empty_like( + ty_context, + ty_x, + ty_dtype, + ty_order, + ty_device, + ty_usm_type, + ty_sycl_queue, + ty_retty_ref, +): + """A numba "intrinsic" function to inject code for dpnp.empty_like(). + + Args: + ty_context (numba.core.typing.context.Context): The typing context + for the codegen. + ty_x (numba.core.types.npytypes.Array): Numba type class for ndarray. + ty_dtype (numba.core.types.functions.NumberClass): Type class for + number classes (e.g. "np.float64"). + ty_order (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_device (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_usm_type (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_sycl_queue (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_retty_ref (numba.core.types.abstract.TypeRef): Reference to + a type from numba, used when a type is passed as a value. + + Returns: + tuple(numba.core.typing.templates.Signature, function): A tuple of + numba function signature type and a function object. + """ + + ty_retty = ty_retty_ref.instance_type + sig = ty_retty( + ty_x, + ty_dtype, + ty_order, + ty_device, + ty_usm_type, + ty_sycl_queue, + ty_retty_ref, + ) + + def codegen(context, builder, sig, llargs): + ary, _ = alloc_empty_arrayobj( + context, builder, sig, llargs, is_like=True + ) + return ary._getvalue() + + return sig, codegen + + +@intrinsic +def impl_dpnp_zeros_like( + ty_context, + ty_x, + ty_dtype, + ty_order, + ty_device, + ty_usm_type, + ty_sycl_queue, + ty_retty_ref, +): + """A numba "intrinsic" function to inject code for dpnp.zeros_like(). + + Args: + ty_context (numba.core.typing.context.Context): The typing context + for the codegen. + ty_x (numba.core.types.npytypes.Array): Numba type class for ndarray. + ty_dtype (numba.core.types.functions.NumberClass): Type class for + number classes (e.g. "np.float64"). + ty_order (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_device (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_usm_type (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_sycl_queue (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_retty_ref (numba.core.types.abstract.TypeRef): Reference to + a type from numba, used when a type is passed as a value. + + Returns: + tuple(numba.core.typing.templates.Signature, function): A tuple of + numba function signature type and a function object. + """ + + ty_retty = ty_retty_ref.instance_type + sig = ty_retty( + ty_x, + ty_dtype, + ty_order, + ty_device, + ty_usm_type, + ty_sycl_queue, + ty_retty_ref, + ) + + def codegen(context, builder, sig, llargs): + ary, _ = fill_arrayobj(context, builder, sig, llargs, 0, is_like=True) + return ary._getvalue() + + return sig, codegen + + +@intrinsic +def impl_dpnp_ones_like( + ty_context, + ty_x, + ty_dtype, + ty_order, + ty_device, + ty_usm_type, + ty_sycl_queue, + ty_retty_ref, +): + """A numba "intrinsic" function to inject code for dpnp.ones_like(). + + Args: + ty_context (numba.core.typing.context.Context): The typing context + for the codegen. + ty_x (numba.core.types.npytypes.Array): Numba type class for ndarray. + ty_dtype (numba.core.types.functions.NumberClass): Type class for + number classes (e.g. "np.float64"). + ty_order (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_device (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_usm_type (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_sycl_queue (numba.core.types.misc.UnicodeType): UnicodeType + from numba for strings. + ty_retty_ref (numba.core.types.abstract.TypeRef): Reference to + a type from numba, used when a type is passed as a value. + + Returns: + tuple(numba.core.typing.templates.Signature, function): A tuple of + numba function signature type and a function object. + """ + + ty_retty = ty_retty_ref.instance_type + sig = ty_retty( + ty_x, + ty_dtype, + ty_order, + ty_device, + ty_usm_type, + ty_sycl_queue, + ty_retty_ref, + ) + + def codegen(context, builder, sig, llargs): + ary, _ = fill_arrayobj(context, builder, sig, llargs, 1, is_like=True) + return ary._getvalue() + + return sig, codegen diff --git a/numba_dpex/dpnp_iface/arrayobj.py b/numba_dpex/dpnp_iface/arrayobj.py index 5aa019637b..e906176d6a 100644 --- a/numba_dpex/dpnp_iface/arrayobj.py +++ b/numba_dpex/dpnp_iface/arrayobj.py @@ -3,176 +3,179 @@ # SPDX-License-Identifier: Apache-2.0 import dpnp -from llvmlite import ir -from llvmlite.ir import Constant from numba import errors, types -from numba.core import cgutils -from numba.core.types.scalars import Float -from numba.core.typing import signature -from numba.core.typing.npydecl import parse_dtype as ty_parse_dtype -from numba.core.typing.npydecl import parse_shape -from numba.extending import intrinsic, overload, overload_classmethod -from numba.np.arrayobj import ( - _parse_empty_args, - get_itemsize, - make_array, - populate_array, -) +from numba.core.typing.npydecl import parse_dtype as _ty_parse_dtype +from numba.core.typing.npydecl import parse_shape as _ty_parse_shape +from numba.extending import overload, overload_classmethod from numba.np.numpy_support import is_nonelike -from numba_dpex.core.runtime import context as dpexrt from numba_dpex.core.types import DpnpNdArray -from ..decorators import dpjit +from ._intrinsic import ( + impl_dpnp_empty, + impl_dpnp_empty_like, + impl_dpnp_ones, + impl_dpnp_ones_like, + impl_dpnp_zeros, + impl_dpnp_zeros_like, + intrin_usm_alloc, +) -# ------------------------------------------------------------------------------ -# Helps to parse dpnp constructor arguments +# ========================================================================= +# Helps to parse dpnp constructor arguments +# ========================================================================= -def _parse_usm_type(usm_type): +def _parse_dtype(dtype, data=None): + """Resolve dtype parameter. + + Resolves the dtype parameter based on the given value + or the dtype of the given array. + + Args: + dtype (numba.core.types.functions.NumberClass): Numba type + class for number classes (e.g. "np.float64"). + data (numba.core.types.npytypes.Array, optional): Numba type + class for nd-arrays. Defaults to None. + + Returns: + numba.core.types.functions.NumberClass: Resolved numba type + class for number classes. """ - Returns the usm_type, if it is a string literal. + _dtype = None + if data and isinstance(data, types.Array): + _dtype = data.dtype + if not is_nonelike(dtype): + _dtype = _ty_parse_dtype(dtype) + return _dtype + + +def _parse_usm_type(usm_type): + """Parse usm_type parameter. + + Resolves the usm_type parameter based on the type + of the parameter. + + Args: + usm_type (str, numba.core.types.misc.StringLiteral): + The type class for the string to specify the usm_type. + + Raises: + errors.NumbaValueError: If an invalid usm_type is specified. + TypeError: If the parameter is neither a 'str' + nor a 'types.StringLiteral' + + Returns: + str: The stringized usm_type. """ - from numba.core.errors import TypingError if isinstance(usm_type, types.StringLiteral): usm_type_str = usm_type.literal_value if usm_type_str not in ["shared", "device", "host"]: msg = f"Invalid usm_type specified: '{usm_type_str}'" - raise TypingError(msg) + raise errors.NumbaValueError(msg) return usm_type_str + elif isinstance(usm_type, str): + return usm_type else: - raise TypeError + raise TypeError( + "The parameter 'usm_type' is neither of " + + "'str' nor 'types.StringLiteral'" + ) def _parse_device_filter_string(device): + """Parse the device type parameter. + + Returns the device filter string, + if it is a string literal. + + Args: + device (str, numba.core.types.misc.StringLiteral): + The type class for the string to specify the device. + + Raises: + TypeError: If the parameter is neither a 'str' + nor a 'types.StringLiteral' + + Returns: + str: The stringized device. """ - Returns the device filter string, if it is a string literal. - """ - from numba.core.errors import TypingError if isinstance(device, types.StringLiteral): device_filter_str = device.literal_value return device_filter_str + elif isinstance(device, str): + return device else: - raise TypeError + raise TypeError( + "The parameter 'device' is neither of " + + "'str' nor 'types.StringLiteral'" + ) -# ------------------------------------------------------------------------------ -# Helper functions to support dpnp array constructors +def build_dpnp_ndarray( + ndim, + layout="C", + dtype=None, + usm_type="device", + device="unknown", + queue=None, +): + """Constructs `DpnpNdArray` from the parameters provided. -# FIXME: The _empty_nd_impl was copied over *as it is* from numba.np.arrayobj. -# However, we cannot use it yet as the `_call_allocator` function needs to be -# tailored to our needs. Specifically, we need to pass the device string so that -# a correct type of external allocator may be created for the NRT_MemInfo -# object. + Args: + ndim (int): The dimension of the array. + layout ("C", or F"): memory layout for the array. Default: "C" + dtype (numba.core.types.functions.NumberClass, optional): + Data type of the array. Can be typestring, a `numpy.dtype` + object, `numpy` char string, or a numpy scalar type. + Default: None + usm_type (numba.core.types.misc.StringLiteral, optional): + The type of SYCL USM allocation for the output array. + Allowed values are "device"|"shared"|"host". + Default: `"device"`. + device (optional): array API concept of device where the + output array is created. `device` can be `None`, a oneAPI + filter selector string, an instance of :class:`dpctl.SyclDevice` + corresponding to a non-partitioned SYCL device, an instance of + :class:`dpctl.SyclQueue`, or a `Device` object returnedby + `dpctl.tensor.usm_array.device`. Default: `"unknwon"`. + queue (:class:`dpctl.SyclQueue`, optional): Not supported. + Default: `None`. + Raises: + errors.TypingError: If `sycl_queue` is provided for some reason. -def _empty_nd_impl(context, builder, arrtype, shapes): - """Utility function used for allocating a new array during LLVM code - generation (lowering). Given a target context, builder, array - type, and a tuple or list of lowered dimension sizes, returns a - LLVM value pointing at a Numba runtime allocated array. + Returns: + DpnpNdArray: The Numba type to represent an dpnp.ndarray. + The type has the same structure as USMNdArray used to + represent dpctl.tensor.usm_ndarray. """ - arycls = make_array(arrtype) - ary = arycls(context, builder) - - datatype = context.get_data_type(arrtype.dtype) - itemsize = context.get_constant(types.intp, get_itemsize(context, arrtype)) - - # compute array length - arrlen = context.get_constant(types.intp, 1) - overflow = Constant(ir.IntType(1), 0) - for s in shapes: - arrlen_mult = builder.smul_with_overflow(arrlen, s) - arrlen = builder.extract_value(arrlen_mult, 0) - overflow = builder.or_(overflow, builder.extract_value(arrlen_mult, 1)) - - if arrtype.ndim == 0: - strides = () - elif arrtype.layout == "C": - strides = [itemsize] - for dimension_size in reversed(shapes[1:]): - strides.append(builder.mul(strides[-1], dimension_size)) - strides = tuple(reversed(strides)) - elif arrtype.layout == "F": - strides = [itemsize] - for dimension_size in shapes[:-1]: - strides.append(builder.mul(strides[-1], dimension_size)) - strides = tuple(strides) - else: - raise NotImplementedError( - "Don't know how to allocate array with layout '{0}'.".format( - arrtype.layout - ) - ) - # Check overflow, numpy also does this after checking order - allocsize_mult = builder.smul_with_overflow(arrlen, itemsize) - allocsize = builder.extract_value(allocsize_mult, 0) - overflow = builder.or_(overflow, builder.extract_value(allocsize_mult, 1)) - - with builder.if_then(overflow, likely=False): - # Raise same error as numpy, see: - # https://github.com/numpy/numpy/blob/2a488fe76a0f732dc418d03b452caace161673da/numpy/core/src/multiarray/ctors.c#L1095-L1101 # noqa: E501 - context.call_conv.return_user_exc( - builder, - ValueError, - ( - "array is too big; `arr.size * arr.dtype.itemsize` is larger " - "than the maximum possible size.", - ), + if queue: + raise errors.TypingError( + "The sycl_queue keyword is not yet supported by " + "dpnp.empty(), dpnp.zeros(), dpnp.ones(), dpnp.empty_like(), " + "dpnp.zeros_like() and dpnp.ones_like() inside " + "a dpjit decorated function." ) - usm_ty = arrtype.usm_type - usm_ty_val = 0 - if usm_ty == "device": - usm_ty_val = 1 - elif usm_ty == "shared": - usm_ty_val = 2 - elif usm_ty == "host": - usm_ty_val = 3 - usm_type = context.get_constant(types.uint64, usm_ty_val) - device = context.insert_const_string(builder.module, arrtype.device) - - args = ( - context.get_dummy_value(), - allocsize, - usm_type, - device, - ) - mip = types.MemInfoPointer(types.voidptr) - arytypeclass = types.TypeRef(type(arrtype)) - sig = signature( - mip, - arytypeclass, - types.intp, - types.uint64, - types.voidptr, - ) + # If a dtype value was passed in, then try to convert it to the + # coresponding Numba type. If None was passed, the default, then pass None + # to the DpnpNdArray constructor. The default dtype will be derived based + # on the behavior defined in dpctl.tensor.usm_ndarray. - op = _call_usm_allocator - fnop = context.typing_context.resolve_value_type(op) - # The _call_usm_allocator function will be compiled and added to registry - # when the get_call_type function is invoked. - fnop.get_call_type(context.typing_context, sig.args, {}) - eqfn = context.get_function(fnop, sig) - meminfo = eqfn(builder, args) - data = context.nrt.meminfo_data(builder, meminfo) - intp_t = context.get_value_type(types.intp) - shape_array = cgutils.pack_array(builder, shapes, ty=intp_t) - strides_array = cgutils.pack_array(builder, strides, ty=intp_t) - - populate_array( - ary, - data=builder.bitcast(data, datatype.as_pointer()), - shape=shape_array, - strides=strides_array, - itemsize=itemsize, - meminfo=meminfo, + ret_ty = DpnpNdArray( + ndim=ndim, layout=layout, dtype=dtype, usm_type=usm_type, device=device ) - return ary + return ret_ty + + +# ========================================================================= +# Dpnp array constructor overloads +# ========================================================================= @overload_classmethod(DpnpNdArray, "_usm_allocate") @@ -185,277 +188,528 @@ def impl(cls, allocsize, usm_type, device): return impl -@dpjit -def _call_usm_allocator(arrtype, size, usm_type, device): - """Trampoline to call the intrinsic used for allocation""" - return arrtype._usm_allocate(size, usm_type, device) - - -@intrinsic -def intrin_usm_alloc(typingctx, allocsize, usm_type, device): - """Intrinsic to call into the allocator for Array""" - - def codegen(context, builder, signature, args): - [allocsize, usm_type, device] = args - dpexrtCtx = dpexrt.DpexRTContext(context) - meminfo = dpexrtCtx.meminfo_alloc(builder, allocsize, usm_type, device) - return meminfo - - mip = types.MemInfoPointer(types.voidptr) # return untyped pointer - sig = signature(mip, allocsize, usm_type, device) - return sig, codegen - - -@intrinsic -def impl_dpnp_empty( - tyctx, - ty_shape, - ty_dtype, - ty_usm_type, - ty_device, - ty_retty_ref, +@overload(dpnp.empty, prefer_literal=True) +def ol_dpnp_empty( + shape, + dtype=None, + order="C", + device=None, + usm_type="device", + sycl_queue=None, ): - ty_retty = ty_retty_ref.instance_type - - sig = ty_retty(ty_shape, ty_dtype, ty_usm_type, ty_device, ty_retty_ref) + """Implementation of an overload to support dpnp.empty() inside + a jit function. - def codegen(context, builder, sig, llargs): - arrtype = _parse_empty_args(context, builder, sig, llargs) - ary = _empty_nd_impl(context, builder, *arrtype) - return ary._getvalue() + Args: + shape (tuple): Dimensions of the array to be created. + dtype (numba.core.types.functions.NumberClass, optional): + Data type of the array. Can be typestring, a `numpy.dtype` + object, `numpy` char string, or a numpy scalar type. + Default: None + order (str, optional): memory layout for the array "C" or "F". + Default: "C" + device (numba.core.types.misc.StringLiteral, optional): array API + concept of device where the output array is created. `device` + can be `None`, a oneAPI filter selector string, an instance of + :class:`dpctl.SyclDevice` corresponding to a non-partitioned + SYCL device, an instance of :class:`dpctl.SyclQueue`, or a + `Device` object returnedby`dpctl.tensor.usm_array.device`. + Default: `None`. + usm_type (numba.core.types.misc.StringLiteral or str, optional): + The type of SYCL USM allocation for the output array. + Allowed values are "device"|"shared"|"host". + Default: `"device"`. + sycl_queue (:class:`dpctl.SyclQueue`, optional): Not supported. - return sig, codegen + Raises: + errors.TypingError: If rank of the ndarray couldn't be inferred. + errors.TypingError: If couldn't parse input types to dpnp.empty(). + Returns: + function: Local function `impl_dpnp_empty()` + """ -def aryobj_fill(context, builder, sig, llargs, value): - arrtype = _parse_empty_args(context, builder, sig, llargs) - ary = _empty_nd_impl(context, builder, *arrtype) - itemsize = context.get_constant( - types.intp, get_itemsize(context, arrtype[0]) + _ndim = _ty_parse_shape(shape) + _dtype = _parse_dtype(dtype) + _usm_type = _parse_usm_type(usm_type) if usm_type is not None else "device" + _device = ( + _parse_device_filter_string(device) if device is not None else "unknown" ) - device = context.insert_const_string(builder.module, arrtype[0].device) - value = context.get_constant(types.int8, value) - if isinstance(arrtype[0].dtype, Float): - is_float = context.get_constant(types.boolean, 1) + if _ndim: + ret_ty = build_dpnp_ndarray( + _ndim, + layout=order, + dtype=_dtype, + usm_type=_usm_type, + device=_device, + queue=sycl_queue, + ) + if ret_ty: + + def impl( + shape, + dtype=None, + order="C", + device=None, + usm_type="device", + sycl_queue=None, + ): + return impl_dpnp_empty( + shape, _dtype, order, _device, _usm_type, sycl_queue, ret_ty + ) + + return impl + else: + raise errors.TypingError( + "Cannot parse input types to " + + f"function dpnp.empty({shape}, {dtype}, ...)." + ) else: - is_float = context.get_constant(types.boolean, 0) - dpexrtCtx = dpexrt.DpexRTContext(context) - dpexrtCtx.meminfo_fill( - builder, ary.meminfo, itemsize, is_float, value, device - ) - return ary._getvalue() + raise errors.TypingError("Could not infer the rank of the ndarray.") -@intrinsic -def impl_dpnp_zeros( - tyctx, ty_shape, ty_dtype, ty_usm_type, ty_device, ty_retty_ref -): - ty_retty = ty_retty_ref.instance_type - - sig = ty_retty(ty_shape, ty_dtype, ty_usm_type, ty_device, ty_retty_ref) - - def codegen(context, builder, sig, llargs): - return aryobj_fill(context, builder, sig, llargs, 0) - - return sig, codegen - - -@intrinsic -def impl_dpnp_ones( - tyctx, ty_shape, ty_dtype, ty_usm_type, ty_device, ty_retty_ref +@overload(dpnp.zeros, prefer_literal=True) +def ol_dpnp_zeros( + shape, + dtype=None, + order="C", + device=None, + usm_type="device", + sycl_queue=None, ): - ty_retty = ty_retty_ref.instance_type - - sig = ty_retty(ty_shape, ty_dtype, ty_usm_type, ty_device, ty_retty_ref) + """Implementation of an overload to support dpnp.zeros() inside + a jit function. - def codegen(context, builder, sig, llargs): - return aryobj_fill(context, builder, sig, llargs, 1) + Args: + shape (tuple): Dimensions of the array to be created. + dtype (numba.core.types.functions.NumberClass, optional): + Data type of the array. Can be typestring, a `numpy.dtype` + object, `numpy` char string, or a numpy scalar type. + Default: None + order (str, optional): memory layout for the array "C" or "F". + Default: "C" + device (numba.core.types.misc.StringLiteral, optional): array API + concept of device where the output array is created. `device` + can be `None`, a oneAPI filter selector string, an instance of + :class:`dpctl.SyclDevice` corresponding to a non-partitioned + SYCL device, an instance of :class:`dpctl.SyclQueue`, or a + `Device` object returnedby`dpctl.tensor.usm_array.device`. + Default: `None`. + usm_type (numba.core.types.misc.StringLiteral or str, optional): + The type of SYCL USM allocation for the output array. + Allowed values are "device"|"shared"|"host". + Default: `"device"`. + sycl_queue (:class:`dpctl.SyclQueue`, optional): Not supported. - return sig, codegen + Raises: + errors.TypingError: If rank of the ndarray couldn't be inferred. + errors.TypingError: If couldn't parse input types to dpnp.zeros(). + Returns: + function: Local function `impl_dpnp_zeros()` + """ -# ------------------------------------------------------------------------------ -# Dpnp array constructor overloads + _ndim = _ty_parse_shape(shape) + _dtype = _parse_dtype(dtype) + _usm_type = _parse_usm_type(usm_type) if usm_type is not None else "device" + _device = ( + _parse_device_filter_string(device) if device is not None else "unknown" + ) + if _ndim: + ret_ty = build_dpnp_ndarray( + _ndim, + layout=order, + dtype=_dtype, + usm_type=_usm_type, + device=_device, + queue=sycl_queue, + ) + if ret_ty: + + def impl( + shape, + dtype=None, + order="C", + device=None, + usm_type="device", + sycl_queue=None, + ): + return impl_dpnp_zeros( + shape, _dtype, order, _device, _usm_type, sycl_queue, ret_ty + ) + + return impl + else: + raise errors.TypingError( + "Cannot parse input types to " + + f"function dpnp.zeros({shape}, {dtype}, ...)." + ) + else: + raise errors.TypingError("Could not infer the rank of the ndarray.") -@overload(dpnp.empty, prefer_literal=True) -def ol_dpnp_empty( - shape, dtype=None, usm_type=None, device=None, sycl_queue=None +@overload(dpnp.ones, prefer_literal=True) +def ol_dpnp_ones( + shape, + dtype=None, + order="C", + device=None, + usm_type="device", + sycl_queue=None, ): - """Implementation of an overload to support dpnp.empty inside a jit - function. + """Implementation of an overload to support dpnp.ones() inside + a jit function. Args: shape (tuple): Dimensions of the array to be created. - dtype optional): Data type of the array. Can be typestring, - a `numpy.dtype` object, `numpy` char string, or a numpy - scalar type. Default: None - usm_type ("device"|"shared"|"host", optional): The type of SYCL USM - allocation for the output array. Default: `"device"`. - device (optional): array API concept of device where the output array - is created. `device` can be `None`, a oneAPI filter selector string, - an instance of :class:`dpctl.SyclDevice` corresponding to a - non-partitioned SYCL device, an instance of - :class:`dpctl.SyclQueue`, or a `Device` object returnedby - `dpctl.tensor.usm_array.device`. Default: `None`. + dtype (numba.core.types.functions.NumberClass, optional): + Data type of the array. Can be typestring, a `numpy.dtype` + object, `numpy` char string, or a numpy scalar type. + Default: None + order (str, optional): memory layout for the array "C" or "F". + Default: "C" + device (numba.core.types.misc.StringLiteral, optional): array API + concept of device where the output array is created. `device` + can be `None`, a oneAPI filter selector string, an instance of + :class:`dpctl.SyclDevice` corresponding to a non-partitioned + SYCL device, an instance of :class:`dpctl.SyclQueue`, or a + `Device` object returnedby`dpctl.tensor.usm_array.device`. + Default: `None`. + usm_type (numba.core.types.misc.StringLiteral or str, optional): + The type of SYCL USM allocation for the output array. + Allowed values are "device"|"shared"|"host". + Default: `"device"`. sycl_queue (:class:`dpctl.SyclQueue`, optional): Not supported. - Returns: Numba implementation of the dpnp.empty + Raises: + errors.TypingError: If rank of the ndarray couldn't be inferred. + errors.TypingError: If couldn't parse input types to dpnp.ones(). + + Returns: + function: Local function `impl_dpnp_ones()` """ - if sycl_queue: - raise errors.TypingError( - "The sycl_queue keyword is not yet supported by dpnp.empty inside " - "a dpjit decorated function." + _ndim = _ty_parse_shape(shape) + _dtype = _parse_dtype(dtype) + _usm_type = _parse_usm_type(usm_type) if usm_type is not None else "device" + _device = ( + _parse_device_filter_string(device) if device is not None else "unknown" + ) + if _ndim: + ret_ty = build_dpnp_ndarray( + _ndim, + layout=order, + dtype=_dtype, + usm_type=_usm_type, + device=_device, + queue=sycl_queue, ) + if ret_ty: + + def impl( + shape, + dtype=None, + order="C", + device=None, + usm_type="device", + sycl_queue=None, + ): + return impl_dpnp_ones( + shape, _dtype, order, _device, _usm_type, sycl_queue, ret_ty + ) + + return impl + else: + raise errors.TypingError( + "Cannot parse input types to " + + f"function dpnp.ones({shape}, {dtype}, ...)." + ) + else: + raise errors.TypingError("Could not infer the rank of the ndarray.") + + +@overload(dpnp.empty_like, prefer_literal=True) +def ol_dpnp_empty_like( + x, + dtype=None, + order="C", + shape=None, + device=None, + usm_type=None, + sycl_queue=None, +): + """Creates `usm_ndarray` from uninitialized USM allocation. - ndim = parse_shape(shape) - if not ndim: - raise errors.TypingError("Could not infer the rank of the ndarray") + This is an overloaded function implementation for dpnp.empty_like(). - # If a dtype value was passed in, then try to convert it to the - # coresponding Numba type. If None was passed, the default, then pass None - # to the DpnpNdArray constructor. The default dtype will be derived based - # on the behavior defined in dpctl.tensor.usm_ndarray. - if not is_nonelike(dtype): - nb_dtype = ty_parse_dtype(dtype) - else: - nb_dtype = None + Args: + x (numba.core.types.npytypes.Array): Input array from which to + derive the output array shape. + dtype (numba.core.types.functions.NumberClass, optional): + Data type of the array. Can be typestring, a `numpy.dtype` + object, `numpy` char string, or a numpy scalar type. + Default: None + order (str, optional): memory layout for the array "C" or "F". + Default: "C" + shape (numba.core.types.containers.UniTuple, optional): The shape + to override the shape of the given array. Not supported. + Default: `None` + device (numba.core.types.misc.StringLiteral, optional): array API + concept of device where the output array is created. `device` + can be `None`, a oneAPI filter selector string, an instance of + :class:`dpctl.SyclDevice` corresponding to a non-partitioned + SYCL device, an instance of :class:`dpctl.SyclQueue`, or a + `Device` object returnedby`dpctl.tensor.usm_array.device`. + Default: `None`. + usm_type (numba.core.types.misc.StringLiteral or str, optional): + The type of SYCL USM allocation for the output array. + Allowed values are "device"|"shared"|"host". + Default: `"device"`. + sycl_queue (:class:`dpctl.SyclQueue`, optional): Not supported. - if usm_type is not None: - usm_type = _parse_usm_type(usm_type) - else: - usm_type = "device" + Raises: + errors.TypingError: If couldn't parse input types to dpnp.empty_like(). + errors.TypingError: If shape is provided. - if device is not None: - device = _parse_device_filter_string(device) - else: - device = "unknown" - - if ndim is not None: - retty = DpnpNdArray( - dtype=nb_dtype, - ndim=ndim, - usm_type=usm_type, - device=device, + Returns: + function: Local function `impl_dpnp_empty_like()` + """ + + if shape: + raise errors.TypingError( + "The parameter shape is not supported " + + "inside overloaded dpnp.empty_like() function." ) + _ndim = x.ndim if hasattr(x, "ndim") and x.ndim is not None else 0 + _dtype = _parse_dtype(dtype, data=x) + _order = x.layout if order is None else order + _usm_type = _parse_usm_type(usm_type) if usm_type is not None else "device" + _device = ( + _parse_device_filter_string(device) if device is not None else "unknown" + ) + ret_ty = build_dpnp_ndarray( + _ndim, + layout=_order, + dtype=_dtype, + usm_type=_usm_type, + device=_device, + queue=sycl_queue, + ) + if ret_ty: def impl( - shape, dtype=None, usm_type=None, device=None, sycl_queue=None + x, + dtype=None, + order="C", + shape=None, + device=None, + usm_type=None, + sycl_queue=None, ): - return impl_dpnp_empty(shape, dtype, usm_type, device, retty) + return impl_dpnp_empty_like( + x, + _dtype, + _order, + _device, + _usm_type, + sycl_queue, + ret_ty, + ) return impl else: - msg = ( - f"Cannot parse input types to function dpnp.empty({shape}, {dtype})" + raise errors.TypingError( + "Cannot parse input types to " + + f"function dpnp.empty_like({x}, {dtype}, ...)." ) - raise errors.TypingError(msg) -@overload(dpnp.zeros, prefer_literal=True) -def ol_dpnp_zeros( - shape, dtype=None, usm_type=None, device=None, sycl_queue=None +@overload(dpnp.zeros_like, prefer_literal=True) +def ol_dpnp_zeros_like( + x, + dtype=None, + order="C", + shape=None, + device=None, + usm_type=None, + sycl_queue=None, ): - if sycl_queue: - raise errors.TypingError( - "The sycl_queue keyword is not yet supported by dpnp.empty inside " - "a dpjit decorated function." - ) + """Creates `usm_ndarray` from USM allocation initialized with zeros. - ndim = parse_shape(shape) - if not ndim: - raise errors.TypingError("Could not infer the rank of the ndarray") + This is an overloaded function implementation for dpnp.zeros_like(). - # If a dtype value was passed in, then try to convert it to the - # coresponding Numba type. If None was passed, the default, then pass None - # to the DpnpNdArray constructor. The default dtype will be derived based - # on the behavior defined in dpctl.tensor.usm_ndarray. - if not is_nonelike(dtype): - nb_dtype = ty_parse_dtype(dtype) - else: - nb_dtype = None + Args: + x (numba.core.types.npytypes.Array): Input array from which to + derive the output array shape. + dtype (numba.core.types.functions.NumberClass, optional): + Data type of the array. Can be typestring, a `numpy.dtype` + object, `numpy` char string, or a numpy scalar type. + Default: None + order (str, optional): memory layout for the array "C" or "F". + Default: "C" + shape (numba.core.types.containers.UniTuple, optional): The shape + to override the shape of the given array. Not supported. + Default: `None` + device (numba.core.types.misc.StringLiteral, optional): array API + concept of device where the output array is created. `device` + can be `None`, a oneAPI filter selector string, an instance of + :class:`dpctl.SyclDevice` corresponding to a non-partitioned + SYCL device, an instance of :class:`dpctl.SyclQueue`, or a + `Device` object returnedby`dpctl.tensor.usm_array.device`. + Default: `None`. + usm_type (numba.core.types.misc.StringLiteral or str, optional): + The type of SYCL USM allocation for the output array. + Allowed values are "device"|"shared"|"host". + Default: `"device"`. + sycl_queue (:class:`dpctl.SyclQueue`, optional): Not supported. - if usm_type is not None: - usm_type = _parse_usm_type(usm_type) - else: - usm_type = "device" + Raises: + errors.TypingError: If couldn't parse input types to dpnp.zeros_like(). + errors.TypingError: If shape is provided. - if device is not None: - device = _parse_device_filter_string(device) - else: - device = "unknown" - - if ndim is not None: - retty = DpnpNdArray( - dtype=nb_dtype, - ndim=ndim, - usm_type=usm_type, - device=device, + Returns: + function: Local function `impl_dpnp_zeros_like()` + """ + + if shape: + raise errors.TypingError( + "The parameter shape is not supported " + + "inside overloaded dpnp.zeros_like() function." ) + _ndim = x.ndim if hasattr(x, "ndim") and x.ndim is not None else 0 + _dtype = _parse_dtype(dtype, data=x) + _order = x.layout if order is None else order + _usm_type = _parse_usm_type(usm_type) if usm_type is not None else "device" + _device = ( + _parse_device_filter_string(device) if device is not None else "unknown" + ) + ret_ty = build_dpnp_ndarray( + _ndim, + layout=_order, + dtype=_dtype, + usm_type=_usm_type, + device=_device, + queue=sycl_queue, + ) + if ret_ty: def impl( - shape, dtype=None, usm_type=None, device=None, sycl_queue=None + x, + dtype=None, + order="C", + shape=None, + device=None, + usm_type=None, + sycl_queue=None, ): - return impl_dpnp_zeros(shape, dtype, usm_type, device, retty) + return impl_dpnp_zeros_like( + x, + _dtype, + _order, + _device, + _usm_type, + sycl_queue, + ret_ty, + ) return impl else: - msg = ( - f"Cannot parse input types to function dpnp.empty({shape}, {dtype})" + raise errors.TypingError( + "Cannot parse input types to " + + f"function dpnp.empty_like({x}, {dtype}, ...)." ) - raise errors.TypingError(msg) -@overload(dpnp.ones, prefer_literal=True) -def ol_dpnp_ones( - shape, dtype=None, usm_type=None, device=None, sycl_queue=None +@overload(dpnp.ones_like, prefer_literal=True) +def ol_dpnp_ones_like( + x, + dtype=None, + order="C", + shape=None, + device=None, + usm_type=None, + sycl_queue=None, ): - if sycl_queue: - raise errors.TypingError( - "The sycl_queue keyword is not yet supported by dpnp.empty inside " - "a dpjit decorated function." - ) + """Creates `usm_ndarray` from USM allocation initialized with ones. - ndim = parse_shape(shape) - if not ndim: - raise errors.TypingError("Could not infer the rank of the ndarray") + This is an overloaded function implementation for dpnp.ones_like(). - # If a dtype value was passed in, then try to convert it to the - # coresponding Numba type. If None was passed, the default, then pass None - # to the DpnpNdArray constructor. The default dtype will be derived based - # on the behavior defined in dpctl.tensor.usm_ndarray. - if not is_nonelike(dtype): - nb_dtype = ty_parse_dtype(dtype) - else: - nb_dtype = None + Args: + x (numba.core.types.npytypes.Array): Input array from which to + derive the output array shape. + dtype (numba.core.types.functions.NumberClass, optional): + Data type of the array. Can be typestring, a `numpy.dtype` + object, `numpy` char string, or a numpy scalar type. + Default: None + order (str, optional): memory layout for the array "C" or "F". + Default: "C" + shape (numba.core.types.containers.UniTuple, optional): The shape + to override the shape of the given array. Not supported. + Default: `None` + device (numba.core.types.misc.StringLiteral, optional): array API + concept of device where the output array is created. `device` + can be `None`, a oneAPI filter selector string, an instance of + :class:`dpctl.SyclDevice` corresponding to a non-partitioned + SYCL device, an instance of :class:`dpctl.SyclQueue`, or a + `Device` object returnedby`dpctl.tensor.usm_array.device`. + Default: `None`. + usm_type (numba.core.types.misc.StringLiteral or str, optional): + The type of SYCL USM allocation for the output array. + Allowed values are "device"|"shared"|"host". + Default: `"device"`. + sycl_queue (:class:`dpctl.SyclQueue`, optional): Not supported. - if usm_type is not None: - usm_type = _parse_usm_type(usm_type) - else: - usm_type = "device" + Raises: + errors.TypingError: If couldn't parse input types to dpnp.ones_like(). + errors.TypingError: If shape is provided. - if device is not None: - device = _parse_device_filter_string(device) - else: - device = "unknown" - - if ndim is not None: - retty = DpnpNdArray( - dtype=nb_dtype, - ndim=ndim, - usm_type=usm_type, - device=device, + Returns: + function: Local function `impl_dpnp_ones_like()` + """ + + if shape: + raise errors.TypingError( + "The parameter shape is not supported " + + "inside overloaded dpnp.ones_like() function." ) + _ndim = x.ndim if hasattr(x, "ndim") and x.ndim is not None else 0 + _dtype = _parse_dtype(dtype, data=x) + _order = x.layout if order is None else order + _usm_type = _parse_usm_type(usm_type) if usm_type is not None else "device" + _device = ( + _parse_device_filter_string(device) if device is not None else "unknown" + ) + ret_ty = build_dpnp_ndarray( + _ndim, + layout=_order, + dtype=_dtype, + usm_type=_usm_type, + device=_device, + queue=sycl_queue, + ) + if ret_ty: def impl( - shape, dtype=None, usm_type=None, device=None, sycl_queue=None + x, + dtype=None, + order="C", + device=None, + usm_type=None, + sycl_queue=None, ): - return impl_dpnp_ones(shape, dtype, usm_type, device, retty) + return impl_dpnp_ones_like( + x, + _dtype, + _order, + _device, + _usm_type, + sycl_queue, + ret_ty, + ) return impl else: - msg = ( - f"Cannot parse input types to function dpnp.empty({shape}, {dtype})" + raise errors.TypingError( + "Cannot parse input types to " + + f"function dpnp.empty_like({x}, {dtype}, ...)." ) - raise errors.TypingError(msg) diff --git a/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_empty_like.py b/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_empty_like.py new file mode 100644 index 0000000000..91099bbf73 --- /dev/null +++ b/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_empty_like.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for dpnp ndarray constructors.""" + + +import dpctl +import dpnp +import numpy +import pytest +from numba import errors + +from numba_dpex import dpjit + +shapes = [10, (2, 5)] +dtypes = [dpnp.int32, dpnp.int64, dpnp.float32, dpnp.float64] +usm_types = ["device", "shared", "host"] +devices = ["cpu", "unknown"] + + +@pytest.mark.parametrize("shape", shapes) +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("usm_type", usm_types) +@pytest.mark.parametrize("device", devices) +def test_dpnp_empty_like(shape, dtype, usm_type, device): + @dpjit + def func(a): + c = dpnp.empty_like(a, dtype=dtype, usm_type=usm_type, device=device) + return c + + if isinstance(shape, int): + NZ = numpy.random.rand(shape) + else: + NZ = numpy.random.rand(*shape) + + try: + c = func(NZ) + except Exception: + pytest.fail("Calling dpnp.empty_like inside dpjit failed") + + if len(c.shape) == 1: + assert c.shape[0] == NZ.shape[0] + else: + assert c.shape == NZ.shape + + assert c.dtype == dtype + assert c.usm_type == usm_type + if device != "unknown": + assert ( + c.sycl_device.filter_string + == dpctl.SyclDevice(device).filter_string + ) + else: + c.sycl_device.filter_string == dpctl.SyclDevice().filter_string + + +def test_dpnp_empty_like_exceptions(): + @dpjit + def func1(a): + c = dpnp.empty_like(a, shape=(3, 3)) + return c + + try: + func1(numpy.random.rand(5, 5)) + except Exception as e: + assert isinstance(e, errors.TypingError) + assert ( + "No implementation of function Function(