Skip to content

Commit 9761bee

Browse files
Add mean_var (#46)
Co-authored-by: Ilan Gold <[email protected]>
1 parent 2c7a94a commit 9761bee

File tree

8 files changed

+263
-13
lines changed

8 files changed

+263
-13
lines changed

src/fast_array_utils/stats/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
from ._is_constant import is_constant
77
from ._mean import mean
8+
from ._mean_var import mean_var
89
from ._sum import sum
910

1011

11-
__all__ = ["is_constant", "mean", "sum"]
12+
__all__ = ["is_constant", "mean", "mean_var", "sum"]

src/fast_array_utils/stats/_mean.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: MPL-2.0
22
from __future__ import annotations
33

4-
from typing import TYPE_CHECKING
4+
from typing import TYPE_CHECKING, overload
55

66
import numpy as np
77

@@ -13,19 +13,34 @@
1313

1414
from numpy._typing._array_like import _ArrayLikeFloat_co as ArrayLike
1515
from numpy.typing import DTypeLike, NDArray
16+
from optype.numpy import ToDType
1617

1718
from .. import types
1819

19-
# all supported types except OutOfCoreDataset (TODO)
20-
Array = (
20+
# all supported types except Dask and OutOfCoreDataset (TODO)
21+
NonDaskArray = (
2122
NDArray[Any]
2223
| types.CSBase
2324
| types.H5Dataset
2425
| types.ZarrArray
2526
| types.CupyArray
2627
| types.CupySparseMatrix
27-
| types.DaskArray
2828
)
29+
Array = NonDaskArray | types.DaskArray
30+
31+
32+
@overload
33+
def mean(
34+
x: ArrayLike | NonDaskArray, /, *, axis: Literal[None] = None, dtype: DTypeLike | None = None
35+
) -> np.number[Any]: ...
36+
@overload
37+
def mean(
38+
x: ArrayLike | NonDaskArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None
39+
) -> NDArray[np.number[Any]]: ...
40+
@overload
41+
def mean(
42+
x: types.DaskArray, /, *, axis: Literal[0, 1], dtype: ToDType[Any] | None = None
43+
) -> types.DaskArray: ...
2944

3045

3146
def mean(
@@ -34,7 +49,7 @@ def mean(
3449
*,
3550
axis: Literal[0, 1, None] = None,
3651
dtype: DTypeLike | None = None,
37-
) -> NDArray[Any] | types.DaskArray:
52+
) -> NDArray[np.number[Any]] | np.number[Any] | types.DaskArray:
3853
"""Mean over both or one axis.
3954
4055
Returns
@@ -49,7 +64,7 @@ def mean(
4964
if not hasattr(x, "shape"):
5065
raise NotImplementedError # TODO(flying-sheep): infer shape # noqa: TD003
5166
if TYPE_CHECKING:
52-
assert isinstance(x, Array)
67+
assert isinstance(x, Array) # type:ignore[unused-ignore]
5368
total = sum_(x, axis=axis, dtype=dtype)
5469
n = np.prod(x.shape) if axis is None else x.shape[axis]
5570
return total / n
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# SPDX-License-Identifier: MPL-2.0
2+
from __future__ import annotations
3+
4+
from typing import TYPE_CHECKING, no_type_check, overload
5+
6+
import numba
7+
import numpy as np
8+
9+
from .. import types
10+
from ._mean import mean
11+
from ._power import power
12+
13+
14+
if TYPE_CHECKING:
15+
from typing import Any, Literal
16+
17+
from numpy.typing import NDArray
18+
19+
MemArray = NDArray[Any] | types.CSBase | types.CupyArray | types.CupySparseMatrix
20+
21+
22+
__all__ = ["mean_var"]
23+
24+
25+
@overload
26+
def mean_var(
27+
x: MemArray, /, *, axis: Literal[None] = None, correction: int = 0
28+
) -> tuple[NDArray[np.float64], NDArray[np.float64]]: ...
29+
@overload
30+
def mean_var(
31+
x: MemArray, /, *, axis: Literal[0, 1], correction: int = 0
32+
) -> tuple[np.float64, np.float64]: ...
33+
@overload
34+
def mean_var(
35+
x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, correction: int = 0
36+
) -> tuple[types.DaskArray, types.DaskArray]: ...
37+
38+
39+
@no_type_check # mypy is extremely confused
40+
def mean_var(
41+
x: MemArray | types.DaskArray,
42+
/,
43+
*,
44+
axis: Literal[0, 1, None] = None,
45+
correction: int = 0,
46+
) -> (
47+
tuple[NDArray[np.float64], NDArray[np.float64]]
48+
| tuple[np.float64, np.float64]
49+
| tuple[types.DaskArray, types.DaskArray]
50+
):
51+
if axis is not None and isinstance(x, types.CSBase):
52+
mean_, var = _sparse_mean_var(x, axis=axis)
53+
else:
54+
mean_ = mean(x, axis=axis, dtype=np.float64)
55+
mean_sq = mean(power(x, 2), axis=axis, dtype=np.float64)
56+
var = mean_sq - mean_**2
57+
if correction: # R convention == 1 (unbiased estimator)
58+
n = np.prod(x.shape) if axis is None else x.shape[axis]
59+
if n != 1:
60+
var *= n / (n - correction)
61+
return mean_, var
62+
63+
64+
def _sparse_mean_var(
65+
mtx: types.CSBase, /, *, axis: Literal[0, 1]
66+
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
67+
"""Calculate means and variances for each row or column of a sparse matrix.
68+
69+
This code and internal functions are based on sklearns `sparsefuncs.mean_variance_axis`.
70+
71+
Modifications:
72+
- allow deciding on the output type,
73+
which can increase accuracy when calculating the mean and variance of 32bit floats.
74+
- Doesn't currently implement support for null values, but could.
75+
- Uses numba instead of Cython
76+
"""
77+
assert axis in (0, 1)
78+
if mtx.format == "csr":
79+
ax_minor = 1
80+
shape = mtx.shape
81+
elif mtx.format == "csc":
82+
ax_minor = 0
83+
shape = mtx.shape[::-1]
84+
else:
85+
msg = "This function only works on sparse csr and csc matrices"
86+
raise TypeError(msg)
87+
if len(shape) == 1:
88+
msg = "array must have 2 dimensions"
89+
raise TypeError(msg)
90+
f = sparse_mean_var_major_axis if axis == ax_minor else sparse_mean_var_minor_axis
91+
return f(
92+
mtx.data,
93+
mtx.indptr,
94+
mtx.indices,
95+
major_len=shape[0],
96+
minor_len=shape[1],
97+
n_threads=numba.get_num_threads(),
98+
)
99+
100+
101+
@numba.njit
102+
def sparse_mean_var_minor_axis(
103+
data: NDArray[np.number[Any]],
104+
indptr: NDArray[np.integer[Any]],
105+
indices: NDArray[np.integer[Any]],
106+
*,
107+
major_len: int,
108+
minor_len: int,
109+
n_threads: int,
110+
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
111+
"""Compute mean and variance along the minor axis of a compressed sparse matrix."""
112+
rows = len(indptr) - 1
113+
sums = np.zeros((n_threads, minor_len))
114+
squared_sums = np.zeros((n_threads, minor_len))
115+
means = np.zeros(minor_len)
116+
variances = np.zeros(minor_len)
117+
for i in numba.prange(n_threads):
118+
for r in range(i, rows, n_threads):
119+
for j in range(indptr[r], indptr[r + 1]):
120+
minor_index = indices[j]
121+
if minor_index >= minor_len:
122+
continue
123+
value = data[j]
124+
sums[i, minor_index] += value
125+
squared_sums[i, minor_index] += value * value
126+
for c in numba.prange(minor_len):
127+
sum = sums[:, c].sum()
128+
means[c] = sum / major_len
129+
variances[c] = squared_sums[:, c].sum() / major_len - (sum / major_len) ** 2
130+
return means, variances
131+
132+
133+
@numba.njit
134+
def sparse_mean_var_major_axis(
135+
data: NDArray[np.number[Any]],
136+
indptr: NDArray[np.integer[Any]],
137+
indices: NDArray[np.integer[Any]], # noqa: ARG001
138+
*,
139+
major_len: int,
140+
minor_len: int,
141+
n_threads: int,
142+
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
143+
"""Compute means and variances along the major axis of a compressed sparse matrix."""
144+
rows = len(indptr) - 1
145+
means = np.zeros(major_len)
146+
variances = np.zeros_like(means)
147+
148+
for i in numba.prange(n_threads):
149+
for r in range(i, rows, n_threads):
150+
sum_major = np.float64(0.0)
151+
squared_sum_minor = np.float64(0.0)
152+
for j in range(indptr[r], indptr[r + 1]):
153+
value = np.float64(data[j])
154+
sum_major += value
155+
squared_sum_minor += value * value
156+
means[r] = sum_major
157+
variances[r] = squared_sum_minor
158+
for c in numba.prange(major_len):
159+
mean = means[c] / minor_len
160+
means[c] = mean
161+
variances[c] = variances[c] / minor_len - mean * mean
162+
return means, variances
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# SPDX-License-Identifier: MPL-2.0
2+
from __future__ import annotations
3+
4+
from functools import singledispatch
5+
from typing import TYPE_CHECKING, cast
6+
7+
from .. import types
8+
9+
10+
if TYPE_CHECKING:
11+
from typing import Any, TypeVar
12+
13+
from numpy.typing import NDArray
14+
15+
# All supported array types except for disk ones and OutOfCoreDataset
16+
Array = NDArray[Any] | types.CSBase | types.CupyArray | types.CupySparseMatrix | types.DaskArray
17+
18+
_Arr = TypeVar("_Arr", bound=Array)
19+
20+
21+
def power(x: _Arr, n: int, /) -> _Arr:
22+
"""Take array or matrix to a power."""
23+
# This wrapper is necessary because TypeVars can’t be used in `singledispatch` functions
24+
return _power(x, n) # type: ignore[return-value]
25+
26+
27+
@singledispatch
28+
def _power(x: Array, n: int, /) -> Array:
29+
if TYPE_CHECKING:
30+
assert not isinstance(x, types.DaskArray | types.CSMatrix)
31+
return x**n # type: ignore[operator]
32+
33+
34+
@_power.register(types.CSMatrix) # type: ignore[call-overload,misc]
35+
def _power_cs(x: types.CSMatrix, n: int, /) -> types.CSMatrix:
36+
return x.power(n)
37+
38+
39+
@_power.register(types.DaskArray)
40+
def _power_dask(x: types.DaskArray, n: int, /) -> types.DaskArray:
41+
import dask.array as da
42+
43+
return cast(types.DaskArray, da.map_blocks(power, x, n)) # type: ignore[no-untyped-call]

stubs/cupy.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# SPDX-License-Identifier: MPL-2.0
2-
from typing import Any, Literal
2+
from typing import Any, Literal, Self
33

44
import numpy as np
55
from numpy.typing import ArrayLike, DTypeLike, NDArray
@@ -8,6 +8,7 @@ class ndarray:
88
dtype: np.dtype[Any]
99
shape: tuple[int, ...]
1010
def get(self) -> NDArray[Any]: ...
11+
def __power__(self, other: int) -> Self: ...
1112

1213
def asarray(
1314
a: ArrayLike,

stubs/cupyx/scipy/sparse.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# SPDX-License-Identifier: MPL-2.0
2-
from typing import Any, Literal
2+
from typing import Any, Literal, Self
33

44
import cupy
55
import numpy as np
@@ -8,3 +8,4 @@ class spmatrix:
88
dtype: np.dtype[Any]
99
shape: tuple[int, int]
1010
def toarray(self, order: Literal["C", "F", None] = None, out: None = None) -> cupy.ndarray: ...
11+
def __power__(self, other: int) -> Self: ...

stubs/numba/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ def prange(stop: SupportsIndex, /) -> Iterable[int]: ...
3131
def prange(
3232
start: SupportsIndex, stop: SupportsIndex, step: SupportsIndex = ..., /
3333
) -> Iterable[int]: ...
34+
def get_num_threads() -> int: ...

tests/test_stats.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,41 @@ def test_sum(
9090

9191
@pytest.mark.parametrize(("axis", "expected"), [(None, 3.5), (0, [2.5, 3.5, 4.5]), (1, [2.0, 5.0])])
9292
def test_mean(
93-
array_type: ArrayType[Array], axis: Literal[0, 1, None], expected: list[float]
93+
array_type: ArrayType[Array], axis: Literal[0, 1, None], expected: float | list[float]
9494
) -> None:
95-
arr = array_type(np.array([[1, 2, 3], [4, 5, 6]]))
95+
np_arr = np.array([[1, 2, 3], [4, 5, 6]])
96+
np.testing.assert_array_equal(np.mean(np_arr, axis=axis), expected)
97+
98+
arr = array_type(np_arr)
9699
result = stats.mean(arr, axis=axis)
97100
if isinstance(result, types.DaskArray):
98-
result = result.compute() # type: ignore[no-untyped-call]
101+
result = result.compute()
99102
np.testing.assert_array_equal(result, expected)
100103

101104

105+
@pytest.mark.array_type(skip=Flags.Disk)
106+
@pytest.mark.parametrize(
107+
("axis", "mean_expected", "var_expected"),
108+
[(None, 3.5, 3.5), (0, [2.5, 3.5, 4.5], [4.5, 4.5, 4.5]), (1, [2.0, 5.0], [1.0, 1.0])],
109+
)
110+
def test_mean_var(
111+
array_type: ArrayType[
112+
NDArray[Any] | types.CSBase | types.CupyArray | types.CupySparseMatrix | types.DaskArray
113+
],
114+
axis: Literal[0, 1, None],
115+
mean_expected: float | list[float],
116+
var_expected: float | list[float],
117+
) -> None:
118+
np_arr = np.array([[1, 2, 3], [4, 5, 6]])
119+
np.testing.assert_array_equal(np.mean(np_arr, axis=axis), mean_expected)
120+
np.testing.assert_array_equal(np.var(np_arr, axis=axis, correction=1), var_expected)
121+
122+
arr = array_type(np_arr)
123+
mean, var = stats.mean_var(arr, axis=axis, correction=1)
124+
np.testing.assert_array_equal(mean, mean_expected)
125+
np.testing.assert_array_almost_equal_nulp(var, var_expected, nulp=8)
126+
127+
102128
# TODO(flying-sheep): enable for GPU # noqa: TD003
103129
@pytest.mark.array_type(skip=Flags.Disk | Flags.Gpu)
104130
@pytest.mark.parametrize(
@@ -149,7 +175,7 @@ def test_dask_constant_blocks(
149175

150176
@pytest.mark.benchmark
151177
@pytest.mark.array_type(skip=Flags.Matrix | Flags.Dask | Flags.Disk | Flags.Gpu)
152-
@pytest.mark.parametrize("func", [stats.sum, stats.mean, stats.is_constant])
178+
@pytest.mark.parametrize("func", [stats.sum, stats.mean, stats.mean_var, stats.is_constant])
153179
@pytest.mark.parametrize("dtype", [np.float32, np.float64]) # random only supports float
154180
def test_stats_benchmark(
155181
benchmark: BenchmarkFixture,

0 commit comments

Comments
 (0)