diff --git a/README.rst b/README.rst index 476e3c547..f5ad91d21 100644 --- a/README.rst +++ b/README.rst @@ -168,11 +168,12 @@ Below is a list of the methods currently implemented in this module. 1. Random minority over-sampling with replacement 2. SMOTE - Synthetic Minority Over-sampling Technique [8]_ 3. SMOTENC - SMOTE for Nominal Continuous [8]_ - 4. bSMOTE(1 & 2) - Borderline SMOTE of types 1 and 2 [9]_ - 5. SVM SMOTE - Support Vectors SMOTE [10]_ - 6. ADASYN - Adaptive synthetic sampling approach for imbalanced learning [15]_ - 7. KMeans-SMOTE [17]_ - 8. ROSE - Random OverSampling Examples [19]_ + 4. SMOTEN - SMMOTE for Nominal only [8]_ + 5. bSMOTE(1 & 2) - Borderline SMOTE of types 1 and 2 [9]_ + 6. SVM SMOTE - Support Vectors SMOTE [10]_ + 7. ADASYN - Adaptive synthetic sampling approach for imbalanced learning [15]_ + 8. KMeans-SMOTE [17]_ + 9. ROSE - Random OverSampling Examples [19]_ * Over-sampling followed by under-sampling 1. SMOTE + Tomek links [12]_ diff --git a/doc/over_sampling.rst b/doc/over_sampling.rst index 9ad88097b..7540252d7 100644 --- a/doc/over_sampling.rst +++ b/doc/over_sampling.rst @@ -211,6 +211,44 @@ Therefore, it can be seen that the samples generated in the first and last columns are belonging to the same categories originally presented without any other extra interpolation. +However, :class:`SMOTENC` is working with data composed of categorical data +only. WHen data are made of only nominal categorical data, one can use the +:class:`SMOTEN` variant :cite:`chawla2002smote`. The algorithm changes in +two ways: + +* the nearest neighbors search does not rely on the Euclidean distance. Indeed, + the value difference metric (VDM) also implemented in the class + :class:`~imblearn.metrics.ValueDifferenceMetric` is used. +* the new sample generation is based on majority vote per feature to generate + the most common category seen in the neighbors samples. + +Let's take the following example:: + + >>> import numpy as np + >>> X = np.array(["green"] * 5 + ["red"] * 10 + ["blue"] * 7, + ... dtype=object).reshape(-1, 1) + >>> y = np.array(["apple"] * 5 + ["not apple"] * 3 + ["apple"] * 7 + + ... ["not apple"] * 5 + ["apple"] * 2, dtype=object) + +We generate a dataset associating a color to being an apple or not an apple. +We strongly associated "green" and "red" to being an apple. The minority class +being "not apple", we expect new data generated belonging to the category +"blue":: + + >>> from imblearn.over_sampling import SMOTEN + >>> sampler = SMOTEN(random_state=0) + >>> X_res, y_res = sampler.fit_resample(X, y) + >>> X_res[y.size:] + array([['blue'], + ['blue'], + ['blue'], + ['blue'], + ['blue'], + ['blue']], dtype=object) + >>> y_res[y.size:] + array(['not apple', 'not apple', 'not apple', 'not apple', 'not apple', + 'not apple'], dtype=object) + Mathematical formulation ======================== diff --git a/doc/references/over_sampling.rst b/doc/references/over_sampling.rst index c3a932d22..1dba0cd2f 100644 --- a/doc/references/over_sampling.rst +++ b/doc/references/over_sampling.rst @@ -27,6 +27,7 @@ SMOTE algorithms SMOTE SMOTENC + SMOTEN ADASYN BorderlineSMOTE KMeansSMOTE diff --git a/doc/whats_new/v0.8.rst b/doc/whats_new/v0.8.rst index d5f2d8969..494c88166 100644 --- a/doc/whats_new/v0.8.rst +++ b/doc/whats_new/v0.8.rst @@ -19,6 +19,10 @@ New features compute pairwise distances between samples containing only nominal values. :pr:`796` by :user:`Guillaume Lemaitre `. +- Add the class :class:`imblearn.over_sampling.SMOTEN` to over-sample data + only containing nominal categorical features. + :pr:`802` by :user:`Guillaume Lemaitre `. + Enhancements ............ diff --git a/imblearn/over_sampling/__init__.py b/imblearn/over_sampling/__init__.py index bd20b76ea..a959cbb43 100644 --- a/imblearn/over_sampling/__init__.py +++ b/imblearn/over_sampling/__init__.py @@ -10,6 +10,7 @@ from ._smote import KMeansSMOTE from ._smote import SVMSMOTE from ._smote import SMOTENC +from ._smote import SMOTEN __all__ = [ "ADASYN", @@ -19,4 +20,5 @@ "BorderlineSMOTE", "SVMSMOTE", "SMOTENC", + "SMOTEN", ] diff --git a/imblearn/over_sampling/_adasyn.py b/imblearn/over_sampling/_adasyn.py index 768209f98..25bb95e8d 100644 --- a/imblearn/over_sampling/_adasyn.py +++ b/imblearn/over_sampling/_adasyn.py @@ -50,6 +50,15 @@ class ADASYN(BaseOverSampler): -------- SMOTE : Over-sample using SMOTE. + SMOTENC : Over-sample using SMOTE for continuous and categorical features. + + SMOTEN : Over-sample using the SMOTE variable specifically for categorical + features only. + + SVMSMOTE : Over-sample using SVM-SMOTE variant. + + BorderlineSMOTE : Over-sample using Borderline-SMOTE variant. + Notes ----- The implementation is based on [1]_. @@ -169,3 +178,8 @@ def _fit_resample(self, X, y): y_resampled = np.hstack(y_resampled) return X_resampled, y_resampled + + def _more_tags(self): + return { + "X_types": ["2darray"], + } diff --git a/imblearn/over_sampling/_random_over_sampler.py b/imblearn/over_sampling/_random_over_sampler.py index 4d2795d7e..99168bbbd 100644 --- a/imblearn/over_sampling/_random_over_sampler.py +++ b/imblearn/over_sampling/_random_over_sampler.py @@ -76,6 +76,9 @@ class RandomOverSampler(BaseOverSampler): SMOTENC : Over-sample using SMOTE for continuous and categorical features. + SMOTEN : Over-sample using the SMOTE variable specifically for categorical + features only. + SVMSMOTE : Over-sample using SVM-SMOTE variant. ADASYN : Over-sample using ADASYN. diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index cdc6483d0..5adb3f69b 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -11,11 +11,12 @@ import numpy as np from scipy import sparse +from scipy import stats from sklearn.base import clone from sklearn.cluster import MiniBatchKMeans from sklearn.metrics import pairwise_distances -from sklearn.preprocessing import OneHotEncoder +from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder from sklearn.svm import SVC from sklearn.utils import check_random_state from sklearn.utils import _safe_indexing @@ -25,6 +26,7 @@ from .base import BaseOverSampler from ..exceptions import raise_isinstance_error +from ..metrics.pairwise import ValueDifferenceMetric from ..utils import check_neighbors_object from ..utils import check_target_type from ..utils import Substitution @@ -448,6 +450,9 @@ class SVMSMOTE(BaseSMOTE): SMOTENC : Over-sample using SMOTE for continuous and categorical features. + SMOTEN : Over-sample using the SMOTE variable specifically for categorical + features only. + BorderlineSMOTE : Over-sample using Borderline-SMOTE. ADASYN : Over-sample using ADASYN. @@ -643,6 +648,9 @@ class SMOTE(BaseSMOTE): -------- SMOTENC : Over-sample using SMOTE for continuous and categorical features. + SMOTEN : Over-sample using the SMOTE variable specifically for categorical + features only. + BorderlineSMOTE : Over-sample using the borderline-SMOTE variant. SVMSMOTE : Over-sample using the SVM-SMOTE variant. @@ -766,6 +774,9 @@ class SMOTENC(SMOTE): -------- SMOTE : Over-sample using SMOTE. + SMOTEN : Over-sample using the SMOTE variable specifically for categorical + features only. + SVMSMOTE : Over-sample using SVM-SMOTE variant. BorderlineSMOTE : Over-sample using Borderline-SMOTE variant. @@ -1055,6 +1066,11 @@ class KMeansSMOTE(BaseSMOTE): -------- SMOTE : Over-sample using SMOTE. + SMOTENC : Over-sample using SMOTE for continuous and categorical features. + + SMOTEN : Over-sample using the SMOTE variable specifically for categorical + features only. + SVMSMOTE : Over-sample using SVM-SMOTE variant. BorderlineSMOTE : Over-sample using Borderline-SMOTE variant. @@ -1248,3 +1264,145 @@ def _fit_resample(self, X, y): y_resampled = np.hstack((y_resampled, y_new)) return X_resampled, y_resampled + + +@Substitution( + sampling_strategy=BaseOverSampler._sampling_strategy_docstring, + n_jobs=_n_jobs_docstring, + random_state=_random_state_docstring, +) +class SMOTEN(SMOTE): + """Perform SMOTE over-sampling for nominal categorical features only. + + This method is refered as SMOTEN in [1]_. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + {sampling_strategy} + + {random_state} + + k_neighbors : int or object, default=5 + If ``int``, number of nearest neighbours to used to construct synthetic + samples. If object, an estimator that inherits from + :class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to + find the k_neighbors. + + {n_jobs} + + See Also + -------- + SMOTE : Over-sample using SMOTE. + + SMOTENC : Over-sample using SMOTE for continuous and categorical features. + + BorderlineSMOTE : Over-sample using the borderline-SMOTE variant. + + SVMSMOTE : Over-sample using the SVM-SMOTE variant. + + ADASYN : Over-sample using ADASYN. + + KMeansSMOTE : Over-sample applying a clustering before to oversample using + SMOTE. + + Notes + ----- + See the original papers: [1]_ for more details. + + Supports multi-class resampling. A one-vs.-rest scheme is used as + originally proposed in [1]_. + + References + ---------- + .. [1] N. V. Chawla, K. W. Bowyer, L. O.Hall, W. P. Kegelmeyer, "SMOTE: + synthetic minority over-sampling technique," Journal of artificial + intelligence research, 321-357, 2002. + + Examples + -------- + >>> import numpy as np + >>> X = np.array(["A"] * 10 + ["B"] * 20 + ["C"] * 30, dtype=object).reshape(-1, 1) + >>> y = np.array([0] * 20 + [1] * 40, dtype=np.int32) + >>> from collections import Counter + >>> print(f"Original class counts: {{Counter(y)}}") + Original class counts: Counter({{1: 40, 0: 20}}) + >>> from imblearn.over_sampling import SMOTEN + >>> sampler = SMOTEN(random_state=0) + >>> X_res, y_res = sampler.fit_resample(X, y) + >>> print(f"Class counts after resampling {{Counter(y_res)}}") + Class counts after resampling Counter({{0: 40, 1: 40}}) + """ + + def _check_X_y(self, X, y): + """Check should accept strings and not sparse matrices.""" + y, binarize_y = check_target_type(y, indicate_one_vs_all=True) + X, y = self._validate_data( + X, + y, + reset=True, + dtype=None, + accept_sparse=False, + ) + return X, y, binarize_y + + def _validate_estimator(self): + """Force to use precomputed distance matrix.""" + super()._validate_estimator() + self.nn_k_.set_params(metric="precomputed") + + def _make_samples(self, X_class, klass, y_dtype, nn_indices, n_samples): + random_state = check_random_state(self.random_state) + # generate sample indices that will be used to generate new samples + samples_indices = random_state.choice( + np.arange(X_class.shape[0]), size=n_samples, replace=True + ) + # for each drawn samples, select its k-neighbors and generate a sample + # where for each feature individually, each category generated is the + # most common category + X_new = np.squeeze( + stats.mode(X_class[nn_indices[samples_indices]], axis=1).mode, axis=1 + ) + y_new = np.full(n_samples, fill_value=klass, dtype=y_dtype) + return X_new, y_new + + def _fit_resample(self, X, y): + self._validate_estimator() + + X_resampled = [X.copy()] + y_resampled = [y.copy()] + + encoder = OrdinalEncoder(dtype=np.int32) + X_encoded = encoder.fit_transform(X) + + vdm = ValueDifferenceMetric( + n_categories=[len(cat) for cat in encoder.categories_] + ).fit(X_encoded, y) + + for class_sample, n_samples in self.sampling_strategy_.items(): + if n_samples == 0: + continue + target_class_indices = np.flatnonzero(y == class_sample) + X_class = _safe_indexing(X_encoded, target_class_indices) + + X_class_dist = vdm.pairwise(X_class) + self.nn_k_.fit(X_class_dist) + # the kneigbors search will include the sample itself which is + # expected from the original algorithm + nn_indices = self.nn_k_.kneighbors(X_class_dist, return_distance=False) + X_new, y_new = self._make_samples( + X_class, class_sample, y.dtype, nn_indices, n_samples + ) + + X_new = encoder.inverse_transform(X_new) + X_resampled.append(X_new) + y_resampled.append(y_new) + + X_resampled = np.vstack(X_resampled) + y_resampled = np.hstack(y_resampled) + + return X_resampled, y_resampled + + def _more_tags(self): + return {"X_types": ["2darray", "dataframe", "string"]} diff --git a/imblearn/over_sampling/tests/test_smoten.py b/imblearn/over_sampling/tests/test_smoten.py new file mode 100644 index 000000000..774ad9963 --- /dev/null +++ b/imblearn/over_sampling/tests/test_smoten.py @@ -0,0 +1,54 @@ +import numpy as np +import pytest + +from imblearn.over_sampling import SMOTEN + + +@pytest.fixture +def data(): + rng = np.random.RandomState(0) + + feature_1 = ["A"] * 10 + ["B"] * 20 + ["C"] * 30 + feature_2 = ["A"] * 40 + ["B"] * 20 + feature_3 = ["A"] * 20 + ["B"] * 20 + ["C"] * 10 + ["D"] * 10 + X = np.array([feature_1, feature_2, feature_3], dtype=object).T + rng.shuffle(X) + y = np.array([0] * 20 + [1] * 40, dtype=np.int32) + y_labels = np.array(["not apple", "apple"], dtype=object) + y = y_labels[y] + return X, y + + +def test_smoten(data): + # overall check for SMOTEN + X, y = data + sampler = SMOTEN(random_state=0) + X_res, y_res = sampler.fit_resample(X, y) + + assert X_res.shape == (80, 3) + assert y_res.shape == (80,) + + +def test_smoten_resampling(): + # check if the SMOTEN resample data as expected + # we generate data such that "not apple" will be the minority class and + # samples from this class will be generated. We will force the "blue" + # category to be associated with this class. Therefore, the new generated + # samples should as well be from the "blue" category. + X = np.array(["green"] * 5 + ["red"] * 10 + ["blue"] * 7, dtype=object).reshape( + -1, 1 + ) + y = np.array( + ["apple"] * 5 + + ["not apple"] * 3 + + ["apple"] * 7 + + ["not apple"] * 5 + + ["apple"] * 2, + dtype=object, + ) + sampler = SMOTEN(random_state=0) + X_res, y_res = sampler.fit_resample(X, y) + + X_generated, y_generated = X_res[X.shape[0] :], y_res[X.shape[0] :] + np.testing.assert_array_equal(X_generated, "blue") + np.testing.assert_array_equal(y_generated, "not apple")