Skip to content

Commit 940727f

Browse files
fix: improve variance precision (#127)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9df6b5a commit 940727f

File tree

10 files changed

+72
-15
lines changed

10 files changed

+72
-15
lines changed

.github/workflows/ci.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,16 @@ jobs:
4242
test:
4343
name: Tests
4444
needs: get-environments
45-
runs-on: ubuntu-latest
45+
runs-on: ${{ matrix.os }}
4646
strategy:
4747
matrix:
4848
env: ${{ fromJSON(needs.get-environments.outputs.envs) }}
49+
os: [ubuntu-latest]
50+
include:
51+
- env:
52+
name: hatch-test.py3.13-full
53+
python: "3.13"
54+
os: macos-latest
4955
steps:
5056
- uses: actions/checkout@v4
5157
with:

.vscode/settings.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@
1515
"source.organizeImports": "explicit",
1616
},
1717
},
18-
"python.testing.pytestArgs": ["-vv", "--color=yes"],
18+
"python.testing.pytestArgs": ["-vv", "--color=yes", "-m", "not benchmark"],
1919
"python.testing.pytestEnabled": true,
2020
}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ overrides.matrix.resolution.features = [
100100
{ if = [ "lowest" ], value = "min-reqs" }, # feature added by hatch-min-requirements
101101
]
102102
overrides.matrix.resolution.dependencies = [
103-
# TODO: move to min dep once this is fixed: https://github.com/tlambert03/hatch-min-requirements/issues/5
103+
# TODO: move to min dep once this is fixed: https://github.com/tlambert03/hatch-min-requirements/issues/11
104104
{ if = [ "lowest" ], value = "dask==2023.6.1" },
105105
]
106106

src/fast_array_utils/stats/_mean.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,4 @@ def mean_(
2626
) -> NDArray[np.number[Any]] | np.number[Any] | types.DaskArray:
2727
total = sum_(x, axis=axis, dtype=dtype)
2828
n = np.prod(x.shape) if axis is None else x.shape[axis]
29-
return total / n # type: ignore[call-overload,operator,return-value]
29+
return total / n # type: ignore[operator,return-value]

src/fast_array_utils/stats/_mean_var.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def mean_var_(
3737
mean_, var = _sparse_mean_var(x, axis=axis)
3838
else:
3939
mean_ = mean(x, axis=axis, dtype=np.float64)
40-
mean_sq = mean(power(x, 2), axis=axis, dtype=np.float64)
40+
mean_sq = mean(power(x, 2, dtype=np.float64), axis=axis) if isinstance(x, types.DaskArray) else mean(power(x, 2), axis=axis, dtype=np.float64)
4141
var = mean_sq - mean_**2
4242
if correction: # R convention == 1 (unbiased estimator)
4343
n = np.prod(x.shape) if axis is None else x.shape[axis]

src/fast_array_utils/stats/_power.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,38 +4,45 @@
44
from functools import singledispatch
55
from typing import TYPE_CHECKING
66

7+
import numpy as np
8+
79
from .. import types
810

911

1012
if TYPE_CHECKING:
1113
from typing import TypeAlias, TypeVar
1214

15+
from numpy.typing import DTypeLike
16+
1317
from fast_array_utils.typing import CpuArray, GpuArray
1418

1519
# All supported array types except for disk ones and CSDataset
1620
Array: TypeAlias = CpuArray | GpuArray | types.DaskArray
1721

1822
_Arr = TypeVar("_Arr", bound=Array)
23+
_Mat = TypeVar("_Mat", bound=types.CSBase | types.CupyCSMatrix)
1924

2025

21-
def power(x: _Arr, n: int, /) -> _Arr:
26+
def power(x: _Arr, n: int, /, dtype: DTypeLike | None = None) -> _Arr:
2227
"""Take array or matrix to a power."""
2328
# This wrapper is necessary because TypeVars can’t be used in `singledispatch` functions
24-
return _power(x, n) # type: ignore[return-value]
29+
return _power(x, n, dtype=dtype) # type: ignore[return-value]
2530

2631

2732
@singledispatch
28-
def _power(x: Array, n: int, /) -> Array:
33+
def _power(x: Array, n: int, /, dtype: DTypeLike | None = None) -> Array:
2934
if TYPE_CHECKING:
30-
assert not isinstance(x, types.DaskArray | types.CSMatrix)
31-
return x**n # type: ignore[operator]
35+
assert not isinstance(x, types.DaskArray | types.CSBase | types.CupyCSMatrix)
36+
return x**n if dtype is None else np.power(x, n, dtype=dtype) # type: ignore[operator]
3237

3338

34-
@_power.register(types.CSMatrix | types.CupyCSMatrix)
35-
def _power_cs(x: types.CSMatrix | types.CupyCSMatrix, n: int, /) -> types.CSMatrix | types.CupyCSMatrix:
36-
return x.power(n)
39+
@_power.register(types.CSBase | types.CupyCSMatrix)
40+
def _power_cs(x: _Mat, n: int, /, dtype: DTypeLike | None = None) -> _Mat:
41+
new_data = power(x.data, n, dtype=dtype)
42+
return type(x)((new_data, x.indices, x.indptr), shape=x.shape, dtype=new_data.dtype) # type: ignore[call-overload,return-value]
3743

3844

3945
@_power.register(types.DaskArray)
40-
def _power_dask(x: types.DaskArray, n: int, /) -> types.DaskArray:
41-
return x.map_blocks(lambda c: power(c, n)) # type: ignore[type-var]
46+
def _power_dask(x: types.DaskArray, n: int, /, dtype: DTypeLike | None = None) -> types.DaskArray:
47+
meta = x._meta.astype(dtype or x.dtype) # noqa: SLF001
48+
return x.map_blocks(lambda c: power(c, n, dtype=dtype), dtype=dtype, meta=meta) # type: ignore[type-var,arg-type]
278 KB
Binary file not shown.

tests/test_stats.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,21 @@
22
from __future__ import annotations
33

44
from importlib.util import find_spec
5+
from pathlib import Path
56
from typing import TYPE_CHECKING, cast
67

78
import numpy as np
89
import pytest
10+
import scipy.sparse as sps
911
from numpy.exceptions import AxisError
1012

1113
from fast_array_utils import stats, types
1214
from testing.fast_array_utils import SUPPORTED_TYPES, Flags
1315

1416

17+
DATA_DIR = Path(__file__).parent / "data"
18+
19+
1520
if TYPE_CHECKING:
1621
from collections.abc import Callable
1722
from typing import Any, Literal, Protocol, TypeAlias
@@ -126,6 +131,21 @@ def to_np_dense_checked(
126131
return stat
127132

128133

134+
@pytest.fixture(scope="session")
135+
def pbmc64k_reduced_raw() -> sps.csr_array[np.float32]:
136+
"""Scanpy’s pbmc68k_reduced raw data.
137+
138+
Data was created using:
139+
>>> if not find_spec("scanpy"):
140+
... pytest.skip()
141+
>>> import scanpy as sc
142+
>>> import scipy.sparse as sps
143+
>>> arr = sps.csr_array(sc.datasets.pbmc68k_reduced().raw.X)
144+
>>> sps.save_npz("pbmc68k_reduced_raw_csr.npz", arr)
145+
"""
146+
return cast("sps.csr_array[np.float32]", sps.load_npz(DATA_DIR / "pbmc68k_reduced_raw_csr.npz"))
147+
148+
129149
@pytest.mark.array_type(skip={*ATS_SPARSE_DS, Flags.Matrix})
130150
@pytest.mark.parametrize("func", STAT_FUNCS)
131151
@pytest.mark.parametrize(("ndim", "axis"), [(1, 0), (2, 3), (2, -1)], ids=["1d-ax0", "2d-ax3", "2d-axneg"])
@@ -273,6 +293,23 @@ def test_mean_var_sparse_32(array_type: ArrayType[types.CSArray]) -> None:
273293
assert resid_fau < resid_skl
274294

275295

296+
@pytest.mark.array_type({at for at in SUPPORTED_TYPES if at.flags & Flags.Sparse and at.flags & Flags.Dask})
297+
def test_mean_var_pbmc_dask(array_type: ArrayType[types.DaskArray], pbmc64k_reduced_raw: sps.csr_array[np.float32]) -> None:
298+
"""Test float32 precision for bigger data.
299+
300+
This test is flaky for sparse-in-dask for some reason.
301+
"""
302+
mat = pbmc64k_reduced_raw
303+
arr = array_type(mat)
304+
305+
mean_mat, var_mat = stats.mean_var(mat, axis=0, correction=1)
306+
mean_arr, var_arr = (to_np_dense_checked(a, 0, arr) for a in stats.mean_var(arr, axis=0, correction=1))
307+
308+
rtol = 1.0e-5 if array_type.flags & Flags.Gpu else 1.0e-7
309+
np.testing.assert_allclose(mean_arr, mean_mat, rtol=rtol)
310+
np.testing.assert_allclose(var_arr, var_mat, rtol=rtol)
311+
312+
276313
@pytest.mark.array_type(skip={Flags.Disk, *ATS_CUPY_SPARSE})
277314
@pytest.mark.parametrize(
278315
("axis", "expected"),

typings/cupy/_core/core.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ class ndarray:
2929
def __power__(self, other: int) -> Self: ...
3030

3131
# methods
32+
def astype(
33+
self, dtype: DTypeLike | None, order: Literal["C", "F", "A", "K"] = "K", casting: None = None, subok: None = None, copy: bool = True
34+
) -> Self: ...
3235
@property
3336
def T(self) -> Self: ... # noqa: N802
3437
@overload

typings/cupyx/scipy/sparse/_compressed.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ from ._base import spmatrix
88

99
class _compressed_sparse_matrix(spmatrix):
1010
format: Literal["csr", "csc"]
11+
data: ndarray
12+
indices: ndarray
13+
indptr: ndarray
1114

1215
@overload
1316
def __init__(self, arg1: ndarray | spmatrix) -> None: ...
@@ -19,5 +22,6 @@ class _compressed_sparse_matrix(spmatrix):
1922
def __init__(self, arg1: tuple[ndarray, ndarray, ndarray], shape: tuple[int, int] | None = None) -> None: ...
2023

2124
# methods
25+
def astype(self, dtype: DTypeLike | None) -> Self: ...
2226
def power(self, n: int, dtype: DTypeLike | None = None) -> Self: ...
2327
def sum(self, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, out: Self | None = None) -> ndarray: ...

0 commit comments

Comments
 (0)