From 836fabda64630138c456997a2df2e0e145b8a4c1 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 15 Feb 2021 19:52:03 +0100 Subject: [PATCH 01/11] FEA implement SMOTEN --- imblearn/over_sampling/__init__.py | 2 + imblearn/over_sampling/_smote.py | 68 +++++++++++++++++++++++++++++- 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/imblearn/over_sampling/__init__.py b/imblearn/over_sampling/__init__.py index bd20b76ea..a959cbb43 100644 --- a/imblearn/over_sampling/__init__.py +++ b/imblearn/over_sampling/__init__.py @@ -10,6 +10,7 @@ from ._smote import KMeansSMOTE from ._smote import SVMSMOTE from ._smote import SMOTENC +from ._smote import SMOTEN __all__ = [ "ADASYN", @@ -19,4 +20,5 @@ "BorderlineSMOTE", "SVMSMOTE", "SMOTENC", + "SMOTEN", ] diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index ea66c7fec..2118f84b6 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -11,11 +11,12 @@ import numpy as np from scipy import sparse +from scipy import stats from sklearn.base import clone from sklearn.cluster import MiniBatchKMeans from sklearn.metrics import pairwise_distances -from sklearn.preprocessing import OneHotEncoder +from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder from sklearn.svm import SVC from sklearn.utils import check_random_state from sklearn.utils import _safe_indexing @@ -25,6 +26,7 @@ from .base import BaseOverSampler from ..exceptions import raise_isinstance_error +from ..metrics.pairwise import ValueDifferenceMetric from ..utils import check_neighbors_object from ..utils import check_target_type from ..utils import Substitution @@ -1293,3 +1295,67 @@ def _fit_resample(self, X, y): y_resampled = np.hstack((y_resampled, y_new)) return X_resampled, y_resampled + + +class SMOTEN(SMOTE): + def _check_X_y(self, X, y): + y, binarize_y = check_target_type(y, indicate_one_vs_all=True) + X, y = self._validate_data( + X, y, reset=True, dtype=None, accept_sparse=["csr", "csc"] + ) + return X, y, binarize_y + + def _validate_estimator(self): + super()._validate_estimator() + self.nn_k_.set_params(metric="precomputed") + + def _make_samples(self, X_class, klass, y_dtype, nn_indices, n_samples): + random_state = check_random_state(self.random_state) + # generate sample indices that will be used to generate new samples + samples_indices = random_state.choice( + np.arange(X_class.shape[0]), size=n_samples, replace=True + ) + X_new = np.empty(shape=(n_samples, X_class.shape[1]), dtype=X_class.dtype) + for idx, sample_idx in enumerate(samples_indices): + X_new[idx, :] = stats.mode(X_class[nn_indices[sample_idx]], axis=0).mode + y_new = np.full(n_samples, fill_value=klass, dtype=y_dtype) + return X_new, y_new + + def _fit_resample(self, X, y): + self._validate_estimator() + + X_resampled = [X.copy()] + y_resampled = [y.copy()] + + encoder = OrdinalEncoder(dtype=np.int32) + X_encoded = encoder.fit_transform(X) + + vdm = ValueDifferenceMetric( + n_categories=[len(cat) for cat in encoder.categories_] + ).fit(X_encoded, y) + + for class_sample, n_samples in self.sampling_strategy_.items(): + if n_samples == 0: + continue + target_class_indices = np.flatnonzero(y == class_sample) + X_class = _safe_indexing(X_encoded, target_class_indices) + + X_class_dist = vdm.pairwise(X_class) + self.nn_k_.fit(X_class_dist) + # should countain the point itself + nn_indices = self.nn_k_.kneighbors(X_class_dist, return_distance=False) + X_new, y_new = self._make_samples( + X_class, class_sample, y.dtype, nn_indices, n_samples + ) + + X_new = encoder.inverse_transform(X_new) + X_resampled.append(X_new) + y_resampled.append(y_new) + + if sparse.issparse(X): + X_resampled = sparse.vstack(X_resampled, format=X.format) + else: + X_resampled = np.vstack(X_resampled) + y_resampled = np.hstack(y_resampled) + + return X_resampled, y_resampled From 049dde9d8ed286c13608a8cfbbe276b51738fa39 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 15 Feb 2021 19:52:31 +0100 Subject: [PATCH 02/11] TST add first basic test --- imblearn/over_sampling/tests/test_smoten.py | 30 +++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 imblearn/over_sampling/tests/test_smoten.py diff --git a/imblearn/over_sampling/tests/test_smoten.py b/imblearn/over_sampling/tests/test_smoten.py new file mode 100644 index 000000000..3dc0d3709 --- /dev/null +++ b/imblearn/over_sampling/tests/test_smoten.py @@ -0,0 +1,30 @@ +from collections import Counter + +import numpy as np +import pytest + +from imblearn.over_sampling import SMOTEN + + +@pytest.fixture +def data(): + rng = np.random.RandomState(0) + + feature_1 = ["A"] * 10 + ["B"] * 20 + ["C"] * 30 + feature_2 = ["A"] * 40 + ["B"] * 20 + feature_3 = ["A"] * 20 + ["B"] * 20 + ["C"] * 10 + ["D"] * 10 + X = np.array([feature_1, feature_2, feature_3], dtype=object).T + rng.shuffle(X) + y = np.array([0] * 20 + [1] * 40, dtype=np.int32) + y_labels = np.array(["not apple", "apple"], dtype=object) + y = y_labels[y] + return X, y + + +def test_smoten(data): + X, y = data + print(X, y) + sampler = SMOTEN(random_state=0) + X_res, y_res = sampler.fit_resample(X, y) + print(X_res, y_res) + print(Counter(y_res)) From fbac276bf4dc7176a3bfc6b721ea6e92e6588964 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 15 Feb 2021 20:26:23 +0100 Subject: [PATCH 03/11] add doc --- imblearn/over_sampling/_smote.py | 77 ++++++++++++++++++++++++++++++-- 1 file changed, 74 insertions(+), 3 deletions(-) diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index 2118f84b6..e150eea88 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -1297,7 +1297,75 @@ def _fit_resample(self, X, y): return X_resampled, y_resampled +@Substitution( + sampling_strategy=BaseOverSampler._sampling_strategy_docstring, + n_jobs=_n_jobs_docstring, + random_state=_random_state_docstring, +) class SMOTEN(SMOTE): + """Perform SMOTE over-sampling for nominal categorical features only. + + This method is refered as SMOTEN in [1]_. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + {sampling_strategy} + + {random_state} + + k_neighbors : int or object, default=5 + If ``int``, number of nearest neighbours to used to construct synthetic + samples. If object, an estimator that inherits from + :class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to + find the k_neighbors. + + {n_jobs} + + See Also + -------- + SMOTE : Over-sample using SMOTE. + + SMOTENC : Over-sample using SMOTE for continuous and categorical features. + + BorderlineSMOTE : Over-sample using the borderline-SMOTE variant. + + SVMSMOTE : Over-sample using the SVM-SMOTE variant. + + ADASYN : Over-sample using ADASYN. + + KMeansSMOTE : Over-sample applying a clustering before to oversample using + SMOTE. + + Notes + ----- + See the original papers: [1]_ for more details. + + Supports multi-class resampling. A one-vs.-rest scheme is used as + originally proposed in [1]_. + + References + ---------- + .. [1] N. V. Chawla, K. W. Bowyer, L. O.Hall, W. P. Kegelmeyer, "SMOTE: + synthetic minority over-sampling technique," Journal of artificial + intelligence research, 321-357, 2002. + + Examples + -------- + >>> import numpy as np + >>> X = np.array(["A"] * 10 + ["B"] * 20 + ["C"] * 30, dtype=object).reshape(-1, 1) + >>> y = np.array([0] * 20 + [1] * 40, dtype=np.int32) + >>> from collections import Counter + >>> print(f"Original class counts: {{Counter(y)}}") + Original class counts: Counter({{1: 40, 0: 20}}) + >>> from imblearn.over_sampling import SMOTEN + >>> sampler = SMOTEN(random_state=0) + >>> X_res, y_res = sampler.fit_resample(X, y) + >>> print(f"Class counts after resampling {{Counter(y_res)}}") + Class counts after resampling Counter({{0: 40, 1: 40}}) + """ + def _check_X_y(self, X, y): y, binarize_y = check_target_type(y, indicate_one_vs_all=True) X, y = self._validate_data( @@ -1315,9 +1383,12 @@ def _make_samples(self, X_class, klass, y_dtype, nn_indices, n_samples): samples_indices = random_state.choice( np.arange(X_class.shape[0]), size=n_samples, replace=True ) - X_new = np.empty(shape=(n_samples, X_class.shape[1]), dtype=X_class.dtype) - for idx, sample_idx in enumerate(samples_indices): - X_new[idx, :] = stats.mode(X_class[nn_indices[sample_idx]], axis=0).mode + # for each drawn samples, select its k-neighbors and generate a sample + # where for each feature individually, each category generated is the + # most common category + X_new = np.squeeze( + stats.mode(X_class[nn_indices[samples_indices]], axis=1).mode, axis=1 + ) y_new = np.full(n_samples, fill_value=klass, dtype=y_dtype) return X_new, y_new From 795b269e8e285f07a82f2d7ee923c1f7cd58cbaa Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 15 Feb 2021 20:31:02 +0100 Subject: [PATCH 04/11] DOC update other docstring --- imblearn/over_sampling/_adasyn.py | 9 +++++++++ imblearn/over_sampling/_random_over_sampler.py | 3 +++ imblearn/over_sampling/_smote.py | 14 ++++++++++++++ 3 files changed, 26 insertions(+) diff --git a/imblearn/over_sampling/_adasyn.py b/imblearn/over_sampling/_adasyn.py index 768209f98..c0c535e2e 100644 --- a/imblearn/over_sampling/_adasyn.py +++ b/imblearn/over_sampling/_adasyn.py @@ -50,6 +50,15 @@ class ADASYN(BaseOverSampler): -------- SMOTE : Over-sample using SMOTE. + SMOTENC : Over-sample using SMOTE for continuous and categorical features. + + SMOTEN : Over-sample using the SMOTE variable specifically for categorical + features only. + + SVMSMOTE : Over-sample using SVM-SMOTE variant. + + BorderlineSMOTE : Over-sample using Borderline-SMOTE variant. + Notes ----- The implementation is based on [1]_. diff --git a/imblearn/over_sampling/_random_over_sampler.py b/imblearn/over_sampling/_random_over_sampler.py index 1801e258f..40c4a61f6 100644 --- a/imblearn/over_sampling/_random_over_sampler.py +++ b/imblearn/over_sampling/_random_over_sampler.py @@ -76,6 +76,9 @@ class RandomOverSampler(BaseOverSampler): SMOTENC : Over-sample using SMOTE for continuous and categorical features. + SMOTEN : Over-sample using the SMOTE variable specifically for categorical + features only. + SVMSMOTE : Over-sample using SVM-SMOTE variant. ADASYN : Over-sample using ADASYN. diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index e150eea88..ac5174afa 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -450,6 +450,9 @@ class SVMSMOTE(BaseSMOTE): SMOTENC : Over-sample using SMOTE for continuous and categorical features. + SMOTEN : Over-sample using the SMOTE variable specifically for categorical + features only. + BorderlineSMOTE : Over-sample using Borderline-SMOTE. ADASYN : Over-sample using ADASYN. @@ -645,6 +648,9 @@ class SMOTE(BaseSMOTE): -------- SMOTENC : Over-sample using SMOTE for continuous and categorical features. + SMOTEN : Over-sample using the SMOTE variable specifically for categorical + features only. + BorderlineSMOTE : Over-sample using the borderline-SMOTE variant. SVMSMOTE : Over-sample using the SVM-SMOTE variant. @@ -813,6 +819,9 @@ class SMOTENC(SMOTE): -------- SMOTE : Over-sample using SMOTE. + SMOTEN : Over-sample using the SMOTE variable specifically for categorical + features only. + SVMSMOTE : Over-sample using SVM-SMOTE variant. BorderlineSMOTE : Over-sample using Borderline-SMOTE variant. @@ -1102,6 +1111,11 @@ class KMeansSMOTE(BaseSMOTE): -------- SMOTE : Over-sample using SMOTE. + SMOTENC : Over-sample using SMOTE for continuous and categorical features. + + SMOTEN : Over-sample using the SMOTE variable specifically for categorical + features only. + SVMSMOTE : Over-sample using SVM-SMOTE variant. BorderlineSMOTE : Over-sample using Borderline-SMOTE variant. From fee46f701d8f89a5894eb1ce43f0652da4f6981a Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 15 Feb 2021 20:40:57 +0100 Subject: [PATCH 05/11] iter --- imblearn/over_sampling/_adasyn.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/imblearn/over_sampling/_adasyn.py b/imblearn/over_sampling/_adasyn.py index c0c535e2e..502c4093e 100644 --- a/imblearn/over_sampling/_adasyn.py +++ b/imblearn/over_sampling/_adasyn.py @@ -178,3 +178,8 @@ def _fit_resample(self, X, y): y_resampled = np.hstack(y_resampled) return X_resampled, y_resampled + + def _more_tags(self): + return { + "X_types": ["2darray", "string"], + } From 3f20f39b3f12b3c0e951f88c75eb90b045557dec Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 15 Feb 2021 21:27:53 +0100 Subject: [PATCH 06/11] iter --- imblearn/over_sampling/_adasyn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/imblearn/over_sampling/_adasyn.py b/imblearn/over_sampling/_adasyn.py index 502c4093e..25bb95e8d 100644 --- a/imblearn/over_sampling/_adasyn.py +++ b/imblearn/over_sampling/_adasyn.py @@ -181,5 +181,5 @@ def _fit_resample(self, X, y): def _more_tags(self): return { - "X_types": ["2darray", "string"], + "X_types": ["2darray"], } From 8626040f5d2517de3c187c8ca32309cea4c44322 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 15 Feb 2021 21:31:04 +0100 Subject: [PATCH 07/11] iter --- imblearn/over_sampling/_smote.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index ac5174afa..4252c5dd5 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -1444,3 +1444,6 @@ def _fit_resample(self, X, y): y_resampled = np.hstack(y_resampled) return X_resampled, y_resampled + + def _more_tags(self): + return {"X_types": ["2darray", "dataframe", "string"]} From 39dd844e481c351b1c9805749ac23c2aefc4335f Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 15 Feb 2021 22:53:40 +0100 Subject: [PATCH 08/11] iter --- imblearn/over_sampling/_smote.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index 4252c5dd5..4cf9c526e 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -1381,13 +1381,19 @@ class SMOTEN(SMOTE): """ def _check_X_y(self, X, y): + """Check should accept strings and not sparse matrices.""" y, binarize_y = check_target_type(y, indicate_one_vs_all=True) X, y = self._validate_data( - X, y, reset=True, dtype=None, accept_sparse=["csr", "csc"] + X, + y, + reset=True, + dtype=None, + accept_sparse=False, ) return X, y, binarize_y def _validate_estimator(self): + """Force to use precomputed distance matrix.""" super()._validate_estimator() self.nn_k_.set_params(metric="precomputed") @@ -1427,7 +1433,8 @@ def _fit_resample(self, X, y): X_class_dist = vdm.pairwise(X_class) self.nn_k_.fit(X_class_dist) - # should countain the point itself + # the kneigbors search will include the sample itself which is + # expected from the original algorithm nn_indices = self.nn_k_.kneighbors(X_class_dist, return_distance=False) X_new, y_new = self._make_samples( X_class, class_sample, y.dtype, nn_indices, n_samples @@ -1437,10 +1444,7 @@ def _fit_resample(self, X, y): X_resampled.append(X_new) y_resampled.append(y_new) - if sparse.issparse(X): - X_resampled = sparse.vstack(X_resampled, format=X.format) - else: - X_resampled = np.vstack(X_resampled) + X_resampled = np.vstack(X_resampled) y_resampled = np.hstack(y_resampled) return X_resampled, y_resampled From 516b33c25efcf717646a25321b140fe001852b31 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 15 Feb 2021 23:02:34 +0100 Subject: [PATCH 09/11] DOC add docstring to API --- doc/references/over_sampling.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/references/over_sampling.rst b/doc/references/over_sampling.rst index c3a932d22..1dba0cd2f 100644 --- a/doc/references/over_sampling.rst +++ b/doc/references/over_sampling.rst @@ -27,6 +27,7 @@ SMOTE algorithms SMOTE SMOTENC + SMOTEN ADASYN BorderlineSMOTE KMeansSMOTE From 57f839f07d12cfae541c2a8cf34bd7bf24bbd1e4 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 15 Feb 2021 23:53:02 +0100 Subject: [PATCH 10/11] DOC add whats new --- doc/over_sampling.rst | 38 +++++++++++++++++++++ doc/whats_new/v0.8.rst | 4 +++ imblearn/over_sampling/tests/test_smoten.py | 34 +++++++++++++++--- 3 files changed, 71 insertions(+), 5 deletions(-) diff --git a/doc/over_sampling.rst b/doc/over_sampling.rst index 9ad88097b..7540252d7 100644 --- a/doc/over_sampling.rst +++ b/doc/over_sampling.rst @@ -211,6 +211,44 @@ Therefore, it can be seen that the samples generated in the first and last columns are belonging to the same categories originally presented without any other extra interpolation. +However, :class:`SMOTENC` is working with data composed of categorical data +only. WHen data are made of only nominal categorical data, one can use the +:class:`SMOTEN` variant :cite:`chawla2002smote`. The algorithm changes in +two ways: + +* the nearest neighbors search does not rely on the Euclidean distance. Indeed, + the value difference metric (VDM) also implemented in the class + :class:`~imblearn.metrics.ValueDifferenceMetric` is used. +* the new sample generation is based on majority vote per feature to generate + the most common category seen in the neighbors samples. + +Let's take the following example:: + + >>> import numpy as np + >>> X = np.array(["green"] * 5 + ["red"] * 10 + ["blue"] * 7, + ... dtype=object).reshape(-1, 1) + >>> y = np.array(["apple"] * 5 + ["not apple"] * 3 + ["apple"] * 7 + + ... ["not apple"] * 5 + ["apple"] * 2, dtype=object) + +We generate a dataset associating a color to being an apple or not an apple. +We strongly associated "green" and "red" to being an apple. The minority class +being "not apple", we expect new data generated belonging to the category +"blue":: + + >>> from imblearn.over_sampling import SMOTEN + >>> sampler = SMOTEN(random_state=0) + >>> X_res, y_res = sampler.fit_resample(X, y) + >>> X_res[y.size:] + array([['blue'], + ['blue'], + ['blue'], + ['blue'], + ['blue'], + ['blue']], dtype=object) + >>> y_res[y.size:] + array(['not apple', 'not apple', 'not apple', 'not apple', 'not apple', + 'not apple'], dtype=object) + Mathematical formulation ======================== diff --git a/doc/whats_new/v0.8.rst b/doc/whats_new/v0.8.rst index d5f2d8969..494c88166 100644 --- a/doc/whats_new/v0.8.rst +++ b/doc/whats_new/v0.8.rst @@ -19,6 +19,10 @@ New features compute pairwise distances between samples containing only nominal values. :pr:`796` by :user:`Guillaume Lemaitre `. +- Add the class :class:`imblearn.over_sampling.SMOTEN` to over-sample data + only containing nominal categorical features. + :pr:`802` by :user:`Guillaume Lemaitre `. + Enhancements ............ diff --git a/imblearn/over_sampling/tests/test_smoten.py b/imblearn/over_sampling/tests/test_smoten.py index 3dc0d3709..774ad9963 100644 --- a/imblearn/over_sampling/tests/test_smoten.py +++ b/imblearn/over_sampling/tests/test_smoten.py @@ -1,5 +1,3 @@ -from collections import Counter - import numpy as np import pytest @@ -22,9 +20,35 @@ def data(): def test_smoten(data): + # overall check for SMOTEN X, y = data - print(X, y) sampler = SMOTEN(random_state=0) X_res, y_res = sampler.fit_resample(X, y) - print(X_res, y_res) - print(Counter(y_res)) + + assert X_res.shape == (80, 3) + assert y_res.shape == (80,) + + +def test_smoten_resampling(): + # check if the SMOTEN resample data as expected + # we generate data such that "not apple" will be the minority class and + # samples from this class will be generated. We will force the "blue" + # category to be associated with this class. Therefore, the new generated + # samples should as well be from the "blue" category. + X = np.array(["green"] * 5 + ["red"] * 10 + ["blue"] * 7, dtype=object).reshape( + -1, 1 + ) + y = np.array( + ["apple"] * 5 + + ["not apple"] * 3 + + ["apple"] * 7 + + ["not apple"] * 5 + + ["apple"] * 2, + dtype=object, + ) + sampler = SMOTEN(random_state=0) + X_res, y_res = sampler.fit_resample(X, y) + + X_generated, y_generated = X_res[X.shape[0] :], y_res[X.shape[0] :] + np.testing.assert_array_equal(X_generated, "blue") + np.testing.assert_array_equal(y_generated, "not apple") From a66bfaa729ed9a32296f84d5c28f50b39932f83c Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 15 Feb 2021 23:54:45 +0100 Subject: [PATCH 11/11] DOC update whats new --- README.rst | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/README.rst b/README.rst index 476e3c547..f5ad91d21 100644 --- a/README.rst +++ b/README.rst @@ -168,11 +168,12 @@ Below is a list of the methods currently implemented in this module. 1. Random minority over-sampling with replacement 2. SMOTE - Synthetic Minority Over-sampling Technique [8]_ 3. SMOTENC - SMOTE for Nominal Continuous [8]_ - 4. bSMOTE(1 & 2) - Borderline SMOTE of types 1 and 2 [9]_ - 5. SVM SMOTE - Support Vectors SMOTE [10]_ - 6. ADASYN - Adaptive synthetic sampling approach for imbalanced learning [15]_ - 7. KMeans-SMOTE [17]_ - 8. ROSE - Random OverSampling Examples [19]_ + 4. SMOTEN - SMMOTE for Nominal only [8]_ + 5. bSMOTE(1 & 2) - Borderline SMOTE of types 1 and 2 [9]_ + 6. SVM SMOTE - Support Vectors SMOTE [10]_ + 7. ADASYN - Adaptive synthetic sampling approach for imbalanced learning [15]_ + 8. KMeans-SMOTE [17]_ + 9. ROSE - Random OverSampling Examples [19]_ * Over-sampling followed by under-sampling 1. SMOTE + Tomek links [12]_