diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 35239004af4..aac0f8db942 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -35,18 +35,11 @@ from xarray.core import dask_array_ops, dtypes, nputils from xarray.core.nputils import nanfirst, nanlast from xarray.core.pycompat import array_type, is_duck_dask_array -from xarray.core.utils import is_duck_array, module_available +from xarray.core.utils import get_array_namespace, is_duck_array, module_available dask_available = module_available("dask") -def get_array_namespace(x): - if hasattr(x, "__array_namespace__"): - return x.__array_namespace__() - else: - return np - - def _dask_or_eager_func( name, eager_module=np, diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index 80c988ebd4f..304235d72ca 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -7,6 +7,7 @@ from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined] from xarray.core.options import OPTIONS +from xarray.core.utils import get_array_namespace try: import bottleneck as bn @@ -135,7 +136,7 @@ def __setitem__(self, key, value): self._array[key] = np.moveaxis(value, vindex_positions, mixed_positions) -def _create_bottleneck_method(name, npmodule=np): +def _create_bottleneck_method(name): def f(values, axis=None, **kwargs): dtype = kwargs.get("dtype", None) bn_func = getattr(bn, name, None) @@ -154,7 +155,8 @@ def f(values, axis=None, **kwargs): kwargs.pop("dtype", None) result = bn_func(values, axis=axis, **kwargs) else: - result = getattr(npmodule, name)(values, axis=axis, **kwargs) + xp = get_array_namespace(values) + result = getattr(xp, name)(values, axis=axis, **kwargs) return result diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 86c644de5f0..d533bb6e553 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -265,6 +265,13 @@ def is_duck_array(value: Any) -> bool: ) +def get_array_namespace(x): + if hasattr(x, "__array_namespace__"): + return x.__array_namespace__() + else: + return np + + def either_dict_or_kwargs( pos_kwargs: Mapping[Any, T] | None, kw_kwargs: Mapping[str, T], diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index fddaa120970..3c686493a13 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -14,6 +14,8 @@ import numpy.array_api as xp # isort:skip from numpy.array_api._array_object import Array # isort:skip +_STATISTICAL_FUNCTIONS = ("max", "min", "mean", "prod", "sum", "std", "var") + @pytest.fixture def arrays() -> tuple[xr.DataArray, xr.DataArray]: @@ -39,17 +41,20 @@ def test_arithmetic(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: assert_equal(actual, expected) -def test_aggregation(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: +@pytest.mark.parametrize("method", _STATISTICAL_FUNCTIONS) +def test_aggregation(method: str, arrays: tuple[xr.DataArray, xr.DataArray]) -> None: np_arr, xp_arr = arrays - expected = np_arr.sum() - actual = xp_arr.sum() + expected = getattr(np_arr, method)() + actual = getattr(xp_arr, method)() assert isinstance(actual.data, Array) assert_equal(actual, expected) -def test_aggregation_skipna(arrays) -> None: +@pytest.mark.parametrize("method", _STATISTICAL_FUNCTIONS) +def test_aggregation_skipna(method: str, arrays) -> None: np_arr, xp_arr = arrays - expected = np_arr.sum(skipna=False) + expected = getattr(np_arr, method)(skipna=False) + actual = getattr(xp_arr, method)(skipna=False) actual = xp_arr.sum(skipna=False) assert isinstance(actual.data, Array) assert_equal(actual, expected)