diff --git a/docs/release-notes/3872.feat.md b/docs/release-notes/3872.feat.md new file mode 100644 index 0000000000..6a7db9b86b --- /dev/null +++ b/docs/release-notes/3872.feat.md @@ -0,0 +1 @@ +Add in `csc`-in-{doc}`dask:index` support for {func}`scanpy.get.aggregate` {smaller}`I Gold` diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 219fa6e83a..93194aea0d 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -10,7 +10,7 @@ from scipy import sparse from sklearn.utils.sparsefuncs import csc_median_axis_0 -from scanpy._compat import CSBase, CSCBase, CSRBase, DaskArray +from scanpy._compat import CSBase, CSRBase, DaskArray from .._utils import _resolve_axis, get_literal_vals from .get import _check_mask @@ -357,9 +357,6 @@ def aggregate_dask_mean_var( # TODO: If we don't compute here, the results are not deterministic under the process cluster for sparse. if isinstance(data._meta, CSRBase): sq_mean = sq_mean.compute() - elif isinstance(data._meta, CSCBase): # pragma: no-cover - msg = "Cannot handle CSC matrices as dask meta." - raise ValueError(msg) var = sq_mean - fau_power(mean, 2) if dof != 0: group_counts = np.bincount(by.codes) @@ -376,23 +373,32 @@ def aggregate_dask( mask: NDArray[np.bool_] | None = None, dof: int = 1, ) -> dict[AggType, DaskArray]: - if not isinstance(data._meta, CSRBase | np.ndarray): + if not isinstance(data._meta, CSBase | np.ndarray): msg = f"Got {type(data._meta)} meta in DaskArray but only csr_matrix/csr_array and ndarray are supported." raise ValueError(msg) - if data.chunksize[1] != data.shape[1]: + chunked_axis, unchunked_axis = ( + (0, 1) if isinstance(data._meta, CSRBase | np.ndarray) else (1, 0) + ) + if data.chunksize[unchunked_axis] != data.shape[unchunked_axis]: msg = "Feature axis must be unchunked" raise ValueError(msg) def aggregate_chunk_sum_or_count_nonzero( chunk: Array, *, func: Literal["count_nonzero", "sum"], block_info=None ): - # See https://docs.dask.org/en/stable/generated/dask.array.map_blocks.html - # for what is contained in `block_info`. - subset = slice(*block_info[0]["array-location"][0]) - by_subsetted = by[subset] - mask_subsetted = mask[subset] if mask is not None else mask + # only subset the mask and by if we need to i.e., + # there is chunking along the same axis as by and mask + if chunked_axis == 0: + # See https://docs.dask.org/en/stable/generated/dask.array.map_blocks.html + # for what is contained in `block_info`. + subset = slice(*block_info[0]["array-location"][0]) + by_subsetted = by[subset] + mask_subsetted = mask[subset] if mask is not None else mask + else: + by_subsetted = by + mask_subsetted = mask res = _aggregate(chunk, by_subsetted, func, mask=mask_subsetted, dof=dof)[func] - return res[None, :] + return res[None, :] if unchunked_axis == 1 else res funcs = set([func] if isinstance(func, str) else func) if "median" in funcs: @@ -400,23 +406,33 @@ def aggregate_chunk_sum_or_count_nonzero( raise NotImplementedError(msg) has_mean, has_var = (v in funcs for v in ["mean", "var"]) funcs_no_var_or_mean = funcs - {"var", "mean"} - # aggregate each row chunk individually, - # producing a #chunks × #categories × #features array, + # aggregate each row chunk or column chunk individually, + # producing a #chunks × #categories × #features or a #categories × #chunks array, # then aggregate the per-chunk results. + chunks = ( + ((1,) * data.blocks.size, (len(by.categories),), data.shape[1]) + if unchunked_axis == 1 + else (len(by.categories), data.chunks[1]) + ) aggregated = { f: data.map_blocks( partial(aggregate_chunk_sum_or_count_nonzero, func=func), - new_axis=(1,), - chunks=((1,) * data.blocks.size, (len(by.categories),), (data.shape[1],)), + new_axis=(1,) if unchunked_axis == 1 else None, + chunks=chunks, meta=np.array( [], dtype=np.float64 if func not in get_args(ConstantDtypeAgg) else data.dtype, # TODO: figure out best dtype for aggs like sum where dtype can change from original ), - ).sum(axis=0) + ) for f in funcs_no_var_or_mean } + # If we have row chunking, we need to handle the extra axis by summing over all category × feature matrices. + # Otherwise, dask internally concatenates the #categories × #chunks arrays i.e., the column chunks are concatenated together to get a #categories × #features matrix. + if unchunked_axis == 1: + for k, v in aggregated.items(): + aggregated[k] = v.sum(axis=chunked_axis) if has_var: aggredated_mean_var = aggregate_dask_mean_var(data, by, mask=mask, dof=dof) aggregated["var"] = aggredated_mean_var["var"] diff --git a/tests/test_aggregated.py b/tests/test_aggregated.py index d6f18195d3..5ced6a4634 100644 --- a/tests/test_aggregated.py +++ b/tests/test_aggregated.py @@ -29,8 +29,6 @@ not in { "dask_array_dense", "dask_array_sparse", - "dask_array_sparse-1d_chunked-csc_array", - "dask_array_sparse-1d_chunked-csc_matrix", } ] @@ -246,7 +244,6 @@ def to_csc(x: CSRBase): @pytest.mark.parametrize( ("func", "error_msg"), [ - pytest.param(to_csc, r"only csr_matrix", id="csc"), pytest.param( to_bad_chunking, r"Feature axis must be unchunked", id="bad_chunking" ),