Skip to content

Commit d47de41

Browse files
authored
CuPy support (#51)
1 parent 7319f89 commit d47de41

File tree

27 files changed

+338
-108
lines changed

27 files changed

+338
-108
lines changed

.cirun.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
runners:
2+
- name: aws-gpu-runner
3+
cloud: aws
4+
instance_type: g4dn.xlarge
5+
machine_image: ami-067a4ba2816407ee9
6+
region: eu-north-1
7+
preemptible:
8+
- true
9+
- false
10+
labels:
11+
- cirun-aws-gpu

.editorconfig

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@ max_line_length = 88
99
indent_size = 4
1010
indent_style = space
1111

12-
[*.toml]
12+
[*.{toml,yml,yaml}]
1313
indent_size = 2
1414
max_line_length = 120

.github/workflows/ci-gpu.yml

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
name: GPU-CI
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
types:
8+
- labeled
9+
- opened
10+
- synchronize
11+
12+
env:
13+
PYTEST_ADDOPTS: "-v --color=yes"
14+
FORCE_COLOR: "1"
15+
UV_HTTP_TIMEOUT: 120
16+
17+
concurrency:
18+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
19+
cancel-in-progress: true
20+
21+
jobs:
22+
check:
23+
name: Check Label
24+
runs-on: ubuntu-latest
25+
steps:
26+
- uses: flying-sheep/check@v1
27+
with:
28+
success: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'run-gpu-ci') }}
29+
test:
30+
name: All Tests
31+
needs: check
32+
runs-on: "cirun-aws-gpu--${{ github.run_id }}"
33+
timeout-minutes: 30
34+
defaults:
35+
run:
36+
shell: bash -el {0}
37+
steps:
38+
- uses: actions/checkout@v4
39+
with:
40+
fetch-depth: 0
41+
- name: Check NVIDIA SMI
42+
run: nvidia-smi
43+
- uses: actions/setup-python@v5
44+
with:
45+
python-version: "3.12"
46+
- uses: hynek/setup-cached-uv@v2
47+
with:
48+
cache-dependency-path: pyproject.toml
49+
- name: Install package
50+
run: uv pip install --system -e .[test,full] cupy-cuda12x --extra-index-url=https://pypi.nvidia.com --index-strategy=unsafe-best-match
51+
- name: List installed packages
52+
run: uv pip list
53+
- name: Run tests
54+
run: |
55+
coverage run -m pytest -m "not benchmark"
56+
coverage report
57+
# https://github.com/codecov/codecov-cli/issues/648
58+
coverage xml
59+
rm test-data/.coverage
60+
- uses: codecov/codecov-action@v5
61+
with:
62+
name: GPU Tests
63+
fail_ci_if_error: true
64+
files: test-data/coverage.xml
65+
token: ${{ secrets.CODECOV_TOKEN }}
66+
- name: Remove “run-gpu-ci” Label
67+
if: always()
68+
uses: actions-ecosystem/action-remove-labels@v1
69+
with:
70+
labels: run-gpu-ci
71+
github_token: ${{ secrets.GITHUB_TOKEN }}

.github/workflows/ci.yml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
22

3-
name: Python
3+
name: CI
44

55
on:
66
push:
@@ -13,11 +13,11 @@ env:
1313

1414
jobs:
1515
test:
16+
name: Min Tests
1617
runs-on: ubuntu-latest
1718
strategy:
1819
matrix:
1920
python-version: ["3.11", "3.13"]
20-
extras: [min, full]
2121
steps:
2222
- uses: actions/checkout@v4
2323
- uses: actions/setup-python@v5
@@ -27,7 +27,7 @@ jobs:
2727
with:
2828
enable-cache: true
2929
cache-dependency-glob: pyproject.toml
30-
- run: uv pip install --system -e .[test${{ matrix.extras == 'full' && ',full' || '' }}]
30+
- run: uv pip install --system -e .[test]
3131
- run: |
3232
coverage run -m pytest -m "not benchmark"
3333
coverage report
@@ -36,10 +36,12 @@ jobs:
3636
rm test-data/.coverage
3737
- uses: codecov/codecov-action@v5
3838
with:
39+
name: Min Tests
3940
fail_ci_if_error: true
4041
files: test-data/coverage.xml
4142
token: ${{ secrets.CODECOV_TOKEN }}
4243
bench:
44+
name: CPU Benchmarks
4345
runs-on: ubuntu-latest
4446
steps:
4547
- uses: actions/checkout@v4
@@ -56,6 +58,7 @@ jobs:
5658
run: pytest -m benchmark --codspeed
5759
token: ${{ secrets.CODSPEED_TOKEN }}
5860
check:
61+
name: Static Checks
5962
runs-on: ubuntu-latest
6063
strategy:
6164
matrix:

src/fast_array_utils/conv/_to_dense.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,9 @@ def _to_dense_cs(x: types.CSBase, /, *, to_memory: bool = False) -> NDArray[Any]
3737
def _to_dense_dask(
3838
x: types.DaskArray, /, *, to_memory: bool = False
3939
) -> NDArray[Any] | types.DaskArray:
40-
import dask.array as da
41-
4240
from . import to_dense
4341

44-
x = da.map_blocks(to_dense, x)
42+
x = x.map_blocks(lambda x: to_dense(x, to_memory=to_memory))
4543
return x.compute() if to_memory else x # type: ignore[return-value]
4644

4745

@@ -56,7 +54,7 @@ def _to_dense_ooc(x: types.CSDataset, /, *, to_memory: bool = False) -> NDArray[
5654
return to_dense(cast("types.CSBase", x.to_memory()))
5755

5856

59-
@to_dense_.register(GpuArray) # type: ignore[call-overload,misc]
57+
@to_dense_.register(types.CupyArray | types.CupyCSMatrix) # type: ignore[call-overload,misc]
6058
def _to_dense_cupy(x: GpuArray, /, *, to_memory: bool = False) -> NDArray[Any] | types.CupyArray:
61-
x = x.toarray() if isinstance(x, types.CupySparseMatrix) else x
59+
x = x.toarray() if isinstance(x, types.CupyCSMatrix) else x
6260
return x.get() if to_memory else x

src/fast_array_utils/stats/__init__.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Any, Literal
1818

1919
import numpy as np
20-
from numpy.typing import ArrayLike, DTypeLike, NDArray
20+
from numpy.typing import DTypeLike, NDArray
2121
from optype.numpy import ToDType
2222

2323
from .. import types
@@ -27,16 +27,23 @@
2727

2828

2929
@overload
30-
def is_constant(a: types.DaskArray, /, *, axis: Literal[0, 1, None] = None) -> types.DaskArray: ...
30+
def is_constant(
31+
a: NDArray[Any] | types.CSBase | types.CupyArray, /, *, axis: None = None
32+
) -> bool: ...
33+
@overload
34+
def is_constant(a: NDArray[Any] | types.CSBase, /, *, axis: Literal[0, 1]) -> NDArray[np.bool]: ...
3135
@overload
32-
def is_constant(a: CpuArray, /, *, axis: None = None) -> bool: ...
36+
def is_constant(a: types.CupyArray, /, *, axis: Literal[0, 1]) -> types.CupyArray: ...
3337
@overload
34-
def is_constant(a: CpuArray, /, *, axis: Literal[0, 1]) -> NDArray[np.bool]: ...
38+
def is_constant(a: types.DaskArray, /, *, axis: Literal[0, 1, None] = None) -> types.DaskArray: ...
3539

3640

3741
def is_constant(
38-
a: CpuArray | types.DaskArray, /, *, axis: Literal[0, 1, None] = None
39-
) -> bool | NDArray[np.bool] | types.DaskArray:
42+
a: NDArray[Any] | types.CSBase | types.CupyArray | types.DaskArray,
43+
/,
44+
*,
45+
axis: Literal[0, 1, None] = None,
46+
) -> bool | NDArray[np.bool] | types.CupyArray | types.DaskArray:
4047
"""Check whether values in array are constant.
4148
4249
Params
@@ -82,9 +89,13 @@ def mean(
8289
) -> np.number[Any]: ...
8390
@overload
8491
def mean(
85-
x: CpuArray | GpuArray | DiskArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None
92+
x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None
8693
) -> NDArray[np.number[Any]]: ...
8794
@overload
95+
def mean(
96+
x: GpuArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None
97+
) -> types.CupyArray: ...
98+
@overload
8899
def mean(
89100
x: types.DaskArray, /, *, axis: Literal[0, 1], dtype: ToDType[Any] | None = None
90101
) -> types.DaskArray: ...
@@ -96,7 +107,7 @@ def mean(
96107
*,
97108
axis: Literal[0, 1, None] = None,
98109
dtype: DTypeLike | None = None,
99-
) -> NDArray[np.number[Any]] | np.number[Any] | types.DaskArray:
110+
) -> NDArray[np.number[Any]] | types.CupyArray | np.number[Any] | types.DaskArray:
100111
"""Mean over both or one axis.
101112
102113
Returns
@@ -115,11 +126,15 @@ def mean(
115126
@overload
116127
def mean_var(
117128
x: CpuArray | GpuArray, /, *, axis: Literal[None] = None, correction: int = 0
129+
) -> tuple[np.float64, np.float64]: ...
130+
@overload
131+
def mean_var(
132+
x: CpuArray, /, *, axis: Literal[0, 1], correction: int = 0
118133
) -> tuple[NDArray[np.float64], NDArray[np.float64]]: ...
119134
@overload
120135
def mean_var(
121-
x: CpuArray | GpuArray, /, *, axis: Literal[0, 1], correction: int = 0
122-
) -> tuple[np.float64, np.float64]: ...
136+
x: GpuArray, /, *, axis: Literal[0, 1], correction: int = 0
137+
) -> tuple[types.CupyArray, types.CupyArray]: ...
123138
@overload
124139
def mean_var(
125140
x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, correction: int = 0
@@ -133,8 +148,9 @@ def mean_var(
133148
axis: Literal[0, 1, None] = None,
134149
correction: int = 0,
135150
) -> (
136-
tuple[NDArray[np.float64], NDArray[np.float64]]
137-
| tuple[np.float64, np.float64]
151+
tuple[np.float64, np.float64]
152+
| tuple[NDArray[np.float64], NDArray[np.float64]]
153+
| tuple[types.CupyArray, types.CupyArray]
138154
| tuple[types.DaskArray, types.DaskArray]
139155
):
140156
"""Mean and variance over both or one axis.
@@ -169,33 +185,29 @@ def mean_var(
169185
# https://github.com/scverse/fast-array-utils/issues/52
170186
@overload
171187
def sum(
172-
x: ArrayLike | CpuArray | GpuArray | DiskArray,
173-
/,
174-
*,
175-
axis: None = None,
176-
dtype: DTypeLike | None = None,
188+
x: CpuArray | GpuArray | DiskArray, /, *, axis: None = None, dtype: DTypeLike | None = None
177189
) -> np.number[Any]: ...
178190
@overload
179191
def sum(
180-
x: ArrayLike | CpuArray | GpuArray | DiskArray,
181-
/,
182-
*,
183-
axis: Literal[0, 1],
184-
dtype: DTypeLike | None = None,
192+
x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None
185193
) -> NDArray[Any]: ...
186194
@overload
195+
def sum(
196+
x: GpuArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None
197+
) -> types.CupyArray: ...
198+
@overload
187199
def sum(
188200
x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
189201
) -> types.DaskArray: ...
190202

191203

192204
def sum(
193-
x: ArrayLike | CpuArray | GpuArray | DiskArray | types.DaskArray,
205+
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
194206
/,
195207
*,
196208
axis: Literal[0, 1, None] = None,
197209
dtype: DTypeLike | None = None,
198-
) -> NDArray[Any] | np.number[Any] | types.DaskArray:
210+
) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray:
199211
"""Sum over both or one axis.
200212
201213
Returns
@@ -209,4 +221,4 @@ def sum(
209221
210222
"""
211223
validate_axis(axis)
212-
return sum_(x, axis=axis, dtype=dtype) # type: ignore[arg-type] # literally the same type, wtf mypy
224+
return sum_(x, axis=axis, dtype=dtype)

src/fast_array_utils/stats/_is_constant.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,18 @@
2020

2121
@singledispatch
2222
def is_constant_(
23-
a: NDArray[Any] | types.CSBase | types.DaskArray, /, *, axis: Literal[0, 1, None] = None
24-
) -> bool | NDArray[np.bool] | types.DaskArray: # pragma: no cover
23+
a: NDArray[Any] | types.CSBase | types.CupyArray | types.DaskArray,
24+
/,
25+
*,
26+
axis: Literal[0, 1, None] = None,
27+
) -> bool | NDArray[np.bool] | types.CupyArray | types.DaskArray: # pragma: no cover
2528
raise NotImplementedError
2629

2730

28-
@is_constant_.register(np.ndarray)
31+
@is_constant_.register(np.ndarray | types.CupyArray) # type: ignore[call-overload,misc]
2932
def _is_constant_ndarray(
30-
a: NDArray[Any], /, *, axis: Literal[0, 1, None] = None
31-
) -> bool | NDArray[np.bool]:
33+
a: NDArray[Any] | types.CupyArray, /, *, axis: Literal[0, 1, None] = None
34+
) -> bool | NDArray[np.bool] | types.CupyArray:
3235
# Should eventually support nd, not now.
3336
match axis:
3437
case None:
@@ -39,7 +42,7 @@ def _is_constant_ndarray(
3942
return _is_constant_rows(a)
4043

4144

42-
def _is_constant_rows(a: NDArray[Any]) -> NDArray[np.bool]:
45+
def _is_constant_rows(a: NDArray[Any] | types.CupyArray) -> NDArray[np.bool] | types.CupyArray:
4346
b = np.broadcast_to(a[:, 0][:, np.newaxis], a.shape)
4447
return cast(NDArray[np.bool], (a == b).all(axis=1))
4548

src/fast_array_utils/stats/_mean_var.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def mean_var_(
2727
correction: int = 0,
2828
) -> (
2929
tuple[NDArray[np.float64], NDArray[np.float64]]
30+
| tuple[types.CupyArray, types.CupyArray]
3031
| tuple[np.float64, np.float64]
3132
| tuple[types.DaskArray, types.DaskArray]
3233
):

src/fast_array_utils/stats/_power.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ def _power(x: Array, n: int, /) -> Array:
3131
return x**n # type: ignore[operator]
3232

3333

34-
@_power.register(types.CSMatrix) # type: ignore[call-overload,misc]
35-
def _power_cs(x: types.CSMatrix, n: int, /) -> types.CSMatrix:
34+
@_power.register(types.CSMatrix | types.CupyCSMatrix) # type: ignore[call-overload,misc]
35+
def _power_cs(
36+
x: types.CSMatrix | types.CupyCSMatrix, n: int, /
37+
) -> types.CSMatrix | types.CupyCSMatrix:
3638
return x.power(n)
3739

3840

0 commit comments

Comments
 (0)