diff --git a/README.rst b/README.rst index 98561a41f..d86120e2e 100644 --- a/README.rst +++ b/README.rst @@ -155,6 +155,7 @@ Below is a list of the methods currently implemented in this module. 5. SVM SMOTE - Support Vectors SMOTE [10]_ 6. ADASYN - Adaptive synthetic sampling approach for imbalanced learning [15]_ 7. KMeans-SMOTE [17]_ + 8. ROSE - Random OverSampling Examples [19]_ * Over-sampling followed by under-sampling 1. SMOTE + Tomek links [12]_ @@ -210,4 +211,6 @@ References: .. [17] : Felix Last, Georgios Douzas, Fernando Bacao, "Oversampling for Imbalanced Learning Based on K-Means and SMOTE" -.. [18] : Seiffert, C., Khoshgoftaar, T. M., Van Hulse, J., & Napolitano, A. "RUSBoost: A hybrid approach to alleviating class imbalance." IEEE Transactions on Systems, Man, and Cybernetics-Part A: Systems and Humans 40.1 (2010): 185-197. \ No newline at end of file +.. [18] : Seiffert, C., Khoshgoftaar, T. M., Van Hulse, J., & Napolitano, A. "RUSBoost: A hybrid approach to alleviating class imbalance." IEEE Transactions on Systems, Man, and Cybernetics-Part A: Systems and Humans 40.1 (2010): 185-197. + +.. [19] : Menardi, G., Torelli, N.: "Training and assessing classification rules with unbalanced data", Data Mining and Knowledge Discovery, 28, (2014): 92–122 \ No newline at end of file diff --git a/doc/api.rst b/doc/api.rst index dc9ead953..f4c84687c 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -76,6 +76,7 @@ Prototype selection over_sampling.SMOTE over_sampling.SMOTENC over_sampling.SVMSMOTE + over_sampling.ROSE .. _combine_ref: diff --git a/doc/bibtex/refs.bib b/doc/bibtex/refs.bib index 2573abc63..c440ffa40 100644 --- a/doc/bibtex/refs.bib +++ b/doc/bibtex/refs.bib @@ -193,3 +193,18 @@ @article{smith2014instance year={2014}, publisher={Springer} } + +@article{torelli2014rose, + author = {Menardi, Giovanna and Torelli, Nicola}, + title={Training and assessing classification rules with imbalanced data}, + author={Menardi G and Torelli N}, + journal={Data Mining and Knowledge Discovery}, + volume={28}, + pages={92-122}, + year={2014}, + publisher={Springer}, + issue = {1}, + issn = {1573-756X}, + url = {https://doi.org/10.1007/s10618-012-0295-5}, + doi = {10.1007/s10618-012-0295-5} +} \ No newline at end of file diff --git a/doc/over_sampling.rst b/doc/over_sampling.rst index 58f0c2d58..f3419ca9a 100644 --- a/doc/over_sampling.rst +++ b/doc/over_sampling.rst @@ -198,6 +198,23 @@ Therefore, it can be seen that the samples generated in the first and last columns are belonging to the same categories originally presented without any other extra interpolation. +.. _rose: + +ROSE (Random Over-Sampling Examples) +------------------------------------ + +ROSE uses smoothed bootstrapping to draw artificial samples from the +feature space neighborhood around selected classes, using a multivariate +Gaussian kernel around randomly selected samples. First, random samples are +selected from original classes. Then the smoothing kernel distribution +is computed around the samples: :math:`\hat f(x|y=Y_i) = \sum_i^{n_j} +p_i Pr(x|x_i)=\sum_i^{n_j} \frac{1}{n_j} Pr(x|x_i)=\sum_i^{n_j} +\frac{1}{n_j} K_{H_j}(x|x_i)`. + +Then new samples are drawn from the computed distribution. + + + Mathematical formulation ======================== diff --git a/doc/whats_new/v0.7.rst b/doc/whats_new/v0.7.rst index 26f48e06c..eb931a2e8 100644 --- a/doc/whats_new/v0.7.rst +++ b/doc/whats_new/v0.7.rst @@ -63,6 +63,9 @@ Enhancements - Lazy import `keras` module when importing `imblearn.keras` :pr:`719` by :user:`Guillaume Lemaitre `. +- Added Random Over-Sampling Examples (ROSE) class. + :pr:`754` by :user:`Andrea Lorenzon `. + Deprecation ........... diff --git a/imblearn/over_sampling/__init__.py b/imblearn/over_sampling/__init__.py index bd20b76ea..3be402135 100644 --- a/imblearn/over_sampling/__init__.py +++ b/imblearn/over_sampling/__init__.py @@ -10,6 +10,7 @@ from ._smote import KMeansSMOTE from ._smote import SVMSMOTE from ._smote import SMOTENC +from ._rose import ROSE __all__ = [ "ADASYN", @@ -19,4 +20,5 @@ "BorderlineSMOTE", "SVMSMOTE", "SMOTENC", + "ROSE" ] diff --git a/imblearn/over_sampling/_rose.py b/imblearn/over_sampling/_rose.py new file mode 100644 index 000000000..48c0a3ddf --- /dev/null +++ b/imblearn/over_sampling/_rose.py @@ -0,0 +1,202 @@ +"""Class to perform over-sampling using ROSE.""" + +import numpy as np +from scipy import sparse +from sklearn.utils import check_random_state +from .base import BaseOverSampler +from ..utils._validation import _deprecate_positional_args + + +class ROSE(BaseOverSampler): + """Random Over-Sampling Examples (ROSE). + + This object is the implementation of ROSE algorithm. + It generates new samples by a smoothed bootstrap approach, + taking a random subsample of original data and adding a + multivariate kernel density estimate :math:`f(x|Y_i)` around + them with a smoothing matrix :math:`H_j`, and finally sampling + from this distribution. A shrinking matrix can be provided, to + set the bandwidth of the gaussian kernel. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + sampling_strategy : float, str, dict or callable, default='auto' + Sampling information to resample the data set. + + - When ``float``, it corresponds to the desired ratio of the number of + samples in the minority class over the number of samples in the + majority class after resampling. Therefore, the ratio is expressed as + :math:`\\alpha_{os} = N_{rm} / N_{M}` where :math:`N_{rm}` is the + number of samples in the minority class after resampling and + :math:`N_{M}` is the number of samples in the majority class. + + .. warning:: + ``float`` is only available for **binary** classification. An + error is raised for multi-class classification. + + - When ``str``, specify the class targeted by the resampling. The + number of samples in the different classes will be equalized. + Possible choices are: + + ``'minority'``: resample only the minority class; + + ``'not minority'``: resample all classes but the minority class; + + ``'not majority'``: resample all classes but the majority class; + + ``'all'``: resample all classes; + + ``'auto'``: equivalent to ``'not majority'``. + + - When ``dict``, the keys correspond to the targeted classes. The + values correspond to the desired number of samples for each targeted + class. + + - When callable, function taking ``y`` and returns a ``dict``. The keys + correspond to the targeted classes. The values correspond to the + desired number of samples for each class. + + shrink_factors : dict, default= 1 for every class + Dict of {classes: shrinkfactors} items, applied to + the gaussian kernels. It can be used to compress/dilate the kernel. + + random_state : int, RandomState instance, default=None + Control the randomization of the algorithm. + + - If int, ``random_state`` is the seed used by the random number + generator; + - If ``RandomState`` instance, random_state is the random number + generator; + - If ``None``, the random number generator is the ``RandomState`` + instance used by ``np.random``. + + n_jobs : int, default=None + Number of CPU cores used during the cross-validation loop. + ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. + ``-1`` means using all processors. See + `Glossary `_ + for more details. + + See Also + -------- + SMOTE : Over-sample using SMOTE. + + Notes + ----- + + References + ---------- + .. [1] N. Lunardon, G. Menardi, N.Torelli, "ROSE: A Package for Binary + Imbalanced Learning," R Journal, 6(1), 2014. + + .. [2] G Menardi, N. Torelli, "Training and assessing classification + rules with imbalanced data," Data Mining and Knowledge + Discovery, 28(1), pp.92-122, 2014. + + Examples + -------- + + >>> from imblearn.over_sampling import ROSE + >>> from sklearn.datasets import make_classification + >>> from collections import Counter + >>> r = ROSE(shrink_factors={0:1, 1:0.5, 2:0.7}) + >>> X, y = make_classification(n_classes=3, class_sep=2, + ... weights=[0.1, 0.7, 0.2], n_informative=3, n_redundant=1, flip_y=0, + ... n_features=20, n_clusters_per_class=1, n_samples=2000, random_state=10) + >>> print('Original dataset shape %s' % Counter(y)) + Original dataset shape Counter({1: 1400, 2: 400, 0: 200}) + >>> X_res, y_res = r.fit_resample(X, y) + >>> print('Resampled dataset shape %s' % Counter(y_res)) + Resampled dataset shape Counter({2: 1400, 1: 1400, 0: 1400}) + """ + + @_deprecate_positional_args + def __init__(self, *, sampling_strategy="auto", shrink_factors=None, + random_state=None, n_jobs=None): + super().__init__(sampling_strategy=sampling_strategy) + self.random_state = random_state + self.shrink_factors = shrink_factors + self.n_jobs = n_jobs + + def _make_samples(self, + X, + class_indices, + n_class_samples, + h_shrink): + """ A support function that returns artificial samples constructed + from a random subsample of the data, by adding a multiviariate + gaussian kernel and sampling from this distribution. An optional + shrink factor can be included, to compress/dilate the kernel. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_samples, n_features) + Observations from which the samples will be created. + + class_indices : ndarray, shape (n_class_samples,) + The target class indices + + n_class_samples : int + The total number of samples per class to generate + + h_shrink : int + the shrink factor + + Returns + ------- + X_new : {ndarray, sparse matrix}, shape (n_samples, n_features) + Synthetically generated samples. + + y_new : ndarray, shape (n_samples,) + Target values for synthetic samples. + + """ + + number_of_features = X.shape[1] + random_state = check_random_state(self.random_state) + samples_indices = random_state.choice( + class_indices, size=n_class_samples, replace=True) + minimize_amise = (4 / ((number_of_features + 2) * len( + class_indices))) ** (1 / (number_of_features + 4)) + if sparse.issparse(X): + variances = np.diagflat( + np.std(X[class_indices, :].toarray(), axis=0, ddof=1)) + else: + variances = np.diagflat( + np.std(X[class_indices, :], axis=0, ddof=1)) + h_opt = h_shrink * minimize_amise * variances + randoms = random_state.standard_normal(size=(n_class_samples, + number_of_features)) + Xrose = np.matmul(randoms, h_opt) + X[samples_indices, :] + if sparse.issparse(X): + return sparse.csr_matrix(Xrose) + return Xrose + + def _fit_resample(self, X, y): + + X_resampled = X.copy() + y_resampled = y.copy() + + if self.shrink_factors is None: + self.shrink_factors = { + key: 1 for key in self.sampling_strategy_.keys()} + + for class_sample, n_samples in self.sampling_strategy_.items(): + class_indices = np.flatnonzero(y == class_sample) + n_class_samples = n_samples + X_new = self._make_samples(X, + class_indices, + n_samples, + self.shrink_factors[class_sample]) + y_new = np.array([class_sample] * n_class_samples) + + if sparse.issparse(X_new): + X_resampled = sparse.vstack([X_resampled, X_new]) + else: + X_resampled = np.concatenate((X_resampled, X_new)) + + y_resampled = np.hstack((y_resampled, y_new)) + + return X_resampled.astype(X.dtype), y_resampled.astype(y.dtype) diff --git a/imblearn/over_sampling/tests/test_rose.py b/imblearn/over_sampling/tests/test_rose.py new file mode 100644 index 000000000..42cafd2dc --- /dev/null +++ b/imblearn/over_sampling/tests/test_rose.py @@ -0,0 +1,123 @@ +"""Test the module ROSE.""" +# Authors: Andrea Lorenzon +# License: MIT + +import numpy as np + +from imblearn.over_sampling import ROSE +from sklearn.datasets import make_spd_matrix as SymPosDef +from sklearn.utils._testing import assert_allclose +from sklearn.utils._testing import assert_array_equal + + +def test_rose(): + + """Check ROSE use""" + + RND_SEED = 0 + + nCols = 3 + ns = [50000, 50000, 75000] + # generate covariance matrices + cov1 = SymPosDef(nCols) + cov2 = SymPosDef(nCols) + cov3 = SymPosDef(nCols) + + # generate data blobs + cl1 = np.array(np.random.multivariate_normal([1, 1, 1], + cov=cov1, + size=ns[0])) + cl2 = np.array(np.random.multivariate_normal([7, 7, 7], + cov=cov2, + size=ns[1])) + cl3 = np.array(np.random.multivariate_normal([2, 9, 9], + cov=cov3, + size=ns[2])) + # assemble dataset + X = np.vstack((cl1, cl2, cl3)) + y = np.hstack((np.array([1] * ns[0]), + np.array([2] * ns[1]), + np.array([3] * ns[2]))) + + r = ROSE(random_state=RND_SEED) + res, lab = r.fit_resample(X, y) + + # compute and check similarity of covariance matrices + res_cov1 = np.cov(res[lab == 1], rowvar=False) + res_cov2 = np.cov(res[lab == 2], rowvar=False) + res_cov3 = np.cov(res[lab == 3], rowvar=False) + + assert res_cov1.shape == cov1.shape + assert res_cov2.shape == cov2.shape + assert res_cov3.shape == cov3.shape + + +def test_rose_resampler(): + """Test ROSE resampled data matches expectation.""" + + RND_SEED = 0 + X = np.array( + [ + [0.11622591, -0.0317206], + [0.77481731, 0.60935141], + [1.25192108, -0.22367336], + [0.53366841, -0.30312976], + [1.52091956, -0.49283504], + [-0.28162401, -2.10400981], + [0.83680821, 1.72827342], + [0.3084254, 0.33299982], + [0.70472253, -0.73309052], + [0.28893132, -0.38761769], + [1.15514042, 0.0129463], + [0.88407872, 0.35454207], + [1.31301027, -0.92648734], + [-1.11515198, -0.93689695], + [-0.18410027, -0.45194484], + [0.9281014, 0.53085498], + [-0.14374509, 0.27370049], + [-0.41635887, -0.38299653], + [0.08711622, 0.93259929], + [1.70580611, -0.11219234], + ] + ) + Y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0]) + R_TOL = 1e-4 + + rose = ROSE(random_state=RND_SEED) + X_resampled, y_resampled = rose.fit_resample(X, Y) + + X_gt = np.array( + [ + [0.11622591, -0.0317206], + [0.77481731, 0.60935141], + [1.25192108, -0.22367336], + [0.53366841, -0.30312976], + [1.52091956, -0.49283504], + [-0.28162401, -2.10400981], + [0.83680821, 1.72827342], + [0.3084254, 0.33299982], + [0.70472253, -0.73309052], + [0.28893132, -0.38761769], + [1.15514042, 0.0129463], + [0.88407872, 0.35454207], + [1.31301027, -0.92648734], + [-1.11515198, -0.93689695], + [-0.18410027, -0.45194484], + [0.9281014, 0.53085498], + [-0.14374509, 0.27370049], + [-0.41635887, -0.38299653], + [0.08711622, 0.93259929], + [1.70580611, -0.11219234], + [1.39400832, 0.94383454], + [2.67881738, -0.36918919], + [1.80801323, -0.96629007], + [0.06244814, 0.07625536], + ] + ) + + y_gt = np.array( + [0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, + 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0] + ) + assert_allclose(X_resampled, X_gt, rtol=R_TOL) + assert_array_equal(y_resampled, y_gt) diff --git a/maint_tools/test_docstring.py b/maint_tools/test_docstring.py index a7e05cf75..b615bdce1 100644 --- a/maint_tools/test_docstring.py +++ b/maint_tools/test_docstring.py @@ -47,6 +47,7 @@ "Pipeline.fit_transform", "Pipeline.fit_resample", "Pipeline.fit_predict", + "ROSE$", "ROSE.", "ROSE", "RUSBoostClassifier$", "RUSBoostClassifier.", "RandomOverSampler$", "RandomOverSampler.", "RandomUnderSampler$", "RandomUnderSampler.", diff --git a/references.bib b/references.bib index d1c4079ab..a1fbcd513 100644 --- a/references.bib +++ b/references.bib @@ -185,3 +185,18 @@ @article{chen2004using pages={1--12}, year={2004} } + +@article{torelli2014rose, + author = {Menardi, Giovanna and Torelli, Nicola}, + title={Training and assessing classification rules with imbalanced data}, + author={Menardi G and Torelli N}, + journal={Data Mining and Knowledge Discovery}, + volume={28}, + pages={92-122}, + year={2014}, + publisher={Springer}, + issue = {1}, + issn = {1573-756X}, + url = {https://doi.org/10.1007/s10618-012-0295-5}, + doi = {10.1007/s10618-012-0295-5} +} \ No newline at end of file