From 584d3a92cdb18e0fc2458ce1c3eb5052102d1c5b Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 3 Nov 2025 14:53:16 +0100 Subject: [PATCH 01/13] chore: add `csc`-in-`dask` tests --- src/testing/scanpy/_pytest/params.py | 26 +++++++++++-- tests/test_aggregated.py | 16 +++++--- tests/test_highly_variable_genes.py | 8 ++-- tests/test_metrics.py | 2 +- tests/test_pca.py | 43 ++++++++++++++------- tests/test_preprocessing.py | 10 ++++- tests/test_qc_metrics.py | 56 ++++++++++++++++++++++++---- tests/test_utils.py | 7 +++- 8 files changed, 129 insertions(+), 39 deletions(-) diff --git a/src/testing/scanpy/_pytest/params.py b/src/testing/scanpy/_pytest/params.py index 512359f3d5..abbfb473e1 100644 --- a/src/testing/scanpy/_pytest/params.py +++ b/src/testing/scanpy/_pytest/params.py @@ -2,7 +2,7 @@ from __future__ import annotations -from functools import wraps +from functools import partial, wraps from importlib.metadata import version from typing import TYPE_CHECKING @@ -48,7 +48,11 @@ def _chunked_1d( @wraps(f) def wrapper(a: np.ndarray) -> DaskArray: da = f(a) - return da.rechunk((da.chunksize[0], -1)) + return da.rechunk( + (da.chunksize[0], -1) + if not hasattr(da._meta, "format") or da._meta.format == "csr" + else (-1, da.chunksize[1]) + ) wrapper.__name__ = f"{wrapper.__name__}-1d_chunked" return wrapper @@ -78,7 +82,23 @@ def wrapper(a: np.ndarray) -> DaskArray: marks=[needs.dask, pytest.mark.anndata_dask_support], id=f"dask_array_sparse{suffix}", ) - for wrapper, suffix in [(lambda x: x, ""), (_chunked_1d, "-1d_chunked")] + for wrapper, suffix in [ + (lambda x: x, ""), + *( + ( + lambda func, + format=format, + matrix_or_array=matrix_or_array: _chunked_1d( + partial( + func, typ=getattr(sparse, f"{format}_{matrix_or_array}") + ) + ), + f"-1d_chunked-{format}_{matrix_or_array}", + ) + for format in ["csr", "csc"] + for matrix_or_array in ["matrix", "array"] + ), + ] ), } diff --git a/tests/test_aggregated.py b/tests/test_aggregated.py index ceb029e411..1e0bfa7fdd 100644 --- a/tests/test_aggregated.py +++ b/tests/test_aggregated.py @@ -22,10 +22,16 @@ from scanpy._compat import CSRBase -ARRAY_TYPES = [ +VALID_ARRAY_TYPES = [ at for at in ARRAY_TYPES_ALL - if at.id not in {"dask_array_dense", "dask_array_sparse"} + if at.id + not in { + "dask_array_dense", + "dask_array_sparse", + "dask_array_sparse-1d_chunked-csc_array", + "dask_array_sparse-1d_chunked-csc_matrix", + } ] @@ -118,7 +124,7 @@ def test_mask(axis): assert np.all(by_name["0"].layers["sum"] == 0) -@pytest.mark.parametrize("array_type", ARRAY_TYPES) +@pytest.mark.parametrize("array_type", VALID_ARRAY_TYPES) def test_aggregate_vs_pandas( metric: AggType, array_type, request: pytest.FixtureRequest ): @@ -160,7 +166,7 @@ def test_aggregate_vs_pandas( pd.testing.assert_frame_equal(result_df, expected, check_dtype=False, atol=1e-5) -@pytest.mark.parametrize("array_type", ARRAY_TYPES) +@pytest.mark.parametrize("array_type", VALID_ARRAY_TYPES) def test_aggregate_axis(array_type, metric, request: pytest.FixtureRequest): adata = pbmc3k_processed().raw.to_adata() adata = adata[ @@ -445,7 +451,7 @@ def test_combine_categories(label_cols, cols, expected): pd.testing.assert_frame_equal(reconstructed_df, result_label_df) -@pytest.mark.parametrize("array_type", ARRAY_TYPES) +@pytest.mark.parametrize("array_type", VALID_ARRAY_TYPES) def test_aggregate_arraytype( array_type, metric: AggType, request: pytest.FixtureRequest ): diff --git a/tests/test_highly_variable_genes.py b/tests/test_highly_variable_genes.py index 5c6693461d..e338edbc13 100644 --- a/tests/test_highly_variable_genes.py +++ b/tests/test_highly_variable_genes.py @@ -425,8 +425,8 @@ def test_compare_to_upstream( np.testing.assert_allclose( hvg_info["dispersions_norm"], pbmc.var["dispersions_norm"], - rtol=2e-05 if "dask" not in array_type.__name__ else 1e-4, - atol=2e-05 if "dask" not in array_type.__name__ else 1e-4, + rtol=2e-05 if "dask" not in request.node.name else 1e-4, + atol=2e-05 if "dask" not in request.node.name else 1e-4, ) @@ -687,9 +687,7 @@ def test_subset_inplace_consistency(flavor, array_type, batch_key): @pytest.mark.parametrize("flavor", ["seurat", "cell_ranger"]) @pytest.mark.parametrize("batch_key", [None, "batch"], ids=["single", "batched"]) -@pytest.mark.parametrize( - "to_dask", [p for p in ARRAY_TYPES if "dask" in p.values[0].__name__] -) +@pytest.mark.parametrize("to_dask", [p for p in ARRAY_TYPES if "dask" in p.id]) def test_dask_consistency(adata: AnnData, flavor, batch_key, to_dask): adata.X = np.abs(adata.X).astype(int) if batch_key is not None: diff --git a/tests/test_metrics.py b/tests/test_metrics.py index e4b5bed022..1326ab77c4 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -105,7 +105,7 @@ def test_correctness(metric, size, expected): def test_graph_metrics_w_constant_values( request: pytest.FixtureRequest, metric, array_type ): - if "dask" in array_type.__name__: + if "dask" in request.node.name: reason = "DaskArray not yet supported" request.applymarker(pytest.mark.xfail(reason=reason)) diff --git a/tests/test_pca.py b/tests/test_pca.py index 61ace19cd2..0740fb00aa 100644 --- a/tests/test_pca.py +++ b/tests/test_pca.py @@ -58,18 +58,24 @@ [-1.50180389, 5.56886849, 1.64034442, 2.24476032, -0.05109001], ]) - -ARRAY_TYPES = [ +# These are array types which are expected to work with the current PCA implementation. +VALID_ARRAY_TYPES = [ param_with( at, marks=[needs.dask_ml] if at.id == "dask_array_dense-1d_chunked" else [], ) for at in ARRAY_TYPES_ALL - if at.id not in {"dask_array_dense", "dask_array_sparse"} + if at.id + not in { + "dask_array_dense", + "dask_array_sparse", + "dask_array_sparse-1d_chunked-csc_array", + "dask_array_sparse-1d_chunked-csc_matrix", + } ] -@pytest.fixture(params=ARRAY_TYPES) +@pytest.fixture(params=VALID_ARRAY_TYPES) def array_type(request: pytest.FixtureRequest) -> ArrayType: return request.param @@ -93,10 +99,14 @@ def gen_pca_params( xfail_reason = "dask without 1d chunking scheme not supported" yield None, None, xfail_reason return - if id == "dask_array_sparse-1d_chunked" and not zero_center: + if "dask_array_sparse-1d_chunked" in id and not zero_center: xfail_reason = "Sparse-in-dask with zero_center=False not implemented yet" yield None, None, xfail_reason return + if "dask_array_sparse-1d_chunked-csc" in id: + xfail_reason = "Sparse-in-dask with csc blocks not implemented yet" + yield None, None, xfail_reason + return if svd_solver_type is None: yield None, None, None return @@ -137,7 +147,7 @@ def possible_solvers( svd_solvers = {"auto", "full", "tsqr", "randomized", "covariance_eigh"} case (dc, False) if id == "dask_array_dense-1d_chunked": svd_solvers = {"tsqr", "randomized"} - case (dc, True) if id == "dask_array_sparse-1d_chunked": + case (dc, True) if "dask_array_sparse-1d_chunked-csr" in id: svd_solvers = {"covariance_eigh"} case (type() as dc, True) if issubclass(dc, CSBase): svd_solvers = {"arpack"} | SKLEARN_ADDITIONAL @@ -148,7 +158,7 @@ def possible_solvers( case (helpers.asarray, False): svd_solvers = {"arpack", "randomized"} case _: - pytest.fail(f"Unknown {array_type=} ({zero_center=})") + pytest.fail(f"Unknown {array_type=} ({zero_center=}) ({id=})") if svd_solver_type == "invalid": svd_solvers = all_svd_solvers - svd_solvers @@ -178,7 +188,7 @@ def possible_solvers( f"{svd_solver or svd_solver_type}-{'xfail' if xfail_reason else warn_pat_expected}" ), ) - for array_type in ARRAY_TYPES + for array_type in VALID_ARRAY_TYPES for zero_center in [True, False] for svd_solver_type in [None, "valid", "invalid"] for svd_solver, warn_pat_expected, xfail_reason in gen_pca_params( @@ -515,10 +525,13 @@ def test_pca_layer(): @pytest.mark.parametrize( "other_array_type", [ - lambda x: x.toarray(), - *(at.values[0] for at in ARRAY_TYPES if "1d_chunked" in at.id), + pytest.param(lambda x: x.toarray(), id="dense"), + *( + pytest.param(at.values[0], id=at.id) + for at in VALID_ARRAY_TYPES + if "1d_chunked" in at.id + ), ], - ids=["dense-mem", "sparse-dask", "dense-dask"], ) def test_covariance_eigh_impls(other_array_type): warnings.filterwarnings("error") @@ -564,8 +577,8 @@ def test_sparse_dask_input_errors(msg_re: str, op: Callable[[DaskArray], DaskArr adata_sparse.X = op( next( at.values[0] - for at in ARRAY_TYPES - if at.id == "dask_array_sparse-1d_chunked" + for at in VALID_ARRAY_TYPES + if "dask_array_sparse-1d_chunked" in at.id )(adata_sparse.X) ) @@ -587,7 +600,9 @@ def test_sparse_dask_input_errors(msg_re: str, op: Callable[[DaskArray], DaskArr def test_cov_sparse_dask(dtype, dtype_arg, rtol): x_arr = A_list.astype(dtype) x = next( - at.values[0] for at in ARRAY_TYPES if at.id == "dask_array_sparse-1d_chunked" + at.values[0] + for at in VALID_ARRAY_TYPES + if "dask_array_sparse-1d_chunked" in at.id )(x_arr) cov, gram, mean = _cov_sparse_dask(x, return_gram=True, dtype=dtype_arg) np.testing.assert_allclose(mean, np.mean(x_arr, axis=0)) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index c88dcde9b3..28f9c22e4d 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -281,7 +281,13 @@ def test_sample_copy_backed_error(tmp_path): @pytest.mark.parametrize("array_type", ARRAY_TYPES) @pytest.mark.parametrize("max_value", [None, 1.0], ids=["no_clip", "clip"]) -def test_scale_matrix_types(array_type, zero_center, max_value): +def test_scale_matrix_types( + *, + request: pytest.FixtureRequest, + array_type: Callable, + zero_center: bool, + max_value: float | None, +): adata = pbmc68k_reduced() adata.X = adata.raw.X adata_casted = adata.copy() @@ -294,7 +300,7 @@ def test_scale_matrix_types(array_type, zero_center, max_value): ( warn_ctx if zero_center - and any(pat in array_type.__name__ for pat in ("sparse", "csc", "csr")) + and any(pat in request.node.name for pat in ("sparse", "csc", "csr")) else nullcontext() ), maybe_dask_process_context(), diff --git a/tests/test_qc_metrics.py b/tests/test_qc_metrics.py index 29adfef8f9..cf60e26873 100644 --- a/tests/test_qc_metrics.py +++ b/tests/test_qc_metrics.py @@ -1,5 +1,7 @@ from __future__ import annotations +from contextlib import nullcontext + import numpy as np import pandas as pd import pytest @@ -9,7 +11,7 @@ from scipy import sparse import scanpy as sc -from scanpy._compat import DaskArray +from scanpy._compat import CSCBase, DaskArray from scanpy.preprocessing._qc import ( describe_obs, describe_var, @@ -79,12 +81,20 @@ def test_segments_binary(): "array_type", [*ARRAY_TYPES, pytest.param(sparse.coo_matrix, id="scipy_coo")] ) def test_top_segments(request: pytest.FixtureRequest, array_type): - if "dask" in array_type.__name__ and "1d_chunked" not in array_type.__name__: + if "dask" in request.node.name and "1d_chunked" not in request.node.name: reason = "DaskArray with feature axis chunking not yet supported" request.applymarker(pytest.mark.xfail(reason=reason)) a = array_type(np.ones((300, 100))) - with maybe_dask_process_context(): + is_csc_dask = isinstance(a, DaskArray) and isinstance(a._meta, CSCBase) + with ( + maybe_dask_process_context(), + pytest.raises(ValueError, match=r"DaskArray must have csr") + if is_csc_dask + else nullcontext(), + ): seg = top_segment_proportions(a, [50, 100]) + if is_csc_dask: + return assert (seg[:, 0] == 0.5).all() assert (seg[:, 1] == 1.0).all() @@ -93,7 +103,7 @@ def test_top_segments(request: pytest.FixtureRequest, array_type): "array_type", [*ARRAY_TYPES, pytest.param(sparse.coo_matrix, id="scipy_coo")] ) def test_top_proportions(request: pytest.FixtureRequest, array_type): - if "dask" in array_type.__name__: + if "dask" in request.node.name: reason = "DaskArray not yet supported" request.applymarker(pytest.mark.xfail(reason=reason)) a = array_type(np.ones((300, 100))) @@ -107,10 +117,22 @@ def test_top_proportions(request: pytest.FixtureRequest, array_type): # While many of these are trivial, # they’re also just making sure the metrics are there def test_qc_metrics(adata_prepared: AnnData): - with maybe_dask_process_context(): + is_csc_dask = isinstance(adata_prepared.X, DaskArray) and isinstance( + adata_prepared.X._meta, CSCBase + ) + with ( + maybe_dask_process_context(), + ( + pytest.raises(ValueError, match=r"DaskArray must have csr") + if is_csc_dask + else nullcontext() + ), + ): sc.pp.calculate_qc_metrics( adata_prepared, qc_vars=["mito", "negative"], inplace=True ) + if is_csc_dask: + return x = ( adata_prepared.X.compute() if isinstance(adata_prepared.X, DaskArray) @@ -159,7 +181,17 @@ def test_qc_metrics(adata_prepared: AnnData): def test_qc_metrics_idempotent(adata_prepared: AnnData): - with maybe_dask_process_context(): + is_csc_dask = isinstance(adata_prepared.X, DaskArray) and isinstance( + adata_prepared.X._meta, CSCBase + ) + with ( + maybe_dask_process_context(), + ( + pytest.raises(ValueError, match=r"DaskArray must have csr") + if is_csc_dask + else nullcontext() + ), + ): sc.pp.calculate_qc_metrics( adata_prepared, qc_vars=["mito", "negative"], inplace=True ) @@ -167,6 +199,8 @@ def test_qc_metrics_idempotent(adata_prepared: AnnData): sc.pp.calculate_qc_metrics( adata_prepared, qc_vars=["mito", "negative"], inplace=True ) + if is_csc_dask: + return assert set(adata_prepared.obs.columns) == set(old_obs.columns) assert set(adata_prepared.var.columns) == set(old_var.columns) for col in adata_prepared.obs: @@ -176,7 +210,15 @@ def test_qc_metrics_idempotent(adata_prepared: AnnData): def test_qc_metrics_no_log1p(adata_prepared: AnnData): - with maybe_dask_process_context(): + with ( + maybe_dask_process_context(), + ( + pytest.raises(ValueError, match=r"DaskArray must have csr") + if isinstance(adata_prepared.X, DaskArray) + and isinstance(adata_prepared.X._meta, CSCBase) + else nullcontext() + ), + ): sc.pp.calculate_qc_metrics( adata_prepared, qc_vars=["mito", "negative"], log1p=False, inplace=True ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 8981f81503..36e86319e2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -30,6 +30,7 @@ ) if TYPE_CHECKING: + from collections.abc import Callable from typing import Any @@ -112,12 +113,14 @@ def test_divide_by_zero(array_type): @pytest.mark.parametrize("array_type", ARRAY_TYPES_SPARSE) -def test_scale_out_with_dask_or_sparse_raises(array_type): +def test_scale_out_with_dask_or_sparse_raises( + *, request: pytest.FixtureRequest, array_type: Callable +): dividend = array_type(asarray([[0, 1.0, 2.0], [3.0, 0, 4.0]])) divisor = np.array([0.1, 0.2, 0.5]) if isinstance(dividend, DaskArray): with pytest.raises( - TypeError if "dask" in array_type.__name__ else ValueError, + TypeError if "dask" in request.node.name else ValueError, match="`out`", ): axis_mul_or_truediv(dividend, divisor, op=truediv, axis=1, out=dividend) From a21a88098fd30152b89684edb4b50a79514e87f0 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 3 Nov 2025 15:37:52 +0100 Subject: [PATCH 02/13] feat: support `csc` in `dask` arrays in `get.aggregate` --- src/scanpy/get/_aggregated.py | 45 +++++++++++++++++++++++------------ tests/test_aggregated.py | 3 --- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 46355bff19..b9b822b8e2 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -10,7 +10,7 @@ from scipy import sparse from sklearn.utils.sparsefuncs import csc_median_axis_0 -from scanpy._compat import CSBase, CSCBase, CSRBase, DaskArray +from scanpy._compat import CSBase, CSRBase, DaskArray from .._utils import _resolve_axis, get_literal_vals from .get import _check_mask @@ -354,9 +354,6 @@ def aggregate_dask_mean_var( # TODO: If we don't compute here, the results are not deterministic under the process cluster for sparse. if isinstance(data._meta, CSRBase): sq_mean = sq_mean.compute() - elif isinstance(data._meta, CSCBase): # pragma: no-cover - msg = "Cannot handle CSC matrices as dask meta." - raise ValueError(msg) var = sq_mean - fau_power(mean, 2) if dof != 0: group_counts = np.bincount(by.codes) @@ -373,10 +370,13 @@ def aggregate_dask( mask: NDArray[np.bool_] | None = None, dof: int = 1, ) -> dict[AggType, DaskArray]: - if not isinstance(data._meta, CSRBase | np.ndarray): + if not isinstance(data._meta, CSBase | np.ndarray): msg = f"Got {type(data._meta)} meta in DaskArray but only csr_matrix/csr_array and ndarray are supported." raise ValueError(msg) - if data.chunksize[1] != data.shape[1]: + # i.e., if data._meta is CSR/np.ndarray, this is 1 because we use row-chunking, but otherwise 0 and column chunking + unchunked_axis = int(isinstance(data._meta, CSRBase | np.ndarray)) + chunked_axis = (unchunked_axis - 1) % 2 + if data.chunksize[unchunked_axis] != data.shape[unchunked_axis]: msg = "Feature axis must be unchunked" raise ValueError(msg) @@ -385,11 +385,16 @@ def aggregate_chunk_sum_or_count_nonzero( ): # See https://docs.dask.org/en/stable/generated/dask.array.map_blocks.html # for what is contained in `block_info`. - subset = slice(*block_info[0]["array-location"][0]) - by_subsetted = by[subset] - mask_subsetted = mask[subset] if mask is not None else mask + if chunked_axis == 0: + # only subset the mask and by if we need to i.e., there is chunking along the same axis as by and mask + subset = slice(*block_info[0]["array-location"][0]) + by_subsetted = by[subset] + mask_subsetted = mask[subset] if mask is not None else mask + else: + by_subsetted = by + mask_subsetted = mask res = _aggregate(chunk, by_subsetted, func, mask=mask_subsetted, dof=dof)[func] - return res[None, :] + return res[None, :] if unchunked_axis == 1 else res funcs = set([func] if isinstance(func, str) else func) if "median" in funcs: @@ -397,23 +402,33 @@ def aggregate_chunk_sum_or_count_nonzero( raise NotImplementedError(msg) has_mean, has_var = (v in funcs for v in ["mean", "var"]) funcs_no_var_or_mean = funcs - {"var", "mean"} - # aggregate each row chunk individually, - # producing a #chunks × #categories × #features array, + # aggregate each row chunk or column chunk individually, + # producing a #chunks × #categories × #features or a #categories x #chunks array, # then aggregate the per-chunk results. + chunks = ( + ((1,) * data.blocks.size, (len(by.categories),), data.shape[1]) + if unchunked_axis == 1 + else (len(by.categories), data.chunks[1]) + ) aggregated = { f: data.map_blocks( partial(aggregate_chunk_sum_or_count_nonzero, func=func), - new_axis=(1,), - chunks=((1,) * data.blocks.size, (len(by.categories),), (data.shape[1],)), + new_axis=(1,) if unchunked_axis == 1 else None, + chunks=chunks, meta=np.array( [], dtype=np.float64 if func not in get_args(ConstantDtypeAgg) else data.dtype, # TODO: figure out best dtype for aggs like sum where dtype can change from original ), - ).sum(axis=0) + ) for f in funcs_no_var_or_mean } + # If we have row chunking, we need to handle the extra axis by summing over all category x feature matrices. + # Otherwise, dask internally concatenates the #categories x #chunks arrays i.e., the column chunks are concatenated together toget a #categories x #features matrix. + if unchunked_axis == 1: + for k, v in aggregated.items(): + aggregated[k] = v.sum(axis=chunked_axis) if has_var: aggredated_mean_var = aggregate_dask_mean_var(data, by, mask=mask, dof=dof) aggregated["var"] = aggredated_mean_var["var"] diff --git a/tests/test_aggregated.py b/tests/test_aggregated.py index 1e0bfa7fdd..6f0412908a 100644 --- a/tests/test_aggregated.py +++ b/tests/test_aggregated.py @@ -29,8 +29,6 @@ not in { "dask_array_dense", "dask_array_sparse", - "dask_array_sparse-1d_chunked-csc_array", - "dask_array_sparse-1d_chunked-csc_matrix", } ] @@ -247,7 +245,6 @@ def to_csc(x: CSRBase): @pytest.mark.parametrize( ("func", "error_msg"), [ - pytest.param(to_csc, r"only csr_matrix", id="csc"), pytest.param( to_bad_chunking, r"Feature axis must be unchunked", id="bad_chunking" ), From 73a9735a3b8a743069b4dc886987921da1e8bade Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 6 Nov 2025 15:21:17 +0100 Subject: [PATCH 03/13] fix: handle `anndata` testing utils versions --- src/testing/scanpy/_helpers/__init__.py | 7 ++++++- src/testing/scanpy/_pytest/params.py | 27 +++++++++++++++---------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/src/testing/scanpy/_helpers/__init__.py b/src/testing/scanpy/_helpers/__init__.py index 055274818c..eec415bfa9 100644 --- a/src/testing/scanpy/_helpers/__init__.py +++ b/src/testing/scanpy/_helpers/__init__.py @@ -5,6 +5,7 @@ import warnings from contextlib import AbstractContextManager, contextmanager from dataclasses import dataclass +from importlib.metadata import version from importlib.util import find_spec from itertools import permutations from types import MappingProxyType @@ -13,6 +14,7 @@ import numpy as np from anndata import AnnData from anndata.tests.helpers import asarray, assert_equal +from packaging.version import Version import scanpy as sc @@ -130,7 +132,10 @@ def as_dense_dask_array(*args, **kwargs) -> DaskArray: def as_sparse_dask_array(*args, **kwargs) -> DaskArray: - from anndata.tests.helpers import as_sparse_dask_array + if Version(version("anndata")) < Version("0.12.5"): + from anndata.tests.helpers import as_sparse_dask_array + else: + from anndata.tests.helpers import as_sparse_dask_matrix as as_sparse_dask_array return as_sparse_dask_array(*args, **kwargs) diff --git a/src/testing/scanpy/_pytest/params.py b/src/testing/scanpy/_pytest/params.py index abbfb473e1..38b6ea642c 100644 --- a/src/testing/scanpy/_pytest/params.py +++ b/src/testing/scanpy/_pytest/params.py @@ -85,18 +85,23 @@ def wrapper(a: np.ndarray) -> DaskArray: for wrapper, suffix in [ (lambda x: x, ""), *( - ( - lambda func, - format=format, - matrix_or_array=matrix_or_array: _chunked_1d( - partial( - func, typ=getattr(sparse, f"{format}_{matrix_or_array}") - ) - ), - f"-1d_chunked-{format}_{matrix_or_array}", + ((_chunked_1d, "-1d_chunked"),) + if Version(version("anndata")) < Version("0.12.5") + else ( + ( + lambda func, + format=format, + matrix_or_array=matrix_or_array: _chunked_1d( + partial( + func, typ=getattr(sparse, f"{format}_{matrix_or_array}") + ) + ), + f"-1d_chunked-{format}_{matrix_or_array}", + ) + for format in ["csr", "csc"] + # TODO: use `array` as well once anndata 0.13 drops + for matrix_or_array in ["matrix"] ) - for format in ["csr", "csc"] - for matrix_or_array in ["matrix", "array"] ), ] ), From df5698e7918c9e650aab7f7814d4d948ee5d7e17 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 7 Nov 2025 10:23:13 +0100 Subject: [PATCH 04/13] fix: recognize old chunking args --- tests/test_pca.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_pca.py b/tests/test_pca.py index 0740fb00aa..6f016fd1b1 100644 --- a/tests/test_pca.py +++ b/tests/test_pca.py @@ -147,7 +147,10 @@ def possible_solvers( svd_solvers = {"auto", "full", "tsqr", "randomized", "covariance_eigh"} case (dc, False) if id == "dask_array_dense-1d_chunked": svd_solvers = {"tsqr", "randomized"} - case (dc, True) if "dask_array_sparse-1d_chunked-csr" in id: + case (dc, True) if ( + "dask_array_sparse-1d_chunked-csr" in id + or id == "dask_array_sparse-1d_chunked" + ): svd_solvers = {"covariance_eigh"} case (type() as dc, True) if issubclass(dc, CSBase): svd_solvers = {"arpack"} | SKLEARN_ADDITIONAL From 216b21d91312b899e939db9636d9ab20e7c29d77 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 10 Nov 2025 10:20:39 +0100 Subject: [PATCH 05/13] fix: only test `csr` --- tests/test_highly_variable_genes.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/test_highly_variable_genes.py b/tests/test_highly_variable_genes.py index 1161c56919..21171e7e2b 100644 --- a/tests/test_highly_variable_genes.py +++ b/tests/test_highly_variable_genes.py @@ -649,7 +649,12 @@ def test_seurat_v3_bad_chunking(adata, array_type, flavor): ], ) @pytest.mark.parametrize( - "array_type", [p for p in ARRAY_TYPES if "dask" not in p.id or "1d_chunked" in p.id] + "array_type", + [ + p + for p in ARRAY_TYPES + if "dask" not in p.id or ("1d_chunked" in p.id and "csr" in p.id) + ], ) @pytest.mark.parametrize("batch_key", [None, "batch"]) def test_subset_inplace_consistency(flavor, array_type, batch_key): @@ -728,7 +733,9 @@ def test_subset_inplace_consistency(flavor, array_type, batch_key): ], ) @pytest.mark.parametrize("batch_key", [None, "batch"], ids=["single", "batched"]) -@pytest.mark.parametrize("to_dask", [p for p in ARRAY_TYPES if "1d_chunked" in p.id]) +@pytest.mark.parametrize( + "to_dask", [p for p in ARRAY_TYPES if "1d_chunked" in p.id and "csr" in p.id] +) def test_dask_consistency(adata: AnnData, flavor, batch_key, to_dask): # current blob produces singularities in loess....maybe a bad sign of the data? if "seurat_v3" in flavor: From f529d952feb13cbf79f6593d2ea9ff90ec6fb5c5 Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Mon, 10 Nov 2025 11:47:32 +0100 Subject: [PATCH 06/13] Update src/scanpy/get/_aggregated.py Co-authored-by: Philipp A. --- src/scanpy/get/_aggregated.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index c223fe8bc3..7589867aa5 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -376,9 +376,7 @@ def aggregate_dask( if not isinstance(data._meta, CSBase | np.ndarray): msg = f"Got {type(data._meta)} meta in DaskArray but only csr_matrix/csr_array and ndarray are supported." raise ValueError(msg) - # i.e., if data._meta is CSR/np.ndarray, this is 1 because we use row-chunking, but otherwise 0 and column chunking - unchunked_axis = int(isinstance(data._meta, CSRBase | np.ndarray)) - chunked_axis = (unchunked_axis - 1) % 2 + chunked_axis, unchunked_axis = (0, 1) if isinstance(data._meta, CSRBase | np.ndarray) else (1, 0) if data.chunksize[unchunked_axis] != data.shape[unchunked_axis]: msg = "Feature axis must be unchunked" raise ValueError(msg) From d26a63c68fb7b0e86929a1fbc6e2f47070c2da56 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 Nov 2025 10:47:42 +0000 Subject: [PATCH 07/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scanpy/get/_aggregated.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 7589867aa5..65c8530b3a 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -376,7 +376,9 @@ def aggregate_dask( if not isinstance(data._meta, CSBase | np.ndarray): msg = f"Got {type(data._meta)} meta in DaskArray but only csr_matrix/csr_array and ndarray are supported." raise ValueError(msg) - chunked_axis, unchunked_axis = (0, 1) if isinstance(data._meta, CSRBase | np.ndarray) else (1, 0) + chunked_axis, unchunked_axis = ( + (0, 1) if isinstance(data._meta, CSRBase | np.ndarray) else (1, 0) + ) if data.chunksize[unchunked_axis] != data.shape[unchunked_axis]: msg = "Feature axis must be unchunked" raise ValueError(msg) From 7339239b98c3163d0285a1a1a0d8f15fc0bcb236 Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Mon, 10 Nov 2025 11:48:58 +0100 Subject: [PATCH 08/13] Apply suggestion from @flying-sheep Co-authored-by: Philipp A. --- src/scanpy/get/_aggregated.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 65c8530b3a..9b48d4012b 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -406,7 +406,7 @@ def aggregate_chunk_sum_or_count_nonzero( has_mean, has_var = (v in funcs for v in ["mean", "var"]) funcs_no_var_or_mean = funcs - {"var", "mean"} # aggregate each row chunk or column chunk individually, - # producing a #chunks × #categories × #features or a #categories x #chunks array, + # producing a #chunks × #categories × #features or a #categories × #chunks array, # then aggregate the per-chunk results. chunks = ( ((1,) * data.blocks.size, (len(by.categories),), data.shape[1]) From 510cec5b3ea01bc486839f425a3e61a124b450e5 Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Mon, 10 Nov 2025 11:49:05 +0100 Subject: [PATCH 09/13] Apply suggestion from @flying-sheep Co-authored-by: Philipp A. --- src/scanpy/get/_aggregated.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 9b48d4012b..03f5dd3692 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -427,8 +427,8 @@ def aggregate_chunk_sum_or_count_nonzero( ) for f in funcs_no_var_or_mean } - # If we have row chunking, we need to handle the extra axis by summing over all category x feature matrices. - # Otherwise, dask internally concatenates the #categories x #chunks arrays i.e., the column chunks are concatenated together toget a #categories x #features matrix. + # If we have row chunking, we need to handle the extra axis by summing over all category × feature matrices. + # Otherwise, dask internally concatenates the #categories × #chunks arrays i.e., the column chunks are concatenated together to get a #categories × #features matrix. if unchunked_axis == 1: for k, v in aggregated.items(): aggregated[k] = v.sum(axis=chunked_axis) From ea0b2ea2b5d1ddcf6defb19f2657932b66ede2ae Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 10 Nov 2025 12:00:36 +0100 Subject: [PATCH 10/13] fix: comment order --- src/scanpy/get/_aggregated.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 03f5dd3692..93194aea0d 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -386,10 +386,11 @@ def aggregate_dask( def aggregate_chunk_sum_or_count_nonzero( chunk: Array, *, func: Literal["count_nonzero", "sum"], block_info=None ): - # See https://docs.dask.org/en/stable/generated/dask.array.map_blocks.html - # for what is contained in `block_info`. + # only subset the mask and by if we need to i.e., + # there is chunking along the same axis as by and mask if chunked_axis == 0: - # only subset the mask and by if we need to i.e., there is chunking along the same axis as by and mask + # See https://docs.dask.org/en/stable/generated/dask.array.map_blocks.html + # for what is contained in `block_info`. subset = slice(*block_info[0]["array-location"][0]) by_subsetted = by[subset] mask_subsetted = mask[subset] if mask is not None else mask From df66b35e5c44a21c6dac1a22ed0a2ffee8508700 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 10 Nov 2025 12:08:23 +0100 Subject: [PATCH 11/13] fix: add comment --- tests/test_pca.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_pca.py b/tests/test_pca.py index f0b26d7f76..3e411573ed 100644 --- a/tests/test_pca.py +++ b/tests/test_pca.py @@ -148,6 +148,8 @@ def possible_solvers( case (dc, False) if id == "dask_array_dense-1d_chunked": svd_solvers = {"tsqr", "randomized"} case (dc, True) if ( + # See https://github.com/scverse/scanpy/blob/216b21d91312b899e939db9636d9ab20e7c29d77/src/testing/scanpy/_pytest/params.py#L88-L103 + # for why we need two checks (i.e., before and after allowing CSC matrices) "dask_array_sparse-1d_chunked-csr" in id or id == "dask_array_sparse-1d_chunked" ): From 4a8bcf7a07effe5ccff41fb8ebe3d439b0b8a137 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 10 Nov 2025 13:28:29 +0100 Subject: [PATCH 12/13] fix: use `array_type.__name__` --- src/testing/scanpy/_pytest/params.py | 54 ++++++++++++++++++---------- tests/test_highly_variable_genes.py | 4 +-- tests/test_metrics.py | 2 +- tests/test_preprocessing.py | 3 +- tests/test_qc_metrics.py | 4 +-- tests/test_utils.py | 6 ++-- 6 files changed, 44 insertions(+), 29 deletions(-) diff --git a/src/testing/scanpy/_pytest/params.py b/src/testing/scanpy/_pytest/params.py index fa56b5f7b5..be9bae9db1 100644 --- a/src/testing/scanpy/_pytest/params.py +++ b/src/testing/scanpy/_pytest/params.py @@ -29,6 +29,24 @@ reason="scipy cs{rc}_array not supported in anndata<0.11", ) +anndata_test_utils_supports_typ_kwarg = Version(version("anndata")) >= Version("0.12.6") + + +def gen_csr_csc_params_wrapper( + func: Callable, + format: Literal["csr", "csc"], + matrix_or_array: Literal["matrix", "array"], +): + def wrapper(arr): + if anndata_test_utils_supports_typ_kwarg: + return _chunked_1d( + partial(func, typ=getattr(sparse, f"{format}_{matrix_or_array}")) + )(arr) + return _chunked_1d(func)(arr) + + wrapper.__name__ = f"{func.__name__}-1d_chunked-{format}_{matrix_or_array}" + return wrapper + def param_with( at: ParameterSet, @@ -79,29 +97,29 @@ def wrapper(a: np.ndarray) -> DaskArray: ("dask", "sparse"): tuple( pytest.param( wrapper(as_sparse_dask_matrix), - marks=[needs.dask], + marks=[needs.dask, skip_csc_mark] + if skip_csc_mark is not None + else [needs.dask], id=f"dask_array_sparse{suffix}", ) - for wrapper, suffix in [ - (lambda x: x, ""), + for wrapper, suffix, skip_csc_mark in [ + (lambda x: x, "", None), *( - ((_chunked_1d, "-1d_chunked"),) - if Version(version("anndata")) < Version("0.12.5") - else ( - ( - lambda func, + ( + partial( + gen_csr_csc_params_wrapper, format=format, - matrix_or_array=matrix_or_array: _chunked_1d( - partial( - func, typ=getattr(sparse, f"{format}_{matrix_or_array}") - ) - ), - f"-1d_chunked-{format}_{matrix_or_array}", - ) - for format in ["csr", "csc"] - # TODO: use `array` as well once anndata 0.13 drops - for matrix_or_array in ["matrix"] + matrix_or_array=matrix_or_array, + ), + f"-1d_chunked-{format}_{matrix_or_array}", + pytest.mark.skipif( + not anndata_test_utils_supports_typ_kwarg and format == "csc", + reason="anndata < 0.12.6 lacked the required kwargs to enable csc matrix test utils.", + ), ) + for format in ["csr", "csc"] + # TODO: use `array` as well once anndata 0.13 drops + for matrix_or_array in ["matrix"] ), ] ), diff --git a/tests/test_highly_variable_genes.py b/tests/test_highly_variable_genes.py index 21171e7e2b..847f1ae75c 100644 --- a/tests/test_highly_variable_genes.py +++ b/tests/test_highly_variable_genes.py @@ -437,8 +437,8 @@ def test_compare_to_upstream( np.testing.assert_allclose( hvg_info["dispersions_norm"], pbmc.var["dispersions_norm"], - rtol=2e-05 if "dask" not in request.node.name else 1e-4, - atol=2e-05 if "dask" not in request.node.name else 1e-4, + rtol=2e-05 if "dask" not in array_type.__name__ else 1e-4, + atol=2e-05 if "dask" not in array_type.__name__ else 1e-4, ) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 1326ab77c4..e4b5bed022 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -105,7 +105,7 @@ def test_correctness(metric, size, expected): def test_graph_metrics_w_constant_values( request: pytest.FixtureRequest, metric, array_type ): - if "dask" in request.node.name: + if "dask" in array_type.__name__: reason = "DaskArray not yet supported" request.applymarker(pytest.mark.xfail(reason=reason)) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 1d2c2ee4c8..c29e6fd302 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -282,7 +282,6 @@ def test_sample_copy_backed_error(tmp_path): @pytest.mark.parametrize("max_value", [None, 1.0], ids=["no_clip", "clip"]) def test_scale_matrix_types( *, - request: pytest.FixtureRequest, array_type: Callable, zero_center: bool, max_value: float | None, @@ -299,7 +298,7 @@ def test_scale_matrix_types( ( warn_ctx if zero_center - and any(pat in request.node.name for pat in ("sparse", "csc", "csr")) + and any(pat in array_type.__name__ for pat in ("sparse", "csc", "csr")) else nullcontext() ), maybe_dask_process_context(), diff --git a/tests/test_qc_metrics.py b/tests/test_qc_metrics.py index bf14c9f725..fd9a4ae187 100644 --- a/tests/test_qc_metrics.py +++ b/tests/test_qc_metrics.py @@ -81,7 +81,7 @@ def test_segments_binary(): "array_type", [*ARRAY_TYPES, pytest.param(sparse.coo_matrix, id="scipy_coo")] ) def test_top_segments(request: pytest.FixtureRequest, array_type): - if "dask" in request.node.name and "1d_chunked" not in request.node.name: + if "dask" in array_type.__name__ and "1d_chunked" not in array_type.__name__: reason = "DaskArray with feature axis chunking not yet supported" request.applymarker(pytest.mark.xfail(reason=reason)) a = array_type(np.ones((300, 100))) @@ -103,7 +103,7 @@ def test_top_segments(request: pytest.FixtureRequest, array_type): "array_type", [*ARRAY_TYPES, pytest.param(sparse.coo_matrix, id="scipy_coo")] ) def test_top_proportions(request: pytest.FixtureRequest, array_type): - if "dask" in request.node.name: + if "dask" in array_type.__name__: reason = "DaskArray not yet supported" request.applymarker(pytest.mark.xfail(reason=reason)) a = array_type(np.ones((300, 100))) diff --git a/tests/test_utils.py b/tests/test_utils.py index 36e86319e2..b82deea324 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -113,14 +113,12 @@ def test_divide_by_zero(array_type): @pytest.mark.parametrize("array_type", ARRAY_TYPES_SPARSE) -def test_scale_out_with_dask_or_sparse_raises( - *, request: pytest.FixtureRequest, array_type: Callable -): +def test_scale_out_with_dask_or_sparse_raises(array_type: Callable): dividend = array_type(asarray([[0, 1.0, 2.0], [3.0, 0, 4.0]])) divisor = np.array([0.1, 0.2, 0.5]) if isinstance(dividend, DaskArray): with pytest.raises( - TypeError if "dask" in request.node.name else ValueError, + TypeError if "dask" in array_type.__name__ else ValueError, match="`out`", ): axis_mul_or_truediv(dividend, divisor, op=truediv, axis=1, out=dividend) From ffad853f4f3ae5789346c0954209806453c0b4e1 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 10 Nov 2025 15:00:19 +0100 Subject: [PATCH 13/13] chore: relnote --- docs/release-notes/3872.feat.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 docs/release-notes/3872.feat.md diff --git a/docs/release-notes/3872.feat.md b/docs/release-notes/3872.feat.md new file mode 100644 index 0000000000..6a7db9b86b --- /dev/null +++ b/docs/release-notes/3872.feat.md @@ -0,0 +1 @@ +Add in `csc`-in-{doc}`dask:index` support for {func}`scanpy.get.aggregate` {smaller}`I Gold`