diff --git a/pyproject.toml b/pyproject.toml index 3fde3ce..5515447 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ dynamic = [ "description", "version" ] dependencies = [ "numba", "numpy" ] optional-dependencies.doc = [ "furo", "scanpydoc>=0.15.2", "sphinx>=8", "sphinx-autodoc-typehints" ] -optional-dependencies.full = [ "dask", "h5py", "hatch-docstring-description[sparse]" ] +optional-dependencies.full = [ "dask", "h5py", "hatch-docstring-description[sparse]", "zarr" ] optional-dependencies.sparse = [ "scipy>=1.8", "types-scipy-sparse" ] optional-dependencies.test = [ "coverage[toml]", "pytest" ] urls.'Documentation' = "https://icb-fast-array-utils.readthedocs-hosted.com/" diff --git a/src/fast_array_utils/types.py b/src/fast_array_utils/types.py index 9c49bc6..c1fe15a 100644 --- a/src/fast_array_utils/types.py +++ b/src/fast_array_utils/types.py @@ -13,7 +13,15 @@ import numpy as np -__all__ = ["CSBase", "CupyArray", "CupySparseMatrix", "DaskArray", "H5Dataset", "OutOfCoreDataset"] +__all__ = [ + "CSBase", + "CupyArray", + "CupySparseMatrix", + "DaskArray", + "H5Dataset", + "OutOfCoreDataset", + "ZarrArray", +] T_co = TypeVar("T_co", covariant=True) @@ -69,6 +77,12 @@ H5Dataset = type("Dataset", (), {}) +if find_spec("zarr") or TYPE_CHECKING: + from zarr import Array as ZarrArray +else: + ZarrArray = type("Array", (), {}) + + @runtime_checkable class OutOfCoreDataset(Protocol, Generic[T_co]): """An out-of-core dataset.""" diff --git a/tests/test_asarray.py b/tests/test_asarray.py index 40b9176..8939644 100644 --- a/tests/test_asarray.py +++ b/tests/test_asarray.py @@ -12,12 +12,21 @@ if TYPE_CHECKING: from collections.abc import Callable, Generator - from typing import Any + from typing import Any, TypeAlias from numpy.typing import ArrayLike, NDArray from fast_array_utils import types + SupportedArray: TypeAlias = ( + NDArray[Any] + | types.DaskArray + | types.H5Dataset + | types.ZarrArray + | types.CupyArray + | types.CupySparseMatrix + ) + def skip_if_no(dist: str) -> pytest.MarkDecorator: return pytest.mark.skipif(not find_spec(dist), reason=f"{dist} not installed") @@ -42,6 +51,15 @@ def to_h5py_dataset(x: ArrayLike) -> types.H5Dataset: yield to_h5py_dataset +def to_zarr_array(x: ArrayLike) -> types.ZarrArray: + import zarr + + arr = np.asarray(x) + za = zarr.create_array({}, shape=arr.shape, dtype=arr.dtype) + za[...] = arr + return za + + @pytest.fixture( scope="session", params=[ @@ -52,17 +70,15 @@ def to_h5py_dataset(x: ArrayLike) -> types.H5Dataset: pytest.param("scipy.sparse.csc_matrix", marks=skip_if_no("scipy")), pytest.param("dask.array.Array", marks=skip_if_no("dask")), pytest.param("h5py.Dataset", marks=skip_if_no("h5py")), + pytest.param("zarr.Array", marks=skip_if_no("zarr")), pytest.param("cupy.ndarray", marks=skip_if_no("cupy")), pytest.param("cupyx.scipy.sparse.csr_matrix", marks=skip_if_no("cupy")), pytest.param("cupyx.scipy.sparse.csc_matrix", marks=skip_if_no("cupy")), ], ) -def array_cls( +def array_cls( # noqa: PLR0911 request: pytest.FixtureRequest, -) -> Callable[ - [ArrayLike], - NDArray[Any] | types.DaskArray | types.H5Dataset | types.CupyArray | types.CupySparseMatrix, -]: +) -> Callable[[ArrayLike], SupportedArray]: qualname = cast(str, request.param) match qualname.split("."): case "numpy", "ndarray": @@ -77,6 +93,8 @@ def array_cls( return da.asarray # type: ignore[no-any-return] case "h5py", "Dataset": return request.getfixturevalue("to_h5py_dataset") # type: ignore[no-any-return] + case "zarr", "Array": + return to_zarr_array case "cupy", "ndarray": import cupy @@ -90,7 +108,7 @@ def array_cls( raise AssertionError(msg) -def test_asarray(array_cls: Callable[[ArrayLike], Any]) -> None: +def test_asarray(array_cls: Callable[[ArrayLike], SupportedArray]) -> None: x = array_cls([[1, 2, 3], [4, 5, 6]]) arr: NDArray[Any] = asarray(x) assert isinstance(arr, np.ndarray)