Skip to content

Commit 2013939

Browse files
authored
FIX handle sparse matrices in SMOTEN (#1003)
1 parent c65e21f commit 2013939

File tree

3 files changed

+44
-3
lines changed

3 files changed

+44
-3
lines changed

doc/whats_new/v0.11.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ Bug fixes
1313
`bool` and `pd.category` by delegating the conversion to scikit-learn encoder.
1414
:pr:`1002` by :user:`Guillaume Lemaitre <glemaitre>`.
1515

16+
- Handle sparse matrices in :class:`~imblearn.over_sampling.SMOTEN` and raise a warning
17+
since it requires a conversion to dense matrices.
18+
:pr:`1003` by :user:`Guillaume Lemaitre <glemaitre>`.
19+
1620
Compatibility
1721
.............
1822

imblearn/over_sampling/_smote/base.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import numpy as np
1515
from scipy import sparse
1616
from sklearn.base import clone
17+
from sklearn.exceptions import DataConversionWarning
1718
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
1819
from sklearn.utils import _safe_indexing, check_array, check_random_state
1920
from sklearn.utils.sparsefuncs_fast import (
@@ -893,7 +894,7 @@ def _check_X_y(self, X, y):
893894
y,
894895
reset=True,
895896
dtype=None,
896-
accept_sparse=False,
897+
accept_sparse=["csr", "csc"],
897898
)
898899
return X, y, binarize_y
899900

@@ -927,6 +928,17 @@ def _fit_resample(self, X, y):
927928
FutureWarning,
928929
)
929930

931+
if sparse.issparse(X):
932+
X_sparse_format = X.format
933+
X = X.toarray()
934+
warnings.warn(
935+
"Passing a sparse matrix to SMOTEN is not really efficient since it is"
936+
" converted to a dense array internally.",
937+
DataConversionWarning,
938+
)
939+
else:
940+
X_sparse_format = None
941+
930942
self._validate_estimator()
931943

932944
X_resampled = [X.copy()]
@@ -964,7 +976,12 @@ def _fit_resample(self, X, y):
964976
X_resampled = np.vstack(X_resampled)
965977
y_resampled = np.hstack(y_resampled)
966978

967-
return X_resampled, y_resampled
979+
if X_sparse_format == "csr":
980+
return sparse.csr_matrix(X_resampled), y_resampled
981+
elif X_sparse_format == "csc":
982+
return sparse.csc_matrix(X_resampled), y_resampled
983+
else:
984+
return X_resampled, y_resampled
968985

969986
def _more_tags(self):
970987
return {"X_types": ["2darray", "dataframe", "string"]}

imblearn/over_sampling/_smote/tests/test_smoten.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
22
import pytest
3-
from sklearn.preprocessing import OrdinalEncoder
3+
from sklearn.exceptions import DataConversionWarning
4+
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
5+
from sklearn.utils._testing import _convert_container
46

57
from imblearn.over_sampling import SMOTEN
68

@@ -56,6 +58,24 @@ def test_smoten_resampling():
5658
np.testing.assert_array_equal(y_generated, "not apple")
5759

5860

61+
@pytest.mark.parametrize("sparse_format", ["sparse_csr", "sparse_csc"])
62+
def test_smoten_sparse_input(data, sparse_format):
63+
"""Check that we handle sparse input in SMOTEN even if it is not efficient.
64+
65+
Non-regression test for:
66+
https://github.com/scikit-learn-contrib/imbalanced-learn/issues/971
67+
"""
68+
X, y = data
69+
X = OneHotEncoder().fit_transform(X)
70+
X = _convert_container(X, sparse_format)
71+
72+
with pytest.warns(DataConversionWarning, match="is not really efficient"):
73+
X_res, y_res = SMOTEN(random_state=0).fit_resample(X, y)
74+
75+
assert X_res.format == X.format
76+
assert X_res.shape[0] == len(y_res)
77+
78+
5979
def test_smoten_categorical_encoder(data):
6080
"""Check that `categorical_encoder` is used when provided."""
6181

0 commit comments

Comments
 (0)