Skip to content
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ lint.ignore = [
]
lint.per-file-ignores."docs/**/*.py" = [ "INP001" ] # No __init__.py in docs
lint.per-file-ignores."src/**/stats/*.py" = [ "A001", "A004" ] # Shadows builtins like `sum`
lint.per-file-ignores."src/fast_array_utils/types.py" = [ "N806" ] # We have variables that are classes here
lint.per-file-ignores."stubs/**/*.pyi" = [ "F403", "F405", "N801" ] # Stubs don’t follow name conventions
lint.per-file-ignores."tests/**/test_*.py" = [
"D100", # tests need no module docstrings
Expand Down
4 changes: 1 addition & 3 deletions src/fast_array_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@

from __future__ import annotations

from . import _patches, conv, stats, types
from . import conv, stats, types


__all__ = ["conv", "stats", "types"]

_patches.patch_dask()
94 changes: 94 additions & 0 deletions src/fast_array_utils/_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations

from dataclasses import dataclass, field
from functools import cache
from types import UnionType
from typing import TYPE_CHECKING, Generic, ParamSpec, TypeVar, cast, overload


if TYPE_CHECKING:
from collections.abc import Callable

P = ParamSpec("P")
R = TypeVar("R")


__all__ = ["import_by_qualname", "lazy_singledispatch"]


def import_by_qualname(qualname: str) -> object:
from importlib import import_module

mod_path, obj_path = qualname.split(":")

mod = import_module(mod_path)

if mod_path == "dask" or mod_path.startswith("dask."):
from ._patches import patch_dask

patch_dask()

# get object
obj = mod
for name in obj_path.split("."):
try:
obj = getattr(obj, name)
except AttributeError as e:
msg = f"Could not import {'.'.join(obj_path)} from {'.'.join(mod_path)} "
raise ImportError(msg) from e

Check warning on line 39 in src/fast_array_utils/_import.py

View check run for this annotation

Codecov / codecov/patch

src/fast_array_utils/_import.py#L37-L39

Added lines #L37 - L39 were not covered by tests
return obj


@dataclass
class lazy_singledispatch(Generic[P, R]): # noqa: N801
fallback: Callable[P, R]

_lazy: dict[tuple[str, str], Callable[..., R]] = field(init=False, default_factory=dict)
_eager: dict[type | UnionType, Callable[..., R]] = field(init=False, default_factory=dict)

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
fn = self.dispatch(type(args[0])) # type: ignore[arg-type] # https://github.com/python/mypy/issues/11470
return fn(*args, **kwargs)

def __hash__(self) -> int:
return hash(self.fallback)

@cache # noqa: B019
def dispatch(self, typ: type) -> Callable[P, R]:
for cls_reg, fn in self._eager.items():
if issubclass(typ, cls_reg):
return fn

Check warning on line 61 in src/fast_array_utils/_import.py

View check run for this annotation

Codecov / codecov/patch

src/fast_array_utils/_import.py#L61

Added line #L61 was not covered by tests
for (import_qualname, host_mod_name), fn in self._lazy.items():
for cls in typ.mro():
if cls.__module__.startswith(host_mod_name): # can be deeper
cls_reg = cast(type, import_by_qualname(import_qualname))
if issubclass(typ, cls_reg):
return fn
return self.fallback

@overload
def register(
self, qualname_or_type: str, /, host_mod_name: str | None = None
) -> Callable[[Callable[..., R]], lazy_singledispatch[P, R]]: ...
@overload
def register(
self, qualname_or_type: type | UnionType, /, host_mod_name: None = None
) -> Callable[[Callable[..., R]], lazy_singledispatch[P, R]]: ...

def register(
self, qualname_or_type: str | type | UnionType, /, host_mod_name: str | None = None
) -> Callable[[Callable[..., R]], lazy_singledispatch[P, R]]:
def decorator(fn: Callable[..., R]) -> lazy_singledispatch[P, R]:
match qualname_or_type, host_mod_name:
case str(), _:
hmn = qualname_or_type.split(":")[0] if host_mod_name is None else host_mod_name
self._lazy[(qualname_or_type, hmn)] = fn
case type() | UnionType(), None:
self._eager[qualname_or_type] = fn
case _:
msg = f"name_or_type {qualname_or_type!r} must be a str, type, or UnionType"
raise TypeError(msg)
return self

return decorator
18 changes: 10 additions & 8 deletions src/fast_array_utils/conv/_asarray.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations

from functools import singledispatch
from typing import TYPE_CHECKING, Any, cast

import numpy as np
from numpy.typing import NDArray

from .. import types
from .._import import lazy_singledispatch
from ..types import OutOfCoreDataset


if TYPE_CHECKING:
from numpy.typing import ArrayLike

from .. import types


__all__ = ["asarray"]


# fallback’s arg0 type has to include types of registered functions
@singledispatch
@lazy_singledispatch
def asarray(
x: ArrayLike
| types.CSBase
Expand All @@ -44,28 +46,28 @@ def asarray(
return np.asarray(x)


@asarray.register(types.CSBase)
@asarray.register("fast_array_utils.types:CSBase", "scipy.sparse")
def _(x: types.CSBase) -> NDArray[Any]:
from .scipy import to_dense

return to_dense(x)


@asarray.register(types.DaskArray)
@asarray.register("dask.array:Array")
def _(x: types.DaskArray) -> NDArray[Any]:
return asarray(x.compute()) # type: ignore[no-untyped-call]


@asarray.register(types.OutOfCoreDataset)
@asarray.register(OutOfCoreDataset)
def _(x: types.OutOfCoreDataset[types.CSBase | NDArray[Any]]) -> NDArray[Any]:
return asarray(x.to_memory())


@asarray.register(types.CupyArray)
@asarray.register("cupy:ndarray")
def _(x: types.CupyArray) -> NDArray[Any]:
return cast(NDArray[Any], x.get())


@asarray.register(types.CupySparseMatrix)
@asarray.register("cupyx.scipy.sparse:spmatrix")
def _(x: types.CupySparseMatrix) -> NDArray[Any]:
return cast(NDArray[Any], x.toarray().get())
14 changes: 8 additions & 6 deletions src/fast_array_utils/stats/_sum.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations

from functools import partial, singledispatch
from functools import partial
from typing import TYPE_CHECKING, Any, cast, overload

import numpy as np
from numpy.typing import NDArray

from .. import types
from .._import import lazy_singledispatch
from .._validation import validate_axis


Expand Down Expand Up @@ -54,30 +55,31 @@ def sum(
return _sum(x, axis=axis, dtype=dtype)


@singledispatch
@lazy_singledispatch
def _sum(
x: ArrayLike | types.CSBase | types.DaskArray,
/,
*,
axis: Literal[0, 1, None] = None,
dtype: DTypeLike | None = None,
) -> NDArray[Any] | np.number[Any] | types.DaskArray:
assert not isinstance(x, types.CSBase | types.DaskArray)
return cast(NDArray[Any] | np.number[Any], np.sum(x, axis=axis, dtype=dtype))


@_sum.register(types.CSBase)
@_sum.register("fast_array_utils.types:CSBase", "scipy.sparse")
def _(
x: types.CSBase, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
) -> NDArray[Any] | np.number[Any]:
import scipy.sparse as sp

if isinstance(x, types.CSMatrix):
from ..types import CSMatrix

if isinstance(x, CSMatrix):
x = sp.csr_array(x) if x.format == "csr" else sp.csc_array(x)
return cast(NDArray[Any] | np.number[Any], np.sum(x, axis=axis, dtype=dtype))


@_sum.register(types.DaskArray)
@_sum.register("dask.array:Array")
def _(
x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
) -> types.DaskArray:
Expand Down
100 changes: 69 additions & 31 deletions src/fast_array_utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,15 @@

from __future__ import annotations

from importlib.util import find_spec
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, runtime_checkable
from functools import cache
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast, runtime_checkable

from ._import import import_by_qualname


if TYPE_CHECKING:
from collections.abc import Callable
from types import UnionType


__all__ = [
Expand All @@ -20,61 +27,92 @@
T_co = TypeVar("T_co", covariant=True)


# scipy sparse
# registry for lazy exports:


_REGISTRY: dict[str, str | Callable[[], UnionType]] = {}


def _register(name: str) -> Callable[[Callable[[], UnionType]], Callable[[], UnionType]]:
def _decorator(fn: Callable[[], UnionType]) -> Callable[[], UnionType]:
_REGISTRY[name] = fn
return fn

return _decorator


@cache
def __getattr__(name: str) -> type | UnionType:
if (source := _REGISTRY.get(name)) is None:
# A name we don’t know about
raise AttributeError(name) from None

try:
if callable(source):
return source()

return cast(type, import_by_qualname(source))
except ImportError: # A name we can’t import
return type(name, (), {})


# lazy exports:


if TYPE_CHECKING:
from scipy.sparse import csc_array, csc_matrix, csr_array, csr_matrix

CSArray = csr_array | csc_array
CSMatrix = csr_matrix | csc_matrix
CSBase = CSMatrix | CSArray
else:
try: # cs?_array isn’t available in older scipy versions
from scipy.sparse import csc_array, csr_array
# cs?_array isn’t available in older scipy versions,
# so we import them separately

CSArray = csr_array | csc_array
except ImportError: # pragma: no cover
CSArray = type("CSArray", (), {})

try: # cs?_matrix is available when scipy is installed
@_register("CSMatrix")
def _get_cs_matrix() -> UnionType:
from scipy.sparse import csc_matrix, csr_matrix

CSMatrix = csr_matrix | csc_matrix
except ImportError: # pragma: no cover
CSMatrix = type("CSMatrix", (), {})
return csr_matrix | csc_matrix

CSBase = CSMatrix | CSArray
@_register("CSArray")
def _get_cs_array() -> UnionType:
from scipy.sparse import csc_array, csr_array

return csr_array | csc_array

if TYPE_CHECKING or find_spec("cupy"):
from cupy import ndarray as CupyArray
else: # pragma: no cover
CupyArray = type("ndarray", (), {})
@_register("CSBase")
def _get_cs_base() -> UnionType:
return __getattr__("CSMatrix") | __getattr__("CSArray")


if TYPE_CHECKING or find_spec("cupyx"):
if TYPE_CHECKING:
from cupy import ndarray as CupyArray
from cupyx.scipy.sparse import spmatrix as CupySparseMatrix
else: # pragma: no cover
CupySparseMatrix = type("spmatrix", (), {})
else:
_REGISTRY["CupyArray"] = "cupy:ndarray"
_REGISTRY["CupySparseMatrix"] = "cupyx.scipy.sparse:spmatrix"


if TYPE_CHECKING: # https://github.com/dask/dask/issues/8853
from dask.array.core import Array as DaskArray
elif find_spec("dask"):
from dask.array import Array as DaskArray
else: # pragma: no cover
DaskArray = type("array", (), {})
else:
_REGISTRY["DaskArray"] = "dask.array:Array"


if TYPE_CHECKING or find_spec("h5py"):
if TYPE_CHECKING:
from h5py import Dataset as H5Dataset
else: # pragma: no cover
H5Dataset = type("Dataset", (), {})
else:
_REGISTRY["H5Dataset"] = "h5py:Dataset"


if TYPE_CHECKING or find_spec("zarr"):
if TYPE_CHECKING:
from zarr import Array as ZarrArray
else: # pragma: no cover
ZarrArray = type("Array", (), {})
else:
_REGISTRY["ZarrArray"] = "zarr:Array"


# protocols:


@runtime_checkable
Expand Down
Loading