From d1ac5bd39a1276fb3bc29b53dd7cdad9f9706a13 Mon Sep 17 00:00:00 2001 From: Thomas Kluiters Date: Sun, 5 May 2019 19:34:39 +0200 Subject: [PATCH 1/8] Add SMOTE for pure categorical data --- imblearn/over_sampling/__init__.py | 3 +- imblearn/over_sampling/_smote.py | 192 +++++++++++++++++++ imblearn/over_sampling/tests/test_smote_n.py | 162 ++++++++++++++++ 3 files changed, 356 insertions(+), 1 deletion(-) create mode 100644 imblearn/over_sampling/tests/test_smote_n.py diff --git a/imblearn/over_sampling/__init__.py b/imblearn/over_sampling/__init__.py index 9cd63ac87..db2768c1a 100644 --- a/imblearn/over_sampling/__init__.py +++ b/imblearn/over_sampling/__init__.py @@ -9,6 +9,7 @@ from ._smote import BorderlineSMOTE from ._smote import SVMSMOTE from ._smote import SMOTENC +from ._smote import SMOTEN __all__ = ['ADASYN', 'RandomOverSampler', - 'SMOTE', 'BorderlineSMOTE', 'SVMSMOTE', 'SMOTENC'] + 'SMOTE', 'BorderlineSMOTE', 'SVMSMOTE', 'SMOTEN', 'SMOTENC'] diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index 60cff7d34..85841fb02 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -1090,3 +1090,195 @@ def _generate_sample(self, X, nn_data, nn_num, row, col, step): sample[start_idx + col_sel] = 1 return sparse.csr_matrix(sample) if sparse.issparse(X) else sample + + +# @Substitution( +# sampling_strategy=BaseOverSampler._sampling_strategy_docstring, +# random_state=_random_state_docstring) +class SMOTEN(SMOTE): + """Synthetic Minority Over-sampling Technique for Nominal + (SMOTE-NC). + + Unlike :class:`SMOTE`, SMOTE-N operates on datasets containing categorical + features. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + sampling_strategy : float, str, dict or callable, (default='auto') + Sampling information to resample the data set. + + - When ``float``, it corresponds to the desired ratio of the number of + samples in the minority class over the number of samples in the + majority class after resampling. Therefore, the ratio is expressed as + :math:`\\alpha_{os} = N_{rm} / N_{M}` where :math:`N_{rm}` is the + number of samples in the minority class after resampling and + :math:`N_{M}` is the number of samples in the majority class. + + .. warning:: + ``float`` is only available for **binary** classification. An + error is raised for multi-class classification. + + - When ``str``, specify the class targeted by the resampling. The + number of samples in the different classes will be equalized. + Possible choices are: + + ``'minority'``: resample only the minority class; + + ``'not minority'``: resample all classes but the minority class; + + ``'not majority'``: resample all classes but the majority class; + + ``'all'``: resample all classes; + + ``'auto'``: equivalent to ``'not majority'``. + + - When ``dict``, the keys correspond to the targeted classes. The + values correspond to the desired number of samples for each targeted + class. + + - When callable, function taking ``y`` and returns a ``dict``. The keys + correspond to the targeted classes. The values correspond to the + desired number of samples for each class. + + random_state : int, RandomState instance or None, optional (default=None) + Control the randomization of the algorithm. + + - If int, ``random_state`` is the seed used by the random number + generator; + - If ``RandomState`` instance, random_state is the random number + generator; + - If ``None``, the random number generator is the ``RandomState`` + instance used by ``np.random``. + + k_neighbors : int or object, optional (default=5) + If ``int``, number of nearest neighbours to used to construct synthetic + samples. If object, an estimator that inherits from + :class:`sklearn.neighbors.base.KNeighborsMixin` that will be used to + find the k_neighbors. + + n_jobs : int, optional (default=1) + The number of threads to open if possible. + + Notes + ----- + See the original paper [1]_ for more details. + + Supports mutli-class resampling. A one-vs.-rest scheme is used as + originally proposed in [1]_. + + See + :ref:`sphx_glr_auto_examples_over-sampling_plot_comparison_over_sampling.py`, + and :ref:`sphx_glr_auto_examples_over-sampling_plot_smote.py`. + + See also + -------- + SMOTE : Over-sample using SMOTE. + + SVMSMOTE : Over-sample using SVM-SMOTE variant. + + BorderlineSMOTE : Over-sample using Borderline-SMOTE variant. + + ADASYN : Over-sample using ADASYN. + + References + ---------- + .. [1] N. V. Chawla, K. W. Bowyer, L. O.Hall, W. P. Kegelmeyer, "SMOTE: + synthetic minority over-sampling technique," Journal of artificial + intelligence research, 321-357, 2002. + + Examples + -------- + + >>> from collections import Counter + >>> from numpy.random import RandomState + >>> from sklearn.datasets import make_classification + >>> from imblearn.over_sampling import SMOTEN + >>> X, y = make_classification(n_classes=2, class_sep=2, + ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, + ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) + >>> print('Original dataset shape (%s, %s)' % X.shape) + Original dataset shape (1000, 20) + >>> print('Original dataset samples per class {}'.format(Counter(y))) + Original dataset samples per class Counter({1: 900, 0: 100}) + >>> # simulate the 2 last columns to be categorical features + >>> X[:, ] = RandomState(10).randint(0, 4, size=(1000, 2)) + >>> sm = SMOTEN(random_state=42, categorical_features=[18, 19]) + >>> X_res, y_res = sm.fit_resample(X, y) + >>> print('Resampled dataset samples per class {}'.format(Counter(y_res))) + Resampled dataset samples per class Counter({0: 900, 1: 900}) + + """ + + def __init__(self, sampling_strategy='auto', + random_state=None, k_neighbors=5, n_jobs=1): + super(SMOTEN, self).__init__(sampling_strategy=sampling_strategy, + random_state=random_state, + k_neighbors=k_neighbors, + ratio=None, + n_jobs=n_jobs) + + @staticmethod + def _check_X_y(X, y): + """Overwrite the checking to let pass some string for categorical + features. + """ + y, binarize_y = check_target_type(y, indicate_one_vs_all=True) + X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'], dtype=None) + return X, y, binarize_y + + def _validate_estimator(self): + super(SMOTEN, self)._validate_estimator() + + def _fit_resample(self, X, y): + self.n_features_ = X.shape[1] + self._validate_estimator() + + self.ohe_ = OneHotEncoder(sparse=True, handle_unknown='ignore', + dtype=np.float64) + # the input of the OneHotEncoder needs to be dense + X_ohe = self.ohe_.fit_transform( + X.toarray() if sparse.issparse(X) + else X) + + X_resampled, y_resampled = super(SMOTEN, self)._fit_resample( + X_ohe, y) + X_resampled = self.ohe_.inverse_transform(X_resampled) + + if sparse.issparse(X): + X_resampled = sparse.csr_matrix(X_resampled) + + return X_resampled, y_resampled + + def _generate_sample(self, X, nn_data, nn_num, row, col, step): + """Generate a synthetic sample with an additional steps for the + categorical features. + + Each new sample is generated the same way than in SMOTE. However, the + categorical features are mapped to the most frequent nearest neighbors + of the majority class. + """ + rng = check_random_state(self.random_state) + sample = super(SMOTEN, self)._generate_sample(X, nn_data, nn_num, + row, col, step) + # To avoid conversion and since there is only few samples used, we + # convert those samples to dense array. + sample = (sample.toarray().squeeze() + if sparse.issparse(sample) else sample) + all_neighbors = nn_data[nn_num[row]] + all_neighbors = (all_neighbors.toarray() + if sparse.issparse(all_neighbors) else all_neighbors) + + categories_size = [cat.size for cat in self.ohe_.categories_] + + for start_idx, end_idx in zip(np.cumsum(categories_size)[:-1], + np.cumsum(categories_size)[1:]): + col_max = all_neighbors[:, start_idx:end_idx].sum(axis=0) + # tie breaking argmax + col_sel = rng.choice(np.flatnonzero( + np.isclose(col_max, col_max.max()))) + sample[start_idx:end_idx] = 0 + sample[start_idx + col_sel] = 1 + + return sparse.csr_matrix(sample) if sparse.issparse(X) else sample diff --git a/imblearn/over_sampling/tests/test_smote_n.py b/imblearn/over_sampling/tests/test_smote_n.py new file mode 100644 index 000000000..a310cd188 --- /dev/null +++ b/imblearn/over_sampling/tests/test_smote_n.py @@ -0,0 +1,162 @@ +"""Test the module smoten.""" +# Authors: Guillaume Lemaitre +# Christos Aridas +# Dzianis Dudnik +# License: MIT + +from collections import Counter + +import pytest + +import numpy as np +from scipy import sparse + +from sklearn.datasets import make_classification +from sklearn.utils.testing import assert_allclose + +from imblearn.over_sampling import SMOTEN + + +def data_heterogneous_ordered(): + rng = np.random.RandomState(42) + X = np.empty((30, 2), dtype=object) + # create a categorical feature using some string + X[:, 0] = rng.choice(['a', 'b', 'c'], size=30).astype(object) + # create a categorical feature using some integer + X[:, 1] = rng.randint(3, size=30) + y = np.array([0] * 10 + [1] * 20) + # return the categories + return X, y + + +def data_heterogneous_unordered(): + rng = np.random.RandomState(42) + X = np.empty((30, 2), dtype=object) + # create a categorical feature using some string + X[:, 0] = rng.choice(['a', 'b', 'c'], size=30).astype(object) + # create a categorical feature using some integer + X[:, 1] = rng.randint(3, size=30) + y = np.array([0] * 10 + [1] * 20) + # return the categories + return X, y + + +def data_heterogneous_unordered_multiclass(): + rng = np.random.RandomState(42) + X = np.empty((50, 2), dtype=object) + # create a categorical feature using some string + X[:, 0] = rng.choice(['a', 'b', 'c'], size=50).astype(object) + # create a categorical feature using some integer + X[:, 1] = rng.randint(3, size=50) + y = np.array([0] * 10 + [1] * 15 + [2] * 25) + # return the categories + return X, y + + +def data_sparse(format): + rng = np.random.RandomState(42) + X = np.empty((30, 2), dtype=np.float64) + # create a categorical feature using some string + X[:, 0] = rng.randint(3, size=30) + # create a categorical feature using some integer + X[:, 1] = rng.randint(3, size=30) + y = np.array([0] * 10 + [1] * 20) + X = sparse.csr_matrix(X) if format == 'csr' else sparse.csc_matrix(X) + return X, y + + +@pytest.mark.parametrize( + "data", + [data_heterogneous_ordered(), data_heterogneous_unordered(), + data_sparse('csr'), data_sparse('csc')] +) +def test_smoten(data): + X, y = data + smote = SMOTEN(random_state=0) + X_resampled, y_resampled = smote.fit_resample(X, y) + + assert X_resampled.dtype == X.dtype + + categorical_features = np.array([0, 1]) + if categorical_features.dtype == bool: + categorical_features = np.flatnonzero(categorical_features) + for cat_idx in categorical_features: + if sparse.issparse(X): + assert set(X[:, cat_idx].data) == set(X_resampled[:, cat_idx].data) + assert X[:, cat_idx].dtype == X_resampled[:, cat_idx].dtype + else: + assert set(X[:, cat_idx]) == set(X_resampled[:, cat_idx]) + assert X[:, cat_idx].dtype == X_resampled[:, cat_idx].dtype + + +# part of the common test which apply to SMOTE-N even if it is not default +# constructible +def test_smoten_check_target_type(): + X, _ = data_heterogneous_unordered() + y = np.linspace(0, 1, 30) + smote = SMOTEN(random_state=0) + with pytest.raises(ValueError, match="Unknown label type: 'continuous'"): + smote.fit_resample(X, y) + rng = np.random.RandomState(42) + y = rng.randint(2, size=(20, 3)) + with pytest.raises(ValueError, match="'y' should encode the multiclass"): + smote.fit_resample(X, y) + + +def test_smoten_samplers_one_label(): + X, _ = data_heterogneous_unordered() + y = np.zeros(30) + smote = SMOTEN(random_state=0) + with pytest.raises(ValueError, match='needs to have more than 1 class'): + smote.fit(X, y) + + +def test_smoten_fit(): + X, y = data_heterogneous_unordered() + smote = SMOTEN(random_state=0) + smote.fit_resample(X, y) + assert hasattr(smote, 'sampling_strategy_'), \ + "No fitted attribute sampling_strategy_" + + +def test_smoten_fit_resample(): + X, y = data_heterogneous_unordered() + target_stats = Counter(y) + smote = SMOTEN(random_state=0) + X_res, y_res = smote.fit_resample(X, y) + n_samples = max(target_stats.values()) + assert all(value >= n_samples for value in Counter(y_res).values()) + + +def test_smoten_fit_resample_sampling_strategy(): + X, y = data_heterogneous_unordered_multiclass() + expected_stat = Counter(y)[1] + smote = SMOTEN(random_state=0) + sampling_strategy = {2: 25, 0: 25} + smote.set_params(sampling_strategy=sampling_strategy) + X_res, y_res = smote.fit_resample(X, y) + assert Counter(y_res)[1] == expected_stat + + +def test_smoten_pandas(): + pd = pytest.importorskip("pandas") + # Check that the samplers handle pandas dataframe and pandas series + X, y = data_heterogneous_unordered_multiclass() + X_pd = pd.DataFrame(X) + smote = SMOTEN(random_state=0) + X_res_pd, y_res_pd = smote.fit_resample(X_pd, y) + X_res, y_res = smote.fit_resample(X, y) + assert X_res_pd.tolist() == X_res.tolist() + assert_allclose(y_res_pd, y_res) + + +def test_smoten_preserve_dtype(): + X, y = make_classification(n_samples=50, n_classes=3, n_informative=4, + weights=[0.2, 0.3, 0.5], random_state=0) + # Cast X and y to not default dtype + X = X.astype(np.float32) + y = y.astype(np.int32) + smote = SMOTEN(random_state=0) + X_res, y_res = smote.fit_resample(X, y) + assert X.dtype == X_res.dtype, "X dtype is not preserved" + assert y.dtype == y_res.dtype, "y dtype is not preserved" From e17ba76ad9c39a96b34c8ec80cedebe3b156ebc7 Mon Sep 17 00:00:00 2001 From: Thomas Kluiters Date: Sun, 5 May 2019 19:44:12 +0200 Subject: [PATCH 2/8] Fix PEP issues --- imblearn/over_sampling/_smote.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index 85841fb02..b36c4d111 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -1214,10 +1214,10 @@ class SMOTEN(SMOTE): def __init__(self, sampling_strategy='auto', random_state=None, k_neighbors=5, n_jobs=1): super(SMOTEN, self).__init__(sampling_strategy=sampling_strategy, - random_state=random_state, - k_neighbors=k_neighbors, - ratio=None, - n_jobs=n_jobs) + random_state=random_state, + k_neighbors=k_neighbors, + ratio=None, + n_jobs=n_jobs) @staticmethod def _check_X_y(X, y): @@ -1261,7 +1261,7 @@ def _generate_sample(self, X, nn_data, nn_num, row, col, step): """ rng = check_random_state(self.random_state) sample = super(SMOTEN, self)._generate_sample(X, nn_data, nn_num, - row, col, step) + row, col, step) # To avoid conversion and since there is only few samples used, we # convert those samples to dense array. sample = (sample.toarray().squeeze() From 5f1ca14a3537939262fea7edcd54e1383704deeb Mon Sep 17 00:00:00 2001 From: Thomas Kluiters Date: Sun, 5 May 2019 21:04:39 +0200 Subject: [PATCH 3/8] Fix failing tests --- imblearn/over_sampling/_smote.py | 11 ++++++----- imblearn/utils/estimator_checks.py | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index b36c4d111..97fdce78a 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -1197,27 +1197,28 @@ class SMOTEN(SMOTE): >>> from imblearn.over_sampling import SMOTEN >>> X, y = make_classification(n_classes=2, class_sep=2, ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, - ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) + ... n_features=5, n_clusters_per_class=1, n_samples=1000, random_state=10) >>> print('Original dataset shape (%s, %s)' % X.shape) Original dataset shape (1000, 20) >>> print('Original dataset samples per class {}'.format(Counter(y))) Original dataset samples per class Counter({1: 900, 0: 100}) >>> # simulate the 2 last columns to be categorical features - >>> X[:, ] = RandomState(10).randint(0, 4, size=(1000, 2)) - >>> sm = SMOTEN(random_state=42, categorical_features=[18, 19]) + >>> X[:, ] = RandomState(10).randint(0, 4, size=(1000, 5)) + >>> sm = SMOTEN(random_state=42) >>> X_res, y_res = sm.fit_resample(X, y) >>> print('Resampled dataset samples per class {}'.format(Counter(y_res))) Resampled dataset samples per class Counter({0: 900, 1: 900}) """ - def __init__(self, sampling_strategy='auto', + def __init__(self, sampling_strategy='auto', kind='regular', random_state=None, k_neighbors=5, n_jobs=1): super(SMOTEN, self).__init__(sampling_strategy=sampling_strategy, random_state=random_state, k_neighbors=k_neighbors, ratio=None, - n_jobs=n_jobs) + n_jobs=n_jobs, + kind=kind) @staticmethod def _check_X_y(X, y): diff --git a/imblearn/utils/estimator_checks.py b/imblearn/utils/estimator_checks.py index 7d08f3313..70e0cfe45 100644 --- a/imblearn/utils/estimator_checks.py +++ b/imblearn/utils/estimator_checks.py @@ -33,8 +33,8 @@ from imblearn.over_sampling import SMOTE from imblearn.under_sampling import NearMiss, ClusterCentroids -DONT_SUPPORT_RATIO = ['SVMSMOTE', 'BorderlineSMOTE'] -SUPPORT_STRING = ['RandomUnderSampler', 'RandomOverSampler'] +DONT_SUPPORT_RATIO = ['SVMSMOTE', 'BorderlineSMOTE', 'SMOTEN'] +SUPPORT_STRING = ['SMOTEN', 'RandomUnderSampler', 'RandomOverSampler'] HAVE_SAMPLE_INDICES = [ 'RandomOverSampler', 'RandomUnderSampler', 'InstanceHardnessThreshold', 'NearMiss', 'TomekLinks', 'EditedNearestNeighbours', From c2fc4dac25563ec2424a1ccabdb624b91a6b415f Mon Sep 17 00:00:00 2001 From: Thomas Kluiters Date: Sun, 5 May 2019 21:23:01 +0200 Subject: [PATCH 4/8] Fix doctest error --- imblearn/over_sampling/_smote.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index 97fdce78a..203c106e7 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -1199,7 +1199,7 @@ class SMOTEN(SMOTE): ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, ... n_features=5, n_clusters_per_class=1, n_samples=1000, random_state=10) >>> print('Original dataset shape (%s, %s)' % X.shape) - Original dataset shape (1000, 20) + Original dataset shape (1000, 5) >>> print('Original dataset samples per class {}'.format(Counter(y))) Original dataset samples per class Counter({1: 900, 0: 100}) >>> # simulate the 2 last columns to be categorical features From 9db10719ef0f89d8de06a2f7722c1a4b8ffb8946 Mon Sep 17 00:00:00 2001 From: Thomas Kluiters Date: Sun, 5 May 2019 21:52:46 +0200 Subject: [PATCH 5/8] Change order of counter keys in docstring --- imblearn/over_sampling/_smote.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index 203c106e7..340a585d2 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -1207,7 +1207,7 @@ class SMOTEN(SMOTE): >>> sm = SMOTEN(random_state=42) >>> X_res, y_res = sm.fit_resample(X, y) >>> print('Resampled dataset samples per class {}'.format(Counter(y_res))) - Resampled dataset samples per class Counter({0: 900, 1: 900}) + Resampled dataset samples per class Counter({1: 900, 0: 900}) """ From c2fe8c2a39620153cedd6ff2e5810bc1699b4a2f Mon Sep 17 00:00:00 2001 From: Thomas Kluiters Date: Sun, 5 May 2019 22:43:00 +0200 Subject: [PATCH 6/8] Rephrase docstring --- imblearn/over_sampling/_smote.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index 340a585d2..c1b08b240 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -1096,8 +1096,8 @@ def _generate_sample(self, X, nn_data, nn_num, row, col, step): # sampling_strategy=BaseOverSampler._sampling_strategy_docstring, # random_state=_random_state_docstring) class SMOTEN(SMOTE): - """Synthetic Minority Over-sampling Technique for Nominal - (SMOTE-NC). + """Synthetic Minority Over-sampling Technique for Nominal data + (SMOTE-N). Unlike :class:`SMOTE`, SMOTE-N operates on datasets containing categorical features. @@ -1200,14 +1200,13 @@ class SMOTEN(SMOTE): ... n_features=5, n_clusters_per_class=1, n_samples=1000, random_state=10) >>> print('Original dataset shape (%s, %s)' % X.shape) Original dataset shape (1000, 5) - >>> print('Original dataset samples per class {}'.format(Counter(y))) - Original dataset samples per class Counter({1: 900, 0: 100}) - >>> # simulate the 2 last columns to be categorical features + >>> print('Original dataset samples in class 0: {}'.format(sum(y == 0))) + Original dataset samples in class 0: 100 >>> X[:, ] = RandomState(10).randint(0, 4, size=(1000, 5)) >>> sm = SMOTEN(random_state=42) >>> X_res, y_res = sm.fit_resample(X, y) - >>> print('Resampled dataset samples per class {}'.format(Counter(y_res))) - Resampled dataset samples per class Counter({1: 900, 0: 900}) + >>> print('Resampled dataset samples in class 0: {}'.format(sum(y_res == 0))) + Resampled dataset samples in class 0: 900 """ From bc9622370dbf53665b1865c79537c0cd90285921 Mon Sep 17 00:00:00 2001 From: Thomas Kluiters Date: Tue, 7 May 2019 16:32:44 +0200 Subject: [PATCH 7/8] Refactor SMOTEN and SMOTENC to be more unified --- imblearn/over_sampling/_smote.py | 90 ++++++++++---------------------- 1 file changed, 27 insertions(+), 63 deletions(-) diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index c1b08b240..9f388c597 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -950,11 +950,13 @@ class SMOTENC(SMOTE): """ - def __init__(self, categorical_features, sampling_strategy='auto', + def __init__(self, categorical_features, sampling_strategy='auto', kind='regular', random_state=None, k_neighbors=5, n_jobs=1): super(SMOTENC, self).__init__(sampling_strategy=sampling_strategy, random_state=random_state, k_neighbors=k_neighbors, + n_jobs=n_jobs, + kind=kind, ratio=None) self.categorical_features = categorical_features @@ -986,6 +988,15 @@ def _fit_resample(self, X, y): self.n_features_ = X.shape[1] self._validate_estimator() + X_encoded = self._encode(X, y) + + X_resampled, y_resampled = super(SMOTENC, self)._fit_resample( + X_encoded, y) + X_resampled = self._decode(X, X_resampled) + + return X_resampled, y_resampled + + def _encode(self, X, y): # compute the median of the standard deviation of the minority class target_stats = Counter(y) class_minority = min(target_stats, key=target_stats.get) @@ -1015,18 +1026,15 @@ def _fit_resample(self, X, y): X_ohe = self.ohe_.fit_transform( X_categorical.toarray() if sparse.issparse(X_categorical) else X_categorical) - # we can replace the 1 entries of the categorical features with the # median of the standard deviation. It will ensure that whenever # distance is computed between 2 samples, the difference will be equal # to the median of the standard deviation as in the original paper. X_ohe.data = (np.ones_like(X_ohe.data, dtype=X_ohe.dtype) * self.median_std_ / 2) - X_encoded = sparse.hstack((X_continuous, X_ohe), format='csr') - - X_resampled, y_resampled = super(SMOTENC, self)._fit_resample( - X_encoded, y) + return sparse.hstack((X_continuous, X_ohe), format='csr') + def _decode(self, X, X_resampled): # reverse the encoding of the categorical features X_res_cat = X_resampled[:, self.continuous_features_.size:] X_res_cat.data = np.ones_like(X_res_cat.data) @@ -1055,8 +1063,7 @@ def _fit_resample(self, X, y): X_resampled.indices = col_indices else: X_resampled = X_resampled[:, indices_reordered] - - return X_resampled, y_resampled + return X_resampled def _generate_sample(self, X, nn_data, nn_num, row, col, step): """Generate a synthetic sample with an additional steps for the @@ -1095,7 +1102,7 @@ def _generate_sample(self, X, nn_data, nn_num, row, col, step): # @Substitution( # sampling_strategy=BaseOverSampler._sampling_strategy_docstring, # random_state=_random_state_docstring) -class SMOTEN(SMOTE): +class SMOTEN(SMOTENC): """Synthetic Minority Over-sampling Technique for Nominal data (SMOTE-N). @@ -1212,73 +1219,30 @@ class SMOTEN(SMOTE): def __init__(self, sampling_strategy='auto', kind='regular', random_state=None, k_neighbors=5, n_jobs=1): - super(SMOTEN, self).__init__(sampling_strategy=sampling_strategy, + super(SMOTEN, self).__init__(categorical_features=[], + sampling_strategy=sampling_strategy, random_state=random_state, k_neighbors=k_neighbors, - ratio=None, n_jobs=n_jobs, kind=kind) - @staticmethod - def _check_X_y(X, y): - """Overwrite the checking to let pass some string for categorical - features. - """ - y, binarize_y = check_target_type(y, indicate_one_vs_all=True) - X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'], dtype=None) - return X, y, binarize_y - def _validate_estimator(self): + self.categorical_features = np.asarray(range(self.n_features_)) + self.continuous_features_ = np.asarray([]) super(SMOTEN, self)._validate_estimator() - def _fit_resample(self, X, y): - self.n_features_ = X.shape[1] - self._validate_estimator() + def _decode(self, X, X_resampled): + X_unstacked = self.ohe_.inverse_transform(X_resampled) + if sparse.issparse(X): + X_unstacked = sparse.csr_matrix(X_unstacked) + return X_unstacked + def _encode(self, X, y): self.ohe_ = OneHotEncoder(sparse=True, handle_unknown='ignore', dtype=np.float64) # the input of the OneHotEncoder needs to be dense - X_ohe = self.ohe_.fit_transform( + return self.ohe_.fit_transform( X.toarray() if sparse.issparse(X) else X) - X_resampled, y_resampled = super(SMOTEN, self)._fit_resample( - X_ohe, y) - X_resampled = self.ohe_.inverse_transform(X_resampled) - - if sparse.issparse(X): - X_resampled = sparse.csr_matrix(X_resampled) - - return X_resampled, y_resampled - - def _generate_sample(self, X, nn_data, nn_num, row, col, step): - """Generate a synthetic sample with an additional steps for the - categorical features. - Each new sample is generated the same way than in SMOTE. However, the - categorical features are mapped to the most frequent nearest neighbors - of the majority class. - """ - rng = check_random_state(self.random_state) - sample = super(SMOTEN, self)._generate_sample(X, nn_data, nn_num, - row, col, step) - # To avoid conversion and since there is only few samples used, we - # convert those samples to dense array. - sample = (sample.toarray().squeeze() - if sparse.issparse(sample) else sample) - all_neighbors = nn_data[nn_num[row]] - all_neighbors = (all_neighbors.toarray() - if sparse.issparse(all_neighbors) else all_neighbors) - - categories_size = [cat.size for cat in self.ohe_.categories_] - - for start_idx, end_idx in zip(np.cumsum(categories_size)[:-1], - np.cumsum(categories_size)[1:]): - col_max = all_neighbors[:, start_idx:end_idx].sum(axis=0) - # tie breaking argmax - col_sel = rng.choice(np.flatnonzero( - np.isclose(col_max, col_max.max()))) - sample[start_idx:end_idx] = 0 - sample[start_idx + col_sel] = 1 - - return sparse.csr_matrix(sample) if sparse.issparse(X) else sample From c7e7036d48684cfcef11f91a4c37779acf9dbe82 Mon Sep 17 00:00:00 2001 From: thomas Date: Tue, 7 May 2019 22:39:01 +0200 Subject: [PATCH 8/8] Add documentation for SMOTEN --- doc/over_sampling.rst | 9 +++++++++ doc/whats_new/v0.5.rst | 3 +++ 2 files changed, 12 insertions(+) diff --git a/doc/over_sampling.rst b/doc/over_sampling.rst index 448f4a15a..004e63a71 100644 --- a/doc/over_sampling.rst +++ b/doc/over_sampling.rst @@ -198,6 +198,15 @@ Therefore, it can be seen that the samples generated in the first and last columns are belonging to the same categories originally presented without any other extra interpolation. +Furthermore, if the dataset solely consists of categorical features one may use the :class:`SMOTEN` class. This class generates samples in an identical fashion to :class:`SMOTENC` - however - only categorical features are permitted. Each feature is treated as a categorical feature and therefore it is not advised to use `SMOTEN` for datasets that contain both categorical and continious features:: + + >>> from imblearn.over_sampling import SMOTEN + >>> smote_n = SMOTEN(random_state=0) + >>> X[:, 1] = rng.randint(2, size=n_samples) + >>> X_resampled, y_resampled = smote_n.fit_resample(X, y) + >>> print(sorted(Counter(y_resampled).items())) + [(0, 30), (1, 30)] + .. topic:: References .. [HWB2005] H. Han, W. Wen-Yuan, M. Bing-Huan, "Borderline-SMOTE: a new diff --git a/doc/whats_new/v0.5.rst b/doc/whats_new/v0.5.rst index 2b892e6c1..28192c8af 100644 --- a/doc/whats_new/v0.5.rst +++ b/doc/whats_new/v0.5.rst @@ -27,6 +27,9 @@ Enhancement and issue template showing how to print system and dependency information from the command line. :issue:`557` by :user:`Alexander L. Hayes `. +- Add :class:`SMOTEN`. Add ability to use SMOTE on pure categorical features. + by :user:`Thomas Kluiters