Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion numba_dpex/core/parfors/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from numba_dpex import config

from ..descriptor import dpex_kernel_target
from ..types.dpnp_ndarray_type import DpnpNdArray
from ..types import DpnpNdArray, USMNdArray
from ..utils.kernel_templates import RangeKernelTemplate


Expand Down Expand Up @@ -70,6 +70,30 @@ def _compile_kernel_parfor(
func_ir, kernel_name
)

# A cast from DpnpNdArray type to USMNdArray is needed for all arguments of
# DpnpNdArray type. Although, DpnpNdArray derives from USMNdArray the two
# types use different data models. USMNdArray uses the
# numba_dpex.core.datamodel.models.ArrayModel data model that defines all
# CPointer type members in the GLOBAL address space. The DpnpNdArray uses
# Numba's default ArrayModel that does not define pointers in any specific
# address space. For OpenCL HD Graphics devices, defining a kernel function
# (spir_kernel calling convention) with pointer arguments that have no
# address space qualifier causes a run time crash. By casting the argument
# type for parfor arguments from DpnpNdArray type to the USMNdArray type the
# generated kernel always has an address space qualifier, avoiding the issue
# on OpenCL HD graphics devices.

for i, argty in enumerate(argtypes):
if isinstance(argty, DpnpNdArray):
new_argty = USMNdArray(
ndim=argty.ndim,
layout=argty.layout,
dtype=argty.dtype,
usm_type=argty.usm_type,
queue=argty.queue,
)
argtypes[i] = new_argty

# compile the kernel
kernel.compile(
args=argtypes,
Expand Down
6 changes: 6 additions & 0 deletions numba_dpex/core/types/dpnp_ndarray_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
else:
return

def __str__(self):
return self.name.replace("USMNdArray", "DpnpNdarray")

def __repr__(self):
return self.__str__()

def __allocate__(
self,
typingctx,
Expand Down
5 changes: 4 additions & 1 deletion numba_dpex/core/types/usm_ndarray_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
self.dtype = dtype

if name is None:
type_name = "usm_ndarray"
type_name = "USMNdArray"
if readonly:
type_name = "readonly " + type_name
if not aligned:
Expand Down Expand Up @@ -116,6 +116,9 @@ def __init__(
aligned=aligned,
)

def __repr__(self):
return self.name

def copy(
self,
dtype=None,
Expand Down