Skip to content

Commit a33cf9a

Browse files
authored
Add sum (#13)
1 parent f241148 commit a33cf9a

File tree

12 files changed

+368
-108
lines changed

12 files changed

+368
-108
lines changed

docs/conf.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,20 +51,23 @@
5151
)
5252
# Try overriding type paths
5353
qualname_overrides = autodoc_type_aliases = {
54+
"np.dtype": "numpy.dtype",
5455
"ArrayLike": "numpy.typing.ArrayLike",
56+
"DTypeLike": "numpy.typing.DTypeLike",
57+
"NDArray": "numpy.typing.NDArray",
5558
"CSBase": "scipy.sparse.spmatrix",
5659
"CupyArray": "cupy.ndarray",
5760
"CupySparseMatrix": "cupyx.scipy.sparse.spmatrix",
5861
"DaskArray": "dask.array.Array",
5962
"H5Dataset": "h5py.Dataset",
60-
"NDArray": "numpy.typing.NDArray",
6163
}
6264
# If that doesn’t work, ignore them
6365
nitpick_ignore = {
6466
("py:class", "DT_co"),
6567
("py:class", "fast_array_utils.types.T_co"),
66-
# sphinx bugs, should be covered by `autodoc_type_aliases` below
68+
# sphinx bugs, should be covered by `autodoc_type_aliases` above
6769
("py:class", "ArrayLike"),
70+
("py:class", "DTypeLike"),
6871
("py:class", "NDArray"),
6972
}
7073

docs/index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@
1212
:members:
1313

1414

15+
``fast_array_utils.stats``
16+
--------------------------
17+
18+
.. automodule:: fast_array_utils.stats
19+
:members:
20+
21+
1522
``fast_array_utils.types``
1623
--------------------------
1724

src/fast_array_utils/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
from __future__ import annotations
55

6-
from . import _patches, conv, types
6+
from . import _patches, conv, stats, types
77

88

9-
__all__ = ["conv", "types"]
9+
__all__ = ["conv", "stats", "types"]
1010

1111
_patches.patch_dask()
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# SPDX-License-Identifier: MPL-2.0
2+
"""Statistics utilities."""
3+
4+
from __future__ import annotations
5+
6+
from ._sum import sum
7+
8+
9+
__all__ = ["sum"]

src/fast_array_utils/stats/_sum.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# SPDX-License-Identifier: MPL-2.0
2+
from __future__ import annotations
3+
4+
from functools import partial, singledispatch
5+
from typing import TYPE_CHECKING
6+
7+
import numpy as np
8+
9+
from ..types import CSBase, CSMatrix, DaskArray
10+
11+
12+
if TYPE_CHECKING:
13+
from typing import Literal, TypeVar
14+
15+
from numpy.typing import ArrayLike, DTypeLike, NDArray
16+
17+
DT_co = TypeVar("DT_co", covariant=True, bound=np.generic)
18+
19+
20+
# TODO(flying-sheep): overload so axis=None returns np.floating # noqa: TD003
21+
22+
23+
@singledispatch
24+
def sum(
25+
x: ArrayLike,
26+
*,
27+
axis: Literal[0, 1, None] = None,
28+
dtype: DTypeLike | np.dtype[DT_co] | None = None,
29+
) -> NDArray[DT_co]:
30+
"""Sum over both or one axis."""
31+
return np.sum(x, axis=axis, dtype=dtype) # type: ignore[no-any-return]
32+
33+
34+
@sum.register(CSBase) # type: ignore[misc,call-overload]
35+
def _(
36+
x: CSBase[DT_co], *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
37+
) -> NDArray[DT_co]:
38+
import scipy.sparse as sp
39+
40+
if isinstance(x, CSMatrix):
41+
x = sp.csr_array(x) if x.format == "csr" else sp.csc_array(x)
42+
return np.sum(x, axis=axis, dtype=dtype) # type: ignore[call-overload,no-any-return]
43+
44+
45+
@sum.register(DaskArray)
46+
def _(
47+
x: DaskArray, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
48+
) -> DaskArray:
49+
if TYPE_CHECKING:
50+
from dask.array.reductions import reduction
51+
else:
52+
from dask.array import reduction
53+
54+
if isinstance(x._meta, np.matrix): # noqa: SLF001
55+
msg = "sum does not support numpy matrices"
56+
raise TypeError(msg)
57+
58+
def sum_drop_keepdims(
59+
a: NDArray[DT_co] | CSBase[DT_co],
60+
*,
61+
axis: tuple[Literal[0], Literal[1]] | Literal[0, 1] | None = None,
62+
dtype: np.dtype[DT_co] | None = None,
63+
keepdims: bool = False,
64+
) -> NDArray[DT_co]:
65+
del keepdims
66+
match axis:
67+
case (0 | 1 as n,):
68+
axis = n
69+
case (0, 1) | (1, 0):
70+
axis = None
71+
case tuple():
72+
msg = f"`sum` can only sum over `axis=0|1|(0,1)` but got {axis} instead"
73+
raise ValueError(msg)
74+
rv: NDArray[DT_co] | DT_co = sum(a, axis=axis, dtype=dtype) # type: ignore[arg-type]
75+
rv = np.array(rv, ndmin=1) # make sure rv is at least 1D
76+
return rv.reshape((1, len(rv)))
77+
78+
if dtype is None:
79+
# Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`)
80+
dtype = np.zeros(1, dtype=x.dtype).sum().dtype
81+
82+
return reduction( # type: ignore[no-any-return,no-untyped-call]
83+
x,
84+
sum_drop_keepdims,
85+
partial(np.sum, dtype=dtype),
86+
axis=axis,
87+
dtype=dtype,
88+
meta=np.array([], dtype=dtype),
89+
)
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# SPDX-License-Identifier: MPL-2.0
2+
"""Testing utilities."""
3+
4+
from __future__ import annotations
5+
6+
import re
7+
from typing import TYPE_CHECKING
8+
9+
import numpy as np
10+
11+
12+
if TYPE_CHECKING:
13+
from typing import Any, Generic, Literal, Protocol, SupportsFloat, TypeAlias, TypeVar
14+
15+
from numpy.typing import ArrayLike, NDArray
16+
17+
from fast_array_utils import types
18+
from fast_array_utils.types import CSBase
19+
20+
_SCT_co = TypeVar("_SCT_co", covariant=True, bound=np.generic)
21+
_SCT_contra = TypeVar("_SCT_contra", contravariant=True, bound=np.generic)
22+
_SCT_float = TypeVar("_SCT_float", np.float32, np.float64)
23+
24+
Array: TypeAlias = (
25+
NDArray[_SCT_co]
26+
| types.CSBase[_SCT_co]
27+
| types.CupyArray[_SCT_co]
28+
| types.DaskArray
29+
| types.H5Dataset
30+
| types.ZarrArray
31+
)
32+
33+
class ToArray(Protocol, Generic[_SCT_contra]):
34+
"""Convert to a supported array."""
35+
36+
def __call__( # noqa: D102
37+
self, data: ArrayLike, /, *, dtype: _SCT_contra | None = None
38+
) -> Array[_SCT_contra]: ...
39+
40+
41+
RE_ARRAY_QUAL = re.compile(r"(?P<mod>(?:\w+\.)*\w+)\.(?P<name>[^\[]+)(?:\[(?P<inner>[\w.]+)\])?")
42+
43+
44+
def get_array_cls(qualname: str) -> type[Array[Any]]: # noqa: PLR0911
45+
"""Get a supported array class by qualname."""
46+
m = RE_ARRAY_QUAL.fullmatch(qualname)
47+
assert m
48+
match m["mod"], m["name"], m["inner"]:
49+
case "numpy", "ndarray", None:
50+
return np.ndarray
51+
case "scipy.sparse", (
52+
"csr_array" | "csc_array" | "csr_matrix" | "csc_matrix"
53+
) as cls_name, None:
54+
import scipy.sparse
55+
56+
return getattr(scipy.sparse, cls_name) # type: ignore[no-any-return]
57+
case "cupy", "ndarray", None:
58+
import cupy as cp
59+
60+
return cp.ndarray # type: ignore[no-any-return]
61+
case "cupyx.scipy.sparse", ("csr_matrix" | "csc_matrix") as cls_name, None:
62+
import cupyx.scipy.sparse as cu_sparse
63+
64+
return getattr(cu_sparse, cls_name) # type: ignore[no-any-return]
65+
case "dask.array", cls_name, _:
66+
if TYPE_CHECKING:
67+
from dask.array.core import Array as DaskArray
68+
else:
69+
from dask.array import Array as DaskArray
70+
71+
return DaskArray
72+
case "h5py", "Dataset", _:
73+
import h5py
74+
75+
return h5py.Dataset # type: ignore[no-any-return]
76+
case "zarr", "Array", _:
77+
import zarr
78+
79+
return zarr.Array
80+
case _:
81+
msg = f"Unknown array class: {qualname}"
82+
raise ValueError(msg)
83+
84+
85+
def random_mat(
86+
shape: tuple[int, int],
87+
*,
88+
density: SupportsFloat = 0.01,
89+
format: Literal["csr", "csc"] = "csr", # noqa: A002
90+
dtype: np.dtype[_SCT_float] | type[_SCT_float] | None = None,
91+
container: Literal["array", "matrix"] = "array",
92+
gen: np.random.Generator | None = None,
93+
) -> CSBase[_SCT_float]:
94+
"""Create a random matrix."""
95+
from scipy.sparse import random as random_spmat
96+
from scipy.sparse import random_array as random_sparr
97+
98+
m, n = shape
99+
return (
100+
random_spmat(m, n, density=density, format=format, dtype=dtype, random_state=gen)
101+
if container == "matrix"
102+
else random_sparr(shape, density=density, format=format, dtype=dtype, random_state=gen)
103+
)
104+
105+
106+
def random_array(
107+
qualname: str,
108+
shape: tuple[int, int],
109+
*,
110+
dtype: np.dtype[_SCT_float] | type[_SCT_float] | None,
111+
gen: np.random.Generator | None = None,
112+
) -> Array[_SCT_float]:
113+
"""Create a random array."""
114+
gen = np.random.default_rng(gen)
115+
116+
m = RE_ARRAY_QUAL.fullmatch(qualname)
117+
assert m
118+
match m["mod"], m["name"], m["inner"]:
119+
case "numpy", "ndarray", None:
120+
return gen.random(shape, dtype=dtype or np.float64)
121+
case "scipy.sparse", (
122+
"csr_array" | "csc_array" | "csr_matrix" | "csc_matrix"
123+
) as cls_name, None:
124+
fmt, container = cls_name.split("_")
125+
return random_mat(shape, format=fmt, container=container, dtype=dtype) # type: ignore[arg-type]
126+
case "cupy", "ndarray", None:
127+
raise NotImplementedError
128+
case "cupyx.scipy.sparse", ("csr_matrix" | "csc_matrix") as cls_name, None:
129+
raise NotImplementedError
130+
case "dask.array", cls_name, _:
131+
raise NotImplementedError
132+
case "h5py", "Dataset", _:
133+
raise NotImplementedError
134+
case "zarr", "Array", _:
135+
raise NotImplementedError
136+
case _:
137+
msg = f"Unknown array class: {qualname}"
138+
raise ValueError(msg)

src/testing/fast_array_utils/__init__.pyi

Lines changed: 0 additions & 26 deletions
This file was deleted.

0 commit comments

Comments
 (0)