@@ -293,8 +293,8 @@ def test_mean_var_sparse_32(array_type: ArrayType[types.CSArray]) -> None:
293293 assert resid_fau < resid_skl
294294
295295
296- @pytest .mark .array_type (Flags .Dask )
297- def test_mean_var_pbmc_dask (array_type : ArrayType [types .DaskArray ], axis : Literal [ 0 , 1 ] | None , pbmc64k_reduced_raw : sps .csr_array [np .float32 ]) -> None :
296+ @pytest .mark .array_type ({ at for at in SUPPORTED_TYPES if at . flags & Flags .Sparse and at . flags & Flags . Dask } )
297+ def test_mean_var_pbmc_dask (array_type : ArrayType [types .DaskArray ], pbmc64k_reduced_raw : sps .csr_array [np .float32 ]) -> None :
298298 """Test float32 precision for bigger data.
299299
300300 This test is flaky for sparse-in-dask for some reason.
@@ -308,8 +308,8 @@ def test_mean_var_pbmc_dask(array_type: ArrayType[types.DaskArray], axis: Litera
308308 mat_sc = sc .datasets .pbmc68k_reduced ().raw .X
309309 np .testing .assert_array_equal (mat .toarray (), mat_sc .toarray ())
310310
311- mean_mat , var_mat = stats .mean_var (mat , axis = axis , correction = 1 )
312- mean_arr , var_arr = (to_np_dense_checked (a , axis , arr ) for a in stats .mean_var (arr , axis = axis , correction = 1 ))
311+ mean_mat , var_mat = stats .mean_var (mat , axis = 0 , correction = 1 )
312+ mean_arr , var_arr = (to_np_dense_checked (a , 0 , arr ) for a in stats .mean_var (arr , axis = 0 , correction = 1 ))
313313
314314 np .testing .assert_array_almost_equal (mean_arr , mean_mat )
315315 np .testing .assert_array_almost_equal (var_arr , var_mat )
0 commit comments