Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions scanpy/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ class DaskArray:
pass


try:
from zappy.base import ZappyArray
except ImportError:

class ZappyArray:
pass


__all__ = ["cache", "DaskArray", "fullname", "pkg_metadata", "pkg_version"]


Expand Down
51 changes: 41 additions & 10 deletions scanpy/preprocessing/_distributed.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,49 @@
from __future__ import annotations

from typing import TYPE_CHECKING, overload

import numpy as np

# install dask if available
try:
import dask.array as da
except ImportError:
da = None
from scanpy._compat import DaskArray, ZappyArray

if TYPE_CHECKING:
from numpy.typing import ArrayLike


@overload
def materialize_as_ndarray(a: ArrayLike) -> np.ndarray:
...


@overload
def materialize_as_ndarray(a: tuple[ArrayLike]) -> tuple[np.ndarray]:
...


@overload
def materialize_as_ndarray(
a: tuple[ArrayLike, ArrayLike],
) -> tuple[np.ndarray, np.ndarray]:
...

def materialize_as_ndarray(a):

@overload
def materialize_as_ndarray(
a: tuple[ArrayLike, ArrayLike, ArrayLike],
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
...


def materialize_as_ndarray(
a: ArrayLike | tuple[ArrayLike | ZappyArray | DaskArray, ...],
) -> tuple[np.ndarray] | np.ndarray:
"""Convert distributed arrays to ndarrays."""
if type(a) in (list, tuple):
if da is not None and any(isinstance(arr, da.Array) for arr in a):
return da.compute(*a, sync=True)
if not isinstance(a, tuple):
return np.asarray(a)

if not any(isinstance(arr, DaskArray) for arr in a):
return tuple(np.asarray(arr) for arr in a)
return np.asarray(a)

import dask.array as da

return da.compute(*a, sync=True)
12 changes: 6 additions & 6 deletions scanpy/preprocessing/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def filter_cells(
if max_number is not None:
cell_subset = number_per_cell <= max_number

s = np.sum(~cell_subset)
s = materialize_as_ndarray(np.sum(~cell_subset))
if s > 0:
msg = f"filtered out {s} cells that have "
if min_genes is not None or min_counts is not None:
Expand Down Expand Up @@ -354,7 +354,7 @@ def log1p(


@log1p.register(spmatrix)
def log1p_sparse(X, *, base: Number | None = None, copy: bool = False):
def log1p_sparse(X: spmatrix, *, base: Number | None = None, copy: bool = False):
X = check_array(
X, accept_sparse=("csr", "csc"), dtype=(np.float64, np.float32), copy=copy
)
Expand All @@ -363,7 +363,7 @@ def log1p_sparse(X, *, base: Number | None = None, copy: bool = False):


@log1p.register(np.ndarray)
def log1p_array(X, *, base: Number | None = None, copy: bool = False):
def log1p_array(X: np.ndarray, *, base: Number | None = None, copy: bool = False):
# Can force arrays to be np.ndarrays, but would be useful to not
# X = check_array(X, dtype=(np.float64, np.float32), ensure_2d=False, copy=copy)
if copy:
Expand All @@ -381,7 +381,7 @@ def log1p_array(X, *, base: Number | None = None, copy: bool = False):

@log1p.register(AnnData)
def log1p_anndata(
adata,
adata: AnnData,
*,
base: Number | None = None,
copy: bool = False,
Expand Down Expand Up @@ -564,7 +564,7 @@ def normalize_per_cell( # noqa: PLR0917
else:
raise ValueError('use_rep should be "after", "X" or None')
for layer in layers:
subset, counts = filter_cells(adata.layers[layer], min_counts=min_counts)
_subset, counts = filter_cells(adata.layers[layer], min_counts=min_counts)
temp = normalize_per_cell(adata.layers[layer], after, counts, copy=True)
adata.layers[layer] = temp

Expand All @@ -589,7 +589,7 @@ def normalize_per_cell( # noqa: PLR0917
counts_per_cell += counts_per_cell == 0
counts_per_cell /= counts_per_cell_after
if not issparse(X):
X /= materialize_as_ndarray(counts_per_cell[:, np.newaxis])
X /= counts_per_cell[:, np.newaxis]
else:
sparsefuncs.inplace_row_scale(X, 1 / counts_per_cell)
return X if copy else None
Expand Down
90 changes: 65 additions & 25 deletions scanpy/tests/test_preprocessing_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

from pathlib import Path

import anndata as ad
import numpy.testing as npt
import pytest
from anndata import AnnData, OldFormatWarning, read_zarr

from scanpy._compat import DaskArray, ZappyArray
from scanpy.preprocessing import (
filter_cells,
filter_genes,
Expand All @@ -17,15 +18,19 @@
from scanpy.testing._pytest.marks import needs

HERE = Path(__file__).parent / Path("_data/")
input_file = str(Path(HERE, "10x-10k-subset.zarr"))
input_file = Path(HERE, "10x-10k-subset.zarr")

DIST_TYPES = (DaskArray, ZappyArray)


pytestmark = [needs.zarr]


@pytest.fixture()
def adata():
a = ad.read_zarr(input_file) # regular anndata
def adata() -> AnnData:
with pytest.warns(OldFormatWarning):
a = read_zarr(input_file) # regular anndata
a.var_names_make_unique()
a.X = a.X[:] # convert to numpy array
return a

Expand All @@ -36,78 +41,111 @@ def adata():
pytest.param("dask", marks=[needs.dask]),
]
)
def adata_dist(request):
def adata_dist(request: pytest.FixtureRequest) -> AnnData:
# regular anndata except for X, which we replace on the next line
a = ad.read_zarr(input_file)
with pytest.warns(OldFormatWarning):
a = read_zarr(input_file)
a.var_names_make_unique()
a.uns["dist-mode"] = request.param
input_file_X = f"{input_file}/X"
if request.param == "direct":
import zappy.direct

a.X = zappy.direct.from_zarr(input_file_X)
yield a
elif request.param == "dask":
import dask.array as da
return a

assert request.param == "dask"
import dask.array as da

a.X = da.from_zarr(input_file_X)
yield a
a.X = da.from_zarr(input_file_X)
return a


def test_log1p(adata, adata_dist):
def test_log1p(adata: AnnData, adata_dist: AnnData):
log1p(adata_dist)
assert isinstance(adata_dist.X, DIST_TYPES)
result = materialize_as_ndarray(adata_dist.X)
log1p(adata)
assert result.shape == adata.shape
assert result.shape == (adata.n_obs, adata.n_vars)
npt.assert_allclose(result, adata.X)


def test_normalize_per_cell(adata, adata_dist):
if adata_dist.uns["dist-mode"] == "dask":
pytest.xfail("TODO: Test broken for dask")
def test_normalize_per_cell(
request: pytest.FixtureRequest, adata: AnnData, adata_dist: AnnData
):
if isinstance(adata_dist.X, DaskArray):
request.node.add_marker(
pytest.mark.xfail(
reason="normalize_per_cell deprecated and broken for Dask"
)
)
normalize_per_cell(adata_dist)
assert isinstance(adata_dist.X, DIST_TYPES)
result = materialize_as_ndarray(adata_dist.X)
normalize_per_cell(adata)
assert result.shape == adata.shape
assert result.shape == (adata.n_obs, adata.n_vars)
npt.assert_allclose(result, adata.X)


def test_normalize_total(adata, adata_dist):
def test_normalize_total(adata: AnnData, adata_dist: AnnData):
normalize_total(adata_dist)
assert isinstance(adata_dist.X, DIST_TYPES)
result = materialize_as_ndarray(adata_dist.X)
normalize_total(adata)
assert result.shape == adata.shape
assert result.shape == (adata.n_obs, adata.n_vars)
npt.assert_allclose(result, adata.X)


def test_filter_cells(adata, adata_dist):
def test_filter_cells_array(adata: AnnData, adata_dist: AnnData):
cell_subset_dist, number_per_cell_dist = filter_cells(adata_dist.X, min_genes=3)
assert isinstance(cell_subset_dist, DIST_TYPES)
assert isinstance(number_per_cell_dist, DIST_TYPES)

cell_subset, number_per_cell = filter_cells(adata.X, min_genes=3)
npt.assert_allclose(materialize_as_ndarray(cell_subset_dist), cell_subset)
npt.assert_allclose(materialize_as_ndarray(number_per_cell_dist), number_per_cell)


def test_filter_cells(adata: AnnData, adata_dist: AnnData):
filter_cells(adata_dist, min_genes=3)
assert isinstance(adata_dist.X, DIST_TYPES)
result = materialize_as_ndarray(adata_dist.X)
filter_cells(adata, min_genes=3)

assert result.shape == adata.shape
assert result.shape == (adata.n_obs, adata.n_vars)
npt.assert_array_equal(adata_dist.obs["n_genes"], adata.obs["n_genes"])
npt.assert_allclose(result, adata.X)


def test_filter_genes(adata, adata_dist):
def test_filter_genes_array(adata: AnnData, adata_dist: AnnData):
gene_subset_dist, number_per_gene_dist = filter_genes(adata_dist.X, min_cells=2)
assert isinstance(gene_subset_dist, DIST_TYPES)
assert isinstance(number_per_gene_dist, DIST_TYPES)

gene_subset, number_per_gene = filter_genes(adata.X, min_cells=2)
npt.assert_allclose(materialize_as_ndarray(gene_subset_dist), gene_subset)
npt.assert_allclose(materialize_as_ndarray(number_per_gene_dist), number_per_gene)


def test_filter_genes(adata: AnnData, adata_dist: AnnData):
filter_genes(adata_dist, min_cells=2)
assert isinstance(adata_dist.X, DIST_TYPES)
result = materialize_as_ndarray(adata_dist.X)
filter_genes(adata, min_cells=2)
assert result.shape == adata.shape
assert result.shape == (adata.n_obs, adata.n_vars)
npt.assert_allclose(result, adata.X)


def test_write_zarr(adata, adata_dist):
def test_write_zarr(adata: AnnData, adata_dist: AnnData):
import zarr

log1p(adata_dist)
assert isinstance(adata_dist.X, DIST_TYPES)
temp_store = zarr.TempStore()
chunks = adata_dist.X.chunks
if isinstance(chunks[0], tuple):
chunks = (chunks[0][0],) + chunks[1]

# write metadata using regular anndata
adata.write_zarr(temp_store, chunks)
if adata_dist.uns["dist-mode"] == "dask":
Expand All @@ -116,7 +154,9 @@ def test_write_zarr(adata, adata_dist):
adata_dist.X.to_zarr(temp_store.dir_path("X"), chunks)
else:
assert False, "add branch for new dist-mode"

# read back as zarr directly and check it is the same as adata.X
adata_log1p = ad.read_zarr(temp_store)
with pytest.warns(OldFormatWarning, match="without encoding metadata"):
adata_log1p = read_zarr(temp_store)
log1p(adata)
npt.assert_allclose(adata_log1p.X, adata.X)