Skip to content

Commit 70c36ea

Browse files
author
Diptorup Deb
committed
Adds an arg_pack_unpack module to kernel_interface
- Creates a separate module for the unpack and pack functions for kernel arguments. - The new API is intended for use from the Dispatcher class.
1 parent 6b39cd6 commit 70c36ea

File tree

2 files changed

+292
-1
lines changed

2 files changed

+292
-1
lines changed

numba_dpex/core/exceptions.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
"""The module defines the custom exception classes used in numba_dpex.
5+
"""The module defines the custom error classes used in numba_dpex.
66
"""
77

88
from warnings import warn
@@ -218,3 +218,18 @@ def __init__(self) -> None:
218218
else:
219219
self.message = "Unreachable code executed."
220220
super().__init__(self.message)
221+
222+
223+
class UnsupportedKernelArgumentError(Exception):
224+
def __init__(self, *args: object) -> None:
225+
super().__init__(*args)
226+
227+
228+
class SUAIProtocolError(Exception):
229+
def __init__(self, *args: object) -> None:
230+
super().__init__(*args)
231+
232+
233+
class UnsupportedAccessQualifierError(Exception):
234+
def __init__(self, *args: object) -> None:
235+
super().__init__(*args)
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
# SPDX-FileCopyrightText: 2022 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import ctypes
6+
import logging
7+
from multiprocessing.dummy import Array
8+
9+
import dpctl.memory as dpctl_mem
10+
import numpy as np
11+
from numba.core import types
12+
13+
import numba_dpex.utils as utils
14+
from numba_dpex.core.exceptions import (
15+
SUAIProtocolError,
16+
UnsupportedAccessQualifierError,
17+
UnsupportedKernelArgumentError,
18+
)
19+
from numba_dpex.dpctl_iface import USMNdArrayType
20+
21+
22+
class Packer:
23+
24+
# TODO: Remove after NumPy support is removed
25+
_access_types = ("read_only", "write_only", "read_write")
26+
27+
def _check_for_invalid_access_type(self, access_type):
28+
if access_type not in Packer._access_types:
29+
raise UnsupportedAccessQualifierError()
30+
# msg = (
31+
# "[!] %s is not a valid access type. "
32+
# "Supported access types are [" % (access_type)
33+
# )
34+
# for key in self.valid_access_types:
35+
# msg += " %s |" % (key)
36+
37+
# msg = msg[:-1] + "]"
38+
# if access_type is not None:
39+
# print(msg)
40+
# return True
41+
# else:
42+
# return False
43+
44+
def _get_info_from_suai(self, obj):
45+
"""
46+
Extracts the metadata of an arrya-like object that provides a
47+
__sycl_usm_array_interface__ (SUAI) attribute.
48+
49+
The ``dpctl.memory.as_usm_memory`` function converts the array-like
50+
object into a dpctl.memory.USMMemory object. Using the ``as_usm_memory``
51+
is an implicit way to verify if the array-like object is a legal
52+
SYCL USM memory back Python object that can be passed to a dpex kernel.
53+
54+
Args:
55+
obj: array-like object with a SUAI attribute.
56+
57+
Returns:
58+
usm_mem: USM memory object.
59+
total_size: Total number of items in the array.
60+
shape: Shape of the array.
61+
ndim: Total number of dimensions.
62+
itemsize: Size of each item.
63+
strides: Stride of the array.
64+
dtype: Dtype of the array.
65+
"""
66+
try:
67+
usm_mem = dpctl_mem.as_usm_memory(obj)
68+
except Exception:
69+
logging.exception(
70+
"array-like object does not implement the SUAI protocol."
71+
)
72+
# TODO
73+
raise SUAIProtocolError()
74+
75+
shape = obj.__sycl_usm_array_interface__["shape"]
76+
total_size = np.prod(obj.__sycl_usm_array_interface__["shape"])
77+
ndim = len(obj.__sycl_usm_array_interface__["shape"])
78+
itemsize = np.dtype(
79+
obj.__sycl_usm_array_interface__["typestr"]
80+
).itemsize
81+
dtype = np.dtype(obj.__sycl_usm_array_interface__["typestr"])
82+
strides = obj.__sycl_usm_array_interface__["strides"]
83+
84+
if strides is None:
85+
strides = [1] * ndim
86+
for i in reversed(range(1, ndim)):
87+
strides[i - 1] = strides[i] * shape[i]
88+
strides = tuple(strides)
89+
90+
return usm_mem, total_size, shape, ndim, itemsize, strides, dtype
91+
92+
def _unpack_array_helper(self, size, itemsize, buf, shape, strides, ndim):
93+
"""
94+
Implements the unpacking logic for array arguments.
95+
96+
TODO: Add more detail
97+
98+
Args:
99+
size: Total number of elements in the array.
100+
itemsize: Size in bytes of each element in the array.
101+
buf: The pointer to the memory.
102+
shape: The shape of the array.
103+
ndim: Number of dimension.
104+
105+
Returns:
106+
A list a ctype value for each array attribute argument
107+
"""
108+
unpacked_array_attrs = []
109+
110+
# meminfo (FIXME: should be removed and the USMArrayType modified once
111+
# NumPy support is removed)
112+
unpacked_array_attrs.append(ctypes.c_size_t(0))
113+
# meminfo (FIXME: Evaluate if the attribute should be removed and the
114+
# USMArrayType modified once NumPy support is removed)
115+
unpacked_array_attrs.append(ctypes.c_size_t(0))
116+
unpacked_array_attrs.append(ctypes.c_longlong(size))
117+
unpacked_array_attrs.append(ctypes.c_longlong(itemsize))
118+
unpacked_array_attrs.append(buf)
119+
for ax in range(ndim):
120+
unpacked_array_attrs.append(ctypes.c_longlong(shape[ax]))
121+
for ax in range(ndim):
122+
unpacked_array_attrs.append(ctypes.c_longlong(strides[ax]))
123+
124+
return unpacked_array_attrs
125+
126+
def _unpack_usm_array(self, val):
127+
(
128+
usm_mem,
129+
total_size,
130+
shape,
131+
ndim,
132+
itemsize,
133+
strides,
134+
dtype,
135+
) = self._get_info_from_suai(val)
136+
137+
return self._unpack_device_array_argument(
138+
total_size,
139+
itemsize,
140+
usm_mem,
141+
shape,
142+
strides,
143+
ndim,
144+
)
145+
146+
def _unpack_array(self, val, access_type):
147+
packed_val = val
148+
# Check if the NumPy array is backed by USM memory
149+
usm_mem = utils.has_usm_memory(val)
150+
151+
# If the NumPy array is not USM backed, then copy to a USM memory
152+
# object. Add an entry to the repack_map so that on exit from kernel
153+
# the USM object can be copied back into the NumPy array.
154+
if usm_mem is None:
155+
self._check_for_invalid_access_type(access_type)
156+
usm_mem = utils.as_usm_obj(val, queue=self._queue, copy=False)
157+
158+
orig_val = val
159+
packed = False
160+
if not val.flags.c_contiguous:
161+
# If the numpy.ndarray is not C-contiguous
162+
# we pack the strided array into a packed array.
163+
# This allows us to treat the data from here on as C-contiguous.
164+
# While packing we treat the data as C-contiguous.
165+
# We store the reference of both (strided and packed)
166+
# array and during unpacking we use numpy.copyto() to copy
167+
# the data back from the packed temporary array to the
168+
# original strided array.
169+
packed_val = val.flatten(order="C")
170+
packed = True
171+
172+
if access_type == "read_only":
173+
utils.copy_from_numpy_to_usm_obj(usm_mem, packed_val)
174+
elif access_type == "read_write":
175+
utils.copy_from_numpy_to_usm_obj(usm_mem, packed_val)
176+
# Store to the repack map
177+
self._repack_map.update(
178+
{orig_val: (usm_mem, packed_val, packed)}
179+
)
180+
elif access_type == "write_only":
181+
self._repack_map.update(
182+
{orig_val: (usm_mem, packed_val, packed)}
183+
)
184+
185+
return self._unpack_array_helper(
186+
packed_val.size,
187+
packed_val.dtype.itemsize,
188+
usm_mem,
189+
packed_val.shape,
190+
packed_val.strides,
191+
packed_val.ndim,
192+
)
193+
194+
def _unpack_argument(self, ty, val):
195+
"""
196+
Unpack a Python object into a ctype value using Numba's
197+
type-inference machinery.
198+
199+
Args:
200+
ty: The data types of the kernel argument defined as in instance of
201+
numba.types.
202+
val: The value of the kernel argument.
203+
204+
Raises:
205+
UnsupportedKernelArgumentError: When the argument is of an
206+
unsupported type.
207+
208+
"""
209+
210+
if isinstance(ty, USMNdArrayType):
211+
return self._unpack_usm_array(val)
212+
elif isinstance(ty, Array):
213+
return self._unpack_array(val)
214+
elif ty == types.int64:
215+
return ctypes.c_longlong(val)
216+
elif ty == types.uint64:
217+
return ctypes.c_ulonglong(val)
218+
elif ty == types.int32:
219+
return ctypes.c_int(val)
220+
elif ty == types.uint32:
221+
return ctypes.c_uint(val)
222+
elif ty == types.float64:
223+
return ctypes.c_double(val)
224+
elif ty == types.float32:
225+
return ctypes.c_float(val)
226+
elif ty == types.boolean:
227+
return ctypes.c_uint8(int(val))
228+
elif ty == types.complex64:
229+
raise UnsupportedKernelArgumentError(ty, val)
230+
elif ty == types.complex128:
231+
raise UnsupportedKernelArgumentError(ty, val)
232+
else:
233+
raise UnsupportedKernelArgumentError(ty, val)
234+
235+
def _pack_array(self):
236+
"""
237+
Copy device data back to host
238+
"""
239+
for obj in self._repack_map.keys():
240+
241+
(usm_mem, packed_ndarr, packed) = self._repack_map[obj]
242+
utils.copy_to_numpy_from_usm_obj(usm_mem, packed_ndarr)
243+
if packed:
244+
np.copyto(obj, packed_ndarr)
245+
246+
def __init__(self, arg_list, argty_list, queue) -> None:
247+
"""_summary_
248+
249+
Args:
250+
arg_list (_type_): _description_
251+
argty_list (_type_): _description_
252+
queue: _description_
253+
"""
254+
self._arg_list = arg_list
255+
self._argty_list = argty_list
256+
self._queue = queue
257+
258+
# loop over the arg_list and generate the kernelargs list
259+
self._unpacked_args = []
260+
for i, val in enumerate(arg_list):
261+
self._unpacked_args.append(
262+
self._unpack_argument(ty=argty_list[i], val=val)
263+
)
264+
265+
# Create a map for numpy arrays storing the unpacked information, as
266+
# these arrays will need to be repacked.
267+
self._repack_map = {}
268+
269+
@property
270+
def unpacked_args(self):
271+
return self._unpacked_args
272+
273+
@property
274+
def repacked_args(self):
275+
self._pack_array()
276+
return self._repack_map.keys()

0 commit comments

Comments
 (0)