|
| 1 | +# SPDX-License-Identifier: MPL-2.0 |
| 2 | +"""Testing utilities.""" |
| 3 | + |
| 4 | +from __future__ import annotations |
| 5 | + |
| 6 | +import os |
| 7 | +import re |
| 8 | +from importlib.util import find_spec |
| 9 | +from typing import TYPE_CHECKING, cast |
| 10 | + |
| 11 | +import numpy as np |
| 12 | +import pytest |
| 13 | + |
| 14 | +from fast_array_utils import types |
| 15 | + |
| 16 | + |
| 17 | +if TYPE_CHECKING: |
| 18 | + from collections.abc import Generator |
| 19 | + from typing import Any, TypeVar |
| 20 | + |
| 21 | + from numpy.typing import ArrayLike, DTypeLike |
| 22 | + |
| 23 | + from testing.fast_array_utils import _ToArray |
| 24 | + |
| 25 | + from . import _Array |
| 26 | + |
| 27 | + _SCT_co = TypeVar("_SCT_co", covariant=True, bound=np.generic) |
| 28 | + |
| 29 | + |
| 30 | +def _skip_if_no(dist: str) -> pytest.MarkDecorator: |
| 31 | + return pytest.mark.skipif(not find_spec(dist), reason=f"{dist} not installed") |
| 32 | + |
| 33 | + |
| 34 | +@pytest.fixture( |
| 35 | + scope="session", |
| 36 | + params=[ |
| 37 | + pytest.param("numpy.ndarray"), |
| 38 | + pytest.param("scipy.sparse.csr_array", marks=_skip_if_no("scipy")), |
| 39 | + pytest.param("scipy.sparse.csc_array", marks=_skip_if_no("scipy")), |
| 40 | + pytest.param("scipy.sparse.csr_matrix", marks=_skip_if_no("scipy")), |
| 41 | + pytest.param("scipy.sparse.csc_matrix", marks=_skip_if_no("scipy")), |
| 42 | + pytest.param("dask.array.Array[numpy.ndarray]", marks=_skip_if_no("dask")), |
| 43 | + pytest.param("dask.array.Array[scipy.sparse.csr_array]", marks=_skip_if_no("dask")), |
| 44 | + pytest.param("dask.array.Array[scipy.sparse.csc_array]", marks=_skip_if_no("dask")), |
| 45 | + pytest.param("dask.array.Array[scipy.sparse.csr_matrix]", marks=_skip_if_no("dask")), |
| 46 | + pytest.param("dask.array.Array[scipy.sparse.csc_matrix]", marks=_skip_if_no("dask")), |
| 47 | + pytest.param("h5py.Dataset", marks=_skip_if_no("h5py")), |
| 48 | + pytest.param("zarr.Array", marks=_skip_if_no("zarr")), |
| 49 | + pytest.param("cupy.ndarray", marks=_skip_if_no("cupy")), |
| 50 | + pytest.param("cupyx.scipy.sparse.csr_matrix", marks=_skip_if_no("cupy")), |
| 51 | + pytest.param("cupyx.scipy.sparse.csc_matrix", marks=_skip_if_no("cupy")), |
| 52 | + ], |
| 53 | +) |
| 54 | +def array_cls_name(request: pytest.FixtureRequest) -> str: |
| 55 | + """Fixture for a supported array class.""" |
| 56 | + return cast(str, request.param) |
| 57 | + |
| 58 | + |
| 59 | +@pytest.fixture(scope="session") |
| 60 | +def array_cls(array_cls_name: str) -> type[_Array[Any]]: |
| 61 | + """Fixture for a supported array class.""" |
| 62 | + return get_array_cls(array_cls_name) |
| 63 | + |
| 64 | + |
| 65 | +def get_array_cls(qualname: str) -> type[_Array[Any]]: # noqa: PLR0911 |
| 66 | + """Get a supported array class by qualname.""" |
| 67 | + m = re.fullmatch( |
| 68 | + r"(?P<mod>(?:\w+\.)*\w+)\.(?P<name>[^\[]+)(?:\[(?P<inner>[\w.]+)\])?", qualname |
| 69 | + ) |
| 70 | + assert m |
| 71 | + match m["mod"], m["name"], m["inner"]: |
| 72 | + case "numpy", "ndarray", None: |
| 73 | + return np.ndarray |
| 74 | + case "scipy.sparse", ( |
| 75 | + "csr_array" | "csc_array" | "csr_matrix" | "csc_matrix" |
| 76 | + ) as cls_name, None: |
| 77 | + import scipy.sparse |
| 78 | + |
| 79 | + return getattr(scipy.sparse, cls_name) # type: ignore[no-any-return] |
| 80 | + case "cupy", "ndarray", None: |
| 81 | + import cupy as cp |
| 82 | + |
| 83 | + return cp.ndarray # type: ignore[no-any-return] |
| 84 | + case "cupyx.scipy.sparse", ("csr_matrix" | "csc_matrix") as cls_name, None: |
| 85 | + import cupyx.scipy.sparse as cu_sparse |
| 86 | + |
| 87 | + return getattr(cu_sparse, cls_name) # type: ignore[no-any-return] |
| 88 | + case "dask.array", cls_name, _: |
| 89 | + if TYPE_CHECKING: |
| 90 | + from dask.array.core import Array as DaskArray |
| 91 | + else: |
| 92 | + from dask.array import Array as DaskArray |
| 93 | + |
| 94 | + return DaskArray |
| 95 | + case "h5py", "Dataset", _: |
| 96 | + import h5py |
| 97 | + |
| 98 | + return h5py.Dataset # type: ignore[no-any-return] |
| 99 | + case "zarr", "Array", _: |
| 100 | + import zarr |
| 101 | + |
| 102 | + return zarr.Array |
| 103 | + case _: |
| 104 | + pytest.fail(f"Unknown array class: {qualname}") |
| 105 | + |
| 106 | + |
| 107 | +@pytest.fixture(scope="session") |
| 108 | +def to_array( |
| 109 | + request: pytest.FixtureRequest, array_cls: type[_Array[_SCT_co]], array_cls_name: str |
| 110 | +) -> _ToArray[_SCT_co]: |
| 111 | + """Fixture for conversion into a supported array.""" |
| 112 | + return get_to_array(array_cls, array_cls_name, request) |
| 113 | + |
| 114 | + |
| 115 | +def get_to_array( |
| 116 | + array_cls: type[_Array[_SCT_co]], |
| 117 | + array_cls_name: str | None = None, |
| 118 | + request: pytest.FixtureRequest | None = None, |
| 119 | +) -> _ToArray[_SCT_co]: |
| 120 | + """Create a function to convert to a supported array.""" |
| 121 | + if array_cls is np.ndarray: |
| 122 | + return np.asarray # type: ignore[return-value] |
| 123 | + if array_cls is types.DaskArray: |
| 124 | + assert array_cls_name is not None |
| 125 | + return to_dask_array(array_cls_name) |
| 126 | + if array_cls is types.H5Dataset: |
| 127 | + assert request is not None |
| 128 | + return request.getfixturevalue("to_h5py_dataset") # type: ignore[no-any-return] |
| 129 | + if array_cls is types.ZarrArray: |
| 130 | + return to_zarr_array |
| 131 | + if array_cls is types.CupyArray: |
| 132 | + import cupy as cu |
| 133 | + |
| 134 | + return cu.asarray # type: ignore[no-any-return] |
| 135 | + |
| 136 | + return array_cls # type: ignore[return-value] |
| 137 | + |
| 138 | + |
| 139 | +def _half_chunk_size(a: tuple[int, ...]) -> tuple[int, ...]: |
| 140 | + def half_rounded_up(x: int) -> int: |
| 141 | + div, mod = divmod(x, 2) |
| 142 | + return div + (mod > 0) |
| 143 | + |
| 144 | + return tuple(half_rounded_up(x) for x in a) |
| 145 | + |
| 146 | + |
| 147 | +def to_dask_array(array_cls_name: str) -> _ToArray[Any]: |
| 148 | + """Convert to a dask array.""" |
| 149 | + if TYPE_CHECKING: |
| 150 | + import dask.array.core as da |
| 151 | + else: |
| 152 | + import dask.array as da |
| 153 | + |
| 154 | + inner_cls_name = array_cls_name.removeprefix("dask.array.Array[").removesuffix("]") |
| 155 | + inner_cls = get_array_cls(inner_cls_name) |
| 156 | + to_array_fn: _ToArray[Any] = get_to_array(array_cls=inner_cls) |
| 157 | + |
| 158 | + def to_dask_array(x: ArrayLike, *, dtype: DTypeLike | None = None) -> types.DaskArray: |
| 159 | + x = np.asarray(x, dtype=dtype) |
| 160 | + return da.from_array(to_array_fn(x), _half_chunk_size(x.shape)) # type: ignore[no-untyped-call,no-any-return] |
| 161 | + |
| 162 | + return to_dask_array |
| 163 | + |
| 164 | + |
| 165 | +@pytest.fixture(scope="session") |
| 166 | +# worker_id for xdist since we don't want to override open files |
| 167 | +def to_h5py_dataset( |
| 168 | + tmp_path_factory: pytest.TempPathFactory, |
| 169 | + worker_id: str = "serial", |
| 170 | +) -> Generator[_ToArray[Any], None, None]: |
| 171 | + """Convert to a h5py dataset.""" |
| 172 | + import h5py |
| 173 | + |
| 174 | + tmp_path = tmp_path_factory.mktemp("backed_adata") |
| 175 | + tmp_path = tmp_path / f"test_{worker_id}.h5ad" |
| 176 | + |
| 177 | + with h5py.File(tmp_path, "x") as f: |
| 178 | + |
| 179 | + def to_h5py_dataset(x: ArrayLike, *, dtype: DTypeLike | None = None) -> types.H5Dataset: |
| 180 | + arr = np.asarray(x, dtype=dtype) |
| 181 | + test_name = os.environ["PYTEST_CURRENT_TEST"].rsplit(":", 1)[-1].split(" ", 1)[0] |
| 182 | + return f.create_dataset(test_name, arr.shape, arr.dtype, data=arr) |
| 183 | + |
| 184 | + yield to_h5py_dataset |
| 185 | + |
| 186 | + |
| 187 | +def to_zarr_array(x: ArrayLike, *, dtype: DTypeLike | None = None) -> types.ZarrArray: |
| 188 | + """Convert to a zarr array.""" |
| 189 | + import zarr |
| 190 | + |
| 191 | + arr = np.asarray(x, dtype=dtype) |
| 192 | + za = zarr.create_array({}, shape=arr.shape, dtype=arr.dtype) |
| 193 | + za[...] = arr |
| 194 | + return za |
0 commit comments