11from __future__ import annotations
22
3+ from functools import singledispatch
34from typing import TYPE_CHECKING
45from warnings import warn
56
67import numba
78import numpy as np
89import pandas as pd
9- from scipy .sparse import csr_matrix , issparse , isspmatrix_coo , isspmatrix_csr
10- from sklearn .utils .sparsefuncs import mean_variance_axis
10+ from scipy .sparse import csr_matrix , issparse , isspmatrix_coo , isspmatrix_csr , spmatrix
1111
12- from .._utils import _doc_params
12+ from scanpy .preprocessing ._distributed import materialize_as_ndarray
13+ from scanpy .preprocessing ._utils import _get_mean_var
14+
15+ from .._compat import DaskArray
16+ from .._utils import _doc_params , axis_nnz , axis_sum
1317from ._docs import (
1418 doc_adata_basic ,
1519 doc_expr_reps ,
2327 from collections .abc import Collection
2428
2529 from anndata import AnnData
26- from scipy .sparse import spmatrix
2730
2831
2932def _choose_mtx_rep (adata , * , use_raw : bool = False , layer : str | None = None ):
@@ -104,15 +107,14 @@ def describe_obs(
104107 if issparse (X ):
105108 X .eliminate_zeros ()
106109 obs_metrics = pd .DataFrame (index = adata .obs_names )
107- if issparse (X ):
108- obs_metrics [f"n_{ var_type } _by_{ expr_type } " ] = X .getnnz (axis = 1 )
109- else :
110- obs_metrics [f"n_{ var_type } _by_{ expr_type } " ] = np .count_nonzero (X , axis = 1 )
110+ obs_metrics [f"n_{ var_type } _by_{ expr_type } " ] = materialize_as_ndarray (
111+ axis_nnz (X , axis = 1 )
112+ )
111113 if log1p :
112114 obs_metrics [f"log1p_n_{ var_type } _by_{ expr_type } " ] = np .log1p (
113115 obs_metrics [f"n_{ var_type } _by_{ expr_type } " ]
114116 )
115- obs_metrics [f"total_{ expr_type } " ] = np .ravel (X . sum ( axis = 1 ))
117+ obs_metrics [f"total_{ expr_type } " ] = np .ravel (axis_sum ( X , axis = 1 ))
116118 if log1p :
117119 obs_metrics [f"log1p_total_{ expr_type } " ] = np .log1p (
118120 obs_metrics [f"total_{ expr_type } " ]
@@ -126,7 +128,7 @@ def describe_obs(
126128 )
127129 for qc_var in qc_vars :
128130 obs_metrics [f"total_{ expr_type } _{ qc_var } " ] = np .ravel (
129- X [:, adata .var [qc_var ].values ]. sum ( axis = 1 )
131+ axis_sum ( X [:, adata .var [qc_var ].values ], axis = 1 )
130132 )
131133 if log1p :
132134 obs_metrics [f"log1p_total_{ expr_type } _{ qc_var } " ] = np .log1p (
@@ -141,6 +143,7 @@ def describe_obs(
141143 adata .obs [obs_metrics .columns ] = obs_metrics
142144 else :
143145 return obs_metrics
146+ return None
144147
145148
146149@_doc_params (
@@ -191,21 +194,17 @@ def describe_var(
191194 if issparse (X ):
192195 X .eliminate_zeros ()
193196 var_metrics = pd .DataFrame (index = adata .var_names )
194- if issparse (X ):
195- # Current memory bottleneck for csr matrices:
196- var_metrics ["n_cells_by_{expr_type}" ] = X .getnnz (axis = 0 )
197- var_metrics ["mean_{expr_type}" ] = mean_variance_axis (X , axis = 0 )[0 ]
198- else :
199- var_metrics ["n_cells_by_{expr_type}" ] = np .count_nonzero (X , axis = 0 )
200- var_metrics ["mean_{expr_type}" ] = X .mean (axis = 0 )
197+ var_metrics ["n_cells_by_{expr_type}" ], var_metrics ["mean_{expr_type}" ] = (
198+ materialize_as_ndarray ((axis_nnz (X , axis = 0 ), _get_mean_var (X , axis = 0 )[0 ]))
199+ )
201200 if log1p :
202201 var_metrics ["log1p_mean_{expr_type}" ] = np .log1p (
203202 var_metrics ["mean_{expr_type}" ]
204203 )
205204 var_metrics ["pct_dropout_by_{expr_type}" ] = (
206205 1 - var_metrics ["n_cells_by_{expr_type}" ] / X .shape [0 ]
207206 ) * 100
208- var_metrics ["total_{expr_type}" ] = np .ravel (X . sum ( axis = 0 ))
207+ var_metrics ["total_{expr_type}" ] = np .ravel (axis_sum ( X , axis = 0 ))
209208 if log1p :
210209 var_metrics ["log1p_total_{expr_type}" ] = np .log1p (
211210 var_metrics ["total_{expr_type}" ]
@@ -217,8 +216,8 @@ def describe_var(
217216 var_metrics .columns = new_colnames
218217 if inplace :
219218 adata .var [var_metrics .columns ] = var_metrics
220- else :
221- return var_metrics
219+ return None
220+ return var_metrics
222221
223222
224223@_doc_params (
@@ -387,9 +386,18 @@ def top_proportions_sparse_csr(data, indptr, n):
387386 return values
388387
389388
390- def top_segment_proportions (
391- mtx : np .ndarray | spmatrix , ns : Collection [int ]
392- ) -> np .ndarray :
389+ def check_ns (func ):
390+ def check_ns_inner (mtx : np .ndarray | spmatrix | DaskArray , ns : Collection [int ]):
391+ if not (max (ns ) <= mtx .shape [1 ] and min (ns ) > 0 ):
392+ raise IndexError ("Positions outside range of features." )
393+ return func (mtx , ns )
394+
395+ return check_ns_inner
396+
397+
398+ @singledispatch
399+ @check_ns
400+ def top_segment_proportions (mtx : np .ndarray , ns : Collection [int ]) -> np .ndarray :
393401 """
394402 Calculates total percentage of counts in top ns genes.
395403
@@ -402,20 +410,6 @@ def top_segment_proportions(
402410 1-indexed, e.g. `ns=[50]` will calculate cumulative proportion up to
403411 the 50th most expressed gene.
404412 """
405- # Pretty much just does dispatch
406- if not (max (ns ) <= mtx .shape [1 ] and min (ns ) > 0 ):
407- raise IndexError ("Positions outside range of features." )
408- if issparse (mtx ):
409- if not isspmatrix_csr (mtx ):
410- mtx = csr_matrix (mtx )
411- return top_segment_proportions_sparse_csr (mtx .data , mtx .indptr , np .array (ns ))
412- else :
413- return top_segment_proportions_dense (mtx , ns )
414-
415-
416- def top_segment_proportions_dense (
417- mtx : np .ndarray | spmatrix , ns : Collection [int ]
418- ) -> np .ndarray :
419413 # Currently ns is considered to be 1 indexed
420414 ns = np .sort (ns )
421415 sums = mtx .sum (axis = 1 )
@@ -432,6 +426,25 @@ def top_segment_proportions_dense(
432426 return values / sums [:, None ]
433427
434428
429+ @top_segment_proportions .register (DaskArray )
430+ @check_ns
431+ def _ (mtx : DaskArray , ns : Collection [int ]) -> DaskArray :
432+ if not isinstance (mtx ._meta , csr_matrix | np .ndarray ):
433+ msg = f"DaskArray must have csr matrix or ndarray meta, got { mtx ._meta } ."
434+ raise ValueError (msg )
435+ return mtx .map_blocks (
436+ lambda x : top_segment_proportions (x , ns ), meta = np .array ([])
437+ ).compute ()
438+
439+
440+ @top_segment_proportions .register (spmatrix )
441+ @check_ns
442+ def _ (mtx : spmatrix , ns : Collection [int ]) -> DaskArray :
443+ if not isspmatrix_csr (mtx ):
444+ mtx = csr_matrix (mtx )
445+ return top_segment_proportions_sparse_csr (mtx .data , mtx .indptr , np .array (ns ))
446+
447+
435448@numba .njit (cache = True , parallel = True )
436449def top_segment_proportions_sparse_csr (data , indptr , ns ):
437450 # work around https://github.com/numba/numba/issues/5056
0 commit comments