diff --git a/doc/whats_new/v0.12.rst b/doc/whats_new/v0.12.rst index 1c4325356..88017b547 100644 --- a/doc/whats_new/v0.12.rst +++ b/doc/whats_new/v0.12.rst @@ -42,3 +42,9 @@ Deprecations - Deprecate `kind_sel` in :class:`~imblearn.under_sampling.NeighbourhoodCleaningRule. It will be removed in 0.14. The parameter does not have any effect. :pr:`1012` by :user:`Guillaume Lemaitre `. + +Enhancements +............ + +- Allows to output dataframe with sparse format if provided as input. + :pr:`1059` by :user:`ts2095 `. diff --git a/imblearn/utils/_validation.py b/imblearn/utils/_validation.py index a36e6d81b..bf1d8351f 100644 --- a/imblearn/utils/_validation.py +++ b/imblearn/utils/_validation.py @@ -10,6 +10,7 @@ from numbers import Integral, Real import numpy as np +from scipy.sparse import issparse from sklearn.base import clone from sklearn.neighbors import NearestNeighbors from sklearn.utils import check_array, column_or_1d @@ -61,7 +62,10 @@ def _transfrom_one(self, array, props): elif type_ == "dataframe": import pandas as pd - ret = pd.DataFrame(array, columns=props["columns"]) + if issparse(array): + ret = pd.DataFrame.sparse.from_spmatrix(array, columns=props["columns"]) + else: + ret = pd.DataFrame(array, columns=props["columns"]) ret = ret.astype(props["dtypes"]) elif type_ == "series": import pandas as pd diff --git a/imblearn/utils/estimator_checks.py b/imblearn/utils/estimator_checks.py index eae78099e..570427759 100644 --- a/imblearn/utils/estimator_checks.py +++ b/imblearn/utils/estimator_checks.py @@ -108,6 +108,7 @@ def _yield_sampler_checks(sampler): yield check_samplers_sparse if "dataframe" in tags["X_types"]: yield check_samplers_pandas + yield check_samplers_pandas_sparse if "string" in tags["X_types"]: yield check_samplers_string if tags["allow_nan"]: @@ -312,6 +313,34 @@ def check_samplers_sparse(name, sampler_orig): assert_allclose(y_res_sparse, y_res) +def check_samplers_pandas_sparse(name, sampler_orig): + pd = pytest.importorskip("pandas") + sampler = clone(sampler_orig) + # Check that the samplers handle pandas dataframe and pandas series + X, y = sample_dataset_generator() + X_df = pd.DataFrame( + X, columns=[str(i) for i in range(X.shape[1])], dtype=pd.SparseDtype(float, 0) + ) + y_s = pd.Series(y, name="class") + + X_res_df, y_res_s = sampler.fit_resample(X_df, y_s) + X_res, y_res = sampler.fit_resample(X, y) + + # check that we return the same type for dataframes or series types + assert isinstance(X_res_df, pd.DataFrame) + assert isinstance(y_res_s, pd.Series) + + for column_dtype in X_res_df.dtypes: + assert isinstance(column_dtype, pd.SparseDtype) + + assert X_df.columns.tolist() == X_res_df.columns.tolist() + assert y_s.name == y_res_s.name + + # FIXME: we should use to_numpy with pandas >= 0.25 + assert_allclose(X_res_df.values, X_res) + assert_allclose(y_res_s.values, y_res) + + def check_samplers_pandas(name, sampler_orig): pd = pytest.importorskip("pandas") sampler = clone(sampler_orig)