1717 from typing import Any , Literal
1818
1919 import numpy as np
20- from numpy .typing import ArrayLike , DTypeLike , NDArray
20+ from numpy .typing import DTypeLike , NDArray
2121 from optype .numpy import ToDType
2222
2323 from .. import types
2727
2828
2929@overload
30- def is_constant (a : types .DaskArray , / , * , axis : Literal [0 , 1 , None ] = None ) -> types .DaskArray : ...
30+ def is_constant (
31+ a : NDArray [Any ] | types .CSBase | types .CupyArray , / , * , axis : None = None
32+ ) -> bool : ...
33+ @overload
34+ def is_constant (a : NDArray [Any ] | types .CSBase , / , * , axis : Literal [0 , 1 ]) -> NDArray [np .bool ]: ...
3135@overload
32- def is_constant (a : CpuArray , / , * , axis : None = None ) -> bool : ...
36+ def is_constant (a : types . CupyArray , / , * , axis : Literal [ 0 , 1 ] ) -> types . CupyArray : ...
3337@overload
34- def is_constant (a : CpuArray , / , * , axis : Literal [0 , 1 ] ) -> NDArray [ np . bool ] : ...
38+ def is_constant (a : types . DaskArray , / , * , axis : Literal [0 , 1 , None ] = None ) -> types . DaskArray : ...
3539
3640
3741def is_constant (
38- a : CpuArray | types .DaskArray , / , * , axis : Literal [0 , 1 , None ] = None
39- ) -> bool | NDArray [np .bool ] | types .DaskArray :
42+ a : NDArray [Any ] | types .CSBase | types .CupyArray | types .DaskArray ,
43+ / ,
44+ * ,
45+ axis : Literal [0 , 1 , None ] = None ,
46+ ) -> bool | NDArray [np .bool ] | types .CupyArray | types .DaskArray :
4047 """Check whether values in array are constant.
4148
4249 Params
@@ -82,9 +89,13 @@ def mean(
8289) -> np .number [Any ]: ...
8390@overload
8491def mean (
85- x : CpuArray | GpuArray | DiskArray , / , * , axis : Literal [0 , 1 ], dtype : DTypeLike | None = None
92+ x : CpuArray | DiskArray , / , * , axis : Literal [0 , 1 ], dtype : DTypeLike | None = None
8693) -> NDArray [np .number [Any ]]: ...
8794@overload
95+ def mean (
96+ x : GpuArray , / , * , axis : Literal [0 , 1 ], dtype : DTypeLike | None = None
97+ ) -> types .CupyArray : ...
98+ @overload
8899def mean (
89100 x : types .DaskArray , / , * , axis : Literal [0 , 1 ], dtype : ToDType [Any ] | None = None
90101) -> types .DaskArray : ...
@@ -96,7 +107,7 @@ def mean(
96107 * ,
97108 axis : Literal [0 , 1 , None ] = None ,
98109 dtype : DTypeLike | None = None ,
99- ) -> NDArray [np .number [Any ]] | np .number [Any ] | types .DaskArray :
110+ ) -> NDArray [np .number [Any ]] | types . CupyArray | np .number [Any ] | types .DaskArray :
100111 """Mean over both or one axis.
101112
102113 Returns
@@ -115,11 +126,15 @@ def mean(
115126@overload
116127def mean_var (
117128 x : CpuArray | GpuArray , / , * , axis : Literal [None ] = None , correction : int = 0
129+ ) -> tuple [np .float64 , np .float64 ]: ...
130+ @overload
131+ def mean_var (
132+ x : CpuArray , / , * , axis : Literal [0 , 1 ], correction : int = 0
118133) -> tuple [NDArray [np .float64 ], NDArray [np .float64 ]]: ...
119134@overload
120135def mean_var (
121- x : CpuArray | GpuArray , / , * , axis : Literal [0 , 1 ], correction : int = 0
122- ) -> tuple [np . float64 , np . float64 ]: ...
136+ x : GpuArray , / , * , axis : Literal [0 , 1 ], correction : int = 0
137+ ) -> tuple [types . CupyArray , types . CupyArray ]: ...
123138@overload
124139def mean_var (
125140 x : types .DaskArray , / , * , axis : Literal [0 , 1 , None ] = None , correction : int = 0
@@ -133,8 +148,9 @@ def mean_var(
133148 axis : Literal [0 , 1 , None ] = None ,
134149 correction : int = 0 ,
135150) -> (
136- tuple [NDArray [np .float64 ], NDArray [np .float64 ]]
137- | tuple [np .float64 , np .float64 ]
151+ tuple [np .float64 , np .float64 ]
152+ | tuple [NDArray [np .float64 ], NDArray [np .float64 ]]
153+ | tuple [types .CupyArray , types .CupyArray ]
138154 | tuple [types .DaskArray , types .DaskArray ]
139155):
140156 """Mean and variance over both or one axis.
@@ -169,33 +185,29 @@ def mean_var(
169185# https://github.com/scverse/fast-array-utils/issues/52
170186@overload
171187def sum (
172- x : ArrayLike | CpuArray | GpuArray | DiskArray ,
173- / ,
174- * ,
175- axis : None = None ,
176- dtype : DTypeLike | None = None ,
188+ x : CpuArray | GpuArray | DiskArray , / , * , axis : None = None , dtype : DTypeLike | None = None
177189) -> np .number [Any ]: ...
178190@overload
179191def sum (
180- x : ArrayLike | CpuArray | GpuArray | DiskArray ,
181- / ,
182- * ,
183- axis : Literal [0 , 1 ],
184- dtype : DTypeLike | None = None ,
192+ x : CpuArray | DiskArray , / , * , axis : Literal [0 , 1 ], dtype : DTypeLike | None = None
185193) -> NDArray [Any ]: ...
186194@overload
195+ def sum (
196+ x : GpuArray , / , * , axis : Literal [0 , 1 ], dtype : DTypeLike | None = None
197+ ) -> types .CupyArray : ...
198+ @overload
187199def sum (
188200 x : types .DaskArray , / , * , axis : Literal [0 , 1 , None ] = None , dtype : DTypeLike | None = None
189201) -> types .DaskArray : ...
190202
191203
192204def sum (
193- x : ArrayLike | CpuArray | GpuArray | DiskArray | types .DaskArray ,
205+ x : CpuArray | GpuArray | DiskArray | types .DaskArray ,
194206 / ,
195207 * ,
196208 axis : Literal [0 , 1 , None ] = None ,
197209 dtype : DTypeLike | None = None ,
198- ) -> NDArray [Any ] | np .number [Any ] | types .DaskArray :
210+ ) -> NDArray [Any ] | types . CupyArray | np .number [Any ] | types .DaskArray :
199211 """Sum over both or one axis.
200212
201213 Returns
@@ -209,4 +221,4 @@ def sum(
209221
210222 """
211223 validate_axis (axis )
212- return sum_ (x , axis = axis , dtype = dtype ) # type: ignore[arg-type] # literally the same type, wtf mypy
224+ return sum_ (x , axis = axis , dtype = dtype )
0 commit comments