diff --git a/azure-pipelines.yml b/azure-pipelines.yml index d0ce77834..e5b58374f 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -16,6 +16,20 @@ jobs: ./build_tools/circle/linting.sh displayName: Run linting +- template: build_tools/azure/posix.yml + parameters: + name: Linux_Runs + vmImage: ubuntu-18.04 + matrix: + pylatest_pip_openblas_pandas: + DISTRIB: 'conda-pip-latest' + PYTHON_VERSION: '3.9' + COVERAGE: 'true' + PANDAS_VERSION: '*' + TEST_DOCSTRINGS: 'true' + JOBLIB_VERSION: '*' + CHECK_WARNINGS: 'true' + - template: build_tools/azure/posix.yml parameters: name: Linux @@ -29,15 +43,6 @@ jobs: DISTRIB: 'ubuntu' PYTHON_VERSION: '3.6' JOBLIB_VERSION: '*' - # Linux environment to test the latest available dependencies and MKL. - pylatest_pip_openblas_pandas: - DISTRIB: 'conda-pip-latest' - PYTHON_VERSION: '3.9' - COVERAGE: 'true' - PANDAS_VERSION: '*' - TEST_DOCSTRINGS: 'true' - JOBLIB_VERSION: '*' - CHECK_WARNINGS: 'true' pylatest_conda_pandas_keras: DISTRIB: 'conda' PYTHON_VERSION: '3.7' diff --git a/build_tools/circle/linting.sh b/build_tools/circle/linting.sh index 8a08be987..97bdb43af 100755 --- a/build_tools/circle/linting.sh +++ b/build_tools/circle/linting.sh @@ -140,7 +140,7 @@ else check_files "$(echo "$MODIFIED_FILES" | grep -v ^examples)" check_files "$(echo "$MODIFIED_FILES" | grep ^examples)" \ - --config ./examples/.flake8 + --config ./setup.cfg fi echo -e "No problem detected by flake8\n" diff --git a/doc/api.rst b/doc/api.rst index bdf85e0bb..04203bc3f 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -76,7 +76,6 @@ Prototype selection over_sampling.SMOTE over_sampling.SMOTENC over_sampling.SVMSMOTE - over_sampling.ROSE .. _combine_ref: diff --git a/doc/over_sampling.rst b/doc/over_sampling.rst index e5948cbd6..623b61bf2 100644 --- a/doc/over_sampling.rst +++ b/doc/over_sampling.rst @@ -80,6 +80,19 @@ It would also work with pandas dataframe:: >>> df_resampled, y_resampled = ros.fit_resample(df_adult, y_adult) >>> df_resampled.head() # doctest: +SKIP +If repeating samples is an issue, the parameter `smoothed_bootstrap` can be +turned to `True` to create a smoothed bootstrap. However, the original data +needs to be numerical. The `shrinkage` parameter controls the dispersion of the +new generated samples. We show an example illustrate that the new samples are +not overlapping anymore once using a smoothed bootstrap. This ways of +generating smoothed bootstrap is also known a Random Over-Sampler Examples +(ROSE) :cite:`torelli2014rose`. + +.. image:: ./auto_examples/over-sampling/images/sphx_glr_plot_comparison_over_sampling_003.png + :target: ./auto_examples/over-sampling/plot_comparison_over_sampling.html + :scale: 60 + :align: center + .. _smote_adasyn: From random over-sampling to SMOTE and ADASYN @@ -104,7 +117,7 @@ the same manner:: The figure below illustrates the major difference of the different over-sampling methods. -.. image:: ./auto_examples/over-sampling/images/sphx_glr_plot_comparison_over_sampling_003.png +.. image:: ./auto_examples/over-sampling/images/sphx_glr_plot_comparison_over_sampling_004.png :target: ./auto_examples/over-sampling/plot_comparison_over_sampling.html :scale: 60 :align: center @@ -122,14 +135,14 @@ implementation of :class:`SMOTE` will not make any distinction between easy and hard samples to be classified using the nearest neighbors rule. Therefore, the decision function found during training will be different among the algorithms. -.. image:: ./auto_examples/over-sampling/images/sphx_glr_plot_comparison_over_sampling_004.png +.. image:: ./auto_examples/over-sampling/images/sphx_glr_plot_comparison_over_sampling_005.png :target: ./auto_examples/over-sampling/plot_comparison_over_sampling.html :align: center The sampling particularities of these two algorithms can lead to some peculiar behavior as shown below. -.. image:: ./auto_examples/over-sampling/images/sphx_glr_plot_comparison_over_sampling_005.png +.. image:: ./auto_examples/over-sampling/images/sphx_glr_plot_comparison_over_sampling_006.png :target: ./auto_examples/over-sampling/plot_comparison_over_sampling.html :scale: 60 :align: center @@ -144,7 +157,7 @@ samples. Those methods focus on samples near of the border of the optimal decision function and will generate samples in the opposite direction of the nearest neighbors class. Those variants are presented in the figure below. -.. image:: ./auto_examples/over-sampling/images/sphx_glr_plot_comparison_over_sampling_006.png +.. image:: ./auto_examples/over-sampling/images/sphx_glr_plot_comparison_over_sampling_007.png :target: ./auto_examples/over-sampling/plot_comparison_over_sampling.html :scale: 60 :align: center @@ -198,29 +211,14 @@ Therefore, it can be seen that the samples generated in the first and last columns are belonging to the same categories originally presented without any other extra interpolation. -.. _rose: - -ROSE (Random Over-Sampling Examples) ------------------------------------- - -ROSE uses smoothed bootstrapping to draw artificial samples from the -feature space neighborhood around selected classes, using a multivariate -Gaussian kernel around randomly selected samples. First, random samples are -selected from original classes. Then the smoothing kernel distribution -is computed around the samples: :math:`\hat f(x|y=Y_i) = \sum_i^{n_j} -p_i Pr(x|x_i)=\sum_i^{n_j} \frac{1}{n_j} Pr(x|x_i)=\sum_i^{n_j} -\frac{1}{n_j} K_{H_j}(x|x_i)`. - -Then new samples are drawn from the computed distribution. - Mathematical formulation ======================== Sample generation ----------------- -Both SMOTE and ADASYN use the same algorithm to generate new samples. -Considering a sample :math:`x_i`, a new sample :math:`x_{new}` will be +Both :class:`SMOTE` and :class:`ADASYN` use the same algorithm to generate new +samples. Considering a sample :math:`x_i`, a new sample :math:`x_{new}` will be generated considering its k neareast-neighbors (corresponding to ``k_neighbors``). For instance, the 3 nearest-neighbors are included in the blue circle as illustrated in the figure below. Then, one of these diff --git a/doc/whats_new/v0.7.rst b/doc/whats_new/v0.7.rst index 4691c112c..78502200f 100644 --- a/doc/whats_new/v0.7.rst +++ b/doc/whats_new/v0.7.rst @@ -72,8 +72,12 @@ Enhancements - Lazy import `keras` module when importing `imblearn.keras` :pr:`719` by :user:`Guillaume Lemaitre `. -- Added Random Over-Sampling Examples (ROSE) class. - :pr:`754` by :user:`Andrea Lorenzon `. +- Added an option to generate smoothed bootstrap in + :class:`imblearn.over_sampling.RandomOverSampler`. It is controls by the + parameters `smoothed_bootstrap` and `shrinkage`. This method is also known as + Random Over-Sampling Examples (ROSE). + :pr:`754` by :user:`Andrea Lorenzon ` and + :user:`Guillaume Lemaitre `. - Add option `output_dict` in :func:`imblearn.metrics.classification_report_imbalanced` to return a diff --git a/examples/over-sampling/plot_comparison_over_sampling.py b/examples/over-sampling/plot_comparison_over_sampling.py index 49c1f4f18..37c370c38 100644 --- a/examples/over-sampling/plot_comparison_over_sampling.py +++ b/examples/over-sampling/plot_comparison_over_sampling.py @@ -106,16 +106,15 @@ def plot_decision_function(X, y, clf, ax): # data using a linear SVM classifier. Greater is the difference between the # number of samples in each class, poorer are the classfication results. -fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12)) +fig, axs = plt.subplots(2, 2, figsize=(15, 12)) -ax_arr = (ax1, ax2, ax3, ax4) weights_arr = ( (0.01, 0.01, 0.98), (0.01, 0.05, 0.94), (0.2, 0.1, 0.7), (0.33, 0.33, 0.33), ) -for ax, weights in zip(ax_arr, weights_arr): +for ax, weights in zip(axs.ravel(), weights_arr): X, y = create_dataset(n_samples=1000, weights=weights) clf = LinearSVC().fit(X, y) plot_decision_function(X, y, clf, ax) @@ -129,20 +128,40 @@ def plot_decision_function(X, y, clf, ax): ############################################################################### # Random over-sampling can be used to repeat some samples and balance the # number of samples between the dataset. It can be seen that with this trivial -# approach the boundary decision is already less biaised toward the majority +# approach the boundary decision is already less biased toward the majority # class. -fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7)) +fig, axs = plt.subplots(1, 2, figsize=(15, 7)) X, y = create_dataset(n_samples=10000, weights=(0.01, 0.05, 0.94)) clf = LinearSVC().fit(X, y) -plot_decision_function(X, y, clf, ax1) -ax1.set_title(f"Linear SVC with y={Counter(y)}") +plot_decision_function(X, y, clf, axs[0]) +axs[0].set_title(f"Linear SVC with y={Counter(y)}") pipe = make_pipeline(RandomOverSampler(random_state=0), LinearSVC()) pipe.fit(X, y) -plot_decision_function(X, y, pipe, ax2) -ax2.set_title("Decision function for RandomOverSampler") +plot_decision_function(X, y, pipe, axs[1]) +axs[1].set_title("Decision function for RandomOverSampler") fig.tight_layout() +############################################################################### +# By default, random over-sampling generates a bootstrap. The parameter +# `smoothed_bootstrap` allows adding a small perturbation to the generated data +# to generate a smoothed bootstrap instead. The plot below shows the difference +# between the two data generation strategies. + +fig, axs = plt.subplots(1, 2, figsize=(15, 7)) +sampler = RandomOverSampler(random_state=0) +plot_resampling(X, y, sampler, ax=axs[0]) +axs[0].set_title("RandomOverSampler with normal bootstrap") +sampler = RandomOverSampler(smoothed_bootstrap=True, shrinkage=0.2, random_state=0) +plot_resampling(X, y, sampler, ax=axs[1]) +axs[1].set_title("RandomOverSampler with smoothed bootstrap") +fig.tight_layout() + +############################################################################### +# It looks like more samples are generated with smoothed bootstrap. This is due +# to the fact that the samples generated are not superimposing with the +# original samples. +# ############################################################################### # More advanced over-sampling using ADASYN and SMOTE ############################################################################### @@ -161,16 +180,15 @@ def _fit_resample(self, X, y): return X, y -fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 15)) +fig, axs = plt.subplots(2, 2, figsize=(15, 15)) X, y = create_dataset(n_samples=10000, weights=(0.01, 0.05, 0.94)) sampler = FakeSampler() clf = make_pipeline(sampler, LinearSVC()) -plot_resampling(X, y, sampler, ax1) -ax1.set_title(f"Original data - y={Counter(y)}") +plot_resampling(X, y, sampler, axs[0, 0]) +axs[0, 0].set_title(f"Original data - y={Counter(y)}") -ax_arr = (ax2, ax3, ax4) for ax, sampler in zip( - ax_arr, + axs.ravel()[1:], ( RandomOverSampler(random_state=0), SMOTE(random_state=0), @@ -189,33 +207,32 @@ def _fit_resample(self, X, y): # nearest-neighbors rule while regular SMOTE will not make any distinction. # Therefore, the decision function depending of the algorithm. -fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 6)) +fig, axs = plt.subplots(1, 3, figsize=(20, 6)) X, y = create_dataset(n_samples=10000, weights=(0.01, 0.05, 0.94)) clf = LinearSVC().fit(X, y) -plot_decision_function(X, y, clf, ax1) -ax1.set_title(f"Linear SVC with y={Counter(y)}") +plot_decision_function(X, y, clf, axs[0]) +axs[0].set_title(f"Linear SVC with y={Counter(y)}") sampler = SMOTE() clf = make_pipeline(sampler, LinearSVC()) clf.fit(X, y) -plot_decision_function(X, y, clf, ax2) -ax2.set_title(f"Decision function for {sampler.__class__.__name__}") +plot_decision_function(X, y, clf, axs[1]) +axs[1].set_title(f"Decision function for {sampler.__class__.__name__}") sampler = ADASYN() clf = make_pipeline(sampler, LinearSVC()) clf.fit(X, y) -plot_decision_function(X, y, clf, ax3) -ax3.set_title(f"Decision function for {sampler.__class__.__name__}") +plot_decision_function(X, y, clf, axs[2]) +axs[2].set_title(f"Decision function for {sampler.__class__.__name__}") fig.tight_layout() ############################################################################### # Due to those sampling particularities, it can give rise to some specific # issues as illustrated below. -fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 15)) +fig, axs = plt.subplots(2, 2, figsize=(15, 15)) X, y = create_dataset(n_samples=5000, weights=(0.01, 0.05, 0.94), class_sep=0.8) -ax_arr = ((ax1, ax2), (ax3, ax4)) -for ax, sampler in zip(ax_arr, (SMOTE(random_state=0), ADASYN(random_state=0))): +for ax, sampler in zip(axs, (SMOTE(random_state=0), ADASYN(random_state=0))): clf = make_pipeline(sampler, LinearSVC()) clf.fit(X, y) plot_decision_function(X, y, clf, ax[0]) @@ -232,16 +249,11 @@ def _fit_resample(self, X, y): # the KMeans version will make a clustering before to generate samples in each # cluster independently depending each cluster density. -( - fig, - ((ax1, ax2), (ax3, ax4), (ax5, ax6), (ax7, ax8), (ax9, ax10)), -) = plt.subplots(5, 2, figsize=(15, 30)) +fig, axs = plt.subplots(5, 2, figsize=(15, 30)) X, y = create_dataset(n_samples=5000, weights=(0.01, 0.05, 0.94), class_sep=0.8) - -ax_arr = ((ax1, ax2), (ax3, ax4), (ax5, ax6), (ax7, ax8), (ax9, ax10)) for ax, sampler in zip( - ax_arr, + axs, ( SMOTE(random_state=0), BorderlineSMOTE(random_state=0, kind="borderline-1"), @@ -282,5 +294,3 @@ def _fit_resample(self, X, y): print(sorted(Counter(y_resampled).items())) print("SMOTE-NC will generate categories for the categorical features:") print(X_resampled[-5:]) - -plt.show() diff --git a/examples/over-sampling/plot_shrinkage_effect.py b/examples/over-sampling/plot_shrinkage_effect.py new file mode 100644 index 000000000..14504cef5 --- /dev/null +++ b/examples/over-sampling/plot_shrinkage_effect.py @@ -0,0 +1,113 @@ +""" +====================================================== +Effect of the shrinkage factor in random over-sampling +====================================================== + +This example shows the effect of the shrinkage factor used to generate the +smoothed bootstrap using the +:class:`~imblearn.over_sampling.RandomOverSampler`. +""" + +# Authors: Guillaume Lemaitre +# License: MIT + +print(__doc__) + +# %% +# First, we will generate a toy classification dataset with only few samples. +# The ratio between the classes will be imbalanced. +from collections import Counter +from sklearn.datasets import make_classification + +X, y = make_classification( + n_samples=100, + n_features=2, + n_redundant=0, + weights=[0.1, 0.9], + random_state=0, +) +Counter(y) + + +# %% +import matplotlib.pyplot as plt + +fig, ax = plt.subplots() +scatter = plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.4) +class_legend = ax.legend(*scatter.legend_elements(), loc="lower left", title="Classes") +ax.add_artist(class_legend) +ax.set_xlabel("Feature #1") +_ = ax.set_ylabel("Feature #2") + +# %% +# Now, we will use a :class:`~imblearn.over_sampling.RandomOverSampler` to +# generate a bootstrap for the minority class with as many samples as in the +# majority class. +from imblearn.over_sampling import RandomOverSampler + +sampler = RandomOverSampler(random_state=0) +X_res, y_res = sampler.fit_resample(X, y) +Counter(y_res) + +# %% +fig, ax = plt.subplots() +scatter = plt.scatter(X_res[:, 0], X_res[:, 1], c=y_res, alpha=0.4) +class_legend = ax.legend(*scatter.legend_elements(), loc="lower left", title="Classes") +ax.add_artist(class_legend) +ax.set_xlabel("Feature #1") +_ = ax.set_ylabel("Feature #2") +# %% +# We observe that the minority samples are less transparent than the samples +# from the majority class. Indeed, it is due to the fact that these samples +# of the minority class are repeated during the bootstrap generation. +# +# We can set `smoothed_bootstrap=True` to add a small perturbation to the +# samples created and therefore create a smoothed bootstrap. +sampler = RandomOverSampler(smoothed_bootstrap=True, random_state=0) +X_res, y_res = sampler.fit_resample(X, y) +Counter(y_res) + +# %% +fig, ax = plt.subplots() +scatter = plt.scatter(X_res[:, 0], X_res[:, 1], c=y_res, alpha=0.4) +class_legend = ax.legend(*scatter.legend_elements(), loc="lower left", title="Classes") +ax.add_artist(class_legend) +ax.set_xlabel("Feature #1") +_ = ax.set_ylabel("Feature #2") + +# %% +# In this case, we see that the samples in the minority class are not +# overlapping anymore due to the added noise. +# +# The parameter `shrinkage` allows to add more or less perturbation. Let's +# add more perturbation when generating the smoothed bootstrap. +sampler = RandomOverSampler(smoothed_bootstrap=True, shrinkage=3, random_state=0) +X_res, y_res = sampler.fit_resample(X, y) +Counter(y_res) + +# %% +fig, ax = plt.subplots() +scatter = plt.scatter(X_res[:, 0], X_res[:, 1], c=y_res, alpha=0.4) +class_legend = ax.legend(*scatter.legend_elements(), loc="lower left", title="Classes") +ax.add_artist(class_legend) +ax.set_xlabel("Feature #1") +_ = ax.set_ylabel("Feature #2") + +# %% +# Increasing the value of `shrinkage` will disperse the new samples. Forcing +# the shrinkage to 0 will be equivalent to generating a normal bootstrap. +sampler = RandomOverSampler(smoothed_bootstrap=True, shrinkage=0, random_state=0) +X_res, y_res = sampler.fit_resample(X, y) +Counter(y_res) + +# %% +fig, ax = plt.subplots() +scatter = plt.scatter(X_res[:, 0], X_res[:, 1], c=y_res, alpha=0.4) +class_legend = ax.legend(*scatter.legend_elements(), loc="lower left", title="Classes") +ax.add_artist(class_legend) +ax.set_xlabel("Feature #1") +_ = ax.set_ylabel("Feature #2") + +# %% +# Therefore, the `shrinkage` is handy to manually tune the dispersion of the +# new samples. diff --git a/imblearn/over_sampling/__init__.py b/imblearn/over_sampling/__init__.py index 3be402135..bd20b76ea 100644 --- a/imblearn/over_sampling/__init__.py +++ b/imblearn/over_sampling/__init__.py @@ -10,7 +10,6 @@ from ._smote import KMeansSMOTE from ._smote import SVMSMOTE from ._smote import SMOTENC -from ._rose import ROSE __all__ = [ "ADASYN", @@ -20,5 +19,4 @@ "BorderlineSMOTE", "SVMSMOTE", "SMOTENC", - "ROSE" ] diff --git a/imblearn/over_sampling/_random_over_sampler.py b/imblearn/over_sampling/_random_over_sampler.py index be63ae308..928e5d24d 100644 --- a/imblearn/over_sampling/_random_over_sampler.py +++ b/imblearn/over_sampling/_random_over_sampler.py @@ -4,11 +4,13 @@ # Christos Aridas # License: MIT -from collections import Counter +from numbers import Real import numpy as np -from sklearn.utils import check_random_state +from scipy import sparse +from sklearn.utils import check_array, check_random_state from sklearn.utils import _safe_indexing +from sklearn.utils.sparsefuncs import mean_variance_axis from .base import BaseOverSampler from ..utils import check_target_type @@ -25,7 +27,7 @@ class RandomOverSampler(BaseOverSampler): """Class to perform random over-sampling. Object to over-sample the minority class(es) by picking samples at random - with replacement. + with replacement. The bootstrap can be generated in a smoothed manner. Read more in the :ref:`User Guide `. @@ -35,6 +37,23 @@ class RandomOverSampler(BaseOverSampler): {random_state} + smoothed_bootstrap : bool, default=False + Whether or not to generate smoothed bootstrap samples. When this option + is triggered, be aware that the data to be resampled needs to be + numerical data since a Gaussian perturbation will be generated and + added to the bootstrap. + + .. versionadded:: 0.7 + + shrinkage : float or dict, default=1.0 + Factor to shrink the covariance matrix used to generate the + smoothed bootstrap. A factor could be shared by all classes by + providing a floating number or different for each class over-sampled + by providing a dictionary where the key are the class targeted and the + value is the shrinkage factor. + + .. versionadded:: 0.7 + Attributes ---------- sample_indices_ : ndarray of shape (n_new_samples,) @@ -42,11 +61,15 @@ class RandomOverSampler(BaseOverSampler): .. versionadded:: 0.4 + shrinkage_ : dict or None + The per-class shrinkage factor used to generate the smoothed bootstrap + sample. `None` when `smoothed_bootstrap=False`. + + .. versionadded:: 0.7 + See Also -------- - ROSE : Random Over-Sampling Examples. - - BorderlineSMOTE : Over-sample using the bordeline-SMOTE variant. + BorderlineSMOTE : Over-sample using the borderline-SMOTE variant. SMOTE : Over-sample using SMOTE. @@ -65,6 +88,20 @@ class RandomOverSampler(BaseOverSampler): Supports heterogeneous data as object array containing string and numeric data. + When generating a smoothed bootstrap, this method is also known as Random + Over-Sampling Examples (ROSE) [1]_. + + .. warning:: + Since smoothed bootstrap are generated by adding a small perturbation + to the drawn samples, this method is not adequate when working with + sparse matrices. + + References + ---------- + .. [1] G Menardi, N. Torelli, "Training and assessing classification + rules with imbalanced data," Data Mining and Knowledge + Discovery, 28(1), pp.92-122, 2014. + Examples -------- >>> from collections import Counter @@ -83,9 +120,18 @@ class RandomOverSampler(BaseOverSampler): """ @_deprecate_positional_args - def __init__(self, *, sampling_strategy="auto", random_state=None): + def __init__( + self, + *, + sampling_strategy="auto", + random_state=None, + smoothed_bootstrap=False, + shrinkage=1.0, + ): super().__init__(sampling_strategy=sampling_strategy) self.random_state = random_state + self.smoothed_bootstrap = smoothed_bootstrap + self.shrinkage = shrinkage def _check_X_y(self, X, y): y, binarize_y = check_target_type(y, indicate_one_vs_all=True) @@ -101,23 +147,85 @@ def _check_X_y(self, X, y): def _fit_resample(self, X, y): random_state = check_random_state(self.random_state) - target_stats = Counter(y) - sample_indices = range(X.shape[0]) + if self.smoothed_bootstrap: + if isinstance(self.shrinkage, Real): + self.shrinkage_ = { + klass: self.shrinkage for klass in self.sampling_strategy_ + } + else: + missing_shrinkage_keys = ( + self.sampling_strategy_.keys() - self.shrinkage.keys() + ) + if missing_shrinkage_keys: + raise ValueError( + f"`shrinkage` should contain a shrinkage factor for " + f"each class that will be resampled. The missing " + f"classes are: {repr(missing_shrinkage_keys)}" + ) + self.shrinkage_ = self.shrinkage + # smoothed bootstrap imposes to make numerical operation; we need + # to be sure to have only numerical data in X + try: + X = check_array(X, accept_sparse=["csr", "csc"], dtype="numeric") + except ValueError as exc: + raise ValueError( + "When smoothed_bootstrap=True, X needs to contain only " + "numerical data to later generate a smoothed bootstrap " + "sample." + ) from exc + else: + self.shrinkage_ = None + + X_resampled = [X.copy()] + y_resampled = [y.copy()] + sample_indices = range(X.shape[0]) for class_sample, num_samples in self.sampling_strategy_.items(): target_class_indices = np.flatnonzero(y == class_sample) - indices = random_state.randint( - low=0, high=target_stats[class_sample], size=num_samples + bootstrap_indices = random_state.choice( + target_class_indices, + size=num_samples, + replace=True, ) + sample_indices = np.append(sample_indices, bootstrap_indices) + if self.smoothed_bootstrap: + # generate a smoothed bootstrap with a perturbation + n_samples, n_features = X.shape + smoothing_constant = (4 / ((n_features + 2) * n_samples)) ** ( + 1 / (n_features + 4) + ) + if sparse.issparse(X): + _, X_class_variance = mean_variance_axis( + X[target_class_indices, :], + axis=0, + ) + X_class_scale = np.sqrt(X_class_variance, out=X_class_variance) + else: + X_class_scale = np.std(X[target_class_indices, :], axis=0) + smoothing_matrix = np.diagflat( + self.shrinkage_[class_sample] * smoothing_constant * X_class_scale + ) + X_new = random_state.randn(num_samples, n_features) + X_new = X_new.dot(smoothing_matrix) + X[bootstrap_indices, :] + if sparse.issparse(X): + X_new = sparse.csr_matrix(X_new, dtype=X.dtype) + X_resampled.append(X_new) + else: + # generate a bootstrap + X_resampled.append(_safe_indexing(X, bootstrap_indices)) + + y_resampled.append(_safe_indexing(y, bootstrap_indices)) - sample_indices = np.append(sample_indices, target_class_indices[indices]) self.sample_indices_ = np.array(sample_indices) - return ( - _safe_indexing(X, sample_indices), - _safe_indexing(y, sample_indices), - ) + if sparse.issparse(X): + X_resampled = sparse.vstack(X_resampled, format=X.format) + else: + X_resampled = np.vstack(X_resampled) + y_resampled = np.hstack(y_resampled) + + return X_resampled, y_resampled def _more_tags(self): return { diff --git a/imblearn/over_sampling/_rose.py b/imblearn/over_sampling/_rose.py deleted file mode 100644 index 125fd6d3b..000000000 --- a/imblearn/over_sampling/_rose.py +++ /dev/null @@ -1,172 +0,0 @@ -"""Class to perform over-sampling using ROSE.""" - -import numpy as np -from scipy import sparse -from sklearn.utils import check_random_state - -from .base import BaseOverSampler -from ..utils import Substitution -from ..utils._docstring import _random_state_docstring -from ..utils._docstring import _n_jobs_docstring -from ..utils._validation import _deprecate_positional_args - - -@Substitution( - sampling_strategy=BaseOverSampler._sampling_strategy_docstring, - random_state=_random_state_docstring, - n_jobs=_n_jobs_docstring -) -class ROSE(BaseOverSampler): - """Random Over-Sampling Examples (ROSE). - - This object is the implementation of ROSE algorithm. - It generates new samples by a smoothed bootstrap approach, - taking a random subsample of original data and adding a - multivariate kernel density estimate :math:`f(x|Y_i)` around - them with a smoothing matrix :math:`H_j`, and finally sampling - from this distribution. A shrinking matrix can be provided, to - set the bandwidth of the gaussian kernel. - - Read more in the :ref:`User Guide `. - - Parameters - ---------- - {sampling_strategy} - - shrink_factors : dict, default=None - Dictionary where the key is the label and the value is the shrinkage - factor. If `None`, each class shrinkage is equal to 1. - The shrinkage applies to the gaussian kernels. It can be used to - compress/dilate the kernel. - - {random_state} - - {n_jobs} - - See Also - -------- - BorderlineSMOTE : Over-sample using the bordeline-SMOTE variant. - - SMOTE : Over-sample using SMOTE. - - SMOTENC : Over-sample using SMOTE for continuous and categorical features. - - SVMSMOTE : Over-sample using SVM-SMOTE variant. - - ADASYN : Over-sample using ADASYN. - - KMeansSMOTE : Over-sample applying a clustering before to oversample using - SMOTE. - - References - ---------- - .. [1] N. Lunardon, G. Menardi, N.Torelli, "ROSE: A Package for Binary - Imbalanced Learning," R Journal, 6(1), 2014. - - .. [2] G Menardi, N. Torelli, "Training and assessing classification - rules with imbalanced data," Data Mining and Knowledge - Discovery, 28(1), pp.92-122, 2014. - - Examples - -------- - >>> from imblearn.over_sampling import ROSE - >>> from sklearn.datasets import make_classification - >>> from collections import Counter - >>> r = ROSE(shrink_factors={{0:1, 1:0.5, 2:0.7}}) - >>> X, y = make_classification(n_classes=3, class_sep=2, - ... weights=[0.1, 0.7, 0.2], n_informative=3, n_redundant=1, flip_y=0, - ... n_features=20, n_clusters_per_class=1, n_samples=2000, random_state=10) - >>> print('Original dataset shape %s' % Counter(y)) - Original dataset shape Counter({{1: 1400, 2: 400, 0: 200}}) - >>> X_res, y_res = r.fit_resample(X, y) - >>> print('Resampled dataset shape %s' % Counter(y_res)) - Resampled dataset shape Counter({{2: 1400, 1: 1400, 0: 1400}}) - """ - - @_deprecate_positional_args - def __init__(self, *, sampling_strategy="auto", shrink_factors=None, - random_state=None, n_jobs=None): - super().__init__(sampling_strategy=sampling_strategy) - self.random_state = random_state - self.shrink_factors = shrink_factors - self.n_jobs = n_jobs - - def _make_samples(self, - X, - class_indices, - n_class_samples, - h_shrink): - """A support function that returns artificial samples constructed - from a random subsample of the data, by adding a multiviariate - gaussian kernel and sampling from this distribution. An optional - shrink factor can be included, to compress/dilate the kernel. - - Parameters - ---------- - X : {array-like, sparse matrix} of shape (n_samples, n_features) - Observations from which the samples will be created. - - class_indices : ndarray of shape (n_class_samples,) - The target class indices - - n_class_samples : int - The total number of samples per class to generate - - h_shrink : int - the shrink factor - - Returns - ------- - X_new : {ndarray, sparse matrix} of shape (n_samples, n_features) - Synthetically generated samples. - - y_new : ndarray of shape (n_samples,) - Target values for synthetic samples. - """ - - number_of_features = X.shape[1] - random_state = check_random_state(self.random_state) - samples_indices = random_state.choice( - class_indices, size=n_class_samples, replace=True) - minimize_amise = (4 / ((number_of_features + 2) * len( - class_indices))) ** (1 / (number_of_features + 4)) - if sparse.issparse(X): - variances = np.diagflat( - np.std(X[class_indices, :].toarray(), axis=0, ddof=1)) - else: - variances = np.diagflat( - np.std(X[class_indices, :], axis=0, ddof=1)) - h_opt = h_shrink * minimize_amise * variances - randoms = random_state.standard_normal(size=(n_class_samples, - number_of_features)) - Xrose = np.matmul(randoms, h_opt) + X[samples_indices, :] - if sparse.issparse(X): - return sparse.csr_matrix(Xrose) - return Xrose - - def _fit_resample(self, X, y): - - X_resampled = X.copy() - y_resampled = y.copy() - - if self.shrink_factors is None: - self.shrink_factors = { - key: 1 for key in self.sampling_strategy_.keys()} - - for class_sample, n_samples in self.sampling_strategy_.items(): - class_indices = np.flatnonzero(y == class_sample) - n_class_samples = n_samples - X_new = self._make_samples(X, - class_indices, - n_samples, - self.shrink_factors[class_sample]) - y_new = np.array([class_sample] * n_class_samples) - - if sparse.issparse(X_new): - X_resampled = sparse.vstack([X_resampled, X_new]) - else: - X_resampled = np.concatenate((X_resampled, X_new)) - - y_resampled = np.hstack((y_resampled, y_new)) - - return X_resampled.astype(X.dtype), y_resampled.astype(y.dtype) diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index 7240a86fa..ea66c7fec 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -247,8 +247,6 @@ class BorderlineSMOTE(BaseSMOTE): See Also -------- - ROSE : Random Over-Sampling Examples. - SMOTE : Over-sample using SMOTE. SMOTENC : Over-sample using SMOTE for continuous and categorical features. @@ -446,8 +444,6 @@ class SVMSMOTE(BaseSMOTE): See Also -------- - ROSE : Random Over-Sampling Examples. - SMOTE : Over-sample using SMOTE. SMOTENC : Over-sample using SMOTE for continuous and categorical features. @@ -645,8 +641,6 @@ class SMOTE(BaseSMOTE): See Also -------- - ROSE : Random Over-Sampling Examples. - SMOTENC : Over-sample using SMOTE for continuous and categorical features. BorderlineSMOTE : Over-sample using the borderline-SMOTE variant. @@ -815,8 +809,6 @@ class SMOTENC(SMOTE): See Also -------- - ROSE : Random Over-Sampling Examples. - SMOTE : Over-sample using SMOTE. SVMSMOTE : Over-sample using SVM-SMOTE variant. @@ -1106,8 +1098,6 @@ class KMeansSMOTE(BaseSMOTE): See Also -------- - ROSE : Random Over-Sampling Examples. - SMOTE : Over-sample using SMOTE. SVMSMOTE : Over-sample using SVM-SMOTE variant. diff --git a/imblearn/over_sampling/tests/test_random_over_sampler.py b/imblearn/over_sampling/tests/test_random_over_sampler.py index 27f936636..fb448970a 100644 --- a/imblearn/over_sampling/tests/test_random_over_sampler.py +++ b/imblearn/over_sampling/tests/test_random_over_sampler.py @@ -10,25 +10,31 @@ from sklearn.utils._testing import assert_allclose from sklearn.utils._testing import assert_array_equal +from sklearn.utils._testing import _convert_container from imblearn.over_sampling import RandomOverSampler RND_SEED = 0 -X = np.array( - [ - [0.04352327, -0.20515826], - [0.92923648, 0.76103773], - [0.20792588, 1.49407907], - [0.47104475, 0.44386323], - [0.22950086, 0.33367433], - [0.15490546, 0.3130677], - [0.09125309, -0.85409574], - [0.12372842, 0.6536186], - [0.13347175, 0.12167502], - [0.094035, -2.55298982], - ] -) -Y = np.array([1, 0, 1, 0, 1, 1, 1, 1, 0, 1]) + + +@pytest.fixture +def data(): + X = np.array( + [ + [0.04352327, -0.20515826], + [0.92923648, 0.76103773], + [0.20792588, 1.49407907], + [0.47104475, 0.44386323], + [0.22950086, 0.33367433], + [0.15490546, 0.3130677], + [0.09125309, -0.85409574], + [0.12372842, 0.6536186], + [0.13347175, 0.12167502], + [0.094035, -2.55298982], + ] + ) + Y = np.array([1, 0, 1, 0, 1, 1, 1, 1, 0, 1]) + return X, Y def test_ros_init(): @@ -37,14 +43,15 @@ def test_ros_init(): assert ros.random_state == RND_SEED -@pytest.mark.parametrize("as_frame", [True, False], ids=["dataframe", "array"]) -def test_ros_fit_resample(as_frame): - if as_frame: - pd = pytest.importorskip("pandas") - X_ = pd.DataFrame(X) - else: - X_ = X - ros = RandomOverSampler(random_state=RND_SEED) +@pytest.mark.parametrize( + "params", + [{"smoothed_bootstrap": False}, {"smoothed_bootstrap": True, "shrinkage": 0}] +) +@pytest.mark.parametrize("X_type", ["array", "dataframe"]) +def test_ros_fit_resample(X_type, data, params): + X, Y = data + X_ = _convert_container(X, X_type) + ros = RandomOverSampler(**params, random_state=RND_SEED) X_resampled, y_resampled = ros.fit_resample(X_, Y) X_gt = np.array( [ @@ -66,17 +73,29 @@ def test_ros_fit_resample(as_frame): ) y_gt = np.array([1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0]) - if as_frame: + if X_type == "dataframe": assert hasattr(X_resampled, "loc") X_resampled = X_resampled.to_numpy() assert_allclose(X_resampled, X_gt) assert_array_equal(y_resampled, y_gt) + if not params["smoothed_bootstrap"]: + assert ros.shrinkage_ is None + else: + assert ros.shrinkage_ == {0: 0} + -def test_ros_fit_resample_half(): +@pytest.mark.parametrize( + "params", + [{"smoothed_bootstrap": False}, {"smoothed_bootstrap": True, "shrinkage": 0}] +) +def test_ros_fit_resample_half(data, params): + X, Y = data sampling_strategy = {0: 3, 1: 7} - ros = RandomOverSampler(sampling_strategy=sampling_strategy, random_state=RND_SEED) + ros = RandomOverSampler( + **params, sampling_strategy=sampling_strategy, random_state=RND_SEED + ) X_resampled, y_resampled = ros.fit_resample(X, Y) X_gt = np.array( [ @@ -96,20 +115,38 @@ def test_ros_fit_resample_half(): assert_allclose(X_resampled, X_gt) assert_array_equal(y_resampled, y_gt) + if not params["smoothed_bootstrap"]: + assert ros.shrinkage_ is None + else: + assert ros.shrinkage_ == {0: 0, 1: 0} + -def test_multiclass_fit_resample(): +@pytest.mark.parametrize( + "params", + [{"smoothed_bootstrap": False}, {"smoothed_bootstrap": True, "shrinkage": 0}] +) +def test_multiclass_fit_resample(data, params): + # check the random over-sampling with a multiclass problem + X, Y = data y = Y.copy() y[5] = 2 y[6] = 2 - ros = RandomOverSampler(random_state=RND_SEED) + ros = RandomOverSampler(**params, random_state=RND_SEED) X_resampled, y_resampled = ros.fit_resample(X, y) count_y_res = Counter(y_resampled) assert count_y_res[0] == 5 assert count_y_res[1] == 5 assert count_y_res[2] == 5 + if not params["smoothed_bootstrap"]: + assert ros.shrinkage_ is None + else: + assert ros.shrinkage_ == {0: 0, 2: 0} + def test_random_over_sampling_heterogeneous_data(): + # check that resampling with heterogeneous dtype is working with basic + # resampling X_hetero = np.array( [["xxx", 1, 1.0], ["yyy", 2, 2.0], ["zzz", 3, 3.0]], dtype=object ) @@ -123,9 +160,10 @@ def test_random_over_sampling_heterogeneous_data(): assert X_res[-1, 0] in X_hetero[:, 0] -def test_random_over_sampling_nan_inf(): +def test_random_over_sampling_nan_inf(data): # check that we can oversample even with missing or infinite data # regression tests for #605 + X, Y = data rng = np.random.RandomState(42) n_not_finite = X.shape[0] // 3 row_indices = rng.choice(np.arange(X.shape[0]), size=n_not_finite) @@ -141,3 +179,85 @@ def test_random_over_sampling_nan_inf(): assert y_res.shape == (14,) assert X_res.shape == (14, 2) assert np.any(~np.isfinite(X_res)) + + +def test_random_over_sampling_heterogeneous_data_smoothed_bootstrap(): + # check that we raise an error when heterogeneous dtype data are given + # and a smoothed bootstrap is requested + X_hetero = np.array( + [["xxx", 1, 1.0], ["yyy", 2, 2.0], ["zzz", 3, 3.0]], dtype=object + ) + y = np.array([0, 0, 1]) + ros = RandomOverSampler( + smoothed_bootstrap=True, + random_state=RND_SEED, + ) + err_msg = "When smoothed_bootstrap=True, X needs to contain only numerical" + with pytest.raises(ValueError, match=err_msg): + ros.fit_resample(X_hetero, y) + + +@pytest.mark.parametrize("X_type", ["dataframe", "array", "sparse_csr", "sparse_csc"]) +def test_random_over_sampler_smoothed_bootstrap(X_type, data): + # check that smoothed bootstrap is working for numerical array + X, y = data + sampler = RandomOverSampler(smoothed_bootstrap=True, shrinkage=1) + X = _convert_container(X, X_type) + X_res, y_res = sampler.fit_resample(X, y) + + assert y_res.shape == (14,) + assert X_res.shape == (14, 2) + + if X_type == "dataframe": + assert hasattr(X_res, "loc") + + +def test_random_over_sampler_equivalence_shrinkage(data): + # check that a shrinkage factor of 0 is equivalent to not create a smoothed + # bootstrap + X, y = data + + ros_not_shrink = RandomOverSampler( + smoothed_bootstrap=True, shrinkage=0, random_state=0 + ) + ros_hard_bootstrap = RandomOverSampler(smoothed_bootstrap=False, random_state=0) + + X_res_not_shrink, y_res_not_shrink = ros_not_shrink.fit_resample(X, y) + X_res, y_res = ros_hard_bootstrap.fit_resample(X, y) + + assert_allclose(X_res_not_shrink, X_res) + assert_allclose(y_res_not_shrink, y_res) + + assert y_res.shape == (14,) + assert X_res.shape == (14, 2) + assert y_res_not_shrink.shape == (14,) + assert X_res_not_shrink.shape == (14, 2) + + +def test_random_over_sampler_shrinkage_behaviour(data): + # check the behaviour of the shrinkage parameter + # the covariance of the data generated with the larger shrinkage factor + # should also be larger. + X, y = data + + ros = RandomOverSampler(smoothed_bootstrap=True, shrinkage=1, random_state=0) + X_res_shink_1, y_res_shrink_1 = ros.fit_resample(X, y) + + ros.set_params(shrinkage=5) + X_res_shink_5, y_res_shrink_5 = ros.fit_resample(X, y) + + disperstion_shrink_1 = np.linalg.det(np.cov(X_res_shink_1[y_res_shrink_1 == 0].T)) + disperstion_shrink_5 = np.linalg.det(np.cov(X_res_shink_5[y_res_shrink_5 == 0].T)) + + assert disperstion_shrink_1 < disperstion_shrink_5 + + +def test_random_over_sampler_shrinkage_error(data): + # check that we raise proper error when shrinkage do not contain the + # necessary information + X, y = data + shrinkage = {} + ros = RandomOverSampler(smoothed_bootstrap=True, shrinkage=shrinkage) + err_msg = "`shrinkage` should contain a shrinkage factor for each class" + with pytest.raises(ValueError, match=err_msg): + ros.fit_resample(X, y) diff --git a/imblearn/over_sampling/tests/test_rose.py b/imblearn/over_sampling/tests/test_rose.py deleted file mode 100644 index 42cafd2dc..000000000 --- a/imblearn/over_sampling/tests/test_rose.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Test the module ROSE.""" -# Authors: Andrea Lorenzon -# License: MIT - -import numpy as np - -from imblearn.over_sampling import ROSE -from sklearn.datasets import make_spd_matrix as SymPosDef -from sklearn.utils._testing import assert_allclose -from sklearn.utils._testing import assert_array_equal - - -def test_rose(): - - """Check ROSE use""" - - RND_SEED = 0 - - nCols = 3 - ns = [50000, 50000, 75000] - # generate covariance matrices - cov1 = SymPosDef(nCols) - cov2 = SymPosDef(nCols) - cov3 = SymPosDef(nCols) - - # generate data blobs - cl1 = np.array(np.random.multivariate_normal([1, 1, 1], - cov=cov1, - size=ns[0])) - cl2 = np.array(np.random.multivariate_normal([7, 7, 7], - cov=cov2, - size=ns[1])) - cl3 = np.array(np.random.multivariate_normal([2, 9, 9], - cov=cov3, - size=ns[2])) - # assemble dataset - X = np.vstack((cl1, cl2, cl3)) - y = np.hstack((np.array([1] * ns[0]), - np.array([2] * ns[1]), - np.array([3] * ns[2]))) - - r = ROSE(random_state=RND_SEED) - res, lab = r.fit_resample(X, y) - - # compute and check similarity of covariance matrices - res_cov1 = np.cov(res[lab == 1], rowvar=False) - res_cov2 = np.cov(res[lab == 2], rowvar=False) - res_cov3 = np.cov(res[lab == 3], rowvar=False) - - assert res_cov1.shape == cov1.shape - assert res_cov2.shape == cov2.shape - assert res_cov3.shape == cov3.shape - - -def test_rose_resampler(): - """Test ROSE resampled data matches expectation.""" - - RND_SEED = 0 - X = np.array( - [ - [0.11622591, -0.0317206], - [0.77481731, 0.60935141], - [1.25192108, -0.22367336], - [0.53366841, -0.30312976], - [1.52091956, -0.49283504], - [-0.28162401, -2.10400981], - [0.83680821, 1.72827342], - [0.3084254, 0.33299982], - [0.70472253, -0.73309052], - [0.28893132, -0.38761769], - [1.15514042, 0.0129463], - [0.88407872, 0.35454207], - [1.31301027, -0.92648734], - [-1.11515198, -0.93689695], - [-0.18410027, -0.45194484], - [0.9281014, 0.53085498], - [-0.14374509, 0.27370049], - [-0.41635887, -0.38299653], - [0.08711622, 0.93259929], - [1.70580611, -0.11219234], - ] - ) - Y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0]) - R_TOL = 1e-4 - - rose = ROSE(random_state=RND_SEED) - X_resampled, y_resampled = rose.fit_resample(X, Y) - - X_gt = np.array( - [ - [0.11622591, -0.0317206], - [0.77481731, 0.60935141], - [1.25192108, -0.22367336], - [0.53366841, -0.30312976], - [1.52091956, -0.49283504], - [-0.28162401, -2.10400981], - [0.83680821, 1.72827342], - [0.3084254, 0.33299982], - [0.70472253, -0.73309052], - [0.28893132, -0.38761769], - [1.15514042, 0.0129463], - [0.88407872, 0.35454207], - [1.31301027, -0.92648734], - [-1.11515198, -0.93689695], - [-0.18410027, -0.45194484], - [0.9281014, 0.53085498], - [-0.14374509, 0.27370049], - [-0.41635887, -0.38299653], - [0.08711622, 0.93259929], - [1.70580611, -0.11219234], - [1.39400832, 0.94383454], - [2.67881738, -0.36918919], - [1.80801323, -0.96629007], - [0.06244814, 0.07625536], - ] - ) - - y_gt = np.array( - [0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, - 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0] - ) - assert_allclose(X_resampled, X_gt, rtol=R_TOL) - assert_array_equal(y_resampled, y_gt) diff --git a/maint_tools/test_docstring.py b/maint_tools/test_docstring.py index 7bf8424fb..c9f802c4a 100644 --- a/maint_tools/test_docstring.py +++ b/maint_tools/test_docstring.py @@ -58,9 +58,6 @@ "Pipeline.fit_transform", "Pipeline.fit_resample", "Pipeline.fit_predict", - "ROSE$", - "ROSE.", - "ROSE", "RUSBoostClassifier$", "RUSBoostClassifier.", "RandomOverSampler$", diff --git a/setup.cfg b/setup.cfg index d5cbdb514..f55a13f95 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,13 +2,13 @@ current_version = 0.7.0 tag = False parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\.(?P[a-z]+)(?P\d+))? -serialize = +serialize = {major}.{minor}.{patch}.{release}{dev} {major}.{minor}.{patch} [bumpversion:part:release] optional_value = gamma -values = +values = dev gamma @@ -21,7 +21,7 @@ test = pytest [tool:pytest] doctest_optionflags = NORMALIZE_WHITESPACE ELLIPSIS -addopts = +addopts = --ignore build_tools --ignore benchmarks --ignore doc @@ -29,11 +29,14 @@ addopts = --ignore maint_tools --doctest-modules -rs -filterwarnings = +filterwarnings = ignore:the matrix subclass:PendingDeprecationWarning [flake8] max-line-length = 88 -ignore = - E203, # space before : (needed for how black formats slicing) - W503 # line break before binary operator \ No newline at end of file +# Default flake8 3.5 ignored flags +ignore=E121,E123,E126,E226,E24,E704,W503,W504,E203 +# It's fine not to put the import at the top of the file in the examples +# folder. +per-file-ignores = + examples/*: E402