From e0a8d7a8a193fe9b7c60f20dc0fcf0bb470ed8b9 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 1 Feb 2025 15:01:55 +0100 Subject: [PATCH] Require that one of x1,x2 arguments to where is an array --- array_api_strict/_searching_functions.py | 17 ++++++++++++++--- .../tests/test_searching_functions.py | 6 +++++- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index 046b569..ad32aaa 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -90,18 +90,29 @@ def searchsorted( # x1 must be 1-D, but NumPy already requires this. return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter), device=x1.device) -def where(condition: Array, x1: bool | int | float | Array, x2: bool | int | float | Array, /) -> Array: +def where( + condition: Array, + x1: bool | int | float | complex | Array, + x2: bool | int | float | complex | Array, / +) -> Array: """ Array API compatible wrapper for :py:func:`np.where `. See its docstring for more information. """ if get_array_api_strict_flags()['api_version'] > '2023.12': - if isinstance(x1, (bool, float, int)): + num_scalars = 0 + + if isinstance(x1, (bool, float, complex, int)): x1 = Array._new(np.asarray(x1), device=condition.device) + num_scalars += 1 - if isinstance(x2, (bool, float, int)): + if isinstance(x2, (bool, float, complex, int)): x2 = Array._new(np.asarray(x2), device=condition.device) + num_scalars += 1 + + if num_scalars == 2: + raise ValueError("One of x1, x2 arguments must be an array.") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) diff --git a/array_api_strict/tests/test_searching_functions.py b/array_api_strict/tests/test_searching_functions.py index 8ccd1dd..dfb3fe7 100644 --- a/array_api_strict/tests/test_searching_functions.py +++ b/array_api_strict/tests/test_searching_functions.py @@ -20,7 +20,11 @@ def test_where_with_scalars(): ), ArrayAPIStrictFlags(api_version=draft_version), ): - x_where = xp.where(x == 1, 42, 44) + x_where = xp.where(x == 1, xp.asarray(42), 44) expected = xp.asarray([42, 44, 44, 42]) assert xp.all(x_where == expected) + + # The spec does not allow both x1 and x2 to be scalars + with pytest.raises(ValueError, match="One of"): + xp.where(x == 1, 42, 44)