|
2 | 2 |
|
3 | 3 | from ._array_object import Array |
4 | 4 | from ._dtypes import _result_type, _real_numeric_dtypes, bool as _bool |
5 | | -from ._flags import requires_data_dependent_shapes, requires_api_version |
| 5 | +from ._flags import requires_data_dependent_shapes, requires_api_version, get_array_api_strict_flags |
6 | 6 |
|
7 | 7 | from typing import TYPE_CHECKING |
8 | 8 | if TYPE_CHECKING: |
@@ -90,20 +90,38 @@ def searchsorted( |
90 | 90 | # x1 must be 1-D, but NumPy already requires this. |
91 | 91 | return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter), device=x1.device) |
92 | 92 |
|
93 | | -def where(condition: Array, x1: Array, x2: Array, /) -> Array: |
| 93 | +def where( |
| 94 | + condition: Array, |
| 95 | + x1: bool | int | float | complex | Array, |
| 96 | + x2: bool | int | float | complex | Array, / |
| 97 | +) -> Array: |
94 | 98 | """ |
95 | 99 | Array API compatible wrapper for :py:func:`np.where <numpy.where>`. |
96 | 100 |
|
97 | 101 | See its docstring for more information. |
98 | 102 | """ |
| 103 | + if get_array_api_strict_flags()['api_version'] > '2023.12': |
| 104 | + num_scalars = 0 |
| 105 | + |
| 106 | + if isinstance(x1, (bool, float, complex, int)): |
| 107 | + x1 = Array._new(np.asarray(x1), device=condition.device) |
| 108 | + num_scalars += 1 |
| 109 | + |
| 110 | + if isinstance(x2, (bool, float, complex, int)): |
| 111 | + x2 = Array._new(np.asarray(x2), device=condition.device) |
| 112 | + num_scalars += 1 |
| 113 | + |
| 114 | + if num_scalars == 2: |
| 115 | + raise ValueError("One of x1, x2 arguments must be an array.") |
| 116 | + |
99 | 117 | # Call result type here just to raise on disallowed type combinations |
100 | 118 | _result_type(x1.dtype, x2.dtype) |
101 | 119 |
|
102 | 120 | if condition.dtype != _bool: |
103 | 121 | raise TypeError("`condition` must be have a boolean data type") |
104 | 122 |
|
105 | 123 | if len({a.device for a in (condition, x1, x2)}) > 1: |
106 | | - raise ValueError("where inputs must all be on the same device") |
| 124 | + raise ValueError("Inputs to `where` must all use the same device") |
107 | 125 |
|
108 | 126 | x1, x2 = Array._normalize_two_args(x1, x2) |
109 | 127 | return Array._new(np.where(condition._array, x1._array, x2._array), device=x1.device) |
0 commit comments