Skip to content

Commit f241148

Browse files
authored
Sparse in dask (#16)
1 parent 130038e commit f241148

File tree

10 files changed

+302
-107
lines changed

10 files changed

+302
-107
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,8 @@ repos:
3232
- numpy
3333
- scipy
3434
- types-scipy-sparse
35+
- dask
36+
- zarr
37+
- h5py
3538
ci:
3639
skip: [mypy] # too big

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ extras = [ "min", "full" ]
5959

6060
[tool.ruff]
6161
line-length = 100
62+
namespace-packages = [ "src/testing" ]
6263
lint.select = [ "ALL" ]
6364
lint.ignore = [
6465
"A005", # submodules never shadow builtins.
@@ -74,6 +75,7 @@ lint.ignore = [
7475
"TID252", # relative imports are fine
7576
]
7677
lint.per-file-ignores."docs/**/*.py" = [ "INP001" ] # No __init__.py in docs
78+
lint.per-file-ignores."src/**/stats/*.py" = [ "A001", "A004" ] # Shadows builtins like `sum`
7779
lint.per-file-ignores."stubs/**/*.pyi" = [ "F403", "F405", "N801" ] # Stubs don’t follow name conventions
7880
lint.per-file-ignores."tests/**/test_*.py" = [
7981
"D100", # tests need no module docstrings
@@ -95,6 +97,7 @@ addopts = [
9597
"--import-mode=importlib",
9698
"--strict-markers",
9799
"--pyargs",
100+
"-ptesting.fast_array_utils.pytest",
98101
]
99102
filterwarnings = [
100103
"error",

src/fast_array_utils/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
from __future__ import annotations
55

6-
from . import conv, types
6+
from . import _patches, conv, types
77

88

99
__all__ = ["conv", "types"]
10+
11+
_patches.patch_dask()

src/fast_array_utils/_patches.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# SPDX-License-Identifier: MPL-2.0
2+
from __future__ import annotations
3+
4+
import numpy as np
5+
6+
7+
# TODO(flying-sheep): upstream
8+
# https://github.com/dask/dask/issues/11749
9+
def patch_dask() -> None:
10+
"""Patch dask to support sparse arrays.
11+
12+
See <https://github.com/dask/dask/blob/4d71629d1f22ced0dd780919f22e70a642ec6753/dask/array/backends.py#L212-L232>
13+
"""
14+
try:
15+
# Other lookup candidates: tensordot_lookup and take_lookup
16+
from dask.array.dispatch import concatenate_lookup
17+
from scipy.sparse import sparray, spmatrix
18+
except ImportError:
19+
return # No need to patch if dask or scipy is not installed
20+
21+
# Avoid patch if already patched or upstream support has been added
22+
if concatenate_lookup.dispatch(sparray) is not np.concatenate: # type: ignore[no-untyped-call]
23+
return
24+
25+
concatenate = concatenate_lookup.dispatch(spmatrix) # type: ignore[no-untyped-call]
26+
concatenate_lookup.register(sparray, concatenate) # type: ignore[no-untyped-call]

src/fast_array_utils/conv/_asarray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def _(x: CSBase[DT_co]) -> NDArray[DT_co]:
4646

4747

4848
@asarray.register(DaskArray)
49-
def _(x: DaskArray[DT_co]) -> NDArray[DT_co]:
50-
return asarray(x.compute())
49+
def _(x: DaskArray) -> NDArray[DT_co]:
50+
return asarray(x.compute()) # type: ignore[no-untyped-call]
5151

5252

5353
@asarray.register(OutOfCoreDataset)

src/fast_array_utils/types.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,31 +53,33 @@
5353
CSBase = CSMatrix | CSArray
5454

5555

56-
if find_spec("cupy") or TYPE_CHECKING:
56+
if TYPE_CHECKING or find_spec("cupy"):
5757
from cupy import ndarray as CupyArray
5858
else:
5959
CupyArray = type("ndarray", (), {})
6060

6161

62-
if find_spec("cupyx") or TYPE_CHECKING:
62+
if TYPE_CHECKING or find_spec("cupyx"):
6363
from cupyx.scipy.sparse import spmatrix as CupySparseMatrix
6464
else:
6565
CupySparseMatrix = type("spmatrix", (), {})
6666

6767

68-
if find_spec("dask") or TYPE_CHECKING:
68+
if TYPE_CHECKING: # https://github.com/dask/dask/issues/8853
69+
from dask.array.core import Array as DaskArray
70+
elif find_spec("dask"):
6971
from dask.array import Array as DaskArray
7072
else:
7173
DaskArray = type("array", (), {})
7274

7375

74-
if find_spec("h5py") or TYPE_CHECKING:
76+
if TYPE_CHECKING or find_spec("h5py"):
7577
from h5py import Dataset as H5Dataset
7678
else:
7779
H5Dataset = type("Dataset", (), {})
7880

7981

80-
if find_spec("zarr") or TYPE_CHECKING:
82+
if TYPE_CHECKING or find_spec("zarr"):
8183
from zarr import Array as ZarrArray
8284
else:
8385
ZarrArray = type("Array", (), {})
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# SPDX-License-Identifier: MPL-2.0
2+
from typing import Generic, Protocol, TypeAlias, TypeVar
3+
4+
import numpy as np
5+
from numpy.typing import ArrayLike, NDArray
6+
7+
from fast_array_utils import types
8+
9+
_SCT_co = TypeVar("_SCT_co", covariant=True, bound=np.generic)
10+
_SCT_contra = TypeVar("_SCT_contra", contravariant=True, bound=np.generic)
11+
12+
_Array: TypeAlias = (
13+
NDArray[_SCT_co]
14+
| types.CSBase[_SCT_co]
15+
| types.CupyArray[_SCT_co]
16+
| types.DaskArray
17+
| types.H5Dataset
18+
| types.ZarrArray
19+
)
20+
21+
class _ToArray(Protocol, Generic[_SCT_contra]):
22+
def __call__(
23+
self, data: ArrayLike, /, *, dtype: _SCT_contra | None = None
24+
) -> _Array[_SCT_contra]: ...
25+
26+
__all__ = ["_Array", "_ToArray"]
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# SPDX-License-Identifier: MPL-2.0
2+
"""Testing utilities."""
3+
4+
from __future__ import annotations
5+
6+
import os
7+
import re
8+
from importlib.util import find_spec
9+
from typing import TYPE_CHECKING, cast
10+
11+
import numpy as np
12+
import pytest
13+
14+
from fast_array_utils import types
15+
16+
17+
if TYPE_CHECKING:
18+
from collections.abc import Generator
19+
from typing import Any, TypeVar
20+
21+
from numpy.typing import ArrayLike, DTypeLike
22+
23+
from testing.fast_array_utils import _ToArray
24+
25+
from . import _Array
26+
27+
_SCT_co = TypeVar("_SCT_co", covariant=True, bound=np.generic)
28+
29+
30+
def _skip_if_no(dist: str) -> pytest.MarkDecorator:
31+
return pytest.mark.skipif(not find_spec(dist), reason=f"{dist} not installed")
32+
33+
34+
@pytest.fixture(
35+
scope="session",
36+
params=[
37+
pytest.param("numpy.ndarray"),
38+
pytest.param("scipy.sparse.csr_array", marks=_skip_if_no("scipy")),
39+
pytest.param("scipy.sparse.csc_array", marks=_skip_if_no("scipy")),
40+
pytest.param("scipy.sparse.csr_matrix", marks=_skip_if_no("scipy")),
41+
pytest.param("scipy.sparse.csc_matrix", marks=_skip_if_no("scipy")),
42+
pytest.param("dask.array.Array[numpy.ndarray]", marks=_skip_if_no("dask")),
43+
pytest.param("dask.array.Array[scipy.sparse.csr_array]", marks=_skip_if_no("dask")),
44+
pytest.param("dask.array.Array[scipy.sparse.csc_array]", marks=_skip_if_no("dask")),
45+
pytest.param("dask.array.Array[scipy.sparse.csr_matrix]", marks=_skip_if_no("dask")),
46+
pytest.param("dask.array.Array[scipy.sparse.csc_matrix]", marks=_skip_if_no("dask")),
47+
pytest.param("h5py.Dataset", marks=_skip_if_no("h5py")),
48+
pytest.param("zarr.Array", marks=_skip_if_no("zarr")),
49+
pytest.param("cupy.ndarray", marks=_skip_if_no("cupy")),
50+
pytest.param("cupyx.scipy.sparse.csr_matrix", marks=_skip_if_no("cupy")),
51+
pytest.param("cupyx.scipy.sparse.csc_matrix", marks=_skip_if_no("cupy")),
52+
],
53+
)
54+
def array_cls_name(request: pytest.FixtureRequest) -> str:
55+
"""Fixture for a supported array class."""
56+
return cast(str, request.param)
57+
58+
59+
@pytest.fixture(scope="session")
60+
def array_cls(array_cls_name: str) -> type[_Array[Any]]:
61+
"""Fixture for a supported array class."""
62+
return get_array_cls(array_cls_name)
63+
64+
65+
def get_array_cls(qualname: str) -> type[_Array[Any]]: # noqa: PLR0911
66+
"""Get a supported array class by qualname."""
67+
m = re.fullmatch(
68+
r"(?P<mod>(?:\w+\.)*\w+)\.(?P<name>[^\[]+)(?:\[(?P<inner>[\w.]+)\])?", qualname
69+
)
70+
assert m
71+
match m["mod"], m["name"], m["inner"]:
72+
case "numpy", "ndarray", None:
73+
return np.ndarray
74+
case "scipy.sparse", (
75+
"csr_array" | "csc_array" | "csr_matrix" | "csc_matrix"
76+
) as cls_name, None:
77+
import scipy.sparse
78+
79+
return getattr(scipy.sparse, cls_name) # type: ignore[no-any-return]
80+
case "cupy", "ndarray", None:
81+
import cupy as cp
82+
83+
return cp.ndarray # type: ignore[no-any-return]
84+
case "cupyx.scipy.sparse", ("csr_matrix" | "csc_matrix") as cls_name, None:
85+
import cupyx.scipy.sparse as cu_sparse
86+
87+
return getattr(cu_sparse, cls_name) # type: ignore[no-any-return]
88+
case "dask.array", cls_name, _:
89+
if TYPE_CHECKING:
90+
from dask.array.core import Array as DaskArray
91+
else:
92+
from dask.array import Array as DaskArray
93+
94+
return DaskArray
95+
case "h5py", "Dataset", _:
96+
import h5py
97+
98+
return h5py.Dataset # type: ignore[no-any-return]
99+
case "zarr", "Array", _:
100+
import zarr
101+
102+
return zarr.Array
103+
case _:
104+
pytest.fail(f"Unknown array class: {qualname}")
105+
106+
107+
@pytest.fixture(scope="session")
108+
def to_array(
109+
request: pytest.FixtureRequest, array_cls: type[_Array[_SCT_co]], array_cls_name: str
110+
) -> _ToArray[_SCT_co]:
111+
"""Fixture for conversion into a supported array."""
112+
return get_to_array(array_cls, array_cls_name, request)
113+
114+
115+
def get_to_array(
116+
array_cls: type[_Array[_SCT_co]],
117+
array_cls_name: str | None = None,
118+
request: pytest.FixtureRequest | None = None,
119+
) -> _ToArray[_SCT_co]:
120+
"""Create a function to convert to a supported array."""
121+
if array_cls is np.ndarray:
122+
return np.asarray # type: ignore[return-value]
123+
if array_cls is types.DaskArray:
124+
assert array_cls_name is not None
125+
return to_dask_array(array_cls_name)
126+
if array_cls is types.H5Dataset:
127+
assert request is not None
128+
return request.getfixturevalue("to_h5py_dataset") # type: ignore[no-any-return]
129+
if array_cls is types.ZarrArray:
130+
return to_zarr_array
131+
if array_cls is types.CupyArray:
132+
import cupy as cu
133+
134+
return cu.asarray # type: ignore[no-any-return]
135+
136+
return array_cls # type: ignore[return-value]
137+
138+
139+
def _half_chunk_size(a: tuple[int, ...]) -> tuple[int, ...]:
140+
def half_rounded_up(x: int) -> int:
141+
div, mod = divmod(x, 2)
142+
return div + (mod > 0)
143+
144+
return tuple(half_rounded_up(x) for x in a)
145+
146+
147+
def to_dask_array(array_cls_name: str) -> _ToArray[Any]:
148+
"""Convert to a dask array."""
149+
if TYPE_CHECKING:
150+
import dask.array.core as da
151+
else:
152+
import dask.array as da
153+
154+
inner_cls_name = array_cls_name.removeprefix("dask.array.Array[").removesuffix("]")
155+
inner_cls = get_array_cls(inner_cls_name)
156+
to_array_fn: _ToArray[Any] = get_to_array(array_cls=inner_cls)
157+
158+
def to_dask_array(x: ArrayLike, *, dtype: DTypeLike | None = None) -> types.DaskArray:
159+
x = np.asarray(x, dtype=dtype)
160+
return da.from_array(to_array_fn(x), _half_chunk_size(x.shape)) # type: ignore[no-untyped-call,no-any-return]
161+
162+
return to_dask_array
163+
164+
165+
@pytest.fixture(scope="session")
166+
# worker_id for xdist since we don't want to override open files
167+
def to_h5py_dataset(
168+
tmp_path_factory: pytest.TempPathFactory,
169+
worker_id: str = "serial",
170+
) -> Generator[_ToArray[Any], None, None]:
171+
"""Convert to a h5py dataset."""
172+
import h5py
173+
174+
tmp_path = tmp_path_factory.mktemp("backed_adata")
175+
tmp_path = tmp_path / f"test_{worker_id}.h5ad"
176+
177+
with h5py.File(tmp_path, "x") as f:
178+
179+
def to_h5py_dataset(x: ArrayLike, *, dtype: DTypeLike | None = None) -> types.H5Dataset:
180+
arr = np.asarray(x, dtype=dtype)
181+
test_name = os.environ["PYTEST_CURRENT_TEST"].rsplit(":", 1)[-1].split(" ", 1)[0]
182+
return f.create_dataset(test_name, arr.shape, arr.dtype, data=arr)
183+
184+
yield to_h5py_dataset
185+
186+
187+
def to_zarr_array(x: ArrayLike, *, dtype: DTypeLike | None = None) -> types.ZarrArray:
188+
"""Convert to a zarr array."""
189+
import zarr
190+
191+
arr = np.asarray(x, dtype=dtype)
192+
za = zarr.create_array({}, shape=arr.shape, dtype=arr.dtype)
193+
za[...] = arr
194+
return za

0 commit comments

Comments
 (0)