Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,23 @@
)
# Try overriding type paths
qualname_overrides = autodoc_type_aliases = {
"np.dtype": "numpy.dtype",
"ArrayLike": "numpy.typing.ArrayLike",
"DTypeLike": "numpy.typing.DTypeLike",
"NDArray": "numpy.typing.NDArray",
"CSBase": "scipy.sparse.spmatrix",
"CupyArray": "cupy.ndarray",
"CupySparseMatrix": "cupyx.scipy.sparse.spmatrix",
"DaskArray": "dask.array.Array",
"H5Dataset": "h5py.Dataset",
"NDArray": "numpy.typing.NDArray",
}
# If that doesn’t work, ignore them
nitpick_ignore = {
("py:class", "DT_co"),
("py:class", "fast_array_utils.types.T_co"),
# sphinx bugs, should be covered by `autodoc_type_aliases` below
# sphinx bugs, should be covered by `autodoc_type_aliases` above
("py:class", "ArrayLike"),
("py:class", "DTypeLike"),
("py:class", "NDArray"),
}

Expand Down
7 changes: 7 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
:members:


``fast_array_utils.stats``
--------------------------

.. automodule:: fast_array_utils.stats
:members:


``fast_array_utils.types``
--------------------------

Expand Down
4 changes: 2 additions & 2 deletions src/fast_array_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from __future__ import annotations

from . import _patches, conv, types
from . import _patches, conv, stats, types


__all__ = ["conv", "types"]
__all__ = ["conv", "stats", "types"]

_patches.patch_dask()
9 changes: 9 additions & 0 deletions src/fast_array_utils/stats/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# SPDX-License-Identifier: MPL-2.0
"""Statistics utilities."""

from __future__ import annotations

from ._sum import sum


__all__ = ["sum"]
89 changes: 89 additions & 0 deletions src/fast_array_utils/stats/_sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations

from functools import partial, singledispatch
from typing import TYPE_CHECKING

import numpy as np

from ..types import CSBase, CSMatrix, DaskArray


if TYPE_CHECKING:
from typing import Literal, TypeVar

Check warning on line 13 in src/fast_array_utils/stats/_sum.py

View check run for this annotation

Codecov / codecov/patch

src/fast_array_utils/stats/_sum.py#L13

Added line #L13 was not covered by tests

from numpy.typing import ArrayLike, DTypeLike, NDArray

Check warning on line 15 in src/fast_array_utils/stats/_sum.py

View check run for this annotation

Codecov / codecov/patch

src/fast_array_utils/stats/_sum.py#L15

Added line #L15 was not covered by tests

DT_co = TypeVar("DT_co", covariant=True, bound=np.generic)

Check warning on line 17 in src/fast_array_utils/stats/_sum.py

View check run for this annotation

Codecov / codecov/patch

src/fast_array_utils/stats/_sum.py#L17

Added line #L17 was not covered by tests


# TODO(flying-sheep): overload so axis=None returns np.floating # noqa: TD003


@singledispatch
def sum(
x: ArrayLike,
*,
axis: Literal[0, 1, None] = None,
dtype: DTypeLike | np.dtype[DT_co] | None = None,
) -> NDArray[DT_co]:
"""Sum over both or one axis."""
return np.sum(x, axis=axis, dtype=dtype) # type: ignore[no-any-return]


@sum.register(CSBase) # type: ignore[misc,call-overload]
def _(
x: CSBase[DT_co], *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
) -> NDArray[DT_co]:
import scipy.sparse as sp

if isinstance(x, CSMatrix):
x = sp.csr_array(x) if x.format == "csr" else sp.csc_array(x)
return np.sum(x, axis=axis, dtype=dtype) # type: ignore[call-overload,no-any-return]


@sum.register(DaskArray)
def _(
x: DaskArray, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
) -> DaskArray:
if TYPE_CHECKING:
from dask.array.reductions import reduction

Check warning on line 50 in src/fast_array_utils/stats/_sum.py

View check run for this annotation

Codecov / codecov/patch

src/fast_array_utils/stats/_sum.py#L50

Added line #L50 was not covered by tests
else:
from dask.array import reduction

if isinstance(x._meta, np.matrix): # noqa: SLF001
msg = "sum does not support numpy matrices"
raise TypeError(msg)

Check warning on line 56 in src/fast_array_utils/stats/_sum.py

View check run for this annotation

Codecov / codecov/patch

src/fast_array_utils/stats/_sum.py#L55-L56

Added lines #L55 - L56 were not covered by tests

def sum_drop_keepdims(
a: NDArray[DT_co] | CSBase[DT_co],
*,
axis: tuple[Literal[0], Literal[1]] | Literal[0, 1] | None = None,
dtype: np.dtype[DT_co] | None = None,
keepdims: bool = False,
) -> NDArray[DT_co]:
del keepdims
match axis:
case (0 | 1 as n,):
axis = n
case (0, 1) | (1, 0):
axis = None
case tuple():
msg = f"`sum` can only sum over `axis=0|1|(0,1)` but got {axis} instead"
raise ValueError(msg)

Check warning on line 73 in src/fast_array_utils/stats/_sum.py

View check run for this annotation

Codecov / codecov/patch

src/fast_array_utils/stats/_sum.py#L71-L73

Added lines #L71 - L73 were not covered by tests
rv: NDArray[DT_co] | DT_co = sum(a, axis=axis, dtype=dtype) # type: ignore[arg-type]
rv = np.array(rv, ndmin=1) # make sure rv is at least 1D
return rv.reshape((1, len(rv)))

if dtype is None:
# Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`)
dtype = np.zeros(1, dtype=x.dtype).sum().dtype

return reduction( # type: ignore[no-any-return,no-untyped-call]
x,
sum_drop_keepdims,
partial(np.sum, dtype=dtype),
axis=axis,
dtype=dtype,
meta=np.array([], dtype=dtype),
)
138 changes: 138 additions & 0 deletions src/testing/fast_array_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# SPDX-License-Identifier: MPL-2.0
"""Testing utilities."""

from __future__ import annotations

import re
from typing import TYPE_CHECKING

import numpy as np


if TYPE_CHECKING:
from typing import Any, Generic, Literal, Protocol, SupportsFloat, TypeAlias, TypeVar

from numpy.typing import ArrayLike, NDArray

from fast_array_utils import types
from fast_array_utils.types import CSBase

_SCT_co = TypeVar("_SCT_co", covariant=True, bound=np.generic)
_SCT_contra = TypeVar("_SCT_contra", contravariant=True, bound=np.generic)
_SCT_float = TypeVar("_SCT_float", np.float32, np.float64)

Array: TypeAlias = (
NDArray[_SCT_co]
| types.CSBase[_SCT_co]
| types.CupyArray[_SCT_co]
| types.DaskArray
| types.H5Dataset
| types.ZarrArray
)

class ToArray(Protocol, Generic[_SCT_contra]):
"""Convert to a supported array."""

def __call__( # noqa: D102
self, data: ArrayLike, /, *, dtype: _SCT_contra | None = None
) -> Array[_SCT_contra]: ...


RE_ARRAY_QUAL = re.compile(r"(?P<mod>(?:\w+\.)*\w+)\.(?P<name>[^\[]+)(?:\[(?P<inner>[\w.]+)\])?")


def get_array_cls(qualname: str) -> type[Array[Any]]: # noqa: PLR0911
"""Get a supported array class by qualname."""
m = RE_ARRAY_QUAL.fullmatch(qualname)
assert m
match m["mod"], m["name"], m["inner"]:
case "numpy", "ndarray", None:
return np.ndarray
case "scipy.sparse", (
"csr_array" | "csc_array" | "csr_matrix" | "csc_matrix"
) as cls_name, None:
import scipy.sparse

return getattr(scipy.sparse, cls_name) # type: ignore[no-any-return]
case "cupy", "ndarray", None:
import cupy as cp

return cp.ndarray # type: ignore[no-any-return]
case "cupyx.scipy.sparse", ("csr_matrix" | "csc_matrix") as cls_name, None:
import cupyx.scipy.sparse as cu_sparse

return getattr(cu_sparse, cls_name) # type: ignore[no-any-return]
case "dask.array", cls_name, _:
if TYPE_CHECKING:
from dask.array.core import Array as DaskArray
else:
from dask.array import Array as DaskArray

return DaskArray
case "h5py", "Dataset", _:
import h5py

return h5py.Dataset # type: ignore[no-any-return]
case "zarr", "Array", _:
import zarr

return zarr.Array
case _:
msg = f"Unknown array class: {qualname}"
raise ValueError(msg)


def random_mat(
shape: tuple[int, int],
*,
density: SupportsFloat = 0.01,
format: Literal["csr", "csc"] = "csr", # noqa: A002
dtype: np.dtype[_SCT_float] | type[_SCT_float] | None = None,
container: Literal["array", "matrix"] = "array",
gen: np.random.Generator | None = None,
) -> CSBase[_SCT_float]:
"""Create a random matrix."""
from scipy.sparse import random as random_spmat
from scipy.sparse import random_array as random_sparr

m, n = shape
return (
random_spmat(m, n, density=density, format=format, dtype=dtype, random_state=gen)
if container == "matrix"
else random_sparr(shape, density=density, format=format, dtype=dtype, random_state=gen)
)


def random_array(
qualname: str,
shape: tuple[int, int],
*,
dtype: np.dtype[_SCT_float] | type[_SCT_float] | None,
gen: np.random.Generator | None = None,
) -> Array[_SCT_float]:
"""Create a random array."""
gen = np.random.default_rng(gen)

m = RE_ARRAY_QUAL.fullmatch(qualname)
assert m
match m["mod"], m["name"], m["inner"]:
case "numpy", "ndarray", None:
return gen.random(shape, dtype=dtype or np.float64)
case "scipy.sparse", (
"csr_array" | "csc_array" | "csr_matrix" | "csc_matrix"
) as cls_name, None:
fmt, container = cls_name.split("_")
return random_mat(shape, format=fmt, container=container, dtype=dtype) # type: ignore[arg-type]
case "cupy", "ndarray", None:
raise NotImplementedError
case "cupyx.scipy.sparse", ("csr_matrix" | "csc_matrix") as cls_name, None:
raise NotImplementedError
case "dask.array", cls_name, _:
raise NotImplementedError
case "h5py", "Dataset", _:
raise NotImplementedError
case "zarr", "Array", _:
raise NotImplementedError
case _:
msg = f"Unknown array class: {qualname}"
raise ValueError(msg)
26 changes: 0 additions & 26 deletions src/testing/fast_array_utils/__init__.pyi

This file was deleted.

Loading