diff --git a/docs/extensions/array_support.py b/docs/extensions/array_support.py new file mode 100644 index 0000000000..a56dc0de2d --- /dev/null +++ b/docs/extensions/array_support.py @@ -0,0 +1,131 @@ +"""Add `array-support` directive.""" + +from __future__ import annotations + +from itertools import groupby +from typing import TYPE_CHECKING + +from docutils import nodes +from sphinx.util.docutils import SphinxDirective + +from scanpy._utils import _docs + +if TYPE_CHECKING: + from collections.abc import Collection, Sequence + from typing import ClassVar + + from sphinx.application import Sphinx + + +class ArraySupport(SphinxDirective): + """In the scanpy-tutorials repo, this links to the canonical location (here!).""" + + required_arguments: ClassVar = 1 + optional_arguments: ClassVar = 999 + + option_spec: ClassVar = { + "except": lambda arg: arg.split(" "), + } + + def run(self) -> list[nodes.Node]: # noqa: D102 + array_types = list(_docs.parse(self.arguments, self.options.get("except", ()))) + headers = ("Array type", "supported", "… in dask :class:`~dask.array.Array`") + data: list[tuple[_docs.Inner, bool, bool]] = [] + for array_type in _docs.parse(["np", "sp"], inner=True): + dask_array_type = _docs.DaskArray(array_type) + data.append(( + array_type, + array_type in array_types, + dask_array_type in array_types, + )) + + return self._render_table(headers, data) + + def _render_table( + self, + headers: tuple[str, str, str], + data: list[tuple[_docs.Inner, bool, bool]], + ) -> list[nodes.Node]: + + colspecs = [nodes.colspec(stub=True), *(nodes.colspec() for _ in range(2))] + thead = nodes.thead( + "", + nodes.row( + "", + *( + nodes.entry("", nodes.paragraph("", "", *self.parse_inline(t)[0])) + for t in headers + ), + ), + ) + tbody = nodes.tbody() + for t, group in groupby(data, key=lambda r: type(r[0])): + group = list(group) # noqa: PLW2901 + if ( # if all sparse types have the same support, just one row + t is _docs.ScipySparse + and (support := one({s for _, s, _ in group})) is not None + and (in_dask := one({d for _, _, d in group})) is not None + ): + refs: list[nodes.Node] = [ + nodes.inline("", "scipy.sparse.{"), + *self.parse_inline(":class:`csr `")[0], + nodes.inline("", ","), + *self.parse_inline(":class:`csc `")[0], + nodes.inline("", "}_{"), + *self.parse_inline(":class:`array `")[0], + nodes.inline("", ","), + *self.parse_inline(":class:`matrix `")[0], + nodes.inline("", "}"), + ] + header = [nodes.literal("", "", *refs)] + tbody += [self._render_row(header, support=support, in_dask=in_dask)] + else: # otherwise, show them individually + tbody += [ + self._render_row( + self._render_array_type(array_type), + support=support, + in_dask=in_dask, + ) + for array_type, support, in_dask in group + ] + return [ + nodes.table( + "", + nodes.title("", "Array type support"), + nodes.tgroup("", *colspecs, thead, tbody, cols=3), + ids=["array-support"], + ) + ] + + def _render_row( + self, header: Sequence[nodes.Node], *, support: bool, in_dask: bool + ) -> nodes.Node: + cells: list[Sequence[nodes.Node]] = [ + header, + self._render_support(support), + self._render_support(in_dask), + ] + children = (nodes.entry("", nodes.paragraph("", "", *cell)) for cell in cells) + return nodes.row("", *children) + + def _render_array_type(self, array_type: _docs.ArrayType, /) -> list[nodes.Node]: + nodes_, msgs = self.parse_inline(array_type.rst()) + assert not msgs, msgs + return nodes_ + + def _render_support(self, support: bool, /) -> list[nodes.Node]: # noqa: FBT001 + return [nodes.Text("✅" if support else "❌")] + + +def one[T](arg: Collection[T]) -> T | None: + """Return the only item in `arg` or None if `arg` is not of length 1.""" + try: + [item] = arg + except ValueError: + return None + return item + + +def setup(app: Sphinx) -> None: + """App setup hook.""" + app.add_directive("array-support", ArraySupport) diff --git a/docs/release-notes/3895.docs.md b/docs/release-notes/3895.docs.md new file mode 100644 index 0000000000..bb6c119a3a --- /dev/null +++ b/docs/release-notes/3895.docs.md @@ -0,0 +1 @@ +Document array type support for most functions in {mod}`~scanpy.pp` and {mod}`~scanpy.tl` {smaller}`P Angerer` diff --git a/src/scanpy/_utils/_docs.py b/src/scanpy/_utils/_docs.py new file mode 100644 index 0000000000..8000a74bfd --- /dev/null +++ b/src/scanpy/_utils/_docs.py @@ -0,0 +1,117 @@ +"""Add `array-support` directive.""" + +from __future__ import annotations + +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass +from itertools import product +from typing import TYPE_CHECKING, overload + +if TYPE_CHECKING: + from collections.abc import Collection, Generator + from typing import Literal + + +__all__ = ["ArrayType", "DaskArray", "Numpy", "ScipySparse", "parse"] + + +class ArrayType(ABC): + def rst(self) -> str: # pragma: no cover + return f":class:`{self}`" + + @abstractmethod + def __hash__(self) -> int: ... + + +@dataclass(unsafe_hash=True, frozen=True) +class Numpy(ArrayType): + def __str__(self) -> str: # pragma: no cover + return "numpy.ndarray" + + def rst(self) -> str: # pragma: no cover + return f":class:`{self}`" + + +@dataclass(unsafe_hash=True, frozen=True) +class ScipySparse(ArrayType): + format: Literal["csr", "csc"] + container: Literal["array", "matrix"] + + def __str__(self) -> str: # pragma: no cover + return f"scipy.sparse.{self.format}_{self.container}" + + +type Inner = Numpy | ScipySparse + + +@dataclass(unsafe_hash=True, frozen=True) +class DaskArray(ArrayType): + chunk: Inner + + def __str__(self) -> str: # pragma: no cover + return f"dask.array.Array[{self.chunk}]" + + def rst(self) -> str: # pragma: no cover + return rf":class:`dask.array.Array`\ \[{self.chunk.rst()}\]" + + +@overload +def parse( + include: Collection[str], + exclude: Collection[str] = (), + *, + inner: Literal[False] = False, +) -> Generator[ArrayType]: ... +@overload +def parse( + include: Collection[str], exclude: Collection[str] = (), *, inner: Literal[True] +) -> Generator[Inner]: ... +def parse( + include: Collection[str], exclude: Collection[str] = (), *, inner: bool = False +) -> Generator[ArrayType]: + if exclude: + excluded = dict.fromkeys(parse(exclude)).keys() + yield from (t for t in parse(include) if t not in excluded) + return + + inner_includes = [i for i in include if not i.startswith("da")] + for t in include: + if ( + match := re.fullmatch(r"([^\[]+)(?:\[(.+)\])?", t) + ) is None: # pragma: no cover + msg = f"invalid {t!r}" + raise ValueError(msg) + mod, tags = match.groups("") + if mod == "da" and inner: # pragma: no cover + msg = "Can’t nest dask arrays" + raise ValueError(msg) + tags = set(re.split(r",(?![^\[]+\])", tags)) if tags else set() + yield from _parse_mod(mod, tags, inner_includes=inner_includes) + + +def _parse_mod( + mod: str, tags: set[str], *, inner_includes: Collection[str] +) -> Generator[ArrayType]: + match mod: + case "np": + if tags: # pragma: no cover + msg = f"`np` takes no tags {tags!r}" + raise ValueError(msg) + yield Numpy() + case "sp": + if tags - {"csr", "csc", "array", "matrix"}: # pragma: no cover + msg = f"invalid tags {tags!r}" + raise ValueError(msg) + for format, container in product(("csr", "csc"), ("array", "matrix")): + if tags & {"csr", "csc"} and format not in tags: + continue + if tags & {"array", "matrix"} and container not in tags: + continue + yield ScipySparse(format=format, container=container) + case "da": + for chunk in parse(tags if tags else inner_includes, inner=True): + yield DaskArray(chunk=chunk) + case _: # pragma: no cover + msg = f"invalid module {mod!r}" + raise ValueError(msg) diff --git a/src/scanpy/experimental/pp/_highly_variable_genes.py b/src/scanpy/experimental/pp/_highly_variable_genes.py index c1f08cbd9a..78a92b953d 100644 --- a/src/scanpy/experimental/pp/_highly_variable_genes.py +++ b/src/scanpy/experimental/pp/_highly_variable_genes.py @@ -320,6 +320,8 @@ def highly_variable_genes( # noqa: PLR0913 Expects raw count input. + .. array-support:: np sp + Parameters ---------- {adata} diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 93194aea0d..64b6f01f4d 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -202,6 +202,8 @@ def aggregate( # noqa: PLR0912 If none of `layer`, `obsm`, or `varm` are passed in, `X` will be used for aggregation data. + .. array-support:: np sp da + Params ------ adata diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index 3882cced68..ba526b4808 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -91,6 +91,8 @@ def neighbors( # noqa: PLR0913 in the adaption of :cite:t:`Haghverdi2016`. If `method=='jaccard'`, connectivities are computed as in PhenoGraph :cite:p:`Levine2015`. + .. array-support:: np sp + Parameters ---------- adata diff --git a/src/scanpy/preprocessing/_combat.py b/src/scanpy/preprocessing/_combat.py index 02e8323c91..c893ffb99f 100644 --- a/src/scanpy/preprocessing/_combat.py +++ b/src/scanpy/preprocessing/_combat.py @@ -149,6 +149,8 @@ def combat( # noqa: PLR0915 .. _combat.py: https://github.com/brentp/combat.py + .. array-support:: np + Parameters ---------- adata diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index f84c31a26d..6966fb592b 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -662,6 +662,9 @@ def highly_variable_genes( # noqa: PLR0913 See also `scanpy.experimental.pp._highly_variable_genes` for additional flavors (e.g. Pearson residuals). + .. array-support:: np sp da + :except: da[sp[csc]] + Parameters ---------- adata diff --git a/src/scanpy/preprocessing/_normalization.py b/src/scanpy/preprocessing/_normalization.py index 7ff8d65c82..a8ad3d1267 100644 --- a/src/scanpy/preprocessing/_normalization.py +++ b/src/scanpy/preprocessing/_normalization.py @@ -158,6 +158,8 @@ def normalize_total( # noqa: PLR0912 Similar functions are used, for example, by Seurat :cite:p:`Satija2015`, Cell Ranger :cite:p:`Zheng2017` or SPRING :cite:p:`Weinreb2017`. + .. array-support:: np sp[csr] da + .. note:: When used with a :class:`~dask.array.Array` in `adata.X`, this function will have to call functions that trigger `.compute()` on the :class:`~dask.array.Array` if `exclude_highly_expressed` diff --git a/src/scanpy/preprocessing/_pca/__init__.py b/src/scanpy/preprocessing/_pca/__init__.py index 9ca5983441..01c0856cd8 100644 --- a/src/scanpy/preprocessing/_pca/__init__.py +++ b/src/scanpy/preprocessing/_pca/__init__.py @@ -99,6 +99,9 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 .. [#dense-only] This implementation can not handle sparse chunks, try manually densifying them. .. [#densifies] This implementation densifies sparse chunks and therefore has increased memory usage. + .. array-support:: np sp da + :except: da[sp[csc]] + Parameters ---------- data diff --git a/src/scanpy/preprocessing/_qc.py b/src/scanpy/preprocessing/_qc.py index 0d3933bc13..74154fc053 100644 --- a/src/scanpy/preprocessing/_qc.py +++ b/src/scanpy/preprocessing/_qc.py @@ -225,6 +225,8 @@ def calculate_qc_metrics( Note that this method can take a while to compile on the first call. That result is then cached to disk to be used later. + .. array-support:: np sp da + Parameters ---------- {doc_adata_basic} diff --git a/src/scanpy/preprocessing/_scale.py b/src/scanpy/preprocessing/_scale.py index 3c77f38f03..0be8f02a1b 100644 --- a/src/scanpy/preprocessing/_scale.py +++ b/src/scanpy/preprocessing/_scale.py @@ -87,6 +87,8 @@ def scale[A: _Array]( all observations) are retained and (for zero_center==True) set to 0 during this operation. In the future, they might be set to NaNs. + .. array-support:: np sp da + Parameters ---------- data diff --git a/src/scanpy/preprocessing/_scrublet/__init__.py b/src/scanpy/preprocessing/_scrublet/__init__.py index 5c6d39fad4..8fa26e1ec6 100644 --- a/src/scanpy/preprocessing/_scrublet/__init__.py +++ b/src/scanpy/preprocessing/_scrublet/__init__.py @@ -73,6 +73,8 @@ def scrublet( # noqa: PLR0913 :func:`~scanpy.pp.scrublet_simulate_doublets`, and run the core scrublet function :func:`~scanpy.pp.scrublet` with ``adata_sim`` set. + .. array-support:: np sp + Parameters ---------- adata @@ -323,6 +325,8 @@ def _scrublet_call_doublets( # noqa: PLR0913 Predict cell doublets using a nearest-neighbor classifier of observed transcriptomes and simulated doublets. + .. array-support:: np sp + Parameters ---------- adata_obs diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index d0e2dac04c..2733413378 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -94,6 +94,8 @@ def filter_cells( Depending on what was thresholded (`counts` or `genes`), the array stores `n_counts` or `n_genes` per cell. + .. array-support:: np sp da + Examples -------- >>> import scanpy as sc @@ -213,6 +215,8 @@ def filter_genes( Only provide one of the optional parameters `min_counts`, `min_cells`, `max_counts`, `max_cells` per call. + .. array-support:: np sp da + Parameters ---------- data @@ -319,6 +323,8 @@ def log1p( Computes :math:`X = \log(X + 1)`, where :math:`log` denotes the natural logarithm unless a different base is given. + .. array-support:: np sp da + Parameters ---------- data @@ -681,6 +687,8 @@ def regress_out( function in R :cite:p:`Satija2015`. Note that this function tends to overcorrect in certain circumstances as described in :issue:`526`. + .. array-support:: np + Parameters ---------- adata @@ -888,6 +896,8 @@ def sample( # noqa: PLR0912 ) -> AnnData | None | tuple[np.ndarray | CSBase | DaskArray, NDArray[np.int64]]: r"""Sample observations or variables with or without replacement. + .. array-support:: np sp da + Parameters ---------- data @@ -996,6 +1006,8 @@ def downsample_counts( If `total_counts` is specified, expression matrix will be downsampled to contain at most `total_counts`. + .. array-support:: np sp[csr] + Parameters ---------- adata @@ -1045,7 +1057,7 @@ def downsample_counts( def _downsample_per_cell( - x: CSBase, + x: np.ndarray | CSBase, /, counts_per_cell: int, *, @@ -1101,7 +1113,7 @@ def _downsample_per_cell( def _downsample_total_counts( - x: CSBase, + x: np.ndarray | CSBase, /, total_counts: int, *, diff --git a/src/scanpy/tools/_dendrogram.py b/src/scanpy/tools/_dendrogram.py index 8badbfaafa..214c1a54f6 100644 --- a/src/scanpy/tools/_dendrogram.py +++ b/src/scanpy/tools/_dendrogram.py @@ -67,6 +67,8 @@ def dendrogram( # noqa: PLR0913 groups and not per cell. The correlation matrix is computed using by default pearson but other methods are available. + .. array-support:: np sp + Parameters ---------- adata diff --git a/src/scanpy/tools/_diffmap.py b/src/scanpy/tools/_diffmap.py index baaa21653b..89479dd9c0 100644 --- a/src/scanpy/tools/_diffmap.py +++ b/src/scanpy/tools/_diffmap.py @@ -34,6 +34,8 @@ def diffmap( `method=='umap'`. Differences between these options shouldn't usually be dramatic. + .. array-support:: np sp + Parameters ---------- adata diff --git a/src/scanpy/tools/_dpt.py b/src/scanpy/tools/_dpt.py index 1a654a2cd6..3ee764ab72 100644 --- a/src/scanpy/tools/_dpt.py +++ b/src/scanpy/tools/_dpt.py @@ -76,6 +76,8 @@ def dpt( you need to pass ``n_comps=10`` in :func:`~scanpy.tl.diffmap` in order to exactly reproduce previous :func:`~scanpy.tl.dpt` results. + .. array-support:: np sp + Parameters ---------- adata diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index f273577f24..6412e54b00 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -68,6 +68,9 @@ def draw_graph( # noqa: PLR0913 .. _fa2-modified: https://github.com/AminAlam/fa2_modified .. _Force-directed graph drawing: https://en.wikipedia.org/wiki/Force-directed_graph_drawing + .. only uses graph in obsp + .. array-support:: np sp + Parameters ---------- adata diff --git a/src/scanpy/tools/_embedding_density.py b/src/scanpy/tools/_embedding_density.py index ed57789bfa..b36fb9012c 100644 --- a/src/scanpy/tools/_embedding_density.py +++ b/src/scanpy/tools/_embedding_density.py @@ -59,6 +59,8 @@ def embedding_density( # noqa: PLR0912 This function was written by Sophie Tritschler and implemented into Scanpy by Malte Luecken. + .. array-support:: np + Parameters ---------- adata diff --git a/src/scanpy/tools/_ingest.py b/src/scanpy/tools/_ingest.py index 831b3b6f2a..253cb500aa 100644 --- a/src/scanpy/tools/_ingest.py +++ b/src/scanpy/tools/_ingest.py @@ -63,6 +63,8 @@ def ingest( You need to run :func:`~scanpy.pp.neighbors` on `adata_ref` before passing it. + .. array-support:: np sp + Parameters ---------- adata @@ -365,7 +367,7 @@ def _same_rep(self): return adata.obsm[self._use_rep] return adata.X - def fit(self, adata_new): + def fit(self, adata_new: AnnData) -> None: """Map `adata_new` to the same representation as `adata`. This function identifies the representation which was used to diff --git a/src/scanpy/tools/_leiden.py b/src/scanpy/tools/_leiden.py index f49efcca36..f3078ad8f5 100644 --- a/src/scanpy/tools/_leiden.py +++ b/src/scanpy/tools/_leiden.py @@ -56,6 +56,9 @@ def leiden( # noqa: PLR0912, PLR0913, PLR0915 This requires having run :func:`~scanpy.pp.neighbors` or :func:`~scanpy.external.pp.bbknn` first. + .. only uses graph in obsp + .. array-support:: np sp + Parameters ---------- adata diff --git a/src/scanpy/tools/_louvain.py b/src/scanpy/tools/_louvain.py index c2d6590467..5aaafd11f0 100644 --- a/src/scanpy/tools/_louvain.py +++ b/src/scanpy/tools/_louvain.py @@ -76,6 +76,9 @@ def louvain( # noqa: PLR0912, PLR0913, PLR0915 :func:`~scanpy.external.pp.bbknn` first, or explicitly passing a ``adjacency`` matrix. + .. only uses graph in obsp + .. array-support:: np sp + Parameters ---------- adata diff --git a/src/scanpy/tools/_marker_gene_overlap.py b/src/scanpy/tools/_marker_gene_overlap.py index 9d98538405..b362d0f6c3 100644 --- a/src/scanpy/tools/_marker_gene_overlap.py +++ b/src/scanpy/tools/_marker_gene_overlap.py @@ -81,7 +81,7 @@ def marker_gene_overlap( # noqa: PLR0912, PLR0915 adj_pval_threshold: float | None = None, key_added: str = "marker_gene_overlap", inplace: bool = False, -): +) -> pd.DataFrame: """Calculate an overlap score between data-derived marker genes and provided markers. Marker gene overlap scores can be quoted as overlap counts, overlap diff --git a/src/scanpy/tools/_paga.py b/src/scanpy/tools/_paga.py index cd4026f677..3f470ac25d 100644 --- a/src/scanpy/tools/_paga.py +++ b/src/scanpy/tools/_paga.py @@ -52,6 +52,8 @@ def paga( `init_pos='paga'` to get single-cell embeddings that are typically more faithful to the global topology. + .. array-support:: np sp + Parameters ---------- adata diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index 5e8950be5e..5529b330fa 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -526,10 +526,12 @@ def rank_genes_groups( # noqa: PLR0912, PLR0913, PLR0915 Expects logarithmized data. + .. array-support:: np sp + .. warning:: Comparing between cells leads to highly inflated p-values, - since cells are not independent observations :cite:p`Squair2021`. + since cells are not independent observations :cite:p:`Squair2021`. Especially in single-cell data, consider instead to use more appropriate methods such as combining pseudobulking with :doc:`pydeseq2:index`. :func:`decoupler.pp.pseudobulk` or :func:`scanpy.get.aggregate` can be used to aggregate samples for pseudobulking. diff --git a/src/scanpy/tools/_tsne.py b/src/scanpy/tools/_tsne.py index 25bba35263..484503cc8c 100644 --- a/src/scanpy/tools/_tsne.py +++ b/src/scanpy/tools/_tsne.py @@ -52,6 +52,8 @@ def tsne( # noqa: PLR0913 .. _multicore-tsne: https://github.com/DmitryUlyanov/Multicore-TSNE + .. array-support:: np sp + Parameters ---------- adata diff --git a/src/scanpy/tools/_umap.py b/src/scanpy/tools/_umap.py index c7a5d18173..9159574a8b 100644 --- a/src/scanpy/tools/_umap.py +++ b/src/scanpy/tools/_umap.py @@ -71,6 +71,8 @@ def umap( # noqa: PLR0913, PLR0915 .. _umap-learn: https://github.com/lmcinnes/umap + .. array-support:: np sp + Parameters ---------- adata diff --git a/tests/test_package_structure.py b/tests/test_package_structure.py index 6da3881539..6286545525 100644 --- a/tests/test_package_structure.py +++ b/tests/test_package_structure.py @@ -11,7 +11,7 @@ # CLI is locally not imported by default but on travis it is? import scanpy.cli -from scanpy._utils import descend_classes_and_funcs, import_name +from scanpy._utils import _docs, descend_classes_and_funcs, import_name if TYPE_CHECKING: from types import FunctionType @@ -167,3 +167,57 @@ def getsourcelines(obj): return getsourcelines(wrapped) return getsourcelines(obj) + + +ALL_SPS = [ + _docs.ScipySparse(fmt, c) for fmt in ("csr", "csc") for c in ("array", "matrix") +] + + +@pytest.mark.parametrize( + ("include", "exclude", "expected"), + [ + pytest.param("np", "", [_docs.Numpy()], id="np"), + pytest.param("np", "np", [], id="remove_identical"), + pytest.param( + "np da", + "", + [_docs.Numpy(), _docs.DaskArray(_docs.Numpy())], + id="dask_inherits", + ), + pytest.param( + "sp da[sp[csr]]", + [], + [ + *ALL_SPS, + *( + _docs.DaskArray(_docs.ScipySparse("csr", c)) + for c in ("array", "matrix") + ), + ], + id="include_fewer_nested", + ), + pytest.param( + "da[sp[csc]]", + [], + [_docs.DaskArray(_docs.ScipySparse("csc", c)) for c in ("array", "matrix")], + id="include_only_nested", + ), + pytest.param("da[sp[csc]]", "sp da", [], id="remove_more"), + pytest.param( + "sp da", + "sp da[sp[matrix]]", + [ + _docs.DaskArray(_docs.ScipySparse(fmt, "array")) + for fmt in ("csr", "csc") + ], + id="only_dask_subset", + ), + ], +) +def test_array_type_selector( + include: str, exclude: str, expected: list[_docs.ArrayType] +) -> None: + inc, exc = (i.split(" ") if i else [] for i in (include, exclude)) + received = list(_docs.parse(inc, exc)) + assert received == expected diff --git a/tests/test_paga.py b/tests/test_paga.py index c6e42ef653..5975f61eb4 100644 --- a/tests/test_paga.py +++ b/tests/test_paga.py @@ -26,11 +26,13 @@ pytestmark = [needs.igraph] -@pytest.fixture(scope="module") -def pbmc_session(): +@pytest.fixture(scope="module", params=["dense", "sparse"]) +def pbmc_session(request: pytest.FixtureRequest) -> sc.AnnData: pbmc = pbmc68k_reduced() - sc.tl.paga(pbmc, groups="bulk_labels") pbmc.obs["cool_feature"] = pbmc[:, "CST3"].X.squeeze().copy() + if request.param == "sparse": + pbmc.X = pbmc.raw.X.tocsr() + sc.tl.paga(pbmc, groups="bulk_labels") assert not pbmc.obs["cool_feature"].isna().all() return pbmc