Skip to content

Commit 9af4d9f

Browse files
author
Diptorup Deb
authored
Merge pull request #841 from IntelPython/remove/numpy_usm_shared
Removes the numpy_usm_shared module from numba_dpex.
2 parents bf10086 + 24a2cf5 commit 9af4d9f

File tree

9 files changed

+46
-1665
lines changed

9 files changed

+46
-1665
lines changed

.flake8

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,4 @@ exclude =
1717
.git,
1818
__pycache__,
1919
_version.py,
20-
numpy_usm_shared.py,
2120
lowerer.py,

numba_dpex/dpctl_iface/kernel_launch_ops.py

Lines changed: 45 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from numba.core import cgutils, types
66
from numba.core.ir_utils import legalize_names
77

8-
from numba_dpex import numpy_usm_shared as nus
98
from numba_dpex import utils
109
from numba_dpex.dpctl_iface import DpctlCAPIFnBuilder
1110
from numba_dpex.dpctl_iface._helpers import numba_type_to_dpctl_typenum
@@ -251,79 +250,60 @@ def process_kernel_arg(
251250
context=self.context, type=types.voidptr
252251
)
253252

254-
if isinstance(arg_type, nus.UsmSharedArrayType):
255-
self._form_kernel_arg_and_arg_ty(
253+
malloc_fn = DpctlCAPIFnBuilder.get_dpctl_malloc_shared(
254+
builder=self.builder, context=self.context
255+
)
256+
memcpy_fn = DpctlCAPIFnBuilder.get_dpctl_queue_memcpy(
257+
builder=self.builder, context=self.context
258+
)
259+
event_del_fn = DpctlCAPIFnBuilder.get_dpctl_event_delete(
260+
builder=self.builder, context=self.context
261+
)
262+
event_wait_fn = DpctlCAPIFnBuilder.get_dpctl_event_wait(
263+
builder=self.builder, context=self.context
264+
)
265+
266+
# Not known to be USM so we need to copy to USM.
267+
buffer_name = "buffer_ptr" + str(self.cur_arg)
268+
# Create void * to hold new USM buffer.
269+
buffer_ptr = cgutils.alloca_once(
270+
self.builder,
271+
utils.get_llvm_type(context=self.context, type=types.voidptr),
272+
name=buffer_name,
273+
)
274+
# Setup the args to the USM allocator, size and SYCL queue.
275+
args = [
276+
self.builder.load(total_size),
277+
self.builder.load(sycl_queue_val),
278+
]
279+
# Call USM shared allocator and store in buffer_ptr.
280+
self.builder.store(self.builder.call(malloc_fn, args), buffer_ptr)
281+
282+
if legal_names[var] in modified_arrays:
283+
self.write_buffs.append((buffer_ptr, total_size, data_member))
284+
else:
285+
self.read_only_buffs.append(
286+
(buffer_ptr, total_size, data_member)
287+
)
288+
289+
# We really need to detect when an array needs to be copied over
290+
if index < self.num_inputs:
291+
args = [
292+
self.builder.load(sycl_queue_val),
293+
self.builder.load(buffer_ptr),
256294
self.builder.bitcast(
257295
self.builder.load(data_member),
258296
utils.get_llvm_type(
259297
context=self.context, type=types.voidptr
260298
),
261299
),
262-
ty,
263-
)
264-
else:
265-
malloc_fn = DpctlCAPIFnBuilder.get_dpctl_malloc_shared(
266-
builder=self.builder, context=self.context
267-
)
268-
memcpy_fn = DpctlCAPIFnBuilder.get_dpctl_queue_memcpy(
269-
builder=self.builder, context=self.context
270-
)
271-
event_del_fn = DpctlCAPIFnBuilder.get_dpctl_event_delete(
272-
builder=self.builder, context=self.context
273-
)
274-
event_wait_fn = DpctlCAPIFnBuilder.get_dpctl_event_wait(
275-
builder=self.builder, context=self.context
276-
)
277-
278-
# Not known to be USM so we need to copy to USM.
279-
buffer_name = "buffer_ptr" + str(self.cur_arg)
280-
# Create void * to hold new USM buffer.
281-
buffer_ptr = cgutils.alloca_once(
282-
self.builder,
283-
utils.get_llvm_type(
284-
context=self.context, type=types.voidptr
285-
),
286-
name=buffer_name,
287-
)
288-
# Setup the args to the USM allocator, size and SYCL queue.
289-
args = [
290300
self.builder.load(total_size),
291-
self.builder.load(sycl_queue_val),
292301
]
293-
# Call USM shared allocator and store in buffer_ptr.
294-
self.builder.store(
295-
self.builder.call(malloc_fn, args), buffer_ptr
296-
)
297-
298-
if legal_names[var] in modified_arrays:
299-
self.write_buffs.append(
300-
(buffer_ptr, total_size, data_member)
301-
)
302-
else:
303-
self.read_only_buffs.append(
304-
(buffer_ptr, total_size, data_member)
305-
)
306-
307-
# We really need to detect when an array needs to be copied over
308-
if index < self.num_inputs:
309-
args = [
310-
self.builder.load(sycl_queue_val),
311-
self.builder.load(buffer_ptr),
312-
self.builder.bitcast(
313-
self.builder.load(data_member),
314-
utils.get_llvm_type(
315-
context=self.context, type=types.voidptr
316-
),
317-
),
318-
self.builder.load(total_size),
319-
]
320-
event_ref = self.builder.call(memcpy_fn, args)
321-
self.builder.call(event_wait_fn, [event_ref])
322-
self.builder.call(event_del_fn, [event_ref])
302+
event_ref = self.builder.call(memcpy_fn, args)
303+
self.builder.call(event_wait_fn, [event_ref])
304+
self.builder.call(event_del_fn, [event_ref])
323305

324-
self._form_kernel_arg_and_arg_ty(
325-
self.builder.load(buffer_ptr), ty
326-
)
306+
self._form_kernel_arg_and_arg_ty(self.builder.load(buffer_ptr), ty)
327307

328308
# Handle shape
329309
shape_member = self.builder.gep(

0 commit comments

Comments
 (0)