From abcc472ba7c9785894f93bf8e6c6b5a51c2d9006 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 28 Jan 2025 14:47:41 +0100 Subject: [PATCH 01/11] WIP sum --- pyproject.toml | 1 + src/fast_array_utils/stats/__init__.py | 9 +++ src/fast_array_utils/stats/_sum.py | 85 ++++++++++++++++++++++++++ 3 files changed, 95 insertions(+) create mode 100644 src/fast_array_utils/stats/__init__.py create mode 100644 src/fast_array_utils/stats/_sum.py diff --git a/pyproject.toml b/pyproject.toml index 67fe21d..d40c592 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ lint.ignore = [ "TID252", # relative imports are fine ] lint.per-file-ignores."docs/**/*.py" = [ "INP001" ] # No __init__.py in docs +lint.per-file-ignores."src/**/stats/*.py" = [ "A001" ] # Shadows builtins like `sum` lint.per-file-ignores."stubs/**/*.pyi" = [ "F403", "F405", "N801" ] # Stubs don’t follow name conventions lint.per-file-ignores."tests/**/test_*.py" = [ "D100", # tests need no module docstrings diff --git a/src/fast_array_utils/stats/__init__.py b/src/fast_array_utils/stats/__init__.py new file mode 100644 index 0000000..530c562 --- /dev/null +++ b/src/fast_array_utils/stats/__init__.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: MPL-2.0 +"""Statistics utilities.""" + +from __future__ import annotations + +from ._sum import sum + + +__all__ = ["sum"] diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py new file mode 100644 index 0000000..f6943cd --- /dev/null +++ b/src/fast_array_utils/stats/_sum.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: MPL-2.0 +from __future__ import annotations + +from functools import singledispatch +from typing import TYPE_CHECKING, Any, overload + +import numpy as np + +from ..types import DaskArray + + +if TYPE_CHECKING: + from typing import Literal, TypeVar + + from numpy.typing import NDArray + + from ..types import CSBase, CSMatrix + + DT_co = TypeVar("DT_co", covariant=True, bound=np.generic) + + _Ax = tuple[Literal[0, 1], ...] | Literal[0, 1] + + +@overload +def sum( + x: NDArray[DT_co] | CSBase[DT_co], *, axis: _Ax | None = None, dtype: None = None +) -> NDArray[DT_co]: ... + + +@overload +def sum( + x: NDArray[Any] | CSBase[Any], *, axis: _Ax | None = None, dtype: np.dtype[DT_co] +) -> NDArray[DT_co]: ... + + +@singledispatch +def sum( + x: NDArray[DT_co] | CSBase[DT_co], + *, + axis: _Ax | None = None, + dtype: np.dtype[DT_co] | None = None, +) -> NDArray[DT_co]: + return np.sum(np.asarray(x), axis=axis, dtype=dtype) + + +@sum.register(DaskArray) +def _(x: DaskArray, *, axis: _Ax | None = None, dtype: np.dtype[DT_co] | None = None) -> DaskArray: + import dask.array as da + + # TODO(@ilan-gold): why is this so complicated? + # https://github.com/scverse/scanpy/pull/2856/commits/feac6bc7bea69e4cc343a35855307145854a9bc8 + if dtype is None: + dtype = getattr(np.zeros(1, dtype=x.dtype).sum(), "dtype", object) + + if isinstance(x._meta, np.ndarray) and not isinstance(x._meta, np.matrix): + return x.sum(axis=axis, dtype=dtype) + + def sum_drop_keepdims(*args, **kwargs): + kwargs.pop("computing_meta", None) + # masked operations on sparse produce which numpy matrices gives the same API issues handled here + if isinstance(x._meta, CSMatrix | np.matrix) or isinstance(args[0], CSMatrix | np.matrix): + kwargs.pop("keepdims", None) + axis = kwargs["axis"] + if isinstance(axis, tuple): + if len(axis) != 1: + msg = ( + "`axis_sum` can only sum over one axis " + f"when `axis` arg is provided but got {axis} instead" + ) + raise ValueError(msg) + kwargs["axis"] = axis[0] + # returns a np.matrix normally, which is undesireable + return np.array(np.sum(*args, dtype=dtype, **kwargs)) + + def aggregate_sum(*args, **kwargs): + return np.sum(args[0], dtype=dtype, **kwargs) + + return da.reduction( + x, + sum_drop_keepdims, + aggregate_sum, + axis=axis, + dtype=dtype, + meta=np.array([], dtype=dtype), + ) From e1e3572f36a3a7bc975329cc76eab3b496b74c2c Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 14 Feb 2025 15:10:18 +0100 Subject: [PATCH 02/11] Works for all but h5py --- .pre-commit-config.yaml | 3 + pyproject.toml | 4 +- src/fast_array_utils/conv/_asarray.py | 4 +- src/fast_array_utils/stats/_sum.py | 109 ++++++------ src/fast_array_utils/types.py | 6 +- src/testing/fast_array_utils/__init__.pyi | 26 +++ src/testing/fast_array_utils/pytest.py | 194 ++++++++++++++++++++++ tests/test_asarray.py | 106 +----------- tests/test_stats.py | 50 ++++++ tests/test_test_utils.py | 31 ++++ 10 files changed, 376 insertions(+), 157 deletions(-) create mode 100644 src/testing/fast_array_utils/__init__.pyi create mode 100644 src/testing/fast_array_utils/pytest.py create mode 100644 tests/test_stats.py create mode 100644 tests/test_test_utils.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 651940f..11f3909 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,5 +32,8 @@ repos: - numpy - scipy - types-scipy-sparse + - dask + - zarr + - h5py ci: skip: [mypy] # too big diff --git a/pyproject.toml b/pyproject.toml index d40c592..0f6af6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ extras = [ "min", "full" ] [tool.ruff] line-length = 100 +namespace-packages = [ "src/testing" ] lint.select = [ "ALL" ] lint.ignore = [ "A005", # submodules never shadow builtins. @@ -74,7 +75,7 @@ lint.ignore = [ "TID252", # relative imports are fine ] lint.per-file-ignores."docs/**/*.py" = [ "INP001" ] # No __init__.py in docs -lint.per-file-ignores."src/**/stats/*.py" = [ "A001" ] # Shadows builtins like `sum` +lint.per-file-ignores."src/**/stats/*.py" = [ "A001", "A004" ] # Shadows builtins like `sum` lint.per-file-ignores."stubs/**/*.pyi" = [ "F403", "F405", "N801" ] # Stubs don’t follow name conventions lint.per-file-ignores."tests/**/test_*.py" = [ "D100", # tests need no module docstrings @@ -96,6 +97,7 @@ addopts = [ "--import-mode=importlib", "--strict-markers", "--pyargs", + "-ptesting.fast_array_utils.pytest", ] filterwarnings = [ "error", diff --git a/src/fast_array_utils/conv/_asarray.py b/src/fast_array_utils/conv/_asarray.py index 1c121bd..9b25669 100644 --- a/src/fast_array_utils/conv/_asarray.py +++ b/src/fast_array_utils/conv/_asarray.py @@ -46,8 +46,8 @@ def _(x: CSBase[DT_co]) -> NDArray[DT_co]: @asarray.register(DaskArray) -def _(x: DaskArray[DT_co]) -> NDArray[DT_co]: - return asarray(x.compute()) +def _(x: DaskArray) -> NDArray[DT_co]: + return asarray(x.compute()) # type: ignore[no-untyped-call] @asarray.register(OutOfCoreDataset) diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index f6943cd..5a78efe 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -1,84 +1,87 @@ # SPDX-License-Identifier: MPL-2.0 from __future__ import annotations -from functools import singledispatch -from typing import TYPE_CHECKING, Any, overload +from functools import partial, singledispatch +from typing import TYPE_CHECKING import numpy as np -from ..types import DaskArray +from ..types import CSBase, CSMatrix, DaskArray if TYPE_CHECKING: from typing import Literal, TypeVar - from numpy.typing import NDArray - - from ..types import CSBase, CSMatrix + from numpy.typing import ArrayLike, DTypeLike, NDArray DT_co = TypeVar("DT_co", covariant=True, bound=np.generic) - _Ax = tuple[Literal[0, 1], ...] | Literal[0, 1] - - -@overload -def sum( - x: NDArray[DT_co] | CSBase[DT_co], *, axis: _Ax | None = None, dtype: None = None -) -> NDArray[DT_co]: ... - -@overload -def sum( - x: NDArray[Any] | CSBase[Any], *, axis: _Ax | None = None, dtype: np.dtype[DT_co] -) -> NDArray[DT_co]: ... +# TODO(flying-sheep): overload so axis=None returns np.floating # noqa: TD003 @singledispatch def sum( - x: NDArray[DT_co] | CSBase[DT_co], + x: ArrayLike, *, - axis: _Ax | None = None, - dtype: np.dtype[DT_co] | None = None, + axis: Literal[0, 1, None] = None, + dtype: DTypeLike | np.dtype[DT_co] | None = None, ) -> NDArray[DT_co]: - return np.sum(np.asarray(x), axis=axis, dtype=dtype) + return np.sum(x, axis=axis, dtype=dtype) # type: ignore[no-any-return] + + +@sum.register(CSBase) # type: ignore[misc,call-overload] +def _( + x: CSBase[DT_co], *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None +) -> NDArray[DT_co]: + import scipy.sparse as sp + + if isinstance(x, CSMatrix): + x = sp.csr_array(x) if x.format == "csr" else sp.csc_array(x) + return np.sum(x, axis=axis, dtype=dtype) # type: ignore[call-overload,no-any-return] @sum.register(DaskArray) -def _(x: DaskArray, *, axis: _Ax | None = None, dtype: np.dtype[DT_co] | None = None) -> DaskArray: - import dask.array as da +def _( + x: DaskArray, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None +) -> DaskArray: + if TYPE_CHECKING: + from dask.array.reductions import reduction + else: + from dask.array import reduction + + if isinstance(x._meta, np.matrix): # noqa: SLF001 + msg = "sum does not support numpy matrices" + raise TypeError(msg) + + def sum_drop_keepdims( + a: NDArray[DT_co] | CSBase[DT_co], + *, + axis: tuple[Literal[0], Literal[1]] | Literal[0, 1] | None = None, + dtype: np.dtype[DT_co] | None = None, + keepdims: bool = False, + ) -> NDArray[DT_co]: + del keepdims + match axis: + case (0 | 1 as n,): + axis = n + case (0, 1) | (1, 0): + axis = None + case tuple(): + msg = f"`sum` can only sum over `axis=0|1|(0,1)` but got {axis} instead" + raise ValueError(msg) + rv: NDArray[DT_co] | DT_co = sum(a, axis=axis, dtype=dtype) # type: ignore[arg-type] + rv = np.array(rv, ndmin=1) # make sure rv is at least 1D + return rv.reshape((1, len(rv))) - # TODO(@ilan-gold): why is this so complicated? - # https://github.com/scverse/scanpy/pull/2856/commits/feac6bc7bea69e4cc343a35855307145854a9bc8 if dtype is None: - dtype = getattr(np.zeros(1, dtype=x.dtype).sum(), "dtype", object) - - if isinstance(x._meta, np.ndarray) and not isinstance(x._meta, np.matrix): - return x.sum(axis=axis, dtype=dtype) - - def sum_drop_keepdims(*args, **kwargs): - kwargs.pop("computing_meta", None) - # masked operations on sparse produce which numpy matrices gives the same API issues handled here - if isinstance(x._meta, CSMatrix | np.matrix) or isinstance(args[0], CSMatrix | np.matrix): - kwargs.pop("keepdims", None) - axis = kwargs["axis"] - if isinstance(axis, tuple): - if len(axis) != 1: - msg = ( - "`axis_sum` can only sum over one axis " - f"when `axis` arg is provided but got {axis} instead" - ) - raise ValueError(msg) - kwargs["axis"] = axis[0] - # returns a np.matrix normally, which is undesireable - return np.array(np.sum(*args, dtype=dtype, **kwargs)) - - def aggregate_sum(*args, **kwargs): - return np.sum(args[0], dtype=dtype, **kwargs) - - return da.reduction( + # Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`) + dtype = np.zeros(1, dtype=x.dtype).sum().dtype + + return reduction( # type: ignore[no-any-return,no-untyped-call] x, sum_drop_keepdims, - aggregate_sum, + partial(np.sum, dtype=dtype), axis=axis, dtype=dtype, meta=np.array([], dtype=dtype), diff --git a/src/fast_array_utils/types.py b/src/fast_array_utils/types.py index c1fe15a..0b82d4b 100644 --- a/src/fast_array_utils/types.py +++ b/src/fast_array_utils/types.py @@ -65,7 +65,9 @@ CupySparseMatrix = type("spmatrix", (), {}) -if find_spec("dask") or TYPE_CHECKING: +if TYPE_CHECKING: # https://github.com/dask/dask/issues/8853 + from dask.array.core import Array as DaskArray +elif find_spec("dask"): from dask.array import Array as DaskArray else: DaskArray = type("array", (), {}) @@ -80,7 +82,7 @@ if find_spec("zarr") or TYPE_CHECKING: from zarr import Array as ZarrArray else: - ZarrArray = type("Array", (), {}) + ZarrArray = type("Array", (), {}) # type: ignore[misc] @runtime_checkable diff --git a/src/testing/fast_array_utils/__init__.pyi b/src/testing/fast_array_utils/__init__.pyi new file mode 100644 index 0000000..f54511a --- /dev/null +++ b/src/testing/fast_array_utils/__init__.pyi @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: MPL-2.0 +from typing import Generic, Protocol, TypeAlias, TypeVar + +import numpy as np +from numpy.typing import ArrayLike, NDArray + +from fast_array_utils import types + +_SCT_co = TypeVar("_SCT_co", covariant=True, bound=np.generic) +_SCT_contra = TypeVar("_SCT_contra", contravariant=True, bound=np.generic) + +_Array: TypeAlias = ( + NDArray[_SCT_co] + | types.CSBase[_SCT_co] + | types.CupyArray[_SCT_co] + | types.DaskArray + | types.H5Dataset + | types.ZarrArray +) + +class _ToArray(Protocol, Generic[_SCT_contra]): + def __call__( + self, data: ArrayLike, /, *, dtype: _SCT_contra | None = None + ) -> _Array[_SCT_contra]: ... + +__all__ = ["_Array", "_ToArray"] diff --git a/src/testing/fast_array_utils/pytest.py b/src/testing/fast_array_utils/pytest.py new file mode 100644 index 0000000..ec55e15 --- /dev/null +++ b/src/testing/fast_array_utils/pytest.py @@ -0,0 +1,194 @@ +# SPDX-License-Identifier: MPL-2.0 +"""Testing utilities.""" + +from __future__ import annotations + +import os +import re +from importlib.util import find_spec +from typing import TYPE_CHECKING, cast + +import numpy as np +import pytest + +from fast_array_utils import types + + +if TYPE_CHECKING: + from collections.abc import Generator + from typing import Any, TypeVar + + from numpy.typing import ArrayLike, DTypeLike + + from testing.fast_array_utils import _ToArray + + from . import _Array + + _SCT_co = TypeVar("_SCT_co", covariant=True, bound=np.generic) + + +def _skip_if_no(dist: str) -> pytest.MarkDecorator: + return pytest.mark.skipif(not find_spec(dist), reason=f"{dist} not installed") + + +@pytest.fixture( + scope="session", + params=[ + pytest.param("numpy.ndarray"), + pytest.param("scipy.sparse.csr_array", marks=_skip_if_no("scipy")), + pytest.param("scipy.sparse.csc_array", marks=_skip_if_no("scipy")), + pytest.param("scipy.sparse.csr_matrix", marks=_skip_if_no("scipy")), + pytest.param("scipy.sparse.csc_matrix", marks=_skip_if_no("scipy")), + pytest.param("dask.array.Array[numpy.ndarray]", marks=_skip_if_no("dask")), + pytest.param("dask.array.Array[scipy.sparse.csr_array]"), + pytest.param("dask.array.Array[scipy.sparse.csc_array]"), + pytest.param("dask.array.Array[scipy.sparse.csr_matrix]", marks=_skip_if_no("dask")), + pytest.param("dask.array.Array[scipy.sparse.csc_matrix]", marks=_skip_if_no("dask")), + pytest.param("h5py.Dataset", marks=_skip_if_no("h5py")), + pytest.param("zarr.Array", marks=_skip_if_no("zarr")), + pytest.param("cupy.ndarray", marks=_skip_if_no("cupy")), + pytest.param("cupyx.scipy.sparse.csr_matrix", marks=_skip_if_no("cupy")), + pytest.param("cupyx.scipy.sparse.csc_matrix", marks=_skip_if_no("cupy")), + ], +) +def array_cls_name(request: pytest.FixtureRequest) -> str: + """Fixture for a supported array class.""" + return cast(str, request.param) + + +@pytest.fixture(scope="session") +def array_cls(array_cls_name: str) -> type[_Array[Any]]: + """Fixture for a supported array class.""" + return get_array_cls(array_cls_name) + + +def get_array_cls(qualname: str) -> type[_Array[Any]]: # noqa: PLR0911 + """Get a supported array class by qualname.""" + m = re.fullmatch( + r"(?P(?:\w+\.)*\w+)\.(?P[^\[]+)(?:\[(?P[\w.]+)\])?", qualname + ) + assert m + match m["mod"], m["name"], m["inner"]: + case "numpy", "ndarray", None: + return np.ndarray + case "scipy.sparse", ( + "csr_array" | "csc_array" | "csr_matrix" | "csc_matrix" + ) as cls_name, None: + import scipy.sparse + + return getattr(scipy.sparse, cls_name) # type: ignore[no-any-return] + case "cupy", "ndarray", None: + import cupy as cp + + return cp.ndarray # type: ignore[no-any-return] + case "cupyx.scipy.sparse", ("csr_matrix" | "csc_matrix") as cls_name, None: + import cupyx.scipy.sparse as cu_sparse + + return getattr(cu_sparse, cls_name) # type: ignore[no-any-return] + case "dask.array", cls_name, _: + if TYPE_CHECKING: + from dask.array.core import Array as DaskArray + else: + from dask.array import Array as DaskArray + + return DaskArray + case "h5py", "Dataset", _: + import h5py + + return h5py.Dataset # type: ignore[no-any-return] + case "zarr", "Array", _: + import zarr + + return zarr.Array + case _: + pytest.fail(f"Unknown array class: {qualname}") + + +@pytest.fixture(scope="session") +def to_array( + request: pytest.FixtureRequest, array_cls: type[_Array[_SCT_co]], array_cls_name: str +) -> _ToArray[_SCT_co]: + """Fixture for conversion into a supported array.""" + return get_to_array(array_cls, array_cls_name, request) + + +def get_to_array( + array_cls: type[_Array[_SCT_co]], + array_cls_name: str | None = None, + request: pytest.FixtureRequest | None = None, +) -> _ToArray[_SCT_co]: + """Create a function to convert to a supported array.""" + if array_cls is np.ndarray: + return np.asarray # type: ignore[return-value] + if array_cls is types.DaskArray: + assert array_cls_name is not None + return to_dask_array(array_cls_name) + if array_cls is types.H5Dataset: + assert request is not None + return request.getfixturevalue("to_h5py_dataset") # type: ignore[no-any-return] + if array_cls is types.ZarrArray: + return to_zarr_array + if array_cls is types.CupyArray: + import cupy as cu + + return cu.asarray # type: ignore[no-any-return] + + return array_cls # type: ignore[return-value] + + +def _half_chunk_size(a: tuple[int, ...]) -> tuple[int, ...]: + def half_rounded_up(x: int) -> int: + div, mod = divmod(x, 2) + return div + (mod > 0) + + return tuple(half_rounded_up(x) for x in a) + + +def to_dask_array(array_cls_name: str) -> _ToArray[Any]: + """Convert to a dask array.""" + if TYPE_CHECKING: + import dask.array.core as da + else: + import dask.array as da + + inner_cls_name = array_cls_name.removeprefix("dask.array.Array[").removesuffix("]") + inner_cls = get_array_cls(inner_cls_name) + to_array_fn: _ToArray[Any] = get_to_array(array_cls=inner_cls) + + def to_dask_array(x: ArrayLike, *, dtype: DTypeLike | None = None) -> types.DaskArray: + x = np.asarray(x, dtype=dtype) + return da.from_array(to_array_fn(x), _half_chunk_size(x.shape)) # type: ignore[no-untyped-call,no-any-return] + + return to_dask_array + + +@pytest.fixture(scope="session") +# worker_id for xdist since we don't want to override open files +def to_h5py_dataset( + tmp_path_factory: pytest.TempPathFactory, + worker_id: str = "serial", +) -> Generator[_ToArray[Any], None, None]: + """Convert to a h5py dataset.""" + import h5py + + tmp_path = tmp_path_factory.mktemp("backed_adata") + tmp_path = tmp_path / f"test_{worker_id}.h5ad" + + with h5py.File(tmp_path, "x") as f: + + def to_h5py_dataset(x: ArrayLike, *, dtype: DTypeLike | None = None) -> types.H5Dataset: + arr = np.asarray(x, dtype=dtype) + test_name = os.environ["PYTEST_CURRENT_TEST"].rsplit(":", 1)[-1].split(" ", 1)[0] + return f.create_dataset(test_name, arr.shape, arr.dtype) + + yield to_h5py_dataset + + +def to_zarr_array(x: ArrayLike, *, dtype: DTypeLike | None = None) -> types.ZarrArray: + """Convert to a zarr array.""" + import zarr + + arr = np.asarray(x, dtype=dtype) + za = zarr.create_array({}, shape=arr.shape, dtype=arr.dtype) + za[...] = arr + return za diff --git a/tests/test_asarray.py b/tests/test_asarray.py index f9863fc..1c855ef 100644 --- a/tests/test_asarray.py +++ b/tests/test_asarray.py @@ -1,115 +1,23 @@ # SPDX-License-Identifier: MPL-2.0 from __future__ import annotations -from importlib.util import find_spec -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING import numpy as np -import pytest from fast_array_utils.conv import asarray if TYPE_CHECKING: - from collections.abc import Callable, Generator - from typing import Any, TypeAlias + from typing import Any - from numpy.typing import ArrayLike, NDArray + from numpy.typing import NDArray - from fast_array_utils import types + from testing.fast_array_utils import _ToArray - SupportedArray: TypeAlias = ( - NDArray[Any] - | types.DaskArray - | types.H5Dataset - | types.ZarrArray - | types.CupyArray - | types.CupySparseMatrix - ) - -def skip_if_no(dist: str) -> pytest.MarkDecorator: - return pytest.mark.skipif(not find_spec(dist), reason=f"{dist} not installed") - - -@pytest.fixture(scope="session") -# worker_id for xdist since we don't want to override open files -def to_h5py_dataset( - tmp_path_factory: pytest.TempPathFactory, worker_id: str = "serial" -) -> Generator[Callable[[ArrayLike], types.H5Dataset], None, None]: - import h5py - - tmp_path = tmp_path_factory.mktemp("backed_adata") - tmp_path = tmp_path / f"test_{worker_id}.h5ad" - - with h5py.File(tmp_path, "x") as f: - - def to_h5py_dataset(x: ArrayLike) -> types.H5Dataset: - arr = np.asarray(x) - return f.create_dataset("data", arr.shape, arr.dtype) - - yield to_h5py_dataset - - -def to_zarr_array(x: ArrayLike) -> types.ZarrArray: - import zarr - - arr = np.asarray(x) - za = zarr.create_array({}, shape=arr.shape, dtype=arr.dtype) - za[...] = arr - return za - - -@pytest.fixture( - scope="session", - params=[ - pytest.param("numpy.ndarray"), - pytest.param("scipy.sparse.csr_array", marks=skip_if_no("scipy")), - pytest.param("scipy.sparse.csc_array", marks=skip_if_no("scipy")), - pytest.param("scipy.sparse.csr_matrix", marks=skip_if_no("scipy")), - pytest.param("scipy.sparse.csc_matrix", marks=skip_if_no("scipy")), - pytest.param("dask.array.Array", marks=skip_if_no("dask")), - pytest.param("h5py.Dataset", marks=skip_if_no("h5py")), - pytest.param("zarr.Array", marks=skip_if_no("zarr")), - pytest.param("cupy.ndarray", marks=skip_if_no("cupy")), - pytest.param("cupyx.scipy.sparse.csr_matrix", marks=skip_if_no("cupy")), - pytest.param("cupyx.scipy.sparse.csc_matrix", marks=skip_if_no("cupy")), - ], -) -def array_cls( # noqa: PLR0911 - request: pytest.FixtureRequest, -) -> Callable[[ArrayLike], SupportedArray]: - qualname = cast(str, request.param) - match qualname.split("."): - case "numpy", "ndarray": - return np.asarray - case "scipy", "sparse", ("csr_array" | "csc_array" | "csr_matrix" | "csc_matrix") as n: - import scipy.sparse - - return getattr(scipy.sparse, n) # type: ignore[no-any-return] - case "dask", "array", "Array": - import dask.array as da - - return da.asarray # type: ignore[no-any-return] - case "h5py", "Dataset": - return request.getfixturevalue("to_h5py_dataset") # type: ignore[no-any-return] - case "zarr", "Array": - return to_zarr_array - case "cupy", "ndarray": - import cupy - - return cupy.asarray # type: ignore[no-any-return] - case "cupyx", "scipy", "sparse", ("csr_matrix" | "csc_matrix") as n: - import cupyx.scipy.sparse - - return getattr(cupyx.scipy.sparse, n) # type: ignore[no-any-return] - case _: - msg = f"Unknown array type: {qualname}" - raise AssertionError(msg) - - -def test_asarray(array_cls: Callable[[ArrayLike], SupportedArray]) -> None: - x = array_cls([[1, 2, 3], [4, 5, 6]]) - arr: NDArray[Any] = asarray(x) +def test_asarray(to_array: _ToArray[Any]) -> None: + x = to_array([[1, 2, 3], [4, 5, 6]]) + arr: NDArray[Any] = asarray(x) # type: ignore[arg-type] assert isinstance(arr, np.ndarray) assert arr.shape == (2, 3) diff --git a/tests/test_stats.py b/tests/test_stats.py new file mode 100644 index 0000000..368a69b --- /dev/null +++ b/tests/test_stats.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: MPL-2.0 +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeVar + +import numpy as np +import pytest +from scipy.sparse import sparray, spmatrix + +from fast_array_utils import stats, types + + +if TYPE_CHECKING: + from typing import Any, Literal + + from testing.fast_array_utils import _Array, _ToArray + + +DType_float = TypeVar("DType_float", np.float32, np.float64) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("axis", [0, 1, None]) +def test_sum( + array_cls: type[_Array[Any]], + to_array: _ToArray[Any], + dtype: DType_float, + axis: Literal[0, 1, None], +) -> None: + np_arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype) + arr = to_array(np_arr.copy()) + + sum_: _Array[Any] | np.floating = stats.sum(arr, axis=axis) # type: ignore[type-arg,arg-type] + match axis, arr: + case _, types.DaskArray(): + assert isinstance(sum_, types.DaskArray), type(sum_) + sum_ = sum_.compute() # type: ignore[no-untyped-call] + case None, _: + assert isinstance(sum_, np.floating), type(sum_) + case 0 | 1, spmatrix() | sparray() | types.ZarrArray(): + assert isinstance(sum_, np.ndarray), type(sum_) + case 0 | 1, _: + assert isinstance(sum_, array_cls), type(sum_) + case _: + pytest.fail(f"Unhandled case axis {axis} for {type(arr)}: {type(sum_)}") + + assert sum_.shape == () if axis is None else arr.shape[axis] + assert sum_.dtype == dtype + + np.testing.assert_array_equal(sum_, np.sum(np_arr, axis=axis)) diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py new file mode 100644 index 0000000..7c799c0 --- /dev/null +++ b/tests/test_test_utils.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: MPL-2.0 +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +from fast_array_utils import types + + +if TYPE_CHECKING: + from typing import TypeVar + + from testing.fast_array_utils import _Array, _ToArray + + DType_float = TypeVar("DType_float", np.float32, np.float64) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_conv( + array_cls: type[_Array[DType_float]], to_array: _ToArray[DType_float], dtype: DType_float +) -> None: + arr = to_array([[1, 2, 3], [4, 5, 6]], dtype=dtype) + assert isinstance(arr, array_cls) + if isinstance(arr, types.DaskArray): + arr = arr.compute() # type: ignore[no-untyped-call] + elif isinstance(arr, types.CupyArray): + arr = arr.get() + assert arr.shape == (2, 3) + assert arr.dtype == dtype From 8444b8cc4a58a9a588908960763930d4acc1b7c0 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 14 Feb 2025 15:15:18 +0100 Subject: [PATCH 03/11] fix dataset --- src/testing/fast_array_utils/pytest.py | 2 +- tests/test_stats.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/testing/fast_array_utils/pytest.py b/src/testing/fast_array_utils/pytest.py index ec55e15..31fef91 100644 --- a/src/testing/fast_array_utils/pytest.py +++ b/src/testing/fast_array_utils/pytest.py @@ -179,7 +179,7 @@ def to_h5py_dataset( def to_h5py_dataset(x: ArrayLike, *, dtype: DTypeLike | None = None) -> types.H5Dataset: arr = np.asarray(x, dtype=dtype) test_name = os.environ["PYTEST_CURRENT_TEST"].rsplit(":", 1)[-1].split(" ", 1)[0] - return f.create_dataset(test_name, arr.shape, arr.dtype) + return f.create_dataset(test_name, arr.shape, arr.dtype, data=arr) yield to_h5py_dataset diff --git a/tests/test_stats.py b/tests/test_stats.py index 368a69b..790c10b 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -37,7 +37,7 @@ def test_sum( sum_ = sum_.compute() # type: ignore[no-untyped-call] case None, _: assert isinstance(sum_, np.floating), type(sum_) - case 0 | 1, spmatrix() | sparray() | types.ZarrArray(): + case 0 | 1, spmatrix() | sparray() | types.ZarrArray() | types.H5Dataset(): assert isinstance(sum_, np.ndarray), type(sum_) case 0 | 1, _: assert isinstance(sum_, array_cls), type(sum_) From a00edbfba988385e2bb37ce11b813ab2f2939cbd Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 14 Feb 2025 15:58:05 +0100 Subject: [PATCH 04/11] patch dask for sparrays --- src/fast_array_utils/__init__.py | 4 +++- src/fast_array_utils/_patches.py | 26 ++++++++++++++++++++++++++ tests/test_test_utils.py | 4 ++-- 3 files changed, 31 insertions(+), 3 deletions(-) create mode 100644 src/fast_array_utils/_patches.py diff --git a/src/fast_array_utils/__init__.py b/src/fast_array_utils/__init__.py index 8d55e48..c203223 100644 --- a/src/fast_array_utils/__init__.py +++ b/src/fast_array_utils/__init__.py @@ -3,7 +3,9 @@ from __future__ import annotations -from . import conv, types +from . import _patches, conv, types __all__ = ["conv", "types"] + +_patches.patch_dask() diff --git a/src/fast_array_utils/_patches.py b/src/fast_array_utils/_patches.py new file mode 100644 index 0000000..19b6ceb --- /dev/null +++ b/src/fast_array_utils/_patches.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: MPL-2.0 +from __future__ import annotations + +import numpy as np + + +# TODO(flying-sheep): upstream +# https://github.com/dask/dask/issues/11749 +def patch_dask() -> None: + """Patch dask to support sparse arrays. + + See + """ + try: + from dask.array import dispatch + from scipy.sparse import sparray, spmatrix + except ImportError: + return + + if dispatch.concatenate_lookup.dispatch(sparray) is not np.concatenate: # type: ignore[no-untyped-call] + return + + concatenate = dispatch.concatenate_lookup.dispatch(spmatrix) # type: ignore[no-untyped-call] + dispatch.concatenate_lookup.register(sparray, concatenate) # type: ignore[no-untyped-call] + + # Other lookup candidates: tensordot_lookup and take_lookup diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index 7c799c0..74ceab8 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -21,11 +21,11 @@ def test_conv( array_cls: type[_Array[DType_float]], to_array: _ToArray[DType_float], dtype: DType_float ) -> None: - arr = to_array([[1, 2, 3], [4, 5, 6]], dtype=dtype) + arr = to_array(np.arange(12).reshape(3, 4), dtype=dtype) assert isinstance(arr, array_cls) if isinstance(arr, types.DaskArray): arr = arr.compute() # type: ignore[no-untyped-call] elif isinstance(arr, types.CupyArray): arr = arr.get() - assert arr.shape == (2, 3) + assert arr.shape == (3, 4) assert arr.dtype == dtype From 65b1529fc6f27136c5b421e557eee9ca7f089a28 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 14 Feb 2025 16:01:06 +0100 Subject: [PATCH 05/11] better docs --- src/fast_array_utils/_patches.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/fast_array_utils/_patches.py b/src/fast_array_utils/_patches.py index 19b6ceb..1e7fc71 100644 --- a/src/fast_array_utils/_patches.py +++ b/src/fast_array_utils/_patches.py @@ -12,15 +12,15 @@ def patch_dask() -> None: See """ try: - from dask.array import dispatch + # Other lookup candidates: tensordot_lookup and take_lookup + from dask.array.dispatch import concatenate_lookup from scipy.sparse import sparray, spmatrix except ImportError: - return + return # No need to patch if dask or scipy is not installed - if dispatch.concatenate_lookup.dispatch(sparray) is not np.concatenate: # type: ignore[no-untyped-call] + # Avoid patch if already patched or upstream support has been added + if concatenate_lookup.dispatch(sparray) is not np.concatenate: # type: ignore[no-untyped-call] return - concatenate = dispatch.concatenate_lookup.dispatch(spmatrix) # type: ignore[no-untyped-call] - dispatch.concatenate_lookup.register(sparray, concatenate) # type: ignore[no-untyped-call] - - # Other lookup candidates: tensordot_lookup and take_lookup + concatenate = concatenate_lookup.dispatch(spmatrix) # type: ignore[no-untyped-call] + concatenate_lookup.register(sparray, concatenate) # type: ignore[no-untyped-call] From 9fc94e1c6d8a500e07e991b173f51403b5cfa9a3 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 14 Feb 2025 16:57:50 +0100 Subject: [PATCH 06/11] fix min tests --- src/fast_array_utils/types.py | 10 +++++----- src/testing/fast_array_utils/pytest.py | 4 ++-- tests/test_stats.py | 8 +++++++- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/fast_array_utils/types.py b/src/fast_array_utils/types.py index 0b82d4b..3898f56 100644 --- a/src/fast_array_utils/types.py +++ b/src/fast_array_utils/types.py @@ -53,13 +53,13 @@ CSBase = CSMatrix | CSArray -if find_spec("cupy") or TYPE_CHECKING: +if TYPE_CHECKING or find_spec("cupy"): from cupy import ndarray as CupyArray else: CupyArray = type("ndarray", (), {}) -if find_spec("cupyx") or TYPE_CHECKING: +if TYPE_CHECKING or find_spec("cupyx"): from cupyx.scipy.sparse import spmatrix as CupySparseMatrix else: CupySparseMatrix = type("spmatrix", (), {}) @@ -73,16 +73,16 @@ DaskArray = type("array", (), {}) -if find_spec("h5py") or TYPE_CHECKING: +if TYPE_CHECKING or find_spec("h5py"): from h5py import Dataset as H5Dataset else: H5Dataset = type("Dataset", (), {}) -if find_spec("zarr") or TYPE_CHECKING: +if TYPE_CHECKING or find_spec("zarr"): from zarr import Array as ZarrArray else: - ZarrArray = type("Array", (), {}) # type: ignore[misc] + ZarrArray = type("Array", (), {}) @runtime_checkable diff --git a/src/testing/fast_array_utils/pytest.py b/src/testing/fast_array_utils/pytest.py index 31fef91..9720d87 100644 --- a/src/testing/fast_array_utils/pytest.py +++ b/src/testing/fast_array_utils/pytest.py @@ -40,8 +40,8 @@ def _skip_if_no(dist: str) -> pytest.MarkDecorator: pytest.param("scipy.sparse.csr_matrix", marks=_skip_if_no("scipy")), pytest.param("scipy.sparse.csc_matrix", marks=_skip_if_no("scipy")), pytest.param("dask.array.Array[numpy.ndarray]", marks=_skip_if_no("dask")), - pytest.param("dask.array.Array[scipy.sparse.csr_array]"), - pytest.param("dask.array.Array[scipy.sparse.csc_array]"), + pytest.param("dask.array.Array[scipy.sparse.csr_array]", marks=_skip_if_no("dask")), + pytest.param("dask.array.Array[scipy.sparse.csc_array]", marks=_skip_if_no("dask")), pytest.param("dask.array.Array[scipy.sparse.csr_matrix]", marks=_skip_if_no("dask")), pytest.param("dask.array.Array[scipy.sparse.csc_matrix]", marks=_skip_if_no("dask")), pytest.param("h5py.Dataset", marks=_skip_if_no("h5py")), diff --git a/tests/test_stats.py b/tests/test_stats.py index 790c10b..7816621 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -1,11 +1,17 @@ # SPDX-License-Identifier: MPL-2.0 from __future__ import annotations +from importlib.util import find_spec from typing import TYPE_CHECKING, TypeVar import numpy as np import pytest -from scipy.sparse import sparray, spmatrix + + +if TYPE_CHECKING or find_spec("scipy"): + from scipy.sparse import sparray, spmatrix +else: + spmatrix = sparray = type("spmatrix", (), {}) from fast_array_utils import stats, types From ad074f34f61d912c2252fafb1a0a80650ef1ff2d Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 14 Feb 2025 17:39:11 +0100 Subject: [PATCH 07/11] test more dtypes --- tests/test_stats.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/tests/test_stats.py b/tests/test_stats.py index 7816621..23a4ebb 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -2,7 +2,7 @@ from __future__ import annotations from importlib.util import find_spec -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING import numpy as np import pytest @@ -22,27 +22,28 @@ from testing.fast_array_utils import _Array, _ToArray -DType_float = TypeVar("DType_float", np.float32, np.float64) - - -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("dtype_in", [np.float32, np.float64, np.int32, np.bool_]) +@pytest.mark.parametrize("dtype_arg", [np.float32, np.float64, None]) @pytest.mark.parametrize("axis", [0, 1, None]) def test_sum( array_cls: type[_Array[Any]], to_array: _ToArray[Any], - dtype: DType_float, + dtype_in: type[np.generic], + dtype_arg: type[np.generic] | None, axis: Literal[0, 1, None], ) -> None: - np_arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype) + np_arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype_in) arr = to_array(np_arr.copy()) + assert arr.dtype == dtype_in + + sum_: _Array[Any] | np.floating = stats.sum(arr, axis=axis, dtype=dtype_arg) # type: ignore[type-arg,arg-type] - sum_: _Array[Any] | np.floating = stats.sum(arr, axis=axis) # type: ignore[type-arg,arg-type] match axis, arr: case _, types.DaskArray(): assert isinstance(sum_, types.DaskArray), type(sum_) sum_ = sum_.compute() # type: ignore[no-untyped-call] case None, _: - assert isinstance(sum_, np.floating), type(sum_) + assert isinstance(sum_, np.floating | np.integer), type(sum_) case 0 | 1, spmatrix() | sparray() | types.ZarrArray() | types.H5Dataset(): assert isinstance(sum_, np.ndarray), type(sum_) case 0 | 1, _: @@ -50,7 +51,13 @@ def test_sum( case _: pytest.fail(f"Unhandled case axis {axis} for {type(arr)}: {type(sum_)}") - assert sum_.shape == () if axis is None else arr.shape[axis] - assert sum_.dtype == dtype + assert sum_.shape == () if axis is None else arr.shape[axis], (sum_.shape, arr.shape) + + if dtype_arg is not None: + assert sum_.dtype == dtype_arg, (sum_.dtype, dtype_arg) + elif dtype_in in {np.bool_, np.int32}: + assert sum_.dtype == np.int64 + else: + assert sum_.dtype == dtype_in - np.testing.assert_array_equal(sum_, np.sum(np_arr, axis=axis)) + np.testing.assert_array_equal(sum_, np.sum(np_arr, axis=axis, dtype=dtype_arg)) # type: ignore[arg-type] From 0fbb87dd04bcce3b6861f88714d56cc48f3eca05 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 14 Feb 2025 17:47:47 +0100 Subject: [PATCH 08/11] docs --- docs/conf.py | 7 +++++-- docs/index.rst | 7 +++++++ src/fast_array_utils/__init__.py | 4 ++-- src/fast_array_utils/stats/_sum.py | 1 + 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index f4c4764..80d7f70 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -51,20 +51,23 @@ ) # Try overriding type paths qualname_overrides = autodoc_type_aliases = { + "np.dtype": "numpy.dtype", "ArrayLike": "numpy.typing.ArrayLike", + "DTypeLike": "numpy.typing.DTypeLike", + "NDArray": "numpy.typing.NDArray", "CSBase": "scipy.sparse.spmatrix", "CupyArray": "cupy.ndarray", "CupySparseMatrix": "cupyx.scipy.sparse.spmatrix", "DaskArray": "dask.array.Array", "H5Dataset": "h5py.Dataset", - "NDArray": "numpy.typing.NDArray", } # If that doesn’t work, ignore them nitpick_ignore = { ("py:class", "DT_co"), ("py:class", "fast_array_utils.types.T_co"), - # sphinx bugs, should be covered by `autodoc_type_aliases` below + # sphinx bugs, should be covered by `autodoc_type_aliases` above ("py:class", "ArrayLike"), + ("py:class", "DTypeLike"), ("py:class", "NDArray"), } diff --git a/docs/index.rst b/docs/index.rst index 57c5790..697b5c6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,6 +12,13 @@ :members: +``fast_array_utils.stats`` +-------------------------- + +.. automodule:: fast_array_utils.stats + :members: + + ``fast_array_utils.types`` -------------------------- diff --git a/src/fast_array_utils/__init__.py b/src/fast_array_utils/__init__.py index c203223..135e10e 100644 --- a/src/fast_array_utils/__init__.py +++ b/src/fast_array_utils/__init__.py @@ -3,9 +3,9 @@ from __future__ import annotations -from . import _patches, conv, types +from . import _patches, conv, stats, types -__all__ = ["conv", "types"] +__all__ = ["conv", "stats", "types"] _patches.patch_dask() diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index 5a78efe..e4aea57 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -27,6 +27,7 @@ def sum( axis: Literal[0, 1, None] = None, dtype: DTypeLike | np.dtype[DT_co] | None = None, ) -> NDArray[DT_co]: + """Sum over both or one axis.""" return np.sum(x, axis=axis, dtype=dtype) # type: ignore[no-any-return] From 9ae63aeb66058bb5c6d3cb2a9282721f74e87cba Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 18 Feb 2025 13:01:31 +0100 Subject: [PATCH 09/11] add benchmarks --- src/testing/fast_array_utils/__init__.py | 138 ++++++++++++++++++++++ src/testing/fast_array_utils/__init__.pyi | 26 ---- src/testing/fast_array_utils/pytest.py | 65 ++-------- tests/test_asarray.py | 4 +- tests/test_sparse.py | 23 +--- tests/test_stats.py | 55 +++++++-- tests/test_test_utils.py | 4 +- 7 files changed, 202 insertions(+), 113 deletions(-) create mode 100644 src/testing/fast_array_utils/__init__.py delete mode 100644 src/testing/fast_array_utils/__init__.pyi diff --git a/src/testing/fast_array_utils/__init__.py b/src/testing/fast_array_utils/__init__.py new file mode 100644 index 0000000..90d9adf --- /dev/null +++ b/src/testing/fast_array_utils/__init__.py @@ -0,0 +1,138 @@ +# SPDX-License-Identifier: MPL-2.0 +"""Testing utilities.""" + +from __future__ import annotations + +import re +from typing import TYPE_CHECKING + +import numpy as np + + +if TYPE_CHECKING: + from typing import Any, Generic, Literal, Protocol, SupportsFloat, TypeAlias, TypeVar + + from numpy.typing import ArrayLike, NDArray + + from fast_array_utils import types + from fast_array_utils.types import CSBase + + _SCT_co = TypeVar("_SCT_co", covariant=True, bound=np.generic) + _SCT_contra = TypeVar("_SCT_contra", contravariant=True, bound=np.generic) + _SCT_float = TypeVar("_SCT_float", np.float32, np.float64) + + Array: TypeAlias = ( + NDArray[_SCT_co] + | types.CSBase[_SCT_co] + | types.CupyArray[_SCT_co] + | types.DaskArray + | types.H5Dataset + | types.ZarrArray + ) + + class ToArray(Protocol, Generic[_SCT_contra]): + """Convert to a supported array.""" + + def __call__( # noqa: D102 + self, data: ArrayLike, /, *, dtype: _SCT_contra | None = None + ) -> Array[_SCT_contra]: ... + + +RE_ARRAY_QUAL = re.compile(r"(?P(?:\w+\.)*\w+)\.(?P[^\[]+)(?:\[(?P[\w.]+)\])?") + + +def get_array_cls(qualname: str) -> type[Array[Any]]: # noqa: PLR0911 + """Get a supported array class by qualname.""" + m = RE_ARRAY_QUAL.fullmatch(qualname) + assert m + match m["mod"], m["name"], m["inner"]: + case "numpy", "ndarray", None: + return np.ndarray + case "scipy.sparse", ( + "csr_array" | "csc_array" | "csr_matrix" | "csc_matrix" + ) as cls_name, None: + import scipy.sparse + + return getattr(scipy.sparse, cls_name) # type: ignore[no-any-return] + case "cupy", "ndarray", None: + import cupy as cp + + return cp.ndarray # type: ignore[no-any-return] + case "cupyx.scipy.sparse", ("csr_matrix" | "csc_matrix") as cls_name, None: + import cupyx.scipy.sparse as cu_sparse + + return getattr(cu_sparse, cls_name) # type: ignore[no-any-return] + case "dask.array", cls_name, _: + if TYPE_CHECKING: + from dask.array.core import Array as DaskArray + else: + from dask.array import Array as DaskArray + + return DaskArray + case "h5py", "Dataset", _: + import h5py + + return h5py.Dataset # type: ignore[no-any-return] + case "zarr", "Array", _: + import zarr + + return zarr.Array + case _: + msg = f"Unknown array class: {qualname}" + raise ValueError(msg) + + +def random_mat( + shape: tuple[int, int], + *, + density: SupportsFloat = 0.01, + format: Literal["csr", "csc"] = "csr", # noqa: A002 + dtype: np.dtype[_SCT_float] | type[_SCT_float] | None = None, + container: Literal["array", "matrix"] = "array", + gen: np.random.Generator | None = None, +) -> CSBase[_SCT_float]: + """Create a random matrix.""" + from scipy.sparse import random as random_spmat + from scipy.sparse import random_array as random_sparr + + m, n = shape + return ( + random_spmat(m, n, density=density, format=format, dtype=dtype, random_state=gen) + if container == "matrix" + else random_sparr(shape, density=density, format=format, dtype=dtype, random_state=gen) + ) + + +def random_array( + qualname: str, + shape: tuple[int, int], + *, + dtype: np.dtype[_SCT_float] | type[_SCT_float] | None, + gen: np.random.Generator | None = None, +) -> Array[_SCT_float]: + """Create a random array.""" + gen = np.random.default_rng(gen) + + m = RE_ARRAY_QUAL.fullmatch(qualname) + assert m + match m["mod"], m["name"], m["inner"]: + case "numpy", "ndarray", None: + return gen.random(shape, dtype=dtype or np.float64) + case "scipy.sparse", ( + "csr_array" | "csc_array" | "csr_matrix" | "csc_matrix" + ) as cls_name, None: + fmt, container = cls_name.split("_") + return random_mat(shape, format=fmt, container=container, dtype=dtype) # type: ignore[arg-type] + case "cupy", "ndarray", None: + raise NotImplementedError + case "cupyx.scipy.sparse", ("csr_matrix" | "csc_matrix") as cls_name, None: + raise NotImplementedError + case "dask.array", cls_name, _: + raise NotImplementedError + case "h5py", "Dataset", _: + raise NotImplementedError + case "zarr", "Array", _: + raise NotImplementedError + case _: + msg = f"Unknown array class: {qualname}" + raise ValueError(msg) diff --git a/src/testing/fast_array_utils/__init__.pyi b/src/testing/fast_array_utils/__init__.pyi deleted file mode 100644 index f54511a..0000000 --- a/src/testing/fast_array_utils/__init__.pyi +++ /dev/null @@ -1,26 +0,0 @@ -# SPDX-License-Identifier: MPL-2.0 -from typing import Generic, Protocol, TypeAlias, TypeVar - -import numpy as np -from numpy.typing import ArrayLike, NDArray - -from fast_array_utils import types - -_SCT_co = TypeVar("_SCT_co", covariant=True, bound=np.generic) -_SCT_contra = TypeVar("_SCT_contra", contravariant=True, bound=np.generic) - -_Array: TypeAlias = ( - NDArray[_SCT_co] - | types.CSBase[_SCT_co] - | types.CupyArray[_SCT_co] - | types.DaskArray - | types.H5Dataset - | types.ZarrArray -) - -class _ToArray(Protocol, Generic[_SCT_contra]): - def __call__( - self, data: ArrayLike, /, *, dtype: _SCT_contra | None = None - ) -> _Array[_SCT_contra]: ... - -__all__ = ["_Array", "_ToArray"] diff --git a/src/testing/fast_array_utils/pytest.py b/src/testing/fast_array_utils/pytest.py index 9720d87..100769c 100644 --- a/src/testing/fast_array_utils/pytest.py +++ b/src/testing/fast_array_utils/pytest.py @@ -4,7 +4,6 @@ from __future__ import annotations import os -import re from importlib.util import find_spec from typing import TYPE_CHECKING, cast @@ -13,6 +12,8 @@ from fast_array_utils import types +from . import get_array_cls + if TYPE_CHECKING: from collections.abc import Generator @@ -20,9 +21,9 @@ from numpy.typing import ArrayLike, DTypeLike - from testing.fast_array_utils import _ToArray + from testing.fast_array_utils import ToArray - from . import _Array + from . import Array _SCT_co = TypeVar("_SCT_co", covariant=True, bound=np.generic) @@ -57,66 +58,24 @@ def array_cls_name(request: pytest.FixtureRequest) -> str: @pytest.fixture(scope="session") -def array_cls(array_cls_name: str) -> type[_Array[Any]]: +def array_cls(array_cls_name: str) -> type[Array[Any]]: """Fixture for a supported array class.""" return get_array_cls(array_cls_name) -def get_array_cls(qualname: str) -> type[_Array[Any]]: # noqa: PLR0911 - """Get a supported array class by qualname.""" - m = re.fullmatch( - r"(?P(?:\w+\.)*\w+)\.(?P[^\[]+)(?:\[(?P[\w.]+)\])?", qualname - ) - assert m - match m["mod"], m["name"], m["inner"]: - case "numpy", "ndarray", None: - return np.ndarray - case "scipy.sparse", ( - "csr_array" | "csc_array" | "csr_matrix" | "csc_matrix" - ) as cls_name, None: - import scipy.sparse - - return getattr(scipy.sparse, cls_name) # type: ignore[no-any-return] - case "cupy", "ndarray", None: - import cupy as cp - - return cp.ndarray # type: ignore[no-any-return] - case "cupyx.scipy.sparse", ("csr_matrix" | "csc_matrix") as cls_name, None: - import cupyx.scipy.sparse as cu_sparse - - return getattr(cu_sparse, cls_name) # type: ignore[no-any-return] - case "dask.array", cls_name, _: - if TYPE_CHECKING: - from dask.array.core import Array as DaskArray - else: - from dask.array import Array as DaskArray - - return DaskArray - case "h5py", "Dataset", _: - import h5py - - return h5py.Dataset # type: ignore[no-any-return] - case "zarr", "Array", _: - import zarr - - return zarr.Array - case _: - pytest.fail(f"Unknown array class: {qualname}") - - @pytest.fixture(scope="session") def to_array( - request: pytest.FixtureRequest, array_cls: type[_Array[_SCT_co]], array_cls_name: str -) -> _ToArray[_SCT_co]: + request: pytest.FixtureRequest, array_cls: type[Array[_SCT_co]], array_cls_name: str +) -> ToArray[_SCT_co]: """Fixture for conversion into a supported array.""" return get_to_array(array_cls, array_cls_name, request) def get_to_array( - array_cls: type[_Array[_SCT_co]], + array_cls: type[Array[_SCT_co]], array_cls_name: str | None = None, request: pytest.FixtureRequest | None = None, -) -> _ToArray[_SCT_co]: +) -> ToArray[_SCT_co]: """Create a function to convert to a supported array.""" if array_cls is np.ndarray: return np.asarray # type: ignore[return-value] @@ -144,7 +103,7 @@ def half_rounded_up(x: int) -> int: return tuple(half_rounded_up(x) for x in a) -def to_dask_array(array_cls_name: str) -> _ToArray[Any]: +def to_dask_array(array_cls_name: str) -> ToArray[Any]: """Convert to a dask array.""" if TYPE_CHECKING: import dask.array.core as da @@ -153,7 +112,7 @@ def to_dask_array(array_cls_name: str) -> _ToArray[Any]: inner_cls_name = array_cls_name.removeprefix("dask.array.Array[").removesuffix("]") inner_cls = get_array_cls(inner_cls_name) - to_array_fn: _ToArray[Any] = get_to_array(array_cls=inner_cls) + to_array_fn: ToArray[Any] = get_to_array(array_cls=inner_cls) def to_dask_array(x: ArrayLike, *, dtype: DTypeLike | None = None) -> types.DaskArray: x = np.asarray(x, dtype=dtype) @@ -167,7 +126,7 @@ def to_dask_array(x: ArrayLike, *, dtype: DTypeLike | None = None) -> types.Dask def to_h5py_dataset( tmp_path_factory: pytest.TempPathFactory, worker_id: str = "serial", -) -> Generator[_ToArray[Any], None, None]: +) -> Generator[ToArray[Any], None, None]: """Convert to a h5py dataset.""" import h5py diff --git a/tests/test_asarray.py b/tests/test_asarray.py index 1c855ef..ba98e68 100644 --- a/tests/test_asarray.py +++ b/tests/test_asarray.py @@ -13,10 +13,10 @@ from numpy.typing import NDArray - from testing.fast_array_utils import _ToArray + from testing.fast_array_utils import ToArray -def test_asarray(to_array: _ToArray[Any]) -> None: +def test_asarray(to_array: ToArray[Any]) -> None: x = to_array([[1, 2, 3], [4, 5, 6]]) arr: NDArray[Any] = asarray(x) # type: ignore[arg-type] assert isinstance(arr, np.ndarray) diff --git a/tests/test_sparse.py b/tests/test_sparse.py index dd1972e..4f260ca 100644 --- a/tests/test_sparse.py +++ b/tests/test_sparse.py @@ -8,15 +8,14 @@ import pytest from fast_array_utils.conv.scipy import to_dense +from testing.fast_array_utils import random_mat if TYPE_CHECKING: - from typing import Literal, SupportsFloat, TypeVar + from typing import Literal, TypeVar from pytest_codspeed import BenchmarkFixture - from fast_array_utils.types import CSBase - DType = TypeVar("DType", bound=np.generic) DType_float = TypeVar("DType_float", np.float32, np.float64) @@ -39,24 +38,6 @@ def dtype(request: pytest.FixtureRequest) -> np.dtype[np.float32 | np.float64]: return np.dtype(request.param) -def random_mat( - shape: tuple[int, int], - *, - density: SupportsFloat = 0.01, - format: Literal["csr", "csc"] = "csr", # noqa: A002 - dtype: np.dtype[DType_float] | None = None, - container: Literal["array", "matrix"] = "array", -) -> CSBase[DType_float]: - from scipy.sparse import random, random_array - - m, n = shape - return ( - random(m, n, density=density, format=format, dtype=dtype) - if container == "matrix" - else random_array(shape, density=density, format=format, dtype=dtype) - ) - - @pytest.mark.parametrize("order", ["C", "F"]) def test_to_dense( order: Literal["C", "F"], diff --git a/tests/test_stats.py b/tests/test_stats.py index 23a4ebb..91eb374 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -7,6 +7,8 @@ import numpy as np import pytest +from testing.fast_array_utils import random_array + if TYPE_CHECKING or find_spec("scipy"): from scipy.sparse import sparray, spmatrix @@ -19,24 +21,41 @@ if TYPE_CHECKING: from typing import Any, Literal - from testing.fast_array_utils import _Array, _ToArray + from pytest_codspeed import BenchmarkFixture + + from testing.fast_array_utils import Array, ToArray + + DTypeIn = type[np.float32 | np.float64 | np.int32 | np.bool_] + DTypeOut = type[np.float32 | np.float64 | np.int64] + + +@pytest.fixture(scope="session", params=[0, 1, None]) +def axis(request: pytest.FixtureRequest) -> Literal[0, 1, None]: + return request.param # type: ignore[no-any-return] + + +@pytest.fixture(scope="session", params=[np.float32, np.float64, np.int32, np.bool_]) +def dtype_in(request: pytest.FixtureRequest) -> DTypeIn: + return request.param # type: ignore[no-any-return] + + +@pytest.fixture(scope="session", params=[np.float32, np.float64, None]) +def dtype_arg(request: pytest.FixtureRequest) -> DTypeOut | None: + return request.param # type: ignore[no-any-return] -@pytest.mark.parametrize("dtype_in", [np.float32, np.float64, np.int32, np.bool_]) -@pytest.mark.parametrize("dtype_arg", [np.float32, np.float64, None]) -@pytest.mark.parametrize("axis", [0, 1, None]) def test_sum( - array_cls: type[_Array[Any]], - to_array: _ToArray[Any], - dtype_in: type[np.generic], - dtype_arg: type[np.generic] | None, + array_cls: type[Array[Any]], + to_array: ToArray[Any], + dtype_in: DTypeIn, + dtype_arg: DTypeOut | None, axis: Literal[0, 1, None], ) -> None: np_arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype_in) arr = to_array(np_arr.copy()) assert arr.dtype == dtype_in - sum_: _Array[Any] | np.floating = stats.sum(arr, axis=axis, dtype=dtype_arg) # type: ignore[type-arg,arg-type] + sum_: Array[Any] | np.floating = stats.sum(arr, axis=axis, dtype=dtype_arg) # type: ignore[type-arg,arg-type] match axis, arr: case _, types.DaskArray(): @@ -61,3 +80,21 @@ def test_sum( assert sum_.dtype == dtype_in np.testing.assert_array_equal(sum_, np.sum(np_arr, axis=axis, dtype=dtype_arg)) # type: ignore[arg-type] + + +@pytest.mark.benchmark +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) # random only supports float +def test_sum_benchmark( + benchmark: BenchmarkFixture, + array_cls_name: str, + axis: Literal[0, 1, None], + dtype: type[np.float32 | np.float64], +) -> None: + try: + shape = (1_000, 1_000) if "sparse" in array_cls_name else (100, 100) + arr = random_array(array_cls_name, shape, dtype=dtype) # type: ignore # noqa: PGH003 + except NotImplementedError: + pytest.skip("random_array not implemented for dtype") + + stats.sum(arr, axis=axis) # type: ignore[arg-type] # warmup: numba compile + benchmark(stats.sum, arr, axis=axis) diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index 74ceab8..cb69bcc 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -12,14 +12,14 @@ if TYPE_CHECKING: from typing import TypeVar - from testing.fast_array_utils import _Array, _ToArray + from testing.fast_array_utils import Array, ToArray DType_float = TypeVar("DType_float", np.float32, np.float64) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_conv( - array_cls: type[_Array[DType_float]], to_array: _ToArray[DType_float], dtype: DType_float + array_cls: type[Array[DType_float]], to_array: ToArray[DType_float], dtype: DType_float ) -> None: arr = to_array(np.arange(12).reshape(3, 4), dtype=dtype) assert isinstance(arr, array_cls) From f3f683c73aed69c545586c60bbc54b84b4aa3f98 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 18 Feb 2025 13:41:03 +0100 Subject: [PATCH 10/11] Test PR to see sum speed (#17) --- tests/test_stats.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_stats.py b/tests/test_stats.py index 91eb374..818c8cf 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -96,5 +96,10 @@ def test_sum_benchmark( except NotImplementedError: pytest.skip("random_array not implemented for dtype") - stats.sum(arr, axis=axis) # type: ignore[arg-type] # warmup: numba compile - benchmark(stats.sum, arr, axis=axis) + def sum(arr: Array[Any], axis: int | None) -> Array[Any]: # noqa: A001 + if hasattr(arr, "sum"): + return arr.sum(axis=axis) + return np.sum(arr, axis=axis) # type: ignore[arg-type] + + sum(arr, axis=axis) # warmup: numba compile + benchmark(sum, arr, axis=axis) From fbd1d7d54b5815a47c4bb4dc50295f957ab3adea Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 18 Feb 2025 13:51:02 +0100 Subject: [PATCH 11/11] Use our sum (#18) --- tests/test_stats.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/test_stats.py b/tests/test_stats.py index 818c8cf..91eb374 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -96,10 +96,5 @@ def test_sum_benchmark( except NotImplementedError: pytest.skip("random_array not implemented for dtype") - def sum(arr: Array[Any], axis: int | None) -> Array[Any]: # noqa: A001 - if hasattr(arr, "sum"): - return arr.sum(axis=axis) - return np.sum(arr, axis=axis) # type: ignore[arg-type] - - sum(arr, axis=axis) # warmup: numba compile - benchmark(sum, arr, axis=axis) + stats.sum(arr, axis=axis) # type: ignore[arg-type] # warmup: numba compile + benchmark(stats.sum, arr, axis=axis)