66
77import numpy as np
88import pytest
9+ from numpy .exceptions import AxisError
910
1011from fast_array_utils import stats , types
1112from testing .fast_array_utils import SUPPORTED_TYPES , Flags
2627 DTypeIn = np .float32 | np .float64 | np .int32 | np .bool
2728 DTypeOut = np .float32 | np .float64 | np .int64
2829
29- NdAndAx : TypeAlias = tuple [Literal [2 ], Literal [0 , 1 , None ]]
30+ NdAndAx : TypeAlias = tuple [Literal [1 ], Literal [ None ]] | tuple [ Literal [ 2 ], Literal [0 , 1 , None ]]
3031
31- class BenchFun (Protocol ): # noqa: D101
32+ class StatFun (Protocol ): # noqa: D101
3233 def __call__ ( # noqa: D102
3334 self ,
34- arr : CpuArray ,
35+ arr : Array ,
3536 * ,
3637 axis : Literal [0 , 1 , None ] = None ,
3738 dtype : type [DTypeOut ] | None = None ,
@@ -41,6 +42,8 @@ def __call__( # noqa: D102
4142pytestmark = [pytest .mark .skipif (not find_spec ("numba" ), reason = "numba not installed" )]
4243
4344
45+ STAT_FUNCS = [stats .sum , stats .mean , stats .mean_var , stats .is_constant ]
46+
4447# can’t select these using a category filter
4548ATS_SPARSE_DS = {at for at in SUPPORTED_TYPES if at .mod == "anndata.abc" }
4649ATS_CUPY_SPARSE = {at for at in SUPPORTED_TYPES if "cupyx.scipy" in str (at )}
@@ -49,6 +52,7 @@ def __call__( # noqa: D102
4952@pytest .fixture (
5053 scope = "session" ,
5154 params = [
55+ pytest .param ((1 , None ), id = "1d-all" ),
5256 pytest .param ((2 , None ), id = "2d-all" ),
5357 pytest .param ((2 , 0 ), id = "2d-ax0" ),
5458 pytest .param ((2 , 1 ), id = "2d-ax1" ),
@@ -59,18 +63,31 @@ def ndim_and_axis(request: pytest.FixtureRequest) -> NdAndAx:
5963
6064
6165@pytest .fixture
62- def ndim (ndim_and_axis : NdAndAx ) -> Literal [2 ]:
63- return ndim_and_axis [0 ]
66+ def ndim (ndim_and_axis : NdAndAx , array_type : ArrayType ) -> Literal [1 , 2 ]:
67+ return check_ndim (array_type , ndim_and_axis [0 ])
68+
69+
70+ def check_ndim (array_type : ArrayType , ndim : Literal [1 , 2 ]) -> Literal [1 , 2 ]:
71+ inner_cls = array_type .inner .cls if array_type .inner else array_type .cls
72+ if ndim != 2 and issubclass (inner_cls , types .CSMatrix | types .CupyCSMatrix ):
73+ pytest .skip ("CSMatrix only supports 2D" )
74+ if ndim != 2 and inner_cls is types .csc_array :
75+ pytest .skip ("csc_array only supports 2D" )
76+ return ndim
6477
6578
6679@pytest .fixture (scope = "session" )
6780def axis (ndim_and_axis : NdAndAx ) -> Literal [0 , 1 , None ]:
6881 return ndim_and_axis [1 ]
6982
7083
71- @pytest .fixture (scope = "session" , params = [np .float32 , np .float64 , np .int32 , np .bool ])
72- def dtype_in (request : pytest .FixtureRequest ) -> type [DTypeIn ]:
73- return cast ("type[DTypeIn]" , request .param )
84+ @pytest .fixture (params = [np .float32 , np .float64 , np .int32 , np .bool ])
85+ def dtype_in (request : pytest .FixtureRequest , array_type : ArrayType ) -> type [DTypeIn ]:
86+ dtype = cast ("type[DTypeIn]" , request .param )
87+ inner_cls = array_type .inner .cls if array_type .inner else array_type .cls
88+ if np .dtype (dtype ).kind not in "fdFD" and issubclass (inner_cls , types .CupyCSMatrix ):
89+ pytest .skip ("Cupy sparse matrices don’t support non-floating dtypes" )
90+ return dtype
7491
7592
7693@pytest .fixture (scope = "session" , params = [np .float32 , np .float64 , None ])
@@ -79,12 +96,33 @@ def dtype_arg(request: pytest.FixtureRequest) -> type[DTypeOut] | None:
7996
8097
8198@pytest .fixture
82- def np_arr (dtype_in : type [DTypeIn ]) -> NDArray [DTypeIn ]:
99+ def np_arr (dtype_in : type [DTypeIn ], ndim : Literal [ 1 , 2 ] ) -> NDArray [DTypeIn ]:
83100 np_arr = cast ("NDArray[DTypeIn]" , np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = dtype_in ))
84101 np_arr .flags .writeable = False
102+ if ndim == 1 :
103+ np_arr = np_arr .flatten ()
85104 return np_arr
86105
87106
107+ @pytest .mark .array_type (skip = {* ATS_SPARSE_DS , Flags .Matrix })
108+ @pytest .mark .parametrize ("func" , STAT_FUNCS )
109+ @pytest .mark .parametrize (
110+ ("ndim" , "axis" ), [(1 , 0 ), (2 , 3 ), (2 , - 1 )], ids = ["1d-ax0" , "2d-ax3" , "2d-axneg" ]
111+ )
112+ def test_ndim_error (
113+ array_type : ArrayType [Array ], func : StatFun , ndim : Literal [1 , 2 ], axis : Literal [0 , 1 , None ]
114+ ) -> None :
115+ check_ndim (array_type , ndim )
116+ # not using the fixture because we don’t need to test multiple dtypes
117+ np_arr = np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = np .float32 )
118+ if ndim == 1 :
119+ np_arr = np_arr .flatten ()
120+ arr = array_type (np_arr )
121+
122+ with pytest .raises (AxisError ):
123+ func (arr , axis = axis )
124+
125+
88126@pytest .mark .array_type (skip = ATS_SPARSE_DS )
89127def test_sum (
90128 array_type : ArrayType [Array ],
@@ -93,8 +131,6 @@ def test_sum(
93131 axis : Literal [0 , 1 , None ],
94132 np_arr : NDArray [DTypeIn ],
95133) -> None :
96- if array_type in ATS_CUPY_SPARSE and np_arr .dtype .kind != "f" :
97- pytest .skip ("CuPy sparse matrices only support floats" )
98134 arr = array_type (np_arr .copy ())
99135 assert arr .dtype == dtype_in
100136
@@ -133,8 +169,6 @@ def test_sum(
133169def test_mean (
134170 array_type : ArrayType [Array ], axis : Literal [0 , 1 , None ], np_arr : NDArray [DTypeIn ]
135171) -> None :
136- if array_type in ATS_CUPY_SPARSE and np_arr .dtype .kind != "f" :
137- pytest .skip ("CuPy sparse matrices only support floats" )
138172 arr = array_type (np_arr )
139173
140174 result = stats .mean (arr , axis = axis ) # type: ignore[arg-type] # https://github.com/python/mypy/issues/16777
@@ -148,26 +182,21 @@ def test_mean(
148182
149183
150184@pytest .mark .array_type (skip = Flags .Disk )
151- @pytest .mark .parametrize (
152- ("axis" , "mean_expected" , "var_expected" ),
153- [(None , 3.5 , 3.5 ), (0 , [2.5 , 3.5 , 4.5 ], [4.5 , 4.5 , 4.5 ]), (1 , [2.0 , 5.0 ], [1.0 , 1.0 ])],
154- )
155185def test_mean_var (
156186 array_type : ArrayType [CpuArray | GpuArray | types .DaskArray ],
157187 axis : Literal [0 , 1 , None ],
158- mean_expected : float | list [float ],
159- var_expected : float | list [float ],
188+ np_arr : NDArray [DTypeIn ],
160189) -> None :
161- np_arr = np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = np .float64 )
162- np .testing .assert_array_equal (np .mean (np_arr , axis = axis ), mean_expected )
163- np .testing .assert_array_equal (np .var (np_arr , axis = axis , correction = 1 ), var_expected )
164-
165190 arr = array_type (np_arr )
191+
166192 mean , var = stats .mean_var (arr , axis = axis , correction = 1 )
167193 if isinstance (mean , types .DaskArray ) and isinstance (var , types .DaskArray ):
168194 mean , var = mean .compute (), var .compute () # type: ignore[assignment]
169195 if isinstance (mean , types .CupyArray ) and isinstance (var , types .CupyArray ):
170196 mean , var = mean .get (), var .get ()
197+
198+ mean_expected = np .mean (np_arr , axis = axis ) # type: ignore[arg-type]
199+ var_expected = np .var (np_arr , axis = axis , correction = 1 ) # type: ignore[arg-type]
171200 np .testing .assert_array_equal (mean , mean_expected )
172201 np .testing .assert_array_almost_equal (var , var_expected ) # type: ignore[arg-type]
173202
@@ -223,11 +252,11 @@ def test_dask_constant_blocks(
223252
224253@pytest .mark .benchmark
225254@pytest .mark .array_type (skip = Flags .Matrix | Flags .Dask | Flags .Disk | Flags .Gpu )
226- @pytest .mark .parametrize ("func" , [ stats . sum , stats . mean , stats . mean_var , stats . is_constant ] )
255+ @pytest .mark .parametrize ("func" , STAT_FUNCS )
227256@pytest .mark .parametrize ("dtype" , [np .float32 , np .float64 , np .int32 ])
228257def test_stats_benchmark (
229258 benchmark : BenchmarkFixture ,
230- func : BenchFun ,
259+ func : StatFun ,
231260 array_type : ArrayType [CpuArray , None ],
232261 axis : Literal [0 , 1 , None ],
233262 dtype : type [np .float32 | np .float64 ],
0 commit comments