diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 651940f..11f3909 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,5 +32,8 @@ repos: - numpy - scipy - types-scipy-sparse + - dask + - zarr + - h5py ci: skip: [mypy] # too big diff --git a/pyproject.toml b/pyproject.toml index 67fe21d..0f6af6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ extras = [ "min", "full" ] [tool.ruff] line-length = 100 +namespace-packages = [ "src/testing" ] lint.select = [ "ALL" ] lint.ignore = [ "A005", # submodules never shadow builtins. @@ -74,6 +75,7 @@ lint.ignore = [ "TID252", # relative imports are fine ] lint.per-file-ignores."docs/**/*.py" = [ "INP001" ] # No __init__.py in docs +lint.per-file-ignores."src/**/stats/*.py" = [ "A001", "A004" ] # Shadows builtins like `sum` lint.per-file-ignores."stubs/**/*.pyi" = [ "F403", "F405", "N801" ] # Stubs don’t follow name conventions lint.per-file-ignores."tests/**/test_*.py" = [ "D100", # tests need no module docstrings @@ -95,6 +97,7 @@ addopts = [ "--import-mode=importlib", "--strict-markers", "--pyargs", + "-ptesting.fast_array_utils.pytest", ] filterwarnings = [ "error", diff --git a/src/fast_array_utils/__init__.py b/src/fast_array_utils/__init__.py index 8d55e48..c203223 100644 --- a/src/fast_array_utils/__init__.py +++ b/src/fast_array_utils/__init__.py @@ -3,7 +3,9 @@ from __future__ import annotations -from . import conv, types +from . import _patches, conv, types __all__ = ["conv", "types"] + +_patches.patch_dask() diff --git a/src/fast_array_utils/_patches.py b/src/fast_array_utils/_patches.py new file mode 100644 index 0000000..1e7fc71 --- /dev/null +++ b/src/fast_array_utils/_patches.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: MPL-2.0 +from __future__ import annotations + +import numpy as np + + +# TODO(flying-sheep): upstream +# https://github.com/dask/dask/issues/11749 +def patch_dask() -> None: + """Patch dask to support sparse arrays. + + See + """ + try: + # Other lookup candidates: tensordot_lookup and take_lookup + from dask.array.dispatch import concatenate_lookup + from scipy.sparse import sparray, spmatrix + except ImportError: + return # No need to patch if dask or scipy is not installed + + # Avoid patch if already patched or upstream support has been added + if concatenate_lookup.dispatch(sparray) is not np.concatenate: # type: ignore[no-untyped-call] + return + + concatenate = concatenate_lookup.dispatch(spmatrix) # type: ignore[no-untyped-call] + concatenate_lookup.register(sparray, concatenate) # type: ignore[no-untyped-call] diff --git a/src/fast_array_utils/conv/_asarray.py b/src/fast_array_utils/conv/_asarray.py index 1c121bd..9b25669 100644 --- a/src/fast_array_utils/conv/_asarray.py +++ b/src/fast_array_utils/conv/_asarray.py @@ -46,8 +46,8 @@ def _(x: CSBase[DT_co]) -> NDArray[DT_co]: @asarray.register(DaskArray) -def _(x: DaskArray[DT_co]) -> NDArray[DT_co]: - return asarray(x.compute()) +def _(x: DaskArray) -> NDArray[DT_co]: + return asarray(x.compute()) # type: ignore[no-untyped-call] @asarray.register(OutOfCoreDataset) diff --git a/src/fast_array_utils/types.py b/src/fast_array_utils/types.py index c1fe15a..3898f56 100644 --- a/src/fast_array_utils/types.py +++ b/src/fast_array_utils/types.py @@ -53,31 +53,33 @@ CSBase = CSMatrix | CSArray -if find_spec("cupy") or TYPE_CHECKING: +if TYPE_CHECKING or find_spec("cupy"): from cupy import ndarray as CupyArray else: CupyArray = type("ndarray", (), {}) -if find_spec("cupyx") or TYPE_CHECKING: +if TYPE_CHECKING or find_spec("cupyx"): from cupyx.scipy.sparse import spmatrix as CupySparseMatrix else: CupySparseMatrix = type("spmatrix", (), {}) -if find_spec("dask") or TYPE_CHECKING: +if TYPE_CHECKING: # https://github.com/dask/dask/issues/8853 + from dask.array.core import Array as DaskArray +elif find_spec("dask"): from dask.array import Array as DaskArray else: DaskArray = type("array", (), {}) -if find_spec("h5py") or TYPE_CHECKING: +if TYPE_CHECKING or find_spec("h5py"): from h5py import Dataset as H5Dataset else: H5Dataset = type("Dataset", (), {}) -if find_spec("zarr") or TYPE_CHECKING: +if TYPE_CHECKING or find_spec("zarr"): from zarr import Array as ZarrArray else: ZarrArray = type("Array", (), {}) diff --git a/src/testing/fast_array_utils/__init__.pyi b/src/testing/fast_array_utils/__init__.pyi new file mode 100644 index 0000000..f54511a --- /dev/null +++ b/src/testing/fast_array_utils/__init__.pyi @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: MPL-2.0 +from typing import Generic, Protocol, TypeAlias, TypeVar + +import numpy as np +from numpy.typing import ArrayLike, NDArray + +from fast_array_utils import types + +_SCT_co = TypeVar("_SCT_co", covariant=True, bound=np.generic) +_SCT_contra = TypeVar("_SCT_contra", contravariant=True, bound=np.generic) + +_Array: TypeAlias = ( + NDArray[_SCT_co] + | types.CSBase[_SCT_co] + | types.CupyArray[_SCT_co] + | types.DaskArray + | types.H5Dataset + | types.ZarrArray +) + +class _ToArray(Protocol, Generic[_SCT_contra]): + def __call__( + self, data: ArrayLike, /, *, dtype: _SCT_contra | None = None + ) -> _Array[_SCT_contra]: ... + +__all__ = ["_Array", "_ToArray"] diff --git a/src/testing/fast_array_utils/pytest.py b/src/testing/fast_array_utils/pytest.py new file mode 100644 index 0000000..9720d87 --- /dev/null +++ b/src/testing/fast_array_utils/pytest.py @@ -0,0 +1,194 @@ +# SPDX-License-Identifier: MPL-2.0 +"""Testing utilities.""" + +from __future__ import annotations + +import os +import re +from importlib.util import find_spec +from typing import TYPE_CHECKING, cast + +import numpy as np +import pytest + +from fast_array_utils import types + + +if TYPE_CHECKING: + from collections.abc import Generator + from typing import Any, TypeVar + + from numpy.typing import ArrayLike, DTypeLike + + from testing.fast_array_utils import _ToArray + + from . import _Array + + _SCT_co = TypeVar("_SCT_co", covariant=True, bound=np.generic) + + +def _skip_if_no(dist: str) -> pytest.MarkDecorator: + return pytest.mark.skipif(not find_spec(dist), reason=f"{dist} not installed") + + +@pytest.fixture( + scope="session", + params=[ + pytest.param("numpy.ndarray"), + pytest.param("scipy.sparse.csr_array", marks=_skip_if_no("scipy")), + pytest.param("scipy.sparse.csc_array", marks=_skip_if_no("scipy")), + pytest.param("scipy.sparse.csr_matrix", marks=_skip_if_no("scipy")), + pytest.param("scipy.sparse.csc_matrix", marks=_skip_if_no("scipy")), + pytest.param("dask.array.Array[numpy.ndarray]", marks=_skip_if_no("dask")), + pytest.param("dask.array.Array[scipy.sparse.csr_array]", marks=_skip_if_no("dask")), + pytest.param("dask.array.Array[scipy.sparse.csc_array]", marks=_skip_if_no("dask")), + pytest.param("dask.array.Array[scipy.sparse.csr_matrix]", marks=_skip_if_no("dask")), + pytest.param("dask.array.Array[scipy.sparse.csc_matrix]", marks=_skip_if_no("dask")), + pytest.param("h5py.Dataset", marks=_skip_if_no("h5py")), + pytest.param("zarr.Array", marks=_skip_if_no("zarr")), + pytest.param("cupy.ndarray", marks=_skip_if_no("cupy")), + pytest.param("cupyx.scipy.sparse.csr_matrix", marks=_skip_if_no("cupy")), + pytest.param("cupyx.scipy.sparse.csc_matrix", marks=_skip_if_no("cupy")), + ], +) +def array_cls_name(request: pytest.FixtureRequest) -> str: + """Fixture for a supported array class.""" + return cast(str, request.param) + + +@pytest.fixture(scope="session") +def array_cls(array_cls_name: str) -> type[_Array[Any]]: + """Fixture for a supported array class.""" + return get_array_cls(array_cls_name) + + +def get_array_cls(qualname: str) -> type[_Array[Any]]: # noqa: PLR0911 + """Get a supported array class by qualname.""" + m = re.fullmatch( + r"(?P(?:\w+\.)*\w+)\.(?P[^\[]+)(?:\[(?P[\w.]+)\])?", qualname + ) + assert m + match m["mod"], m["name"], m["inner"]: + case "numpy", "ndarray", None: + return np.ndarray + case "scipy.sparse", ( + "csr_array" | "csc_array" | "csr_matrix" | "csc_matrix" + ) as cls_name, None: + import scipy.sparse + + return getattr(scipy.sparse, cls_name) # type: ignore[no-any-return] + case "cupy", "ndarray", None: + import cupy as cp + + return cp.ndarray # type: ignore[no-any-return] + case "cupyx.scipy.sparse", ("csr_matrix" | "csc_matrix") as cls_name, None: + import cupyx.scipy.sparse as cu_sparse + + return getattr(cu_sparse, cls_name) # type: ignore[no-any-return] + case "dask.array", cls_name, _: + if TYPE_CHECKING: + from dask.array.core import Array as DaskArray + else: + from dask.array import Array as DaskArray + + return DaskArray + case "h5py", "Dataset", _: + import h5py + + return h5py.Dataset # type: ignore[no-any-return] + case "zarr", "Array", _: + import zarr + + return zarr.Array + case _: + pytest.fail(f"Unknown array class: {qualname}") + + +@pytest.fixture(scope="session") +def to_array( + request: pytest.FixtureRequest, array_cls: type[_Array[_SCT_co]], array_cls_name: str +) -> _ToArray[_SCT_co]: + """Fixture for conversion into a supported array.""" + return get_to_array(array_cls, array_cls_name, request) + + +def get_to_array( + array_cls: type[_Array[_SCT_co]], + array_cls_name: str | None = None, + request: pytest.FixtureRequest | None = None, +) -> _ToArray[_SCT_co]: + """Create a function to convert to a supported array.""" + if array_cls is np.ndarray: + return np.asarray # type: ignore[return-value] + if array_cls is types.DaskArray: + assert array_cls_name is not None + return to_dask_array(array_cls_name) + if array_cls is types.H5Dataset: + assert request is not None + return request.getfixturevalue("to_h5py_dataset") # type: ignore[no-any-return] + if array_cls is types.ZarrArray: + return to_zarr_array + if array_cls is types.CupyArray: + import cupy as cu + + return cu.asarray # type: ignore[no-any-return] + + return array_cls # type: ignore[return-value] + + +def _half_chunk_size(a: tuple[int, ...]) -> tuple[int, ...]: + def half_rounded_up(x: int) -> int: + div, mod = divmod(x, 2) + return div + (mod > 0) + + return tuple(half_rounded_up(x) for x in a) + + +def to_dask_array(array_cls_name: str) -> _ToArray[Any]: + """Convert to a dask array.""" + if TYPE_CHECKING: + import dask.array.core as da + else: + import dask.array as da + + inner_cls_name = array_cls_name.removeprefix("dask.array.Array[").removesuffix("]") + inner_cls = get_array_cls(inner_cls_name) + to_array_fn: _ToArray[Any] = get_to_array(array_cls=inner_cls) + + def to_dask_array(x: ArrayLike, *, dtype: DTypeLike | None = None) -> types.DaskArray: + x = np.asarray(x, dtype=dtype) + return da.from_array(to_array_fn(x), _half_chunk_size(x.shape)) # type: ignore[no-untyped-call,no-any-return] + + return to_dask_array + + +@pytest.fixture(scope="session") +# worker_id for xdist since we don't want to override open files +def to_h5py_dataset( + tmp_path_factory: pytest.TempPathFactory, + worker_id: str = "serial", +) -> Generator[_ToArray[Any], None, None]: + """Convert to a h5py dataset.""" + import h5py + + tmp_path = tmp_path_factory.mktemp("backed_adata") + tmp_path = tmp_path / f"test_{worker_id}.h5ad" + + with h5py.File(tmp_path, "x") as f: + + def to_h5py_dataset(x: ArrayLike, *, dtype: DTypeLike | None = None) -> types.H5Dataset: + arr = np.asarray(x, dtype=dtype) + test_name = os.environ["PYTEST_CURRENT_TEST"].rsplit(":", 1)[-1].split(" ", 1)[0] + return f.create_dataset(test_name, arr.shape, arr.dtype, data=arr) + + yield to_h5py_dataset + + +def to_zarr_array(x: ArrayLike, *, dtype: DTypeLike | None = None) -> types.ZarrArray: + """Convert to a zarr array.""" + import zarr + + arr = np.asarray(x, dtype=dtype) + za = zarr.create_array({}, shape=arr.shape, dtype=arr.dtype) + za[...] = arr + return za diff --git a/tests/test_asarray.py b/tests/test_asarray.py index f9863fc..1c855ef 100644 --- a/tests/test_asarray.py +++ b/tests/test_asarray.py @@ -1,115 +1,23 @@ # SPDX-License-Identifier: MPL-2.0 from __future__ import annotations -from importlib.util import find_spec -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING import numpy as np -import pytest from fast_array_utils.conv import asarray if TYPE_CHECKING: - from collections.abc import Callable, Generator - from typing import Any, TypeAlias + from typing import Any - from numpy.typing import ArrayLike, NDArray + from numpy.typing import NDArray - from fast_array_utils import types + from testing.fast_array_utils import _ToArray - SupportedArray: TypeAlias = ( - NDArray[Any] - | types.DaskArray - | types.H5Dataset - | types.ZarrArray - | types.CupyArray - | types.CupySparseMatrix - ) - -def skip_if_no(dist: str) -> pytest.MarkDecorator: - return pytest.mark.skipif(not find_spec(dist), reason=f"{dist} not installed") - - -@pytest.fixture(scope="session") -# worker_id for xdist since we don't want to override open files -def to_h5py_dataset( - tmp_path_factory: pytest.TempPathFactory, worker_id: str = "serial" -) -> Generator[Callable[[ArrayLike], types.H5Dataset], None, None]: - import h5py - - tmp_path = tmp_path_factory.mktemp("backed_adata") - tmp_path = tmp_path / f"test_{worker_id}.h5ad" - - with h5py.File(tmp_path, "x") as f: - - def to_h5py_dataset(x: ArrayLike) -> types.H5Dataset: - arr = np.asarray(x) - return f.create_dataset("data", arr.shape, arr.dtype) - - yield to_h5py_dataset - - -def to_zarr_array(x: ArrayLike) -> types.ZarrArray: - import zarr - - arr = np.asarray(x) - za = zarr.create_array({}, shape=arr.shape, dtype=arr.dtype) - za[...] = arr - return za - - -@pytest.fixture( - scope="session", - params=[ - pytest.param("numpy.ndarray"), - pytest.param("scipy.sparse.csr_array", marks=skip_if_no("scipy")), - pytest.param("scipy.sparse.csc_array", marks=skip_if_no("scipy")), - pytest.param("scipy.sparse.csr_matrix", marks=skip_if_no("scipy")), - pytest.param("scipy.sparse.csc_matrix", marks=skip_if_no("scipy")), - pytest.param("dask.array.Array", marks=skip_if_no("dask")), - pytest.param("h5py.Dataset", marks=skip_if_no("h5py")), - pytest.param("zarr.Array", marks=skip_if_no("zarr")), - pytest.param("cupy.ndarray", marks=skip_if_no("cupy")), - pytest.param("cupyx.scipy.sparse.csr_matrix", marks=skip_if_no("cupy")), - pytest.param("cupyx.scipy.sparse.csc_matrix", marks=skip_if_no("cupy")), - ], -) -def array_cls( # noqa: PLR0911 - request: pytest.FixtureRequest, -) -> Callable[[ArrayLike], SupportedArray]: - qualname = cast(str, request.param) - match qualname.split("."): - case "numpy", "ndarray": - return np.asarray - case "scipy", "sparse", ("csr_array" | "csc_array" | "csr_matrix" | "csc_matrix") as n: - import scipy.sparse - - return getattr(scipy.sparse, n) # type: ignore[no-any-return] - case "dask", "array", "Array": - import dask.array as da - - return da.asarray # type: ignore[no-any-return] - case "h5py", "Dataset": - return request.getfixturevalue("to_h5py_dataset") # type: ignore[no-any-return] - case "zarr", "Array": - return to_zarr_array - case "cupy", "ndarray": - import cupy - - return cupy.asarray # type: ignore[no-any-return] - case "cupyx", "scipy", "sparse", ("csr_matrix" | "csc_matrix") as n: - import cupyx.scipy.sparse - - return getattr(cupyx.scipy.sparse, n) # type: ignore[no-any-return] - case _: - msg = f"Unknown array type: {qualname}" - raise AssertionError(msg) - - -def test_asarray(array_cls: Callable[[ArrayLike], SupportedArray]) -> None: - x = array_cls([[1, 2, 3], [4, 5, 6]]) - arr: NDArray[Any] = asarray(x) +def test_asarray(to_array: _ToArray[Any]) -> None: + x = to_array([[1, 2, 3], [4, 5, 6]]) + arr: NDArray[Any] = asarray(x) # type: ignore[arg-type] assert isinstance(arr, np.ndarray) assert arr.shape == (2, 3) diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py new file mode 100644 index 0000000..74ceab8 --- /dev/null +++ b/tests/test_test_utils.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: MPL-2.0 +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +from fast_array_utils import types + + +if TYPE_CHECKING: + from typing import TypeVar + + from testing.fast_array_utils import _Array, _ToArray + + DType_float = TypeVar("DType_float", np.float32, np.float64) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_conv( + array_cls: type[_Array[DType_float]], to_array: _ToArray[DType_float], dtype: DType_float +) -> None: + arr = to_array(np.arange(12).reshape(3, 4), dtype=dtype) + assert isinstance(arr, array_cls) + if isinstance(arr, types.DaskArray): + arr = arr.compute() # type: ignore[no-untyped-call] + elif isinstance(arr, types.CupyArray): + arr = arr.get() + assert arr.shape == (3, 4) + assert arr.dtype == dtype