Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
584d3a9
chore: add `csc`-in-`dask` tests
ilan-gold Nov 3, 2025
a21a880
feat: support `csc` in `dask` arrays in `get.aggregate`
ilan-gold Nov 3, 2025
73a9735
fix: handle `anndata` testing utils versions
ilan-gold Nov 6, 2025
df5698e
fix: recognize old chunking args
ilan-gold Nov 7, 2025
6882c44
Merge branch 'main' into ig/csc_dask_tests
ilan-gold Nov 7, 2025
45cdcb9
Merge branch 'ig/csc_dask_tests' into ig/aggregate_csc
flying-sheep Nov 7, 2025
57e2af6
Merge branch 'main' into ig/csc_dask_tests
ilan-gold Nov 10, 2025
7cd39fc
Merge branch 'ig/csc_dask_tests' of github.com:scverse/scanpy into ig…
ilan-gold Nov 10, 2025
216b21d
fix: only test `csr`
ilan-gold Nov 10, 2025
9655e0d
Merge branch 'ig/csc_dask_tests' into ig/aggregate_csc
ilan-gold Nov 10, 2025
f529d95
Update src/scanpy/get/_aggregated.py
ilan-gold Nov 10, 2025
d26a63c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2025
7339239
Apply suggestion from @flying-sheep
ilan-gold Nov 10, 2025
510cec5
Apply suggestion from @flying-sheep
ilan-gold Nov 10, 2025
ea0b2ea
fix: comment order
ilan-gold Nov 10, 2025
df66b35
fix: add comment
ilan-gold Nov 10, 2025
4a8bcf7
fix: use `array_type.__name__`
ilan-gold Nov 10, 2025
ebb3d92
Merge branch 'main' into ig/csc_dask_tests
ilan-gold Nov 10, 2025
de1165c
Merge branch 'ig/csc_dask_tests' into ig/aggregate_csc
ilan-gold Nov 10, 2025
9a4b723
Merge branch 'main' into ig/aggregate_csc
ilan-gold Nov 10, 2025
ffad853
chore: relnote
ilan-gold Nov 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/3872.feat.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add in `csc`-in-{doc}`dask:index` support for {func}`scanpy.get.aggregate` {smaller}`I Gold`
50 changes: 33 additions & 17 deletions src/scanpy/get/_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from scipy import sparse
from sklearn.utils.sparsefuncs import csc_median_axis_0

from scanpy._compat import CSBase, CSCBase, CSRBase, DaskArray
from scanpy._compat import CSBase, CSRBase, DaskArray

from .._utils import _resolve_axis, get_literal_vals
from .get import _check_mask
Expand Down Expand Up @@ -357,9 +357,6 @@ def aggregate_dask_mean_var(
# TODO: If we don't compute here, the results are not deterministic under the process cluster for sparse.
if isinstance(data._meta, CSRBase):
sq_mean = sq_mean.compute()
elif isinstance(data._meta, CSCBase): # pragma: no-cover
msg = "Cannot handle CSC matrices as dask meta."
raise ValueError(msg)
var = sq_mean - fau_power(mean, 2)
if dof != 0:
group_counts = np.bincount(by.codes)
Expand All @@ -376,47 +373,66 @@ def aggregate_dask(
mask: NDArray[np.bool_] | None = None,
dof: int = 1,
) -> dict[AggType, DaskArray]:
if not isinstance(data._meta, CSRBase | np.ndarray):
if not isinstance(data._meta, CSBase | np.ndarray):
msg = f"Got {type(data._meta)} meta in DaskArray but only csr_matrix/csr_array and ndarray are supported."
raise ValueError(msg)
if data.chunksize[1] != data.shape[1]:
chunked_axis, unchunked_axis = (
(0, 1) if isinstance(data._meta, CSRBase | np.ndarray) else (1, 0)
)
if data.chunksize[unchunked_axis] != data.shape[unchunked_axis]:
msg = "Feature axis must be unchunked"
raise ValueError(msg)

def aggregate_chunk_sum_or_count_nonzero(
chunk: Array, *, func: Literal["count_nonzero", "sum"], block_info=None
):
# See https://docs.dask.org/en/stable/generated/dask.array.map_blocks.html
# for what is contained in `block_info`.
subset = slice(*block_info[0]["array-location"][0])
by_subsetted = by[subset]
mask_subsetted = mask[subset] if mask is not None else mask
# only subset the mask and by if we need to i.e.,
# there is chunking along the same axis as by and mask
if chunked_axis == 0:
# See https://docs.dask.org/en/stable/generated/dask.array.map_blocks.html
# for what is contained in `block_info`.
subset = slice(*block_info[0]["array-location"][0])
by_subsetted = by[subset]
mask_subsetted = mask[subset] if mask is not None else mask
else:
by_subsetted = by
mask_subsetted = mask
res = _aggregate(chunk, by_subsetted, func, mask=mask_subsetted, dof=dof)[func]
return res[None, :]
return res[None, :] if unchunked_axis == 1 else res

funcs = set([func] if isinstance(func, str) else func)
if "median" in funcs:
msg = "Dask median calculation not supported. If you want a median-of-medians calculation, please open an issue."
raise NotImplementedError(msg)
has_mean, has_var = (v in funcs for v in ["mean", "var"])
funcs_no_var_or_mean = funcs - {"var", "mean"}
# aggregate each row chunk individually,
# producing a #chunks × #categories × #features array,
# aggregate each row chunk or column chunk individually,
# producing a #chunks × #categories × #features or a #categories × #chunks array,
# then aggregate the per-chunk results.
chunks = (
((1,) * data.blocks.size, (len(by.categories),), data.shape[1])
if unchunked_axis == 1
else (len(by.categories), data.chunks[1])
)
aggregated = {
f: data.map_blocks(
partial(aggregate_chunk_sum_or_count_nonzero, func=func),
new_axis=(1,),
chunks=((1,) * data.blocks.size, (len(by.categories),), (data.shape[1],)),
new_axis=(1,) if unchunked_axis == 1 else None,
chunks=chunks,
meta=np.array(
[],
dtype=np.float64
if func not in get_args(ConstantDtypeAgg)
else data.dtype, # TODO: figure out best dtype for aggs like sum where dtype can change from original
),
).sum(axis=0)
)
for f in funcs_no_var_or_mean
}
# If we have row chunking, we need to handle the extra axis by summing over all category × feature matrices.
# Otherwise, dask internally concatenates the #categories × #chunks arrays i.e., the column chunks are concatenated together to get a #categories × #features matrix.
if unchunked_axis == 1:
for k, v in aggregated.items():
aggregated[k] = v.sum(axis=chunked_axis)
if has_var:
aggredated_mean_var = aggregate_dask_mean_var(data, by, mask=mask, dof=dof)
aggregated["var"] = aggredated_mean_var["var"]
Expand Down
3 changes: 0 additions & 3 deletions tests/test_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
not in {
"dask_array_dense",
"dask_array_sparse",
"dask_array_sparse-1d_chunked-csc_array",
"dask_array_sparse-1d_chunked-csc_matrix",
}
]

Expand Down Expand Up @@ -246,7 +244,6 @@ def to_csc(x: CSRBase):
@pytest.mark.parametrize(
("func", "error_msg"),
[
pytest.param(to_csc, r"only csr_matrix", id="csc"),
pytest.param(
to_bad_chunking, r"Feature axis must be unchunked", id="bad_chunking"
),
Expand Down
Loading