Skip to content

Commit 2f15f79

Browse files
(feat): calculate_qc_metrics with dask (#3307)
Co-authored-by: Philipp A. <[email protected]>
1 parent 5c0e89e commit 2f15f79

File tree

5 files changed

+208
-86
lines changed

5 files changed

+208
-86
lines changed

docs/release-notes/3307.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add support {class}`dask.array.Array` to {func}`scanpy.pp.calculate_qc_metrics` {smaller}`I Gold`

src/scanpy/_utils/__init__.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from typing import Any, TypeVar
5555

5656
from anndata import AnnData
57-
from numpy.typing import DTypeLike, NDArray
57+
from numpy.typing import ArrayLike, DTypeLike, NDArray
5858

5959
from ..neighbors import NeighborsParams, RPForestDict
6060

@@ -738,6 +738,27 @@ def _(
738738
)
739739

740740

741+
@singledispatch
742+
def axis_nnz(X: ArrayLike, axis: Literal[0, 1]) -> np.ndarray:
743+
return np.count_nonzero(X, axis=axis)
744+
745+
746+
@axis_nnz.register(sparse.spmatrix)
747+
def _(X: sparse.spmatrix, axis: Literal[0, 1]) -> np.ndarray:
748+
return X.getnnz(axis=axis)
749+
750+
751+
@axis_nnz.register(DaskArray)
752+
def _(X: DaskArray, axis: Literal[0, 1]) -> DaskArray:
753+
return X.map_blocks(
754+
partial(axis_nnz, axis=axis),
755+
dtype=np.int64,
756+
meta=np.array([], dtype=np.int64),
757+
drop_axis=0,
758+
chunks=len(X.to_delayed()) * (X.chunksize[int(not axis)],),
759+
)
760+
761+
741762
@overload
742763
def axis_sum(
743764
X: sparse.spmatrix,

src/scanpy/preprocessing/_qc.py

Lines changed: 50 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
from __future__ import annotations
22

3+
from functools import singledispatch
34
from typing import TYPE_CHECKING
45
from warnings import warn
56

67
import numba
78
import numpy as np
89
import pandas as pd
9-
from scipy.sparse import csr_matrix, issparse, isspmatrix_coo, isspmatrix_csr
10-
from sklearn.utils.sparsefuncs import mean_variance_axis
10+
from scipy.sparse import csr_matrix, issparse, isspmatrix_coo, isspmatrix_csr, spmatrix
1111

12-
from .._utils import _doc_params
12+
from scanpy.preprocessing._distributed import materialize_as_ndarray
13+
from scanpy.preprocessing._utils import _get_mean_var
14+
15+
from .._compat import DaskArray
16+
from .._utils import _doc_params, axis_nnz, axis_sum
1317
from ._docs import (
1418
doc_adata_basic,
1519
doc_expr_reps,
@@ -23,7 +27,6 @@
2327
from collections.abc import Collection
2428

2529
from anndata import AnnData
26-
from scipy.sparse import spmatrix
2730

2831

2932
def _choose_mtx_rep(adata, *, use_raw: bool = False, layer: str | None = None):
@@ -104,15 +107,14 @@ def describe_obs(
104107
if issparse(X):
105108
X.eliminate_zeros()
106109
obs_metrics = pd.DataFrame(index=adata.obs_names)
107-
if issparse(X):
108-
obs_metrics[f"n_{var_type}_by_{expr_type}"] = X.getnnz(axis=1)
109-
else:
110-
obs_metrics[f"n_{var_type}_by_{expr_type}"] = np.count_nonzero(X, axis=1)
110+
obs_metrics[f"n_{var_type}_by_{expr_type}"] = materialize_as_ndarray(
111+
axis_nnz(X, axis=1)
112+
)
111113
if log1p:
112114
obs_metrics[f"log1p_n_{var_type}_by_{expr_type}"] = np.log1p(
113115
obs_metrics[f"n_{var_type}_by_{expr_type}"]
114116
)
115-
obs_metrics[f"total_{expr_type}"] = np.ravel(X.sum(axis=1))
117+
obs_metrics[f"total_{expr_type}"] = np.ravel(axis_sum(X, axis=1))
116118
if log1p:
117119
obs_metrics[f"log1p_total_{expr_type}"] = np.log1p(
118120
obs_metrics[f"total_{expr_type}"]
@@ -126,7 +128,7 @@ def describe_obs(
126128
)
127129
for qc_var in qc_vars:
128130
obs_metrics[f"total_{expr_type}_{qc_var}"] = np.ravel(
129-
X[:, adata.var[qc_var].values].sum(axis=1)
131+
axis_sum(X[:, adata.var[qc_var].values], axis=1)
130132
)
131133
if log1p:
132134
obs_metrics[f"log1p_total_{expr_type}_{qc_var}"] = np.log1p(
@@ -141,6 +143,7 @@ def describe_obs(
141143
adata.obs[obs_metrics.columns] = obs_metrics
142144
else:
143145
return obs_metrics
146+
return None
144147

145148

146149
@_doc_params(
@@ -191,21 +194,17 @@ def describe_var(
191194
if issparse(X):
192195
X.eliminate_zeros()
193196
var_metrics = pd.DataFrame(index=adata.var_names)
194-
if issparse(X):
195-
# Current memory bottleneck for csr matrices:
196-
var_metrics["n_cells_by_{expr_type}"] = X.getnnz(axis=0)
197-
var_metrics["mean_{expr_type}"] = mean_variance_axis(X, axis=0)[0]
198-
else:
199-
var_metrics["n_cells_by_{expr_type}"] = np.count_nonzero(X, axis=0)
200-
var_metrics["mean_{expr_type}"] = X.mean(axis=0)
197+
var_metrics["n_cells_by_{expr_type}"], var_metrics["mean_{expr_type}"] = (
198+
materialize_as_ndarray((axis_nnz(X, axis=0), _get_mean_var(X, axis=0)[0]))
199+
)
201200
if log1p:
202201
var_metrics["log1p_mean_{expr_type}"] = np.log1p(
203202
var_metrics["mean_{expr_type}"]
204203
)
205204
var_metrics["pct_dropout_by_{expr_type}"] = (
206205
1 - var_metrics["n_cells_by_{expr_type}"] / X.shape[0]
207206
) * 100
208-
var_metrics["total_{expr_type}"] = np.ravel(X.sum(axis=0))
207+
var_metrics["total_{expr_type}"] = np.ravel(axis_sum(X, axis=0))
209208
if log1p:
210209
var_metrics["log1p_total_{expr_type}"] = np.log1p(
211210
var_metrics["total_{expr_type}"]
@@ -217,8 +216,8 @@ def describe_var(
217216
var_metrics.columns = new_colnames
218217
if inplace:
219218
adata.var[var_metrics.columns] = var_metrics
220-
else:
221-
return var_metrics
219+
return None
220+
return var_metrics
222221

223222

224223
@_doc_params(
@@ -387,9 +386,18 @@ def top_proportions_sparse_csr(data, indptr, n):
387386
return values
388387

389388

390-
def top_segment_proportions(
391-
mtx: np.ndarray | spmatrix, ns: Collection[int]
392-
) -> np.ndarray:
389+
def check_ns(func):
390+
def check_ns_inner(mtx: np.ndarray | spmatrix | DaskArray, ns: Collection[int]):
391+
if not (max(ns) <= mtx.shape[1] and min(ns) > 0):
392+
raise IndexError("Positions outside range of features.")
393+
return func(mtx, ns)
394+
395+
return check_ns_inner
396+
397+
398+
@singledispatch
399+
@check_ns
400+
def top_segment_proportions(mtx: np.ndarray, ns: Collection[int]) -> np.ndarray:
393401
"""
394402
Calculates total percentage of counts in top ns genes.
395403
@@ -402,20 +410,6 @@ def top_segment_proportions(
402410
1-indexed, e.g. `ns=[50]` will calculate cumulative proportion up to
403411
the 50th most expressed gene.
404412
"""
405-
# Pretty much just does dispatch
406-
if not (max(ns) <= mtx.shape[1] and min(ns) > 0):
407-
raise IndexError("Positions outside range of features.")
408-
if issparse(mtx):
409-
if not isspmatrix_csr(mtx):
410-
mtx = csr_matrix(mtx)
411-
return top_segment_proportions_sparse_csr(mtx.data, mtx.indptr, np.array(ns))
412-
else:
413-
return top_segment_proportions_dense(mtx, ns)
414-
415-
416-
def top_segment_proportions_dense(
417-
mtx: np.ndarray | spmatrix, ns: Collection[int]
418-
) -> np.ndarray:
419413
# Currently ns is considered to be 1 indexed
420414
ns = np.sort(ns)
421415
sums = mtx.sum(axis=1)
@@ -432,6 +426,25 @@ def top_segment_proportions_dense(
432426
return values / sums[:, None]
433427

434428

429+
@top_segment_proportions.register(DaskArray)
430+
@check_ns
431+
def _(mtx: DaskArray, ns: Collection[int]) -> DaskArray:
432+
if not isinstance(mtx._meta, csr_matrix | np.ndarray):
433+
msg = f"DaskArray must have csr matrix or ndarray meta, got {mtx._meta}."
434+
raise ValueError(msg)
435+
return mtx.map_blocks(
436+
lambda x: top_segment_proportions(x, ns), meta=np.array([])
437+
).compute()
438+
439+
440+
@top_segment_proportions.register(spmatrix)
441+
@check_ns
442+
def _(mtx: spmatrix, ns: Collection[int]) -> DaskArray:
443+
if not isspmatrix_csr(mtx):
444+
mtx = csr_matrix(mtx)
445+
return top_segment_proportions_sparse_csr(mtx.data, mtx.indptr, np.array(ns))
446+
447+
435448
@numba.njit(cache=True, parallel=True)
436449
def top_segment_proportions_sparse_csr(data, indptr, ns):
437450
# work around https://github.com/numba/numba/issues/5056

src/testing/scanpy/_helpers/__init__.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from __future__ import annotations
66

77
import warnings
8-
from contextlib import AbstractContextManager
8+
from contextlib import AbstractContextManager, contextmanager
99
from dataclasses import dataclass
10+
from importlib.util import find_spec
1011
from itertools import permutations
1112
from typing import TYPE_CHECKING
1213

@@ -158,3 +159,24 @@ def __enter__(self):
158159
def __exit__(self, exc_type, exc_value, traceback):
159160
for ctx in reversed(self.contexts):
160161
ctx.__exit__(exc_type, exc_value, traceback)
162+
163+
164+
@contextmanager
165+
def maybe_dask_process_context():
166+
"""
167+
Running numba with dask's threaded scheduler causes crashes,
168+
so we need to switch to single-threaded (or processes, which is slower)
169+
scheduler for tests that use numba.
170+
"""
171+
if not find_spec("dask"):
172+
yield
173+
return
174+
175+
import dask.config
176+
177+
prev_scheduler = dask.config.get("scheduler", "threads")
178+
dask.config.set(scheduler="single-threaded")
179+
try:
180+
yield
181+
finally:
182+
dask.config.set(scheduler=prev_scheduler)

0 commit comments

Comments
 (0)