22from __future__ import annotations
33
44from functools import partial , singledispatch
5- from typing import TYPE_CHECKING , cast
5+ from typing import TYPE_CHECKING , Literal , cast
66
77import numpy as np
8+ from numpy .exceptions import AxisError
89
910from .. import types
1011
1112
1213if TYPE_CHECKING :
13- from typing import Any , Literal
14+ from typing import Any , Literal , TypeAlias
1415
1516 from numpy .typing import DTypeLike , NDArray
1617
1718 from ..typing import CpuArray , DiskArray , GpuArray
1819
20+ ComplexAxis : TypeAlias = (
21+ tuple [Literal [0 ], Literal [1 ]] | tuple [Literal [0 , 1 ]] | Literal [0 , 1 , None ]
22+ )
23+
1924
2025@singledispatch
2126def sum_ (
@@ -24,7 +29,9 @@ def sum_(
2429 * ,
2530 axis : Literal [0 , 1 , None ] = None ,
2631 dtype : DTypeLike | None = None ,
32+ keep_cupy_as_array : bool = False ,
2733) -> NDArray [Any ] | np .number [Any ] | types .CupyArray | types .DaskArray :
34+ del keep_cupy_as_array
2835 if TYPE_CHECKING :
2936 # these are never passed to this fallback function, but `singledispatch` wants them
3037 assert not isinstance (
@@ -37,16 +44,31 @@ def sum_(
3744
3845@sum_ .register (types .CupyArray | types .CupyCSMatrix ) # type: ignore[call-overload,misc]
3946def _sum_cupy (
40- x : GpuArray , / , * , axis : Literal [0 , 1 , None ] = None , dtype : DTypeLike | None = None
47+ x : GpuArray ,
48+ / ,
49+ * ,
50+ axis : Literal [0 , 1 , None ] = None ,
51+ dtype : DTypeLike | None = None ,
52+ keep_cupy_as_array : bool = False ,
4153) -> types .CupyArray | np .number [Any ]:
4254 arr = cast ("types.CupyArray" , np .sum (x , axis = axis , dtype = dtype ))
43- return cast ("np.number[Any]" , arr .get ()[()]) if axis is None else arr .squeeze ()
55+ return (
56+ cast ("np.number[Any]" , arr .get ()[()])
57+ if not keep_cupy_as_array and axis is None
58+ else arr .squeeze ()
59+ )
4460
4561
4662@sum_ .register (types .CSBase ) # type: ignore[call-overload,misc]
4763def _sum_cs (
48- x : types .CSBase , / , * , axis : Literal [0 , 1 , None ] = None , dtype : DTypeLike | None = None
64+ x : types .CSBase ,
65+ / ,
66+ * ,
67+ axis : Literal [0 , 1 , None ] = None ,
68+ dtype : DTypeLike | None = None ,
69+ keep_cupy_as_array : bool = False ,
4970) -> NDArray [Any ] | np .number [Any ]:
71+ del keep_cupy_as_array
5072 import scipy .sparse as sp
5173
5274 if isinstance (x , types .CSMatrix ):
@@ -59,49 +81,92 @@ def _sum_cs(
5981
6082@sum_ .register (types .DaskArray )
6183def _sum_dask (
62- x : types .DaskArray , / , * , axis : Literal [0 , 1 , None ] = None , dtype : DTypeLike | None = None
84+ x : types .DaskArray ,
85+ / ,
86+ * ,
87+ axis : Literal [0 , 1 , None ] = None ,
88+ dtype : DTypeLike | None = None ,
89+ keep_cupy_as_array : bool = False ,
6390) -> types .DaskArray :
6491 import dask .array as da
6592
66- from . import sum
67-
6893 if isinstance (x ._meta , np .matrix ): # pragma: no cover # noqa: SLF001
6994 msg = "sum does not support numpy matrices"
7095 raise TypeError (msg )
7196
72- def sum_drop_keepdims (
73- a : CpuArray ,
74- / ,
75- * ,
76- axis : tuple [Literal [0 ], Literal [1 ]] | Literal [0 , 1 , None ] = None ,
77- dtype : DTypeLike | None = None ,
78- keepdims : bool = False ,
79- ) -> NDArray [Any ] | types .CupyArray :
80- del keepdims
81- if a .ndim == 1 :
82- axis = None
83- else :
84- match axis :
85- case (0 , 1 ) | (1 , 0 ):
86- axis = None
87- case (0 | 1 as n ,):
88- axis = n
89- case tuple (): # pragma: no cover
90- msg = f"`sum` can only sum over `axis=0|1|(0,1)` but got { axis } instead"
91- raise ValueError (msg )
92- rv = sum (a , axis = axis , dtype = dtype )
93- shape = (1 ,) if a .ndim == 1 else (1 , 1 if rv .shape == () else len (rv )) # type: ignore[arg-type]
94- return np .reshape (rv , shape )
95-
9697 if dtype is None :
9798 # Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`)
9899 dtype = np .zeros (1 , dtype = x .dtype ).sum ().dtype
99100
100- return da .reduction (
101+ rv = da .reduction (
101102 x ,
102- sum_drop_keepdims , # type: ignore[arg-type]
103- partial (np . sum , dtype = dtype ), # pyright: ignore[reportArgumentType]
103+ sum_dask_inner , # type: ignore[arg-type]
104+ partial (sum_dask_inner , dtype = dtype ), # pyright: ignore[reportArgumentType]
104105 axis = axis ,
105106 dtype = dtype ,
106107 meta = np .array ([], dtype = dtype ),
107108 )
109+
110+ if axis is not None or (
111+ isinstance (rv ._meta , types .CupyArray ) # noqa: SLF001
112+ and keep_cupy_as_array
113+ ):
114+ return rv
115+
116+ def to_scalar (a : types .CupyArray | NDArray [Any ]) -> np .number [Any ]:
117+ if isinstance (a , types .CupyArray ):
118+ a = a .get ()
119+ return a .reshape (())[()] # type: ignore[return-value]
120+
121+ return rv .map_blocks (to_scalar , meta = x .dtype .type (0 )) # type: ignore[arg-type]
122+
123+
124+ def sum_dask_inner (
125+ a : CpuArray | GpuArray ,
126+ / ,
127+ * ,
128+ axis : ComplexAxis = None ,
129+ dtype : DTypeLike | None = None ,
130+ keepdims : bool = False ,
131+ ) -> NDArray [Any ] | types .CupyArray :
132+ from . import sum
133+
134+ axis = normalize_axis (axis , a .ndim )
135+ rv = sum (a , axis = axis , dtype = dtype , keep_cupy_as_array = True ) # type: ignore[misc,arg-type]
136+ shape = get_shape (rv , axis = axis , keepdims = keepdims )
137+ return cast ("NDArray[Any] | types.CupyArray" , rv .reshape (shape ))
138+
139+
140+ def normalize_axis (axis : ComplexAxis , ndim : int ) -> Literal [0 , 1 , None ]:
141+ """Adapt `axis` parameter passed by Dask to what we support."""
142+ match axis :
143+ case int () | None :
144+ pass
145+ case (0 | 1 ,):
146+ axis = axis [0 ]
147+ case (0 , 1 ) | (1 , 0 ):
148+ axis = None
149+ case _: # pragma: no cover
150+ raise AxisError (axis , ndim ) # type: ignore[call-overload]
151+ if axis == 0 and ndim == 1 :
152+ return None # dask’s aggregate doesn’t know we don’t accept `axis=0` for 1D arrays
153+ return axis
154+
155+
156+ def get_shape (
157+ a : NDArray [Any ] | np .number [Any ] | types .CupyArray , * , axis : Literal [0 , 1 , None ], keepdims : bool
158+ ) -> tuple [int ] | tuple [int , int ]:
159+ """Get the output shape of an axis-flattening operation."""
160+ match keepdims , a .ndim :
161+ case False , 0 :
162+ return (1 ,)
163+ case True , 0 :
164+ return (1 , 1 )
165+ case False , 1 :
166+ return (a .size ,)
167+ case True , 1 :
168+ assert axis is not None
169+ return (1 , a .size ) if axis == 0 else (a .size , 1 )
170+ # pragma: no cover
171+ msg = f"{ keepdims = } , { type (a )} "
172+ raise AssertionError (msg )
0 commit comments