Skip to content

Commit 481eafb

Browse files
committed
Test cases added
1 parent 4486cf5 commit 481eafb

File tree

5 files changed

+210
-5
lines changed

5 files changed

+210
-5
lines changed

numba_dpex/dpnp_iface/arrayobj.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,8 @@ def ol_dpnp_empty_like(
431431
x,
432432
dtype=None,
433433
order="C",
434+
subok=False,
435+
shape=None,
434436
device=None,
435437
usm_type=None,
436438
sycl_queue=None,
@@ -471,6 +473,8 @@ def ol_dpnp_empty_like(
471473
_ndim = x.ndim if hasattr(x, "ndim") and x.ndim is not None else 0
472474
_dtype = _parse_dtype(dtype, data=x)
473475
_order = x.layout if order is None else order
476+
_subok = subok
477+
_shape = shape
474478
_usm_type = _parse_usm_type(usm_type) if usm_type is not None else "device"
475479
_device = (
476480
_parse_device_filter_string(device) if device is not None else "unknown"
@@ -489,6 +493,8 @@ def impl(
489493
x,
490494
dtype=None,
491495
order="C",
496+
subok=False,
497+
shape=None,
492498
device=None,
493499
usm_type=None,
494500
sycl_queue=None,
@@ -497,6 +503,8 @@ def impl(
497503
x,
498504
_dtype,
499505
_order,
506+
_subok,
507+
_shape,
500508
_device,
501509
_usm_type,
502510
sycl_queue,
@@ -516,6 +524,8 @@ def ol_dpnp_zeros_like(
516524
x,
517525
dtype=None,
518526
order="C",
527+
subok=False,
528+
shape=None,
519529
device=None,
520530
usm_type=None,
521531
sycl_queue=None,
@@ -555,6 +565,8 @@ def ol_dpnp_zeros_like(
555565
_ndim = x.ndim if hasattr(x, "ndim") and x.ndim is not None else 0
556566
_dtype = _parse_dtype(dtype, data=x)
557567
_order = x.layout if order is None else order
568+
_subok = subok
569+
_shape = shape
558570
_usm_type = _parse_usm_type(usm_type) if usm_type is not None else "device"
559571
_device = (
560572
_parse_device_filter_string(device) if device is not None else "unknown"
@@ -573,6 +585,8 @@ def impl(
573585
x,
574586
dtype=None,
575587
order="C",
588+
subok=False,
589+
shape=None,
576590
device=None,
577591
usm_type=None,
578592
sycl_queue=None,
@@ -581,6 +595,8 @@ def impl(
581595
x,
582596
_dtype,
583597
_order,
598+
_subok,
599+
_shape,
584600
_device,
585601
_usm_type,
586602
sycl_queue,
@@ -600,6 +616,8 @@ def ol_dpnp_ones_like(
600616
x,
601617
dtype=None,
602618
order="C",
619+
subok=False,
620+
shape=None,
603621
device=None,
604622
usm_type=None,
605623
sycl_queue=None,
@@ -639,6 +657,8 @@ def ol_dpnp_ones_like(
639657
_ndim = x.ndim if hasattr(x, "ndim") and x.ndim is not None else 0
640658
_dtype = _parse_dtype(dtype, data=x)
641659
_order = x.layout if order is None else order
660+
_subok = subok
661+
_shape = shape
642662
_usm_type = _parse_usm_type(usm_type) if usm_type is not None else "device"
643663
_device = (
644664
_parse_device_filter_string(device) if device is not None else "unknown"
@@ -657,6 +677,8 @@ def impl(
657677
x,
658678
dtype=None,
659679
order="C",
680+
subok=False,
681+
shape=None,
660682
device=None,
661683
usm_type=None,
662684
sycl_queue=None,
@@ -665,6 +687,8 @@ def impl(
665687
x,
666688
_dtype,
667689
_order,
690+
_subok,
691+
_shape,
668692
_device,
669693
_usm_type,
670694
sycl_queue,

numba_dpex/dpnp_iface/intrinsic.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,12 @@ def alloc_empty_arrayobj(context, builder, sig, llargs, is_like=False):
173173
in DpnpNdArray.
174174
"""
175175

176-
arrtype = (
177-
_parse_empty_like_args(context, builder, sig, llargs)
178-
if is_like
179-
else _parse_empty_args(context, builder, sig, llargs)
180-
)
176+
if is_like:
177+
arrtype = _parse_empty_like_args(context, builder, sig, llargs)
178+
else:
179+
arrtype = _parse_empty_args(context, builder, sig, llargs)
181180
ary = _empty_nd_impl(context, builder, *arrtype)
181+
182182
return ary, arrtype
183183

184184

@@ -406,6 +406,8 @@ def impl_dpnp_empty_like(
406406
ty_x,
407407
ty_dtype,
408408
ty_order,
409+
ty_subok,
410+
ty_shape,
409411
ty_device,
410412
ty_usm_type,
411413
ty_sycl_queue,
@@ -440,6 +442,8 @@ def impl_dpnp_empty_like(
440442
ty_x,
441443
ty_dtype,
442444
ty_order,
445+
ty_subok,
446+
ty_shape,
443447
ty_device,
444448
ty_usm_type,
445449
ty_sycl_queue,
@@ -461,6 +465,8 @@ def impl_dpnp_zeros_like(
461465
ty_x,
462466
ty_dtype,
463467
ty_order,
468+
ty_subok,
469+
ty_shape,
464470
ty_device,
465471
ty_usm_type,
466472
ty_sycl_queue,
@@ -495,6 +501,8 @@ def impl_dpnp_zeros_like(
495501
ty_x,
496502
ty_dtype,
497503
ty_order,
504+
ty_subok,
505+
ty_shape,
498506
ty_device,
499507
ty_usm_type,
500508
ty_sycl_queue,
@@ -516,6 +524,8 @@ def impl_dpnp_ones_like(
516524
ty_x,
517525
ty_dtype,
518526
ty_order,
527+
ty_subok,
528+
ty_shape,
519529
ty_device,
520530
ty_usm_type,
521531
ty_sycl_queue,
@@ -550,6 +560,8 @@ def impl_dpnp_ones_like(
550560
ty_x,
551561
ty_dtype,
552562
ty_order,
563+
ty_subok,
564+
ty_shape,
553565
ty_device,
554566
ty_usm_type,
555567
ty_sycl_queue,
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Tests for dpnp ndarray constructors."""
6+
7+
import dpctl
8+
import dpnp
9+
import numpy
10+
import pytest
11+
12+
from numba_dpex import dpjit
13+
14+
shapes = [10, (2, 5)]
15+
dtypes = [dpnp.int32, dpnp.int64, dpnp.float32, dpnp.float64]
16+
usm_types = ["device", "shared", "host"]
17+
devices = ["cpu", "unknown"]
18+
19+
20+
@pytest.mark.parametrize("shape", shapes)
21+
@pytest.mark.parametrize("dtype", dtypes)
22+
@pytest.mark.parametrize("usm_type", usm_types)
23+
@pytest.mark.parametrize("device", devices)
24+
def test_dpnp_empty_like(shape, dtype, usm_type, device):
25+
@dpjit
26+
def func1(a):
27+
c = dpnp.empty_like(a, dtype=dtype, usm_type=usm_type, device=device)
28+
return c
29+
30+
if isinstance(shape, int):
31+
NZ = numpy.random.rand(shape)
32+
else:
33+
NZ = numpy.random.rand(*shape)
34+
35+
try:
36+
c = func1(NZ)
37+
except Exception:
38+
pytest.fail("Calling dpnp.empty_like inside dpjit failed")
39+
40+
if len(c.shape) == 1:
41+
assert c.shape[0] == NZ.shape[0]
42+
else:
43+
assert c.shape == NZ.shape
44+
45+
assert c.dtype == dtype
46+
assert c.usm_type == usm_type
47+
if device != "unknown":
48+
assert (
49+
c.sycl_device.filter_string
50+
== dpctl.SyclDevice(device).filter_string
51+
)
52+
else:
53+
c.sycl_device.filter_string == dpctl.SyclDevice().filter_string
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Tests for dpnp ndarray constructors."""
6+
7+
import dpctl
8+
import dpctl.tensor as dpt
9+
import dpnp
10+
import numpy
11+
import pytest
12+
13+
from numba_dpex import dpjit
14+
15+
shapes = [11, (3, 7)]
16+
dtypes = [dpnp.int32, dpnp.int64, dpnp.float32, dpnp.float64]
17+
usm_types = ["device", "shared", "host"]
18+
devices = ["cpu", "unknown"]
19+
20+
21+
@pytest.mark.parametrize("shape", shapes)
22+
@pytest.mark.parametrize("dtype", dtypes)
23+
@pytest.mark.parametrize("usm_type", usm_types)
24+
@pytest.mark.parametrize("device", devices)
25+
def test_dpnp_ones(shape, dtype, usm_type, device):
26+
@dpjit
27+
def func1(a):
28+
c = dpnp.ones(a, dtype=dtype, usm_type=usm_type, device=device)
29+
return c
30+
31+
if isinstance(shape, int):
32+
NZ = numpy.random.rand(shape)
33+
else:
34+
NZ = numpy.random.rand(*shape)
35+
36+
try:
37+
c = func1(shape)
38+
except Exception:
39+
pytest.fail("Calling dpnp.empty inside dpjit failed")
40+
41+
if len(c.shape) == 1:
42+
assert c.shape[0] == NZ.shape[0]
43+
else:
44+
assert c.shape == NZ.shape
45+
46+
assert c.dtype == dtype
47+
assert c.usm_type == usm_type
48+
if device != "unknown":
49+
assert (
50+
c.sycl_device.filter_string
51+
== dpctl.SyclDevice(device).filter_string
52+
)
53+
else:
54+
c.sycl_device.filter_string == dpctl.SyclDevice().filter_string
55+
56+
assert numpy.array_equal(
57+
dpt.asnumpy(c._array_obj), numpy.ones_like(c._array_obj)
58+
)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Tests for dpnp ndarray constructors."""
6+
7+
import dpctl
8+
import dpctl.tensor as dpt
9+
import dpnp
10+
import numpy
11+
import pytest
12+
13+
from numba_dpex import dpjit
14+
15+
shapes = [11, (3, 7)]
16+
dtypes = [dpnp.int32, dpnp.int64, dpnp.float32, dpnp.float64]
17+
usm_types = ["device", "shared", "host"]
18+
devices = ["cpu", "unknown"]
19+
20+
21+
@pytest.mark.parametrize("shape", shapes)
22+
@pytest.mark.parametrize("dtype", dtypes)
23+
@pytest.mark.parametrize("usm_type", usm_types)
24+
@pytest.mark.parametrize("device", devices)
25+
def test_dpnp_zeros(shape, dtype, usm_type, device):
26+
@dpjit
27+
def func1(a):
28+
c = dpnp.zeros(a, dtype=dtype, usm_type=usm_type, device=device)
29+
return c
30+
31+
if isinstance(shape, int):
32+
NZ = numpy.random.rand(shape)
33+
else:
34+
NZ = numpy.random.rand(*shape)
35+
36+
try:
37+
c = func1(shape)
38+
except Exception:
39+
pytest.fail("Calling dpnp.empty inside dpjit failed")
40+
41+
if len(c.shape) == 1:
42+
assert c.shape[0] == NZ.shape[0]
43+
else:
44+
assert c.shape == NZ.shape
45+
46+
assert c.dtype == dtype
47+
assert c.usm_type == usm_type
48+
if device != "unknown":
49+
assert (
50+
c.sycl_device.filter_string
51+
== dpctl.SyclDevice(device).filter_string
52+
)
53+
else:
54+
c.sycl_device.filter_string == dpctl.SyclDevice().filter_string
55+
56+
assert numpy.array_equal(
57+
dpt.asnumpy(c._array_obj), numpy.zeros_like(c._array_obj)
58+
)

0 commit comments

Comments
 (0)