|
| 1 | +# SPDX-License-Identifier: MPL-2.0 |
| 2 | +from __future__ import annotations |
| 3 | + |
| 4 | +from typing import TYPE_CHECKING, no_type_check, overload |
| 5 | + |
| 6 | +import numba |
| 7 | +import numpy as np |
| 8 | + |
| 9 | +from .. import types |
| 10 | +from ._mean import mean |
| 11 | +from ._power import power |
| 12 | + |
| 13 | + |
| 14 | +if TYPE_CHECKING: |
| 15 | + from typing import Any, Literal |
| 16 | + |
| 17 | + from numpy.typing import NDArray |
| 18 | + |
| 19 | + MemArray = NDArray[Any] | types.CSBase | types.CupyArray | types.CupySparseMatrix |
| 20 | + |
| 21 | + |
| 22 | +__all__ = ["mean_var"] |
| 23 | + |
| 24 | + |
| 25 | +@overload |
| 26 | +def mean_var( |
| 27 | + x: MemArray, /, *, axis: Literal[None] = None, correction: int = 0 |
| 28 | +) -> tuple[NDArray[np.float64], NDArray[np.float64]]: ... |
| 29 | +@overload |
| 30 | +def mean_var( |
| 31 | + x: MemArray, /, *, axis: Literal[0, 1], correction: int = 0 |
| 32 | +) -> tuple[np.float64, np.float64]: ... |
| 33 | +@overload |
| 34 | +def mean_var( |
| 35 | + x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, correction: int = 0 |
| 36 | +) -> tuple[types.DaskArray, types.DaskArray]: ... |
| 37 | + |
| 38 | + |
| 39 | +@no_type_check # mypy is extremely confused |
| 40 | +def mean_var( |
| 41 | + x: MemArray | types.DaskArray, |
| 42 | + /, |
| 43 | + *, |
| 44 | + axis: Literal[0, 1, None] = None, |
| 45 | + correction: int = 0, |
| 46 | +) -> ( |
| 47 | + tuple[NDArray[np.float64], NDArray[np.float64]] |
| 48 | + | tuple[np.float64, np.float64] |
| 49 | + | tuple[types.DaskArray, types.DaskArray] |
| 50 | +): |
| 51 | + if axis is not None and isinstance(x, types.CSBase): |
| 52 | + mean_, var = _sparse_mean_var(x, axis=axis) |
| 53 | + else: |
| 54 | + mean_ = mean(x, axis=axis, dtype=np.float64) |
| 55 | + mean_sq = mean(power(x, 2), axis=axis, dtype=np.float64) |
| 56 | + var = mean_sq - mean_**2 |
| 57 | + if correction: # R convention == 1 (unbiased estimator) |
| 58 | + n = np.prod(x.shape) if axis is None else x.shape[axis] |
| 59 | + if n != 1: |
| 60 | + var *= n / (n - correction) |
| 61 | + return mean_, var |
| 62 | + |
| 63 | + |
| 64 | +def _sparse_mean_var( |
| 65 | + mtx: types.CSBase, /, *, axis: Literal[0, 1] |
| 66 | +) -> tuple[NDArray[np.float64], NDArray[np.float64]]: |
| 67 | + """Calculate means and variances for each row or column of a sparse matrix. |
| 68 | +
|
| 69 | + This code and internal functions are based on sklearns `sparsefuncs.mean_variance_axis`. |
| 70 | +
|
| 71 | + Modifications: |
| 72 | + - allow deciding on the output type, |
| 73 | + which can increase accuracy when calculating the mean and variance of 32bit floats. |
| 74 | + - Doesn't currently implement support for null values, but could. |
| 75 | + - Uses numba instead of Cython |
| 76 | + """ |
| 77 | + assert axis in (0, 1) |
| 78 | + if mtx.format == "csr": |
| 79 | + ax_minor = 1 |
| 80 | + shape = mtx.shape |
| 81 | + elif mtx.format == "csc": |
| 82 | + ax_minor = 0 |
| 83 | + shape = mtx.shape[::-1] |
| 84 | + else: |
| 85 | + msg = "This function only works on sparse csr and csc matrices" |
| 86 | + raise TypeError(msg) |
| 87 | + if len(shape) == 1: |
| 88 | + msg = "array must have 2 dimensions" |
| 89 | + raise TypeError(msg) |
| 90 | + f = sparse_mean_var_major_axis if axis == ax_minor else sparse_mean_var_minor_axis |
| 91 | + return f( |
| 92 | + mtx.data, |
| 93 | + mtx.indptr, |
| 94 | + mtx.indices, |
| 95 | + major_len=shape[0], |
| 96 | + minor_len=shape[1], |
| 97 | + n_threads=numba.get_num_threads(), |
| 98 | + ) |
| 99 | + |
| 100 | + |
| 101 | +@numba.njit |
| 102 | +def sparse_mean_var_minor_axis( |
| 103 | + data: NDArray[np.number[Any]], |
| 104 | + indptr: NDArray[np.integer[Any]], |
| 105 | + indices: NDArray[np.integer[Any]], |
| 106 | + *, |
| 107 | + major_len: int, |
| 108 | + minor_len: int, |
| 109 | + n_threads: int, |
| 110 | +) -> tuple[NDArray[np.float64], NDArray[np.float64]]: |
| 111 | + """Compute mean and variance along the minor axis of a compressed sparse matrix.""" |
| 112 | + rows = len(indptr) - 1 |
| 113 | + sums = np.zeros((n_threads, minor_len)) |
| 114 | + squared_sums = np.zeros((n_threads, minor_len)) |
| 115 | + means = np.zeros(minor_len) |
| 116 | + variances = np.zeros(minor_len) |
| 117 | + for i in numba.prange(n_threads): |
| 118 | + for r in range(i, rows, n_threads): |
| 119 | + for j in range(indptr[r], indptr[r + 1]): |
| 120 | + minor_index = indices[j] |
| 121 | + if minor_index >= minor_len: |
| 122 | + continue |
| 123 | + value = data[j] |
| 124 | + sums[i, minor_index] += value |
| 125 | + squared_sums[i, minor_index] += value * value |
| 126 | + for c in numba.prange(minor_len): |
| 127 | + sum = sums[:, c].sum() |
| 128 | + means[c] = sum / major_len |
| 129 | + variances[c] = squared_sums[:, c].sum() / major_len - (sum / major_len) ** 2 |
| 130 | + return means, variances |
| 131 | + |
| 132 | + |
| 133 | +@numba.njit |
| 134 | +def sparse_mean_var_major_axis( |
| 135 | + data: NDArray[np.number[Any]], |
| 136 | + indptr: NDArray[np.integer[Any]], |
| 137 | + indices: NDArray[np.integer[Any]], # noqa: ARG001 |
| 138 | + *, |
| 139 | + major_len: int, |
| 140 | + minor_len: int, |
| 141 | + n_threads: int, |
| 142 | +) -> tuple[NDArray[np.float64], NDArray[np.float64]]: |
| 143 | + """Compute means and variances along the major axis of a compressed sparse matrix.""" |
| 144 | + rows = len(indptr) - 1 |
| 145 | + means = np.zeros(major_len) |
| 146 | + variances = np.zeros_like(means) |
| 147 | + |
| 148 | + for i in numba.prange(n_threads): |
| 149 | + for r in range(i, rows, n_threads): |
| 150 | + sum_major = np.float64(0.0) |
| 151 | + squared_sum_minor = np.float64(0.0) |
| 152 | + for j in range(indptr[r], indptr[r + 1]): |
| 153 | + value = np.float64(data[j]) |
| 154 | + sum_major += value |
| 155 | + squared_sum_minor += value * value |
| 156 | + means[r] = sum_major |
| 157 | + variances[r] = squared_sum_minor |
| 158 | + for c in numba.prange(major_len): |
| 159 | + mean = means[c] / minor_len |
| 160 | + means[c] = mean |
| 161 | + variances[c] = variances[c] / minor_len - mean * mean |
| 162 | + return means, variances |
0 commit comments