Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import os
import shutil
import sys
from datetime import datetime
from functools import partial
Expand All @@ -13,6 +14,7 @@
import matplotlib # noqa
from docutils import nodes
from packaging.version import Version
from sphinxcontrib.katex import NODEJS_BINARY

# Don’t use tkinter agg when importing scanpy → … → matplotlib
matplotlib.use("agg")
Expand Down Expand Up @@ -52,7 +54,6 @@
bibtex_bibfiles = ["references.bib"]
bibtex_reference_style = "author_year"


# default settings
templates_path = ["_templates"]
master_doc = "index"
Expand All @@ -73,10 +74,10 @@
"sphinx.ext.intersphinx",
"sphinx.ext.doctest",
"sphinx.ext.coverage",
"sphinx.ext.mathjax",
"sphinx.ext.napoleon",
"sphinx.ext.autosummary",
"sphinxcontrib.bibtex",
"sphinxcontrib.katex",
"matplotlib.sphinxext.plot_directive",
"sphinx_autodoc_typehints", # needs to be after napoleon
"git_ref", # needs to be before scanpydoc.rtd_github_links
Expand Down Expand Up @@ -129,6 +130,8 @@
pygments_style = "default"
pygments_dark_style = "native"

katex_prerender = shutil.which(NODEJS_BINARY) is not None

intersphinx_mapping = dict(
anndata=("https://anndata.readthedocs.io/en/stable/", None),
bbknn=("https://bbknn.readthedocs.io/en/latest/", None),
Expand Down
1 change: 1 addition & 0 deletions docs/release-notes/3898.feat.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow specifying graphs in {mod}`scanpy.metrics` functions {smaller}`P Angerer`
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ doc = [
"nbsphinx>=0.9",
"ipython>=7.20", # for nbsphinx code highlighting
"sphinxcontrib-bibtex",
"sphinxcontrib-katex",
# TODO: remove necessity for being able to import doc-linked classes
"scanpy[paga,dask-ml,leiden]",
"sam-algorithm",
Expand Down
22 changes: 14 additions & 8 deletions src/scanpy/metrics/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pandas as pd

from .._compat import CSRBase, DaskArray, SpBase, fullname, warn
from .._utils import NeighborsView

if TYPE_CHECKING:
from typing import NoReturn
Expand Down Expand Up @@ -65,17 +66,22 @@ def __call__(self) -> np.ndarray:
raise NotImplementedError(msg)


def _get_graph(adata: AnnData, *, use_graph: str | None = None) -> CSRBase:
def _get_graph(
adata: AnnData,
*,
use_graph: str | None = None,
neighbors_key: str | None = None,
) -> CSRBase:
if use_graph is not None:
raise NotImplementedError()
# Fix for anndata<0.7
if hasattr(adata, "obsp") and "connectivities" in adata.obsp:
return adata.obsp["connectivities"]
elif "neighbors" in adata.uns:
return adata.uns["neighbors"]["connectivities"]
else:
if neighbors_key is not None:
msg = "Cannot specify both `use_graph` and `neighbors_key`."
raise TypeError(msg)
return adata.obsp[use_graph]
nv = NeighborsView(adata, neighbors_key)
if "connectivities" not in nv:
msg = "Must run neighbors first."
raise ValueError(msg)
return nv["connectivities"]


@overload
Expand Down
22 changes: 14 additions & 8 deletions src/scanpy/metrics/_gearys_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import numpy as np

from .._compat import CSRBase, njit
from .._utils import _doc_params
from ..get import _get_obs_rep
from ..neighbors._doc import doc_neighbors_key
from ._common import _get_graph, _SparseMetric

if TYPE_CHECKING:
Expand All @@ -20,12 +22,14 @@


@singledispatch
@_doc_params(neighbors_key=doc_neighbors_key)
def gearys_c(
adata_or_graph: AnnData | CSRBase,
/,
vals: _Vals | None = None,
*,
use_graph: str | None = None,
neighbors_key: str | None = None,
layer: str | None = None,
obsm: str | None = None,
obsp: str | None = None,
Expand All @@ -42,11 +46,11 @@ def gearys_c(
.. math::

C =
\frac{
(N - 1)\sum_{i,j} w_{i,j} (x_i - x_j)^2
}{
2W \sum_i (x_i - \bar{x})^2
}
\frac{{
(N - 1)\sum_{{i,j}} w_{{i,j}} (x_i - x_j)^2
}}{{
2W \sum_i (x_i - \bar{{x}})^2
}}

Params
------
Expand All @@ -60,8 +64,10 @@ def gearys_c(
object by using key word arguments: `layer`, `obsm`, `obsp`, or
`use_raw`.
use_graph
Key to use for graph in anndata object. If not provided, default
neighbors connectivities will be used instead.
Key to use for graph in anndata object.
If not provided, default neighbors connectivities will be used instead.
(See ``neighbors_key`` below.)
{neighbors_key}
layer
Key for `adata.layers` to choose `vals`.
obsm
Expand Down Expand Up @@ -96,7 +102,7 @@ def gearys_c(

"""
adata = cast("AnnData", adata_or_graph)
g = _get_graph(adata, use_graph=use_graph)
g = _get_graph(adata, use_graph=use_graph, neighbors_key=neighbors_key)
if vals is None:
vals = _get_obs_rep(adata, use_raw=use_raw, layer=layer, obsm=obsm, obsp=obsp).T
return gearys_c(g, vals)
Expand Down
22 changes: 14 additions & 8 deletions src/scanpy/metrics/_morans_i.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import numpy as np

from .._compat import CSRBase, njit
from .._utils import _doc_params
from ..get import _get_obs_rep
from ..neighbors._doc import doc_neighbors_key
from ._common import _get_graph, _SparseMetric

if TYPE_CHECKING:
Expand All @@ -20,12 +22,14 @@


@singledispatch
@_doc_params(neighbors_key=doc_neighbors_key)
def morans_i(
adata_or_graph: AnnData | CSRBase,
/,
vals: _Vals | None = None,
*,
use_graph: str | None = None,
neighbors_key: str | None = None,
layer: str | None = None,
obsm: str | None = None,
obsp: str | None = None,
Expand All @@ -40,11 +44,11 @@ def morans_i(
.. math::

I =
\frac{
N \sum_{i, j} w_{i, j} z_{i} z_{j}
}{
S_{0} \sum_{i} z_{i}^{2}
}
\frac{{
N \sum_{{i,j}} w_{{i,j}} z_{{i}} z_{{j}}
}}{{
S_{{0}} \sum_{{i}} z_{{i}}^{{2}}
}}

Params
------
Expand All @@ -58,8 +62,10 @@ def morans_i(
object by using key word arguments: `layer`, `obsm`, `obsp`, or
`use_raw`.
use_graph
Key to use for graph in anndata object. If not provided, default
neighbors connectivities will be used instead.
Key to use for graph in anndata object.
If not provided, default neighbors connectivities will be used instead.
(See ``neighbors_key`` below.)
{neighbors_key}
layer
Key for `adata.layers` to choose `vals`.
obsm
Expand Down Expand Up @@ -94,7 +100,7 @@ def morans_i(

"""
adata = cast("AnnData", adata_or_graph)
g = _get_graph(adata, use_graph=use_graph)
g = _get_graph(adata, use_graph=use_graph, neighbors_key=neighbors_key)
if vals is None:
vals = _get_obs_rep(adata, use_raw=use_raw, layer=layer, obsm=obsm, obsp=obsp).T
return morans_i(g, vals)
Expand Down
9 changes: 9 additions & 0 deletions src/scanpy/neighbors/_doc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from __future__ import annotations

doc_neighbors_key = """\
neighbors_key
Where to look for neighbors connectivities.
If not specified, this retrieves ``.obsp['connectivities']`` for connectivities
(default storage place for :func:`~scanpy.pp.neighbors`).
If specified, this retrieves
``.obsp[.uns[neighbors_key]['connectivities_key']]`` for connectivities.
"""

doc_use_rep = """\
use_rep
Use the indicated representation. `'X'` or any key for `.obsm` is valid.
Expand Down
11 changes: 4 additions & 7 deletions src/scanpy/plotting/_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

from ..neighbors._doc import doc_neighbors_key

doc_adata_color_etc = """\
adata
Annotated data matrix.
Expand All @@ -22,19 +24,14 @@
takes precedence over `use_raw`.\
"""

doc_edges_arrows = """\
doc_edges_arrows = f"""\
edges
Show edges.
edges_width
Width of edges.
edges_color
Color of edges. See :func:`~networkx.drawing.nx_pylab.draw_networkx_edges`.
neighbors_key
Where to look for neighbors connectivities.
If not specified, this looks .obsp['connectivities'] for connectivities
(default storage place for pp.neighbors).
If specified, this looks
.obsp[.uns[neighbors_key]['connectivities_key']] for connectivities.
{doc_neighbors_key}
arrows
Show arrows (deprecated in favour of `scvelo.pl.velocity_embedding`).
arrows_kwds
Expand Down
4 changes: 0 additions & 4 deletions src/scanpy/preprocessing/_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,10 +363,6 @@ def _highly_variable_genes_single_batch(
if n_removed:
x = x[:, filt].copy()

if hasattr(x, "_view_args"): # AnnData array view
# For compatibility with anndata<0.9
x = x.copy() # Doesn't actually copy memory, just removes View class wrapper

if flavor == "seurat":
x = x.copy()
if (base := adata.uns.get("log1p", {}).get("base")) is not None:
Expand Down
61 changes: 51 additions & 10 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pandas as pd
import pytest
import threadpoolctl
from anndata import AnnData
from scipy import sparse

import scanpy as sc
Expand Down Expand Up @@ -79,10 +80,12 @@ def test_consistency(metric) -> None:
pytest.param(sc.metrics.morans_i, 50, 1.0, id="morans_i"),
],
)
def test_correctness(metric, size, expected):
def test_correctness(metric, size, expected) -> None:
rng = np.random.default_rng()

# Test case with perfectly seperated groups
connected = np.zeros(100)
connected[np.random.choice(100, size=size, replace=False)] = 1
connected[rng.choice(100, size=size, replace=False)] = 1
graph = np.zeros((100, 100))
graph[np.ix_(connected.astype(bool), connected.astype(bool))] = 1
graph[np.ix_(~connected.astype(bool), ~connected.astype(bool))] = 1
Expand All @@ -93,9 +96,6 @@ def test_correctness(metric, size, expected):
metric(graph, connected),
metric(graph, sparse.csr_matrix(connected)), # noqa: TID251
)
# Checking that obsp works
adata = sc.AnnData(sparse.csr_matrix((100, 100)), obsp={"connectivities": graph}) # noqa: TID251
np.testing.assert_equal(metric(adata, vals=connected), expected)


@pytest.mark.usefixtures("_threading")
Expand All @@ -104,18 +104,20 @@ def test_correctness(metric, size, expected):
)
def test_graph_metrics_w_constant_values(
request: pytest.FixtureRequest, metric, array_type
):
) -> None:
if "dask" in array_type.__name__:
reason = "DaskArray not yet supported"
request.applymarker(pytest.mark.xfail(reason=reason))

rng = np.random.default_rng()

# https://github.com/scverse/scanpy/issues/1806
pbmc = pbmc68k_reduced()
x_t = pbmc.raw.X.T.copy()
g = pbmc.obsp["connectivities"].copy()
equality_check = partial(np.testing.assert_allclose, atol=1e-11)

const_inds = np.random.choice(x_t.shape[0], 10, replace=False)
const_inds = rng.choice(x_t.shape[0], 10, replace=False)
with warnings.catch_warnings():
warnings.simplefilter("ignore", sparse.SparseEfficiencyWarning)
x_t_zero_vals = x_t.copy()
Expand Down Expand Up @@ -145,6 +147,43 @@ def test_graph_metrics_w_constant_values(
equality_check(results_full[non_const_mask], results_const_zeros[non_const_mask])


@pytest.mark.parametrize(
("neigh_params", "metric_params"),
[
pytest.param(
dict(key_added="foo"), dict(use_graph="foo_connectivities"), id="use_graph"
),
pytest.param(
dict(key_added="bar"), dict(neighbors_key="bar"), id="neighbors_key"
),
],
)
def test_metrics_graph_params(metric, neigh_params, metric_params) -> None:
rng = np.random.default_rng()
adata = AnnData(rng.normal(size=(10, 20)))
sc.pp.neighbors(adata, **neigh_params)
if "use_graph" in metric_params: # make sure no extra stuff is there
adata = AnnData(adata.X, obsp=adata.obsp)
metric(adata, **metric_params)


@pytest.mark.parametrize(
("params", "err_cls", "pattern"),
[
pytest.param(
dict(use_graph="foo", neighbors_key="bar"), TypeError, r"both", id="both"
),
pytest.param(dict(use_graph="foo"), KeyError, r"foo", id="no_graph"),
pytest.param(dict(neighbors_key="bar"), KeyError, r"bar", id="no_key"),
pytest.param({}, KeyError, r"neighbors.*uns", id="nothing"),
],
)
def test_metrics_graph_params_errors(metric, params, err_cls, pattern) -> None:
adata = AnnData(shape=(10, 20))
with pytest.raises(err_cls, match=pattern):
metric(adata, **params)


def test_confusion_matrix():
mtx = sc.metrics.confusion_matrix(["a", "b"], ["c", "d"], normalize=False)
assert mtx.loc["a", "c"] == 1
Expand Down Expand Up @@ -184,10 +223,12 @@ def test_confusion_matrix_randomized() -> None:
)


def test_confusion_matrix_api():
def test_confusion_matrix_api() -> None:
rng = np.random.default_rng()

data = pd.DataFrame({
"a": np.random.randint(5, size=100),
"b": np.random.randint(5, size=100),
"a": rng.integers(5, size=100),
"b": rng.integers(5, size=100),
})
expected = sc.metrics.confusion_matrix(data["a"], data["b"])

Expand Down
Loading