From 3e7c51d724b08999c4801b8ed43df13ad250bec4 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Thu, 16 Feb 2023 12:27:15 -0600 Subject: [PATCH] Improve dpnp.empty unit tests. --- ...dpnp_empty_dpjit.py => test_dpnp_empty.py} | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) rename numba_dpex/tests/dpjit_tests/dpnp/{test_dpnp_empty_dpjit.py => test_dpnp_empty.py} (58%) diff --git a/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_empty_dpjit.py b/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_empty.py similarity index 58% rename from numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_empty_dpjit.py rename to numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_empty.py index ab83fe8b16..f8867bfa9d 100644 --- a/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_empty_dpjit.py +++ b/numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_empty.py @@ -4,6 +4,7 @@ """Tests for dpnp ndarray constructors.""" +import dpctl import dpnp import pytest @@ -20,10 +21,6 @@ @pytest.mark.parametrize("usm_type", usm_types) @pytest.mark.parametrize("device", devices) def test_dpnp_empty(shape, dtype, usm_type, device): - @dpjit - def func(shape): - dpnp.empty(shape=shape, dtype=dtype, usm_type=usm_type, device=device) - @dpjit def func1(shape): c = dpnp.empty( @@ -31,6 +28,22 @@ def func1(shape): ) return c - func(shape) - - func1(shape) + try: + c = func1(shape) + except Exception: + pytest.fail("Calling dpnp.empty inside dpjit failed") + + if len(c.shape) == 1: + assert c.shape[0] == shape + else: + assert c.shape == shape + + assert c.dtype == dtype + assert c.usm_type == usm_type + if device != "unknown": + assert ( + c.sycl_device.filter_string + == dpctl.SyclDevice(device).filter_string + ) + else: + c.sycl_device.filter_string == dpctl.SyclDevice().filter_string