@@ -108,6 +108,7 @@ def _yield_sampler_checks(sampler):
108108 yield check_samplers_sparse
109109 if "dataframe" in tags ["X_types" ]:
110110 yield check_samplers_pandas
111+ yield check_samplers_pandas_sparse
111112 if "string" in tags ["X_types" ]:
112113 yield check_samplers_string
113114 if tags ["allow_nan" ]:
@@ -312,6 +313,34 @@ def check_samplers_sparse(name, sampler_orig):
312313 assert_allclose (y_res_sparse , y_res )
313314
314315
316+ def check_samplers_pandas_sparse (name , sampler_orig ):
317+ pd = pytest .importorskip ("pandas" )
318+ sampler = clone (sampler_orig )
319+ # Check that the samplers handle pandas dataframe and pandas series
320+ X , y = sample_dataset_generator ()
321+ X_df = pd .DataFrame (
322+ X , columns = [str (i ) for i in range (X .shape [1 ])], dtype = pd .SparseDtype (float , 0 )
323+ )
324+ y_s = pd .Series (y , name = "class" )
325+
326+ X_res_df , y_res_s = sampler .fit_resample (X_df , y_s )
327+ X_res , y_res = sampler .fit_resample (X , y )
328+
329+ # check that we return the same type for dataframes or series types
330+ assert isinstance (X_res_df , pd .DataFrame )
331+ assert isinstance (y_res_s , pd .Series )
332+
333+ for column_dtype in X_res_df .dtypes :
334+ assert isinstance (column_dtype , pd .SparseDtype )
335+
336+ assert X_df .columns .tolist () == X_res_df .columns .tolist ()
337+ assert y_s .name == y_res_s .name
338+
339+ # FIXME: we should use to_numpy with pandas >= 0.25
340+ assert_allclose (X_res_df .values , X_res )
341+ assert_allclose (y_res_s .values , y_res )
342+
343+
315344def check_samplers_pandas (name , sampler_orig ):
316345 pd = pytest .importorskip ("pandas" )
317346 sampler = clone (sampler_orig )
0 commit comments