From 604c86952491dc7acebad0b9d76fff04283d607e Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 6 Jan 2023 17:09:06 +0100 Subject: [PATCH 1/7] Add tests --- xarray/tests/test_array_api.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index 7940c979249..2adcfb8685a 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -39,17 +39,20 @@ def test_arithmetic(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: assert_equal(actual, expected) -def test_aggregation(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: +@pytest.mark.parametrize("method", ["max", "min", "mean", "prod", "sum", "std", "var"]) +def test_aggregation(method: str, arrays: tuple[xr.DataArray, xr.DataArray]) -> None: np_arr, xp_arr = arrays - expected = np_arr.sum() - actual = xp_arr.sum() + expected = getattr(np_arr, method)() + actual = getattr(xp_arr, method)() assert isinstance(actual.data, Array) assert_equal(actual, expected) -def test_aggregation_skipna(arrays) -> None: +@pytest.mark.parametrize("method", ["max", "min", "mean", "prod", "sum", "std", "var"]) +def test_aggregation_skipna(method: str, arrays) -> None: np_arr, xp_arr = arrays - expected = np_arr.sum(skipna=False) + expected = getattr(np_arr, method)(skipna=False) + actual = getattr(xp_arr, method)(skipna=False) actual = xp_arr.sum(skipna=False) assert isinstance(actual.data, Array) assert_equal(actual, expected) From 03cf79ae51737e4e3463b3fc601ea1a255363c43 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 6 Jan 2023 18:10:02 +0100 Subject: [PATCH 2/7] move get_array_namespace to nputils --- xarray/core/duck_array_ops.py | 9 +-------- xarray/core/nputils.py | 7 +++++++ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 35239004af4..0ea79039a7b 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -33,20 +33,13 @@ from numpy.lib.stride_tricks import sliding_window_view # noqa from xarray.core import dask_array_ops, dtypes, nputils -from xarray.core.nputils import nanfirst, nanlast +from xarray.core.nputils import nanfirst, nanlast, get_array_namespace from xarray.core.pycompat import array_type, is_duck_dask_array from xarray.core.utils import is_duck_array, module_available dask_available = module_available("dask") -def get_array_namespace(x): - if hasattr(x, "__array_namespace__"): - return x.__array_namespace__() - else: - return np - - def _dask_or_eager_func( name, eager_module=np, diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index 80c988ebd4f..f29486e3435 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -18,6 +18,13 @@ _USE_BOTTLENECK = False +def get_array_namespace(x): + if hasattr(x, "__array_namespace__"): + return x.__array_namespace__() + else: + return np + + def _select_along_axis(values, idx, axis): other_ind = np.ix_(*[np.arange(s) for s in idx.shape]) sl = other_ind[:axis] + (idx,) + other_ind[axis:] From 864ab02484524a84deaf14599860d4d782b22718 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 6 Jan 2023 18:19:20 +0100 Subject: [PATCH 3/7] move get_array_namespace to utils, --- xarray/core/nputils.py | 8 +------- xarray/core/utils.py | 7 +++++++ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index f29486e3435..d9586176883 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -7,6 +7,7 @@ from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined] from xarray.core.options import OPTIONS +from xarray.core.utils import get_array_namespace try: import bottleneck as bn @@ -18,13 +19,6 @@ _USE_BOTTLENECK = False -def get_array_namespace(x): - if hasattr(x, "__array_namespace__"): - return x.__array_namespace__() - else: - return np - - def _select_along_axis(values, idx, axis): other_ind = np.ix_(*[np.arange(s) for s in idx.shape]) sl = other_ind[:axis] + (idx,) + other_ind[axis:] diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 86c644de5f0..d533bb6e553 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -265,6 +265,13 @@ def is_duck_array(value: Any) -> bool: ) +def get_array_namespace(x): + if hasattr(x, "__array_namespace__"): + return x.__array_namespace__() + else: + return np + + def either_dict_or_kwargs( pos_kwargs: Mapping[Any, T] | None, kw_kwargs: Mapping[str, T], From c91559220ccb0c911575143186b845b87f795c4c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 6 Jan 2023 18:54:47 +0100 Subject: [PATCH 4/7] Update duck_array_ops.py --- xarray/core/duck_array_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 0ea79039a7b..43b7e1d9ef5 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -33,9 +33,9 @@ from numpy.lib.stride_tricks import sliding_window_view # noqa from xarray.core import dask_array_ops, dtypes, nputils -from xarray.core.nputils import nanfirst, nanlast, get_array_namespace +from xarray.core.nputils import nanfirst, nanlast from xarray.core.pycompat import array_type, is_duck_dask_array -from xarray.core.utils import is_duck_array, module_available +from xarray.core.utils import is_duck_array, module_available, get_array_namespace dask_available = module_available("dask") From 44d291fd0825805cf5bf534df4efae4c79ef072e Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 6 Jan 2023 18:54:55 +0100 Subject: [PATCH 5/7] Update test_array_api.py --- xarray/tests/test_array_api.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index 7a8b671099e..3c686493a13 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -14,6 +14,8 @@ import numpy.array_api as xp # isort:skip from numpy.array_api._array_object import Array # isort:skip +_STATISTICAL_FUNCTIONS = ("max", "min", "mean", "prod", "sum", "std", "var") + @pytest.fixture def arrays() -> tuple[xr.DataArray, xr.DataArray]: @@ -39,7 +41,7 @@ def test_arithmetic(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: assert_equal(actual, expected) -@pytest.mark.parametrize("method", ["max", "min", "mean", "prod", "sum", "std", "var"]) +@pytest.mark.parametrize("method", _STATISTICAL_FUNCTIONS) def test_aggregation(method: str, arrays: tuple[xr.DataArray, xr.DataArray]) -> None: np_arr, xp_arr = arrays expected = getattr(np_arr, method)() @@ -48,7 +50,7 @@ def test_aggregation(method: str, arrays: tuple[xr.DataArray, xr.DataArray]) -> assert_equal(actual, expected) -@pytest.mark.parametrize("method", ["max", "min", "mean", "prod", "sum", "std", "var"]) +@pytest.mark.parametrize("method", _STATISTICAL_FUNCTIONS) def test_aggregation_skipna(method: str, arrays) -> None: np_arr, xp_arr = arrays expected = getattr(np_arr, method)(skipna=False) From c8621f8f579cd4da94789dc690f31039e65db9b1 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 6 Jan 2023 18:57:55 +0100 Subject: [PATCH 6/7] use array_namespace instead of only numpy when creating bottleneck method --- xarray/core/nputils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index d9586176883..304235d72ca 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -136,7 +136,7 @@ def __setitem__(self, key, value): self._array[key] = np.moveaxis(value, vindex_positions, mixed_positions) -def _create_bottleneck_method(name, npmodule=np): +def _create_bottleneck_method(name): def f(values, axis=None, **kwargs): dtype = kwargs.get("dtype", None) bn_func = getattr(bn, name, None) @@ -155,7 +155,8 @@ def f(values, axis=None, **kwargs): kwargs.pop("dtype", None) result = bn_func(values, axis=axis, **kwargs) else: - result = getattr(npmodule, name)(values, axis=axis, **kwargs) + xp = get_array_namespace(values) + result = getattr(xp, name)(values, axis=axis, **kwargs) return result From 0537073957525eb2fd0483952618ff5105f84f20 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Jan 2023 17:59:24 +0000 Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/duck_array_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 43b7e1d9ef5..aac0f8db942 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -35,7 +35,7 @@ from xarray.core import dask_array_ops, dtypes, nputils from xarray.core.nputils import nanfirst, nanlast from xarray.core.pycompat import array_type, is_duck_dask_array -from xarray.core.utils import is_duck_array, module_available, get_array_namespace +from xarray.core.utils import get_array_namespace, is_duck_array, module_available dask_available = module_available("dask")