|
5 | 5 | from numba.core import cgutils, types |
6 | 6 | from numba.core.ir_utils import legalize_names |
7 | 7 |
|
8 | | -from numba_dpex import numpy_usm_shared as nus |
9 | 8 | from numba_dpex import utils |
10 | 9 | from numba_dpex.dpctl_iface import DpctlCAPIFnBuilder |
11 | 10 | from numba_dpex.dpctl_iface._helpers import numba_type_to_dpctl_typenum |
@@ -251,79 +250,60 @@ def process_kernel_arg( |
251 | 250 | context=self.context, type=types.voidptr |
252 | 251 | ) |
253 | 252 |
|
254 | | - if isinstance(arg_type, nus.UsmSharedArrayType): |
255 | | - self._form_kernel_arg_and_arg_ty( |
| 253 | + malloc_fn = DpctlCAPIFnBuilder.get_dpctl_malloc_shared( |
| 254 | + builder=self.builder, context=self.context |
| 255 | + ) |
| 256 | + memcpy_fn = DpctlCAPIFnBuilder.get_dpctl_queue_memcpy( |
| 257 | + builder=self.builder, context=self.context |
| 258 | + ) |
| 259 | + event_del_fn = DpctlCAPIFnBuilder.get_dpctl_event_delete( |
| 260 | + builder=self.builder, context=self.context |
| 261 | + ) |
| 262 | + event_wait_fn = DpctlCAPIFnBuilder.get_dpctl_event_wait( |
| 263 | + builder=self.builder, context=self.context |
| 264 | + ) |
| 265 | + |
| 266 | + # Not known to be USM so we need to copy to USM. |
| 267 | + buffer_name = "buffer_ptr" + str(self.cur_arg) |
| 268 | + # Create void * to hold new USM buffer. |
| 269 | + buffer_ptr = cgutils.alloca_once( |
| 270 | + self.builder, |
| 271 | + utils.get_llvm_type(context=self.context, type=types.voidptr), |
| 272 | + name=buffer_name, |
| 273 | + ) |
| 274 | + # Setup the args to the USM allocator, size and SYCL queue. |
| 275 | + args = [ |
| 276 | + self.builder.load(total_size), |
| 277 | + self.builder.load(sycl_queue_val), |
| 278 | + ] |
| 279 | + # Call USM shared allocator and store in buffer_ptr. |
| 280 | + self.builder.store(self.builder.call(malloc_fn, args), buffer_ptr) |
| 281 | + |
| 282 | + if legal_names[var] in modified_arrays: |
| 283 | + self.write_buffs.append((buffer_ptr, total_size, data_member)) |
| 284 | + else: |
| 285 | + self.read_only_buffs.append( |
| 286 | + (buffer_ptr, total_size, data_member) |
| 287 | + ) |
| 288 | + |
| 289 | + # We really need to detect when an array needs to be copied over |
| 290 | + if index < self.num_inputs: |
| 291 | + args = [ |
| 292 | + self.builder.load(sycl_queue_val), |
| 293 | + self.builder.load(buffer_ptr), |
256 | 294 | self.builder.bitcast( |
257 | 295 | self.builder.load(data_member), |
258 | 296 | utils.get_llvm_type( |
259 | 297 | context=self.context, type=types.voidptr |
260 | 298 | ), |
261 | 299 | ), |
262 | | - ty, |
263 | | - ) |
264 | | - else: |
265 | | - malloc_fn = DpctlCAPIFnBuilder.get_dpctl_malloc_shared( |
266 | | - builder=self.builder, context=self.context |
267 | | - ) |
268 | | - memcpy_fn = DpctlCAPIFnBuilder.get_dpctl_queue_memcpy( |
269 | | - builder=self.builder, context=self.context |
270 | | - ) |
271 | | - event_del_fn = DpctlCAPIFnBuilder.get_dpctl_event_delete( |
272 | | - builder=self.builder, context=self.context |
273 | | - ) |
274 | | - event_wait_fn = DpctlCAPIFnBuilder.get_dpctl_event_wait( |
275 | | - builder=self.builder, context=self.context |
276 | | - ) |
277 | | - |
278 | | - # Not known to be USM so we need to copy to USM. |
279 | | - buffer_name = "buffer_ptr" + str(self.cur_arg) |
280 | | - # Create void * to hold new USM buffer. |
281 | | - buffer_ptr = cgutils.alloca_once( |
282 | | - self.builder, |
283 | | - utils.get_llvm_type( |
284 | | - context=self.context, type=types.voidptr |
285 | | - ), |
286 | | - name=buffer_name, |
287 | | - ) |
288 | | - # Setup the args to the USM allocator, size and SYCL queue. |
289 | | - args = [ |
290 | 300 | self.builder.load(total_size), |
291 | | - self.builder.load(sycl_queue_val), |
292 | 301 | ] |
293 | | - # Call USM shared allocator and store in buffer_ptr. |
294 | | - self.builder.store( |
295 | | - self.builder.call(malloc_fn, args), buffer_ptr |
296 | | - ) |
297 | | - |
298 | | - if legal_names[var] in modified_arrays: |
299 | | - self.write_buffs.append( |
300 | | - (buffer_ptr, total_size, data_member) |
301 | | - ) |
302 | | - else: |
303 | | - self.read_only_buffs.append( |
304 | | - (buffer_ptr, total_size, data_member) |
305 | | - ) |
306 | | - |
307 | | - # We really need to detect when an array needs to be copied over |
308 | | - if index < self.num_inputs: |
309 | | - args = [ |
310 | | - self.builder.load(sycl_queue_val), |
311 | | - self.builder.load(buffer_ptr), |
312 | | - self.builder.bitcast( |
313 | | - self.builder.load(data_member), |
314 | | - utils.get_llvm_type( |
315 | | - context=self.context, type=types.voidptr |
316 | | - ), |
317 | | - ), |
318 | | - self.builder.load(total_size), |
319 | | - ] |
320 | | - event_ref = self.builder.call(memcpy_fn, args) |
321 | | - self.builder.call(event_wait_fn, [event_ref]) |
322 | | - self.builder.call(event_del_fn, [event_ref]) |
| 302 | + event_ref = self.builder.call(memcpy_fn, args) |
| 303 | + self.builder.call(event_wait_fn, [event_ref]) |
| 304 | + self.builder.call(event_del_fn, [event_ref]) |
323 | 305 |
|
324 | | - self._form_kernel_arg_and_arg_ty( |
325 | | - self.builder.load(buffer_ptr), ty |
326 | | - ) |
| 306 | + self._form_kernel_arg_and_arg_ty(self.builder.load(buffer_ptr), ty) |
327 | 307 |
|
328 | 308 | # Handle shape |
329 | 309 | shape_member = self.builder.gep( |
|
0 commit comments