diff --git a/numba_dpex/core/typeconv/array_conversion.py b/numba_dpex/core/typeconv/array_conversion.py index 5096045a90..228265b6ec 100644 --- a/numba_dpex/core/typeconv/array_conversion.py +++ b/numba_dpex/core/typeconv/array_conversion.py @@ -37,7 +37,6 @@ def to_usm_ndarray(suai_attrs, addrspace=address_space.GLOBAL): ndim=suai_attrs.dimensions, layout=layout, usm_type=suai_attrs.usm_type, - device=suai_attrs.device, queue=suai_attrs.queue, readonly=not suai_attrs.is_writable, name=None, diff --git a/numba_dpex/core/types/usm_ndarray_type.py b/numba_dpex/core/types/usm_ndarray_type.py index e3d1101633..9a4790de38 100644 --- a/numba_dpex/core/types/usm_ndarray_type.py +++ b/numba_dpex/core/types/usm_ndarray_type.py @@ -31,46 +31,37 @@ def __init__( aligned=True, addrspace=address_space.GLOBAL, ): + if not isinstance(device, str): + raise TypeError( + "The device keyword arg should be a str object specifying " + "a SYCL filter selector" + ) + + if not isinstance(queue, dpctl.SyclQueue) and queue is not None: + raise TypeError( + "The queue keyword arg should be a dpctl.SyclQueue object or None" + ) + self.usm_type = usm_type self.addrspace = addrspace - if queue is not None and device != "unknown": - if not isinstance(device, str): - raise TypeError( - "The device keyword arg should be a str object specifying " - "a SYCL filter selector" - ) - if not isinstance(queue, dpctl.SyclQueue): - raise TypeError( - "The queue keyword arg should be a dpctl.SyclQueue object" - ) - d1 = queue.sycl_device - d2 = dpctl.SyclDevice(device) - if d1 != d2: - raise TypeError( - "The queue keyword arg and the device keyword arg specify " - "different SYCL devices" - ) - self.queue = queue - self.device = device - elif queue is None and device != "unknown": - if not isinstance(device, str): - raise TypeError( - "The device keyword arg should be a str object specifying " - "a SYCL filter selector" - ) - self.queue = dpctl.SyclQueue(device) - self.device = self.queue.sycl_device.filter_string - elif queue is not None and device == "unknown": - if not isinstance(queue, dpctl.SyclQueue): - raise TypeError( - "The queue keyword arg should be a dpctl.SyclQueue object" - ) - self.device = self.queue.sycl_device.filter_string + if device == "unknown": + device = None + + if queue is not None and device is not None: + raise TypeError( + "'queue' and 'device' keywords can not be both specified" + ) + + if queue is not None: self.queue = queue else: - self.queue = dpctl.SyclQueue() - self.device = self.queue.sycl_device.filter_string + if device is None: + device = dpctl.SyclDevice() + + self.queue = dpctl.get_device_cached_queue(device) + + self.device = self.queue.sycl_device.filter_string if not dtype: dummy_tensor = dpctl.tensor.empty( diff --git a/numba_dpex/core/typing/typeof.py b/numba_dpex/core/typing/typeof.py index 0069d1e9be..a9df706ad0 100644 --- a/numba_dpex/core/typing/typeof.py +++ b/numba_dpex/core/typing/typeof.py @@ -42,10 +42,7 @@ def _typeof_helper(val, array_class_type): "The usm_type for the usm_ndarray could not be inferred" ) - try: - device = val.sycl_device.filter_string - except AttributeError: - raise ValueError("The device for the usm_ndarray could not be inferred") + assert val.sycl_queue is not None return array_class_type( dtype=dtype, @@ -53,7 +50,6 @@ def _typeof_helper(val, array_class_type): layout=layout, readonly=readonly, usm_type=usm_type, - device=device, queue=val.sycl_queue, addrspace=address_space.GLOBAL, )