Skip to content

Commit 378abc9

Browse files
author
Diptorup Deb
committed
Overload implementation for dpnp.empty.
- Removes the numba_dpex/core/dpnp_ndarray modules - Adds an arrayobj submodule to numba_dpex/dpnp_iface - Implements boxing and unboxing for dpnp.ndarrays using _dpexrt_python extension. - Adds an overload for dpnp.empty to dono_iface/arrayobj.
1 parent 8a7b496 commit 378abc9

File tree

6 files changed

+409
-220
lines changed

6 files changed

+409
-220
lines changed

numba_dpex/core/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,5 @@
44

55

66
from .datamodel import *
7-
from .dpnp_ndarray import dpnp_empty
87
from .types import *
98
from .typing import *

numba_dpex/core/dpnp_ndarray/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

numba_dpex/core/dpnp_ndarray/dpnp_empty.py

Lines changed: 0 additions & 214 deletions
This file was deleted.

numba_dpex/core/types/dpnp_ndarray_type.py

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,123 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55

6+
from numba.core import cgutils
7+
from numba.core.errors import NumbaNotImplementedError
8+
from numba.core.pythonapi import NativeValue, PythonAPI, box, unbox
9+
from numba.np import numpy_support
10+
11+
from numba_dpex.core.exceptions import UnreachableError
12+
from numba_dpex.core.runtime import context as dpexrt
13+
614
from .usm_ndarray_type import USMNdArray
715

816

917
class DpnpNdArray(USMNdArray):
1018
"""
1119
The Numba type to represent an dpnp.ndarray. The type has the same
12-
structure as USMNdArray used to represnet dpctl.tensor.usm_ndarray.
20+
structure as USMNdArray used to represent dpctl.tensor.usm_ndarray.
21+
"""
22+
23+
@property
24+
def is_internal(self):
25+
return True
26+
27+
28+
# --------------- Boxing/Unboxing logic for dpnp.ndarray ----------------------#
29+
30+
31+
@unbox(DpnpNdArray)
32+
def unbox_dpnp_nd_array(typ, obj, c):
33+
"""Converts a dpnp.ndarray object to a Numba internal array structure.
34+
35+
Args:
36+
typ : The Numba type of the PyObject
37+
obj : The actual PyObject to be unboxed
38+
c :
39+
40+
Returns:
41+
_type_: _description_
1342
"""
43+
# Reusing the numba.core.base.BaseContext's make_array function to get a
44+
# struct allocated. The same struct is used for numpy.ndarray
45+
# and dpnp.ndarray. It is possible to do so, as the extra information
46+
# specific to dpnp.ndarray such as sycl_queue is inferred statically and
47+
# stored as part of the DpnpNdArray type.
48+
49+
# --------------- Original Numba comment from @ubox(types.Array)
50+
#
51+
# This is necessary because unbox_buffer() does not work on some
52+
# dtypes, e.g. datetime64 and timedelta64.
53+
# TODO check matching dtype.
54+
# currently, mismatching dtype will still work and causes
55+
# potential memory corruption
56+
#
57+
# --------------- End of Numba comment from @ubox(types.Array)
58+
nativearycls = c.context.make_array(typ)
59+
nativeary = nativearycls(c.context, c.builder)
60+
aryptr = nativeary._getpointer()
61+
62+
ptr = c.builder.bitcast(aryptr, c.pyapi.voidptr)
63+
# FIXME : We need to check if Numba_RT as well as DPEX RT are enabled.
64+
if c.context.enable_nrt:
65+
dpexrtCtx = dpexrt.DpexRTContext(c.context)
66+
errcode = dpexrtCtx.arraystruct_from_python(c.pyapi, obj, ptr)
67+
else:
68+
raise UnreachableError
69+
70+
# TODO: here we have minimal typechecking by the itemsize.
71+
# need to do better
72+
try:
73+
expected_itemsize = numpy_support.as_dtype(typ.dtype).itemsize
74+
except NumbaNotImplementedError:
75+
# Don't check types that can't be `as_dtype()`-ed
76+
itemsize_mismatch = cgutils.false_bit
77+
else:
78+
expected_itemsize = nativeary.itemsize.type(expected_itemsize)
79+
itemsize_mismatch = c.builder.icmp_unsigned(
80+
"!=",
81+
nativeary.itemsize,
82+
expected_itemsize,
83+
)
84+
85+
failed = c.builder.or_(
86+
cgutils.is_not_null(c.builder, errcode),
87+
itemsize_mismatch,
88+
)
89+
# Handle error
90+
with c.builder.if_then(failed, likely=False):
91+
c.pyapi.err_set_string(
92+
"PyExc_TypeError",
93+
"can't unbox array from PyObject into "
94+
"native value. The object maybe of a "
95+
"different type",
96+
)
97+
return NativeValue(c.builder.load(aryptr), is_error=failed)
98+
99+
100+
@box(DpnpNdArray)
101+
def box_array(typ, val, c):
102+
if c.context.enable_nrt:
103+
np_dtype = numpy_support.as_dtype(typ.dtype)
104+
dtypeptr = c.env_manager.read_const(c.env_manager.add_const(np_dtype))
105+
dpexrtCtx = dpexrt.DpexRTContext(c.context)
106+
newary = dpexrtCtx.usm_ndarray_to_python_acqref(
107+
c.pyapi, typ, val, dtypeptr
108+
)
109+
110+
if not newary:
111+
c.pyapi.err_set_string(
112+
"PyExc_TypeError",
113+
"could not box native array into a dpnp.ndarray PyObject.",
114+
)
115+
116+
# Steals NRT ref
117+
# Refer:
118+
# numba.core.base.nrt -> numba.core.runtime.context -> decref
119+
# The `NRT_decref` function is generated directly as LLVM IR inside
120+
# numba.core.runtime.nrtdynmod.py
121+
c.context.nrt.decref(c.builder, typ, val)
14122

15-
pass
123+
return newary
124+
else:
125+
raise UnreachableError

numba_dpex/dpnp_iface/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from . import arrayobj
6+
57

68
def ensure_dpnp(name):
79
try:
@@ -24,4 +26,5 @@ def _init_dpnp():
2426

2527
_init_dpnp()
2628

29+
2730
DEBUG = None

0 commit comments

Comments
 (0)