Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,8 @@ repos:
- numpy
- scipy
- types-scipy-sparse
- dask
- zarr
- h5py
ci:
skip: [mypy] # too big
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ extras = [ "min", "full" ]

[tool.ruff]
line-length = 100
namespace-packages = [ "src/testing" ]
lint.select = [ "ALL" ]
lint.ignore = [
"A005", # submodules never shadow builtins.
Expand All @@ -74,6 +75,7 @@ lint.ignore = [
"TID252", # relative imports are fine
]
lint.per-file-ignores."docs/**/*.py" = [ "INP001" ] # No __init__.py in docs
lint.per-file-ignores."src/**/stats/*.py" = [ "A001", "A004" ] # Shadows builtins like `sum`
lint.per-file-ignores."stubs/**/*.pyi" = [ "F403", "F405", "N801" ] # Stubs don’t follow name conventions
lint.per-file-ignores."tests/**/test_*.py" = [
"D100", # tests need no module docstrings
Expand All @@ -95,6 +97,7 @@ addopts = [
"--import-mode=importlib",
"--strict-markers",
"--pyargs",
"-ptesting.fast_array_utils.pytest",
]
filterwarnings = [
"error",
Expand Down
4 changes: 3 additions & 1 deletion src/fast_array_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from __future__ import annotations

from . import conv, types
from . import _patches, conv, types


__all__ = ["conv", "types"]

_patches.patch_dask()
26 changes: 26 additions & 0 deletions src/fast_array_utils/_patches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations

import numpy as np


# TODO(flying-sheep): upstream
# https://github.com/dask/dask/issues/11749
def patch_dask() -> None:
"""Patch dask to support sparse arrays.

See <https://github.com/dask/dask/blob/4d71629d1f22ced0dd780919f22e70a642ec6753/dask/array/backends.py#L212-L232>
"""
try:
# Other lookup candidates: tensordot_lookup and take_lookup
from dask.array.dispatch import concatenate_lookup
from scipy.sparse import sparray, spmatrix
except ImportError:
return # No need to patch if dask or scipy is not installed

# Avoid patch if already patched or upstream support has been added
if concatenate_lookup.dispatch(sparray) is not np.concatenate: # type: ignore[no-untyped-call]
return

Check warning on line 23 in src/fast_array_utils/_patches.py

View check run for this annotation

Codecov / codecov/patch

src/fast_array_utils/_patches.py#L23

Added line #L23 was not covered by tests

concatenate = concatenate_lookup.dispatch(spmatrix) # type: ignore[no-untyped-call]
concatenate_lookup.register(sparray, concatenate) # type: ignore[no-untyped-call]
4 changes: 2 additions & 2 deletions src/fast_array_utils/conv/_asarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def _(x: CSBase[DT_co]) -> NDArray[DT_co]:


@asarray.register(DaskArray)
def _(x: DaskArray[DT_co]) -> NDArray[DT_co]:
return asarray(x.compute())
def _(x: DaskArray) -> NDArray[DT_co]:
return asarray(x.compute()) # type: ignore[no-untyped-call]


@asarray.register(OutOfCoreDataset)
Expand Down
12 changes: 7 additions & 5 deletions src/fast_array_utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,31 +53,33 @@
CSBase = CSMatrix | CSArray


if find_spec("cupy") or TYPE_CHECKING:
if TYPE_CHECKING or find_spec("cupy"):
from cupy import ndarray as CupyArray
else:
CupyArray = type("ndarray", (), {})


if find_spec("cupyx") or TYPE_CHECKING:
if TYPE_CHECKING or find_spec("cupyx"):
from cupyx.scipy.sparse import spmatrix as CupySparseMatrix
else:
CupySparseMatrix = type("spmatrix", (), {})


if find_spec("dask") or TYPE_CHECKING:
if TYPE_CHECKING: # https://github.com/dask/dask/issues/8853
from dask.array.core import Array as DaskArray

Check warning on line 69 in src/fast_array_utils/types.py

View check run for this annotation

Codecov / codecov/patch

src/fast_array_utils/types.py#L69

Added line #L69 was not covered by tests
elif find_spec("dask"):
from dask.array import Array as DaskArray
else:
DaskArray = type("array", (), {})


if find_spec("h5py") or TYPE_CHECKING:
if TYPE_CHECKING or find_spec("h5py"):
from h5py import Dataset as H5Dataset
else:
H5Dataset = type("Dataset", (), {})


if find_spec("zarr") or TYPE_CHECKING:
if TYPE_CHECKING or find_spec("zarr"):
from zarr import Array as ZarrArray
else:
ZarrArray = type("Array", (), {})
Expand Down
26 changes: 26 additions & 0 deletions src/testing/fast_array_utils/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# SPDX-License-Identifier: MPL-2.0
from typing import Generic, Protocol, TypeAlias, TypeVar

import numpy as np
from numpy.typing import ArrayLike, NDArray

from fast_array_utils import types

_SCT_co = TypeVar("_SCT_co", covariant=True, bound=np.generic)
_SCT_contra = TypeVar("_SCT_contra", contravariant=True, bound=np.generic)

_Array: TypeAlias = (
NDArray[_SCT_co]
| types.CSBase[_SCT_co]
| types.CupyArray[_SCT_co]
| types.DaskArray
| types.H5Dataset
| types.ZarrArray
)

class _ToArray(Protocol, Generic[_SCT_contra]):
def __call__(
self, data: ArrayLike, /, *, dtype: _SCT_contra | None = None
) -> _Array[_SCT_contra]: ...

__all__ = ["_Array", "_ToArray"]
194 changes: 194 additions & 0 deletions src/testing/fast_array_utils/pytest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# SPDX-License-Identifier: MPL-2.0
"""Testing utilities."""

from __future__ import annotations

import os
import re
from importlib.util import find_spec
from typing import TYPE_CHECKING, cast

import numpy as np
import pytest

from fast_array_utils import types


if TYPE_CHECKING:
from collections.abc import Generator
from typing import Any, TypeVar

from numpy.typing import ArrayLike, DTypeLike

from testing.fast_array_utils import _ToArray

from . import _Array

_SCT_co = TypeVar("_SCT_co", covariant=True, bound=np.generic)


def _skip_if_no(dist: str) -> pytest.MarkDecorator:
return pytest.mark.skipif(not find_spec(dist), reason=f"{dist} not installed")


@pytest.fixture(
scope="session",
params=[
pytest.param("numpy.ndarray"),
pytest.param("scipy.sparse.csr_array", marks=_skip_if_no("scipy")),
pytest.param("scipy.sparse.csc_array", marks=_skip_if_no("scipy")),
pytest.param("scipy.sparse.csr_matrix", marks=_skip_if_no("scipy")),
pytest.param("scipy.sparse.csc_matrix", marks=_skip_if_no("scipy")),
pytest.param("dask.array.Array[numpy.ndarray]", marks=_skip_if_no("dask")),
pytest.param("dask.array.Array[scipy.sparse.csr_array]", marks=_skip_if_no("dask")),
pytest.param("dask.array.Array[scipy.sparse.csc_array]", marks=_skip_if_no("dask")),
pytest.param("dask.array.Array[scipy.sparse.csr_matrix]", marks=_skip_if_no("dask")),
pytest.param("dask.array.Array[scipy.sparse.csc_matrix]", marks=_skip_if_no("dask")),
pytest.param("h5py.Dataset", marks=_skip_if_no("h5py")),
pytest.param("zarr.Array", marks=_skip_if_no("zarr")),
pytest.param("cupy.ndarray", marks=_skip_if_no("cupy")),
pytest.param("cupyx.scipy.sparse.csr_matrix", marks=_skip_if_no("cupy")),
pytest.param("cupyx.scipy.sparse.csc_matrix", marks=_skip_if_no("cupy")),
],
)
def array_cls_name(request: pytest.FixtureRequest) -> str:
"""Fixture for a supported array class."""
return cast(str, request.param)


@pytest.fixture(scope="session")
def array_cls(array_cls_name: str) -> type[_Array[Any]]:
"""Fixture for a supported array class."""
return get_array_cls(array_cls_name)


def get_array_cls(qualname: str) -> type[_Array[Any]]: # noqa: PLR0911
"""Get a supported array class by qualname."""
m = re.fullmatch(
r"(?P<mod>(?:\w+\.)*\w+)\.(?P<name>[^\[]+)(?:\[(?P<inner>[\w.]+)\])?", qualname
)
assert m
match m["mod"], m["name"], m["inner"]:
case "numpy", "ndarray", None:
return np.ndarray
case "scipy.sparse", (
"csr_array" | "csc_array" | "csr_matrix" | "csc_matrix"
) as cls_name, None:
import scipy.sparse

return getattr(scipy.sparse, cls_name) # type: ignore[no-any-return]
case "cupy", "ndarray", None:
import cupy as cp

return cp.ndarray # type: ignore[no-any-return]
case "cupyx.scipy.sparse", ("csr_matrix" | "csc_matrix") as cls_name, None:
import cupyx.scipy.sparse as cu_sparse

return getattr(cu_sparse, cls_name) # type: ignore[no-any-return]
case "dask.array", cls_name, _:
if TYPE_CHECKING:
from dask.array.core import Array as DaskArray
else:
from dask.array import Array as DaskArray

return DaskArray
case "h5py", "Dataset", _:
import h5py

return h5py.Dataset # type: ignore[no-any-return]
case "zarr", "Array", _:
import zarr

return zarr.Array
case _:
pytest.fail(f"Unknown array class: {qualname}")


@pytest.fixture(scope="session")
def to_array(
request: pytest.FixtureRequest, array_cls: type[_Array[_SCT_co]], array_cls_name: str
) -> _ToArray[_SCT_co]:
"""Fixture for conversion into a supported array."""
return get_to_array(array_cls, array_cls_name, request)


def get_to_array(
array_cls: type[_Array[_SCT_co]],
array_cls_name: str | None = None,
request: pytest.FixtureRequest | None = None,
) -> _ToArray[_SCT_co]:
"""Create a function to convert to a supported array."""
if array_cls is np.ndarray:
return np.asarray # type: ignore[return-value]
if array_cls is types.DaskArray:
assert array_cls_name is not None
return to_dask_array(array_cls_name)
if array_cls is types.H5Dataset:
assert request is not None
return request.getfixturevalue("to_h5py_dataset") # type: ignore[no-any-return]
if array_cls is types.ZarrArray:
return to_zarr_array
if array_cls is types.CupyArray:
import cupy as cu

return cu.asarray # type: ignore[no-any-return]

return array_cls # type: ignore[return-value]


def _half_chunk_size(a: tuple[int, ...]) -> tuple[int, ...]:
def half_rounded_up(x: int) -> int:
div, mod = divmod(x, 2)
return div + (mod > 0)

return tuple(half_rounded_up(x) for x in a)


def to_dask_array(array_cls_name: str) -> _ToArray[Any]:
"""Convert to a dask array."""
if TYPE_CHECKING:
import dask.array.core as da
else:
import dask.array as da

inner_cls_name = array_cls_name.removeprefix("dask.array.Array[").removesuffix("]")
inner_cls = get_array_cls(inner_cls_name)
to_array_fn: _ToArray[Any] = get_to_array(array_cls=inner_cls)

def to_dask_array(x: ArrayLike, *, dtype: DTypeLike | None = None) -> types.DaskArray:
x = np.asarray(x, dtype=dtype)
return da.from_array(to_array_fn(x), _half_chunk_size(x.shape)) # type: ignore[no-untyped-call,no-any-return]

return to_dask_array


@pytest.fixture(scope="session")
# worker_id for xdist since we don't want to override open files
def to_h5py_dataset(
tmp_path_factory: pytest.TempPathFactory,
worker_id: str = "serial",
) -> Generator[_ToArray[Any], None, None]:
"""Convert to a h5py dataset."""
import h5py

tmp_path = tmp_path_factory.mktemp("backed_adata")
tmp_path = tmp_path / f"test_{worker_id}.h5ad"

with h5py.File(tmp_path, "x") as f:

def to_h5py_dataset(x: ArrayLike, *, dtype: DTypeLike | None = None) -> types.H5Dataset:
arr = np.asarray(x, dtype=dtype)
test_name = os.environ["PYTEST_CURRENT_TEST"].rsplit(":", 1)[-1].split(" ", 1)[0]
return f.create_dataset(test_name, arr.shape, arr.dtype, data=arr)

yield to_h5py_dataset


def to_zarr_array(x: ArrayLike, *, dtype: DTypeLike | None = None) -> types.ZarrArray:
"""Convert to a zarr array."""
import zarr

arr = np.asarray(x, dtype=dtype)
za = zarr.create_array({}, shape=arr.shape, dtype=arr.dtype)
za[...] = arr
return za
Loading
Loading