22from __future__ import annotations
33
44from functools import partial , singledispatch
5- from typing import TYPE_CHECKING , overload
5+ from typing import TYPE_CHECKING , Any , cast , overload
66
77import numpy as np
8+ from numpy .typing import NDArray
89
910from .. import types
1011
1112
1213if TYPE_CHECKING :
13- from typing import Any , Literal
14+ from typing import Literal
1415
15- from numpy .typing import ArrayLike , DTypeLike , NDArray
16+ from numpy .typing import ArrayLike , DTypeLike
1617
1718
1819@overload
1920def sum (
20- x : ArrayLike , / , * , axis : None = None , dtype : DTypeLike | None = None
21+ x : ArrayLike | types . ZarrArray , / , * , axis : None = None , dtype : DTypeLike | None = None
2122) -> np .number [Any ]: ...
2223@overload
2324def sum (
24- x : ArrayLike , / , * , axis : Literal [0 , 1 ], dtype : DTypeLike | None = None
25+ x : ArrayLike | types . ZarrArray , / , * , axis : Literal [0 , 1 ], dtype : DTypeLike | None = None
2526) -> NDArray [Any ]: ...
2627@overload
2728def sum (
@@ -30,7 +31,11 @@ def sum(
3031
3132
3233def sum (
33- x : ArrayLike , / , * , axis : Literal [0 , 1 , None ] = None , dtype : DTypeLike | None = None
34+ x : ArrayLike | types .ZarrArray ,
35+ / ,
36+ * ,
37+ axis : Literal [0 , 1 , None ] = None ,
38+ dtype : DTypeLike | None = None ,
3439) -> NDArray [Any ] | np .number [Any ] | types .DaskArray :
3540 """Sum over both or one axis.
3641
@@ -56,7 +61,7 @@ def _sum(
5661 dtype : DTypeLike | None = None ,
5762) -> NDArray [Any ] | np .number [Any ] | types .DaskArray :
5863 assert not isinstance (x , types .CSBase | types .DaskArray )
59- return np .sum (x , axis = axis , dtype = dtype ) # type: ignore[no-any-return]
64+ return cast ( NDArray [ Any ] | np .number [ Any ], np . sum (x , axis = axis , dtype = dtype ))
6065
6166
6267@_sum .register (types .CSBase )
6772
6873 if isinstance (x , types .CSMatrix ):
6974 x = sp .csr_array (x ) if x .format == "csr" else sp .csc_array (x )
70- return np .sum (x , axis = axis , dtype = dtype ) # type: ignore[no-any-return]
75+ return cast ( NDArray [ Any ] | np .number [ Any ], np . sum (x , axis = axis , dtype = dtype ))
7176
7277
7378@_sum .register (types .DaskArray )
@@ -108,11 +113,14 @@ def sum_drop_keepdims(
108113 # Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`)
109114 dtype = np .zeros (1 , dtype = x .dtype ).sum ().dtype
110115
111- return reduction ( # type: ignore[no-any-return,no-untyped-call]
112- x ,
113- sum_drop_keepdims ,
114- partial (np .sum , dtype = dtype ),
115- axis = axis ,
116- dtype = dtype ,
117- meta = np .array ([], dtype = dtype ),
116+ return cast (
117+ types .DaskArray ,
118+ reduction ( # type: ignore[no-untyped-call]
119+ x ,
120+ sum_drop_keepdims ,
121+ partial (np .sum , dtype = dtype ),
122+ axis = axis ,
123+ dtype = dtype ,
124+ meta = np .array ([], dtype = dtype ),
125+ ),
118126 )
0 commit comments