|
3 | 3 | # SPDX-License-Identifier: Apache-2.0 |
4 | 4 |
|
5 | 5 |
|
| 6 | +from numba.core import cgutils |
| 7 | +from numba.core.errors import NumbaNotImplementedError |
| 8 | +from numba.core.pythonapi import NativeValue, PythonAPI, box, unbox |
| 9 | +from numba.np import numpy_support |
| 10 | + |
| 11 | +from numba_dpex.core.exceptions import UnreachableError |
| 12 | +from numba_dpex.core.runtime import context as dpexrt |
| 13 | + |
6 | 14 | from .usm_ndarray_type import USMNdArray |
7 | 15 |
|
8 | 16 |
|
9 | 17 | class DpnpNdArray(USMNdArray): |
10 | 18 | """ |
11 | 19 | The Numba type to represent an dpnp.ndarray. The type has the same |
12 | | - structure as USMNdArray used to represnet dpctl.tensor.usm_ndarray. |
| 20 | + structure as USMNdArray used to represent dpctl.tensor.usm_ndarray. |
| 21 | + """ |
| 22 | + |
| 23 | + @property |
| 24 | + def is_internal(self): |
| 25 | + return True |
| 26 | + |
| 27 | + |
| 28 | +# --------------- Boxing/Unboxing logic for dpnp.ndarray ----------------------# |
| 29 | + |
| 30 | + |
| 31 | +@unbox(DpnpNdArray) |
| 32 | +def unbox_dpnp_nd_array(typ, obj, c): |
| 33 | + """Converts a dpnp.ndarray object to a Numba internal array structure. |
| 34 | +
|
| 35 | + Args: |
| 36 | + typ : The Numba type of the PyObject |
| 37 | + obj : The actual PyObject to be unboxed |
| 38 | + c : |
| 39 | +
|
| 40 | + Returns: |
| 41 | + _type_: _description_ |
13 | 42 | """ |
| 43 | + # Reusing the numba.core.base.BaseContext's make_array function to get a |
| 44 | + # struct allocated. The same struct is used for numpy.ndarray |
| 45 | + # and dpnp.ndarray. It is possible to do so, as the extra information |
| 46 | + # specific to dpnp.ndarray such as sycl_queue is inferred statically and |
| 47 | + # stored as part of the DpnpNdArray type. |
| 48 | + |
| 49 | + # --------------- Original Numba comment from @ubox(types.Array) |
| 50 | + # |
| 51 | + # This is necessary because unbox_buffer() does not work on some |
| 52 | + # dtypes, e.g. datetime64 and timedelta64. |
| 53 | + # TODO check matching dtype. |
| 54 | + # currently, mismatching dtype will still work and causes |
| 55 | + # potential memory corruption |
| 56 | + # |
| 57 | + # --------------- End of Numba comment from @ubox(types.Array) |
| 58 | + nativearycls = c.context.make_array(typ) |
| 59 | + nativeary = nativearycls(c.context, c.builder) |
| 60 | + aryptr = nativeary._getpointer() |
| 61 | + |
| 62 | + ptr = c.builder.bitcast(aryptr, c.pyapi.voidptr) |
| 63 | + # FIXME : We need to check if Numba_RT as well as DPEX RT are enabled. |
| 64 | + if c.context.enable_nrt: |
| 65 | + dpexrtCtx = dpexrt.DpexRTContext(c.context) |
| 66 | + errcode = dpexrtCtx.arraystruct_from_python(c.pyapi, obj, ptr) |
| 67 | + else: |
| 68 | + raise UnreachableError |
| 69 | + |
| 70 | + # TODO: here we have minimal typechecking by the itemsize. |
| 71 | + # need to do better |
| 72 | + try: |
| 73 | + expected_itemsize = numpy_support.as_dtype(typ.dtype).itemsize |
| 74 | + except NumbaNotImplementedError: |
| 75 | + # Don't check types that can't be `as_dtype()`-ed |
| 76 | + itemsize_mismatch = cgutils.false_bit |
| 77 | + else: |
| 78 | + expected_itemsize = nativeary.itemsize.type(expected_itemsize) |
| 79 | + itemsize_mismatch = c.builder.icmp_unsigned( |
| 80 | + "!=", |
| 81 | + nativeary.itemsize, |
| 82 | + expected_itemsize, |
| 83 | + ) |
| 84 | + |
| 85 | + failed = c.builder.or_( |
| 86 | + cgutils.is_not_null(c.builder, errcode), |
| 87 | + itemsize_mismatch, |
| 88 | + ) |
| 89 | + # Handle error |
| 90 | + with c.builder.if_then(failed, likely=False): |
| 91 | + c.pyapi.err_set_string( |
| 92 | + "PyExc_TypeError", |
| 93 | + "can't unbox array from PyObject into " |
| 94 | + "native value. The object maybe of a " |
| 95 | + "different type", |
| 96 | + ) |
| 97 | + return NativeValue(c.builder.load(aryptr), is_error=failed) |
| 98 | + |
| 99 | + |
| 100 | +@box(DpnpNdArray) |
| 101 | +def box_array(typ, val, c): |
| 102 | + if c.context.enable_nrt: |
| 103 | + np_dtype = numpy_support.as_dtype(typ.dtype) |
| 104 | + dtypeptr = c.env_manager.read_const(c.env_manager.add_const(np_dtype)) |
| 105 | + dpexrtCtx = dpexrt.DpexRTContext(c.context) |
| 106 | + newary = dpexrtCtx.usm_ndarray_to_python_acqref( |
| 107 | + c.pyapi, typ, val, dtypeptr |
| 108 | + ) |
| 109 | + |
| 110 | + if not newary: |
| 111 | + c.pyapi.err_set_string( |
| 112 | + "PyExc_TypeError", |
| 113 | + "could not box native array into a dpnp.ndarray PyObject.", |
| 114 | + ) |
| 115 | + |
| 116 | + # Steals NRT ref |
| 117 | + # Refer: |
| 118 | + # numba.core.base.nrt -> numba.core.runtime.context -> decref |
| 119 | + # The `NRT_decref` function is generated directly as LLVM IR inside |
| 120 | + # numba.core.runtime.nrtdynmod.py |
| 121 | + c.context.nrt.decref(c.builder, typ, val) |
14 | 122 |
|
15 | | - pass |
| 123 | + return newary |
| 124 | + else: |
| 125 | + raise UnreachableError |
0 commit comments