From 9843e371651154520639b5067017772a0b0bf639 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 16 Jan 2024 12:43:03 +0100 Subject: [PATCH 1/3] From fc61a8f00fbe6cece2d9eaddcd1cd2997cfb014e Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 16 Jan 2024 14:32:01 +0100 Subject: [PATCH 2/3] Fix warnings --- scanpy/_compat.py | 8 +++ scanpy/preprocessing/_distributed.py | 51 +++++++++++++---- scanpy/preprocessing/_simple.py | 8 +-- .../tests/test_preprocessing_distributed.py | 55 ++++++++++++------- 4 files changed, 88 insertions(+), 34 deletions(-) diff --git a/scanpy/_compat.py b/scanpy/_compat.py index 244f8588fa..08fd53bd03 100644 --- a/scanpy/_compat.py +++ b/scanpy/_compat.py @@ -22,6 +22,14 @@ class DaskArray: pass +try: + from zappy.base import ZappyArray +except ImportError: + + class ZappyArray: + pass + + __all__ = ["cache", "DaskArray", "fullname", "pkg_metadata", "pkg_version"] diff --git a/scanpy/preprocessing/_distributed.py b/scanpy/preprocessing/_distributed.py index a134efe758..748ec3d671 100644 --- a/scanpy/preprocessing/_distributed.py +++ b/scanpy/preprocessing/_distributed.py @@ -1,18 +1,49 @@ from __future__ import annotations +from typing import TYPE_CHECKING, overload + import numpy as np -# install dask if available -try: - import dask.array as da -except ImportError: - da = None +from scanpy._compat import DaskArray, ZappyArray + +if TYPE_CHECKING: + from numpy.typing import ArrayLike + + +@overload +def materialize_as_ndarray(a: ArrayLike) -> np.ndarray: + ... + + +@overload +def materialize_as_ndarray(a: tuple[ArrayLike]) -> tuple[np.ndarray]: + ... + +@overload +def materialize_as_ndarray( + a: tuple[ArrayLike, ArrayLike], +) -> tuple[np.ndarray, np.ndarray]: + ... -def materialize_as_ndarray(a): + +@overload +def materialize_as_ndarray( + a: tuple[ArrayLike, ArrayLike, ArrayLike], +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + ... + + +def materialize_as_ndarray( + a: ArrayLike | tuple[ArrayLike | ZappyArray | DaskArray, ...], +) -> tuple[np.ndarray] | np.ndarray: """Convert distributed arrays to ndarrays.""" - if type(a) in (list, tuple): - if da is not None and any(isinstance(arr, da.Array) for arr in a): - return da.compute(*a, sync=True) + if not isinstance(a, tuple): + return np.asarray(a) + + if not any(isinstance(arr, DaskArray) for arr in a): return tuple(np.asarray(arr) for arr in a) - return np.asarray(a) + + import dask.array as da + + return da.compute(*a, sync=True) diff --git a/scanpy/preprocessing/_simple.py b/scanpy/preprocessing/_simple.py index 0da7f9e961..a4a1efa72a 100644 --- a/scanpy/preprocessing/_simple.py +++ b/scanpy/preprocessing/_simple.py @@ -354,7 +354,7 @@ def log1p( @log1p.register(spmatrix) -def log1p_sparse(X, *, base: Number | None = None, copy: bool = False): +def log1p_sparse(X: spmatrix, *, base: Number | None = None, copy: bool = False): X = check_array( X, accept_sparse=("csr", "csc"), dtype=(np.float64, np.float32), copy=copy ) @@ -363,7 +363,7 @@ def log1p_sparse(X, *, base: Number | None = None, copy: bool = False): @log1p.register(np.ndarray) -def log1p_array(X, *, base: Number | None = None, copy: bool = False): +def log1p_array(X: np.ndarray, *, base: Number | None = None, copy: bool = False): # Can force arrays to be np.ndarrays, but would be useful to not # X = check_array(X, dtype=(np.float64, np.float32), ensure_2d=False, copy=copy) if copy: @@ -381,7 +381,7 @@ def log1p_array(X, *, base: Number | None = None, copy: bool = False): @log1p.register(AnnData) def log1p_anndata( - adata, + adata: AnnData, *, base: Number | None = None, copy: bool = False, @@ -564,7 +564,7 @@ def normalize_per_cell( # noqa: PLR0917 else: raise ValueError('use_rep should be "after", "X" or None') for layer in layers: - subset, counts = filter_cells(adata.layers[layer], min_counts=min_counts) + _subset, counts = filter_cells(adata.layers[layer], min_counts=min_counts) temp = normalize_per_cell(adata.layers[layer], after, counts, copy=True) adata.layers[layer] = temp diff --git a/scanpy/tests/test_preprocessing_distributed.py b/scanpy/tests/test_preprocessing_distributed.py index 7f5dce7db0..77d3297b37 100644 --- a/scanpy/tests/test_preprocessing_distributed.py +++ b/scanpy/tests/test_preprocessing_distributed.py @@ -2,10 +2,11 @@ from pathlib import Path -import anndata as ad import numpy.testing as npt import pytest +from anndata import AnnData, OldFormatWarning, read_zarr +from scanpy._compat import DaskArray, ZappyArray from scanpy.preprocessing import ( filter_cells, filter_genes, @@ -17,15 +18,19 @@ from scanpy.testing._pytest.marks import needs HERE = Path(__file__).parent / Path("_data/") -input_file = str(Path(HERE, "10x-10k-subset.zarr")) +input_file = Path(HERE, "10x-10k-subset.zarr") + +DIST_TYPES = (DaskArray, ZappyArray) pytestmark = [needs.zarr] @pytest.fixture() -def adata(): - a = ad.read_zarr(input_file) # regular anndata +def adata() -> AnnData: + with pytest.warns(OldFormatWarning): + a = read_zarr(input_file) # regular anndata + a.var_names_make_unique() a.X = a.X[:] # convert to numpy array return a @@ -36,25 +41,29 @@ def adata(): pytest.param("dask", marks=[needs.dask]), ] ) -def adata_dist(request): +def adata_dist(request: pytest.FixtureRequest) -> AnnData: # regular anndata except for X, which we replace on the next line - a = ad.read_zarr(input_file) + with pytest.warns(OldFormatWarning): + a = read_zarr(input_file) + a.var_names_make_unique() a.uns["dist-mode"] = request.param input_file_X = f"{input_file}/X" if request.param == "direct": import zappy.direct a.X = zappy.direct.from_zarr(input_file_X) - yield a - elif request.param == "dask": - import dask.array as da + return a + + assert request.param == "dask" + import dask.array as da - a.X = da.from_zarr(input_file_X) - yield a + a.X = da.from_zarr(input_file_X) + return a -def test_log1p(adata, adata_dist): +def test_log1p(adata: AnnData, adata_dist: AnnData): log1p(adata_dist) + assert isinstance(adata_dist.X, DIST_TYPES) result = materialize_as_ndarray(adata_dist.X) log1p(adata) assert result.shape == adata.shape @@ -62,10 +71,9 @@ def test_log1p(adata, adata_dist): npt.assert_allclose(result, adata.X) -def test_normalize_per_cell(adata, adata_dist): - if adata_dist.uns["dist-mode"] == "dask": - pytest.xfail("TODO: Test broken for dask") +def test_normalize_per_cell(adata: AnnData, adata_dist: AnnData): normalize_per_cell(adata_dist) + assert isinstance(adata_dist.X, DIST_TYPES) result = materialize_as_ndarray(adata_dist.X) normalize_per_cell(adata) assert result.shape == adata.shape @@ -73,8 +81,9 @@ def test_normalize_per_cell(adata, adata_dist): npt.assert_allclose(result, adata.X) -def test_normalize_total(adata, adata_dist): +def test_normalize_total(adata: AnnData, adata_dist: AnnData): normalize_total(adata_dist) + assert isinstance(adata_dist.X, DIST_TYPES) result = materialize_as_ndarray(adata_dist.X) normalize_total(adata) assert result.shape == adata.shape @@ -82,8 +91,9 @@ def test_normalize_total(adata, adata_dist): npt.assert_allclose(result, adata.X) -def test_filter_cells(adata, adata_dist): +def test_filter_cells(adata: AnnData, adata_dist: AnnData): filter_cells(adata_dist, min_genes=3) + assert isinstance(adata_dist.X, DIST_TYPES) result = materialize_as_ndarray(adata_dist.X) filter_cells(adata, min_genes=3) assert result.shape == adata.shape @@ -91,8 +101,9 @@ def test_filter_cells(adata, adata_dist): npt.assert_allclose(result, adata.X) -def test_filter_genes(adata, adata_dist): +def test_filter_genes(adata: AnnData, adata_dist: AnnData): filter_genes(adata_dist, min_cells=2) + assert isinstance(adata_dist.X, DIST_TYPES) result = materialize_as_ndarray(adata_dist.X) filter_genes(adata, min_cells=2) assert result.shape == adata.shape @@ -100,14 +111,16 @@ def test_filter_genes(adata, adata_dist): npt.assert_allclose(result, adata.X) -def test_write_zarr(adata, adata_dist): +def test_write_zarr(adata: AnnData, adata_dist: AnnData): import zarr log1p(adata_dist) + assert isinstance(adata_dist.X, DIST_TYPES) temp_store = zarr.TempStore() chunks = adata_dist.X.chunks if isinstance(chunks[0], tuple): chunks = (chunks[0][0],) + chunks[1] + # write metadata using regular anndata adata.write_zarr(temp_store, chunks) if adata_dist.uns["dist-mode"] == "dask": @@ -116,7 +129,9 @@ def test_write_zarr(adata, adata_dist): adata_dist.X.to_zarr(temp_store.dir_path("X"), chunks) else: assert False, "add branch for new dist-mode" + # read back as zarr directly and check it is the same as adata.X - adata_log1p = ad.read_zarr(temp_store) + with pytest.warns(OldFormatWarning, match="without encoding metadata"): + adata_log1p = read_zarr(temp_store) log1p(adata) npt.assert_allclose(adata_log1p.X, adata.X) From b4a913df278057e679547324db9aef008dc98d53 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 16 Jan 2024 15:21:19 +0100 Subject: [PATCH 3/3] fix log --- scanpy/preprocessing/_simple.py | 4 +- .../tests/test_preprocessing_distributed.py | 37 ++++++++++++++++--- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/scanpy/preprocessing/_simple.py b/scanpy/preprocessing/_simple.py index a4a1efa72a..ad42105bc7 100644 --- a/scanpy/preprocessing/_simple.py +++ b/scanpy/preprocessing/_simple.py @@ -172,7 +172,7 @@ def filter_cells( if max_number is not None: cell_subset = number_per_cell <= max_number - s = np.sum(~cell_subset) + s = materialize_as_ndarray(np.sum(~cell_subset)) if s > 0: msg = f"filtered out {s} cells that have " if min_genes is not None or min_counts is not None: @@ -589,7 +589,7 @@ def normalize_per_cell( # noqa: PLR0917 counts_per_cell += counts_per_cell == 0 counts_per_cell /= counts_per_cell_after if not issparse(X): - X /= materialize_as_ndarray(counts_per_cell[:, np.newaxis]) + X /= counts_per_cell[:, np.newaxis] else: sparsefuncs.inplace_row_scale(X, 1 / counts_per_cell) return X if copy else None diff --git a/scanpy/tests/test_preprocessing_distributed.py b/scanpy/tests/test_preprocessing_distributed.py index 77d3297b37..aec8044d4c 100644 --- a/scanpy/tests/test_preprocessing_distributed.py +++ b/scanpy/tests/test_preprocessing_distributed.py @@ -67,17 +67,23 @@ def test_log1p(adata: AnnData, adata_dist: AnnData): result = materialize_as_ndarray(adata_dist.X) log1p(adata) assert result.shape == adata.shape - assert result.shape == (adata.n_obs, adata.n_vars) npt.assert_allclose(result, adata.X) -def test_normalize_per_cell(adata: AnnData, adata_dist: AnnData): +def test_normalize_per_cell( + request: pytest.FixtureRequest, adata: AnnData, adata_dist: AnnData +): + if isinstance(adata_dist.X, DaskArray): + request.node.add_marker( + pytest.mark.xfail( + reason="normalize_per_cell deprecated and broken for Dask" + ) + ) normalize_per_cell(adata_dist) assert isinstance(adata_dist.X, DIST_TYPES) result = materialize_as_ndarray(adata_dist.X) normalize_per_cell(adata) assert result.shape == adata.shape - assert result.shape == (adata.n_obs, adata.n_vars) npt.assert_allclose(result, adata.X) @@ -87,27 +93,46 @@ def test_normalize_total(adata: AnnData, adata_dist: AnnData): result = materialize_as_ndarray(adata_dist.X) normalize_total(adata) assert result.shape == adata.shape - assert result.shape == (adata.n_obs, adata.n_vars) npt.assert_allclose(result, adata.X) +def test_filter_cells_array(adata: AnnData, adata_dist: AnnData): + cell_subset_dist, number_per_cell_dist = filter_cells(adata_dist.X, min_genes=3) + assert isinstance(cell_subset_dist, DIST_TYPES) + assert isinstance(number_per_cell_dist, DIST_TYPES) + + cell_subset, number_per_cell = filter_cells(adata.X, min_genes=3) + npt.assert_allclose(materialize_as_ndarray(cell_subset_dist), cell_subset) + npt.assert_allclose(materialize_as_ndarray(number_per_cell_dist), number_per_cell) + + def test_filter_cells(adata: AnnData, adata_dist: AnnData): filter_cells(adata_dist, min_genes=3) assert isinstance(adata_dist.X, DIST_TYPES) result = materialize_as_ndarray(adata_dist.X) filter_cells(adata, min_genes=3) + assert result.shape == adata.shape - assert result.shape == (adata.n_obs, adata.n_vars) + npt.assert_array_equal(adata_dist.obs["n_genes"], adata.obs["n_genes"]) npt.assert_allclose(result, adata.X) +def test_filter_genes_array(adata: AnnData, adata_dist: AnnData): + gene_subset_dist, number_per_gene_dist = filter_genes(adata_dist.X, min_cells=2) + assert isinstance(gene_subset_dist, DIST_TYPES) + assert isinstance(number_per_gene_dist, DIST_TYPES) + + gene_subset, number_per_gene = filter_genes(adata.X, min_cells=2) + npt.assert_allclose(materialize_as_ndarray(gene_subset_dist), gene_subset) + npt.assert_allclose(materialize_as_ndarray(number_per_gene_dist), number_per_gene) + + def test_filter_genes(adata: AnnData, adata_dist: AnnData): filter_genes(adata_dist, min_cells=2) assert isinstance(adata_dist.X, DIST_TYPES) result = materialize_as_ndarray(adata_dist.X) filter_genes(adata, min_cells=2) assert result.shape == adata.shape - assert result.shape == (adata.n_obs, adata.n_vars) npt.assert_allclose(result, adata.X)