@@ -90,20 +90,42 @@ def dtype_in(request: pytest.FixtureRequest, array_type: ArrayType) -> type[DTyp
9090 return dtype
9191
9292
93- @pytest .fixture (scope = "session" , params = [np .float32 , np .float64 , None ])
93+ @pytest .fixture (scope = "session" , params = [np .float32 , np .float64 , np . int64 , None ])
9494def dtype_arg (request : pytest .FixtureRequest ) -> type [DTypeOut ] | None :
9595 return cast ("type[DTypeOut] | None" , request .param )
9696
9797
9898@pytest .fixture
9999def np_arr (dtype_in : type [DTypeIn ], ndim : Literal [1 , 2 ]) -> NDArray [DTypeIn ]:
100100 np_arr = cast ("NDArray[DTypeIn]" , np .array ([[1 , 0 ], [3 , 0 ], [5 , 6 ]], dtype = dtype_in ))
101+ if np .dtype (dtype_in ).kind == "f" :
102+ np_arr /= 4 # type: ignore[misc]
101103 np_arr .flags .writeable = False
102104 if ndim == 1 :
103105 np_arr = np_arr .flatten ()
104106 return np_arr
105107
106108
109+ def to_np_dense_checked (
110+ stat : NDArray [DTypeOut ] | np .number [Any ] | types .DaskArray , axis : Literal [0 , 1 ] | None , arr : CpuArray | GpuArray | DiskArray | types .DaskArray
111+ ) -> NDArray [DTypeOut ] | np .number [Any ]:
112+ match axis , arr :
113+ case _, types .DaskArray ():
114+ assert isinstance (stat , types .DaskArray ), type (stat )
115+ stat = stat .compute () # type: ignore[assignment]
116+ return to_np_dense_checked (stat , axis , arr .compute ())
117+ case None , _:
118+ assert isinstance (stat , np .floating | np .integer ), type (stat )
119+ case 0 | 1 , types .CupyArray () | types .CupyCSRMatrix () | types .CupyCSCMatrix () | types .CupyCOOMatrix ():
120+ assert isinstance (stat , types .CupyArray ), type (stat )
121+ return to_np_dense_checked (stat .get (), axis , arr .get ())
122+ case 0 | 1 , _:
123+ assert isinstance (stat , np .ndarray ), type (stat )
124+ case _:
125+ pytest .fail (f"Unhandled case axis { axis } for { type (arr )} : { type (stat )} " )
126+ return stat
127+
128+
107129@pytest .mark .array_type (skip = {* ATS_SPARSE_DS , Flags .Matrix })
108130@pytest .mark .parametrize ("func" , STAT_FUNCS )
109131@pytest .mark .parametrize (("ndim" , "axis" ), [(1 , 0 ), (2 , 3 ), (2 , - 1 )], ids = ["1d-ax0" , "2d-ax3" , "2d-axneg" ])
@@ -127,26 +149,13 @@ def test_sum(
127149 axis : Literal [0 , 1 ] | None ,
128150 np_arr : NDArray [DTypeIn ],
129151) -> None :
152+ if np .dtype (dtype_arg ).kind in "iu" and (array_type .flags & Flags .Gpu ) and (array_type .flags & Flags .Sparse ):
153+ pytest .skip ("GPU sparse matrices don’t support int dtypes" )
130154 arr = array_type (np_arr .copy ())
131155 assert arr .dtype == dtype_in
132156
133157 sum_ = stats .sum (arr , axis = axis , dtype = dtype_arg )
134-
135- match axis , arr :
136- case _, types .DaskArray ():
137- assert isinstance (sum_ , types .DaskArray ), type (sum_ )
138- sum_ = sum_ .compute () # type: ignore[assignment]
139- if isinstance (sum_ , types .CupyArray ):
140- sum_ = sum_ .get ()
141- case None , _:
142- assert isinstance (sum_ , np .floating | np .integer ), type (sum_ )
143- case 0 | 1 , types .CupyArray () | types .CupyCSRMatrix () | types .CupyCSCMatrix ():
144- assert isinstance (sum_ , types .CupyArray ), type (sum_ )
145- sum_ = sum_ .get ()
146- case 0 | 1 , _:
147- assert isinstance (sum_ , np .ndarray ), type (sum_ )
148- case _:
149- pytest .fail (f"Unhandled case axis { axis } for { type (arr )} : { type (sum_ )} " )
158+ sum_ = to_np_dense_checked (sum_ , axis , arr ) # type: ignore[arg-type]
150159
151160 assert sum_ .shape == () if axis is None else arr .shape [axis ], (sum_ .shape , arr .shape )
152161
@@ -161,6 +170,19 @@ def test_sum(
161170 np .testing .assert_array_equal (sum_ , expected )
162171
163172
173+ @pytest .mark .array_type (skip = {* ATS_SPARSE_DS , Flags .Gpu })
174+ def test_sum_to_int (array_type : ArrayType [CpuArray | DiskArray | types .DaskArray ], axis : Literal [0 , 1 ] | None ) -> None :
175+ rng = np .random .default_rng (0 )
176+ np_arr = rng .random ((100 , 100 ))
177+ arr = array_type (np_arr )
178+
179+ sum_ = stats .sum (arr , axis = axis , dtype = np .int64 )
180+ sum_ = to_np_dense_checked (sum_ , axis , arr )
181+
182+ expected = np .zeros (() if axis is None else arr .shape [axis ], dtype = np .int64 )
183+ np .testing .assert_array_equal (sum_ , expected )
184+
185+
164186@pytest .mark .parametrize (
165187 "data" ,
166188 [
0 commit comments