Skip to content

Commit 9df6b5a

Browse files
authored
fix: Have sum operate on the passed dtype. (#125)
1 parent 49bdfac commit 9df6b5a

File tree

5 files changed

+55
-24
lines changed

5 files changed

+55
-24
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ repos:
2525
rev: v1.18.2
2626
hooks:
2727
- id: mypy
28-
args: [--config-file=pyproject.toml]
28+
args: [--config-file=pyproject.toml, .]
29+
pass_filenames: false
2930
additional_dependencies:
3031
- pytest
3132
- pytest-codspeed!=4.0.0 # https://github.com/CodSpeedHQ/pytest-codspeed/pull/84

src/fast_array_utils/stats/_sum.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,18 @@ def _sum_cs(
6363
del keep_cupy_as_array
6464
import scipy.sparse as sp
6565

66-
if isinstance(x, types.CSMatrix):
67-
x = sp.csr_array(x) if x.format == "csr" else sp.csc_array(x)
66+
# TODO(flying-sheep): once scipy fixes this issue, instead of all this,
67+
# just convert to sparse array, then `return x.sum(dtype=dtype)`
68+
# https://github.com/scipy/scipy/issues/23768
6869

6970
if axis is None:
70-
return cast("np.number[Any]", x.data.sum(dtype=dtype))
71-
return cast("NDArray[Any] | np.number[Any]", x.sum(axis=axis, dtype=dtype))
71+
return cast("NDArray[Any] | np.number[Any]", x.data.sum(dtype=dtype))
72+
73+
if TYPE_CHECKING: # scipy-stubs thinks e.g. "int64" is invalid, which isn’t true
74+
assert isinstance(dtype, np.dtype | type | None)
75+
# convert to array so dimensions collapse as expected
76+
x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, dtype=dtype)
77+
return cast("NDArray[Any] | np.number[Any]", x.sum(axis=axis))
7278

7379

7480
@sum_.register(types.DaskArray)
@@ -92,7 +98,7 @@ def _sum_dask(
9298

9399
rv = da.reduction(
94100
x,
95-
sum_dask_inner, # type: ignore[arg-type]
101+
partial(sum_dask_inner, dtype=dtype), # pyright: ignore[reportArgumentType]
96102
partial(sum_dask_inner, dtype=dtype), # pyright: ignore[reportArgumentType]
97103
axis=axis,
98104
dtype=dtype,

tests/test_stats.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -90,20 +90,42 @@ def dtype_in(request: pytest.FixtureRequest, array_type: ArrayType) -> type[DTyp
9090
return dtype
9191

9292

93-
@pytest.fixture(scope="session", params=[np.float32, np.float64, None])
93+
@pytest.fixture(scope="session", params=[np.float32, np.float64, np.int64, None])
9494
def dtype_arg(request: pytest.FixtureRequest) -> type[DTypeOut] | None:
9595
return cast("type[DTypeOut] | None", request.param)
9696

9797

9898
@pytest.fixture
9999
def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]:
100100
np_arr = cast("NDArray[DTypeIn]", np.array([[1, 0], [3, 0], [5, 6]], dtype=dtype_in))
101+
if np.dtype(dtype_in).kind == "f":
102+
np_arr /= 4 # type: ignore[misc]
101103
np_arr.flags.writeable = False
102104
if ndim == 1:
103105
np_arr = np_arr.flatten()
104106
return np_arr
105107

106108

109+
def to_np_dense_checked(
110+
stat: NDArray[DTypeOut] | np.number[Any] | types.DaskArray, axis: Literal[0, 1] | None, arr: CpuArray | GpuArray | DiskArray | types.DaskArray
111+
) -> NDArray[DTypeOut] | np.number[Any]:
112+
match axis, arr:
113+
case _, types.DaskArray():
114+
assert isinstance(stat, types.DaskArray), type(stat)
115+
stat = stat.compute() # type: ignore[assignment]
116+
return to_np_dense_checked(stat, axis, arr.compute())
117+
case None, _:
118+
assert isinstance(stat, np.floating | np.integer), type(stat)
119+
case 0 | 1, types.CupyArray() | types.CupyCSRMatrix() | types.CupyCSCMatrix() | types.CupyCOOMatrix():
120+
assert isinstance(stat, types.CupyArray), type(stat)
121+
return to_np_dense_checked(stat.get(), axis, arr.get())
122+
case 0 | 1, _:
123+
assert isinstance(stat, np.ndarray), type(stat)
124+
case _:
125+
pytest.fail(f"Unhandled case axis {axis} for {type(arr)}: {type(stat)}")
126+
return stat
127+
128+
107129
@pytest.mark.array_type(skip={*ATS_SPARSE_DS, Flags.Matrix})
108130
@pytest.mark.parametrize("func", STAT_FUNCS)
109131
@pytest.mark.parametrize(("ndim", "axis"), [(1, 0), (2, 3), (2, -1)], ids=["1d-ax0", "2d-ax3", "2d-axneg"])
@@ -127,26 +149,13 @@ def test_sum(
127149
axis: Literal[0, 1] | None,
128150
np_arr: NDArray[DTypeIn],
129151
) -> None:
152+
if np.dtype(dtype_arg).kind in "iu" and (array_type.flags & Flags.Gpu) and (array_type.flags & Flags.Sparse):
153+
pytest.skip("GPU sparse matrices don’t support int dtypes")
130154
arr = array_type(np_arr.copy())
131155
assert arr.dtype == dtype_in
132156

133157
sum_ = stats.sum(arr, axis=axis, dtype=dtype_arg)
134-
135-
match axis, arr:
136-
case _, types.DaskArray():
137-
assert isinstance(sum_, types.DaskArray), type(sum_)
138-
sum_ = sum_.compute() # type: ignore[assignment]
139-
if isinstance(sum_, types.CupyArray):
140-
sum_ = sum_.get()
141-
case None, _:
142-
assert isinstance(sum_, np.floating | np.integer), type(sum_)
143-
case 0 | 1, types.CupyArray() | types.CupyCSRMatrix() | types.CupyCSCMatrix():
144-
assert isinstance(sum_, types.CupyArray), type(sum_)
145-
sum_ = sum_.get()
146-
case 0 | 1, _:
147-
assert isinstance(sum_, np.ndarray), type(sum_)
148-
case _:
149-
pytest.fail(f"Unhandled case axis {axis} for {type(arr)}: {type(sum_)}")
158+
sum_ = to_np_dense_checked(sum_, axis, arr) # type: ignore[arg-type]
150159

151160
assert sum_.shape == () if axis is None else arr.shape[axis], (sum_.shape, arr.shape)
152161

@@ -161,6 +170,19 @@ def test_sum(
161170
np.testing.assert_array_equal(sum_, expected)
162171

163172

173+
@pytest.mark.array_type(skip={*ATS_SPARSE_DS, Flags.Gpu})
174+
def test_sum_to_int(array_type: ArrayType[CpuArray | DiskArray | types.DaskArray], axis: Literal[0, 1] | None) -> None:
175+
rng = np.random.default_rng(0)
176+
np_arr = rng.random((100, 100))
177+
arr = array_type(np_arr)
178+
179+
sum_ = stats.sum(arr, axis=axis, dtype=np.int64)
180+
sum_ = to_np_dense_checked(sum_, axis, arr)
181+
182+
expected = np.zeros(() if axis is None else arr.shape[axis], dtype=np.int64)
183+
np.testing.assert_array_equal(sum_, expected)
184+
185+
164186
@pytest.mark.parametrize(
165187
"data",
166188
[

typings/cupy/_core/core.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ from typing import Any, Literal, Self, overload
55
import numpy as np
66
from cupy.cuda import Stream
77
from numpy._core.multiarray import flagsobj
8-
from numpy.typing import NDArray
8+
from numpy.typing import DTypeLike, NDArray
99

1010
class ndarray:
1111
dtype: np.dtype[Any]
@@ -41,6 +41,7 @@ class ndarray:
4141
def flatten(self, order: Literal["C", "F", "A", "K"] = "C") -> Self: ...
4242
@property
4343
def flat(self) -> _FlatIter: ...
44+
def sum(self, axis: int | None = None, dtype: DTypeLike | None = None, out: ndarray | None = None, keepdims: bool = False) -> ndarray: ...
4445

4546
class _FlatIter:
4647
def __next__(self) -> np.float32 | np.float64: ...

typings/cupyx/scipy/sparse/_compressed.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ class _compressed_sparse_matrix(spmatrix):
2020

2121
# methods
2222
def power(self, n: int, dtype: DTypeLike | None = None) -> Self: ...
23+
def sum(self, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, out: Self | None = None) -> ndarray: ...

0 commit comments

Comments
 (0)