diff --git a/doc/whats_new/v0.7.rst b/doc/whats_new/v0.7.rst index eb931a2e8..ab9fa3943 100644 --- a/doc/whats_new/v0.7.rst +++ b/doc/whats_new/v0.7.rst @@ -49,6 +49,11 @@ Bug fixes :class:`imblearn.over_sampling.SMOTENC`. :pr:`675` by :user:`bganglia `. +- Fix a bug in :class:`imblearn.under_sampling.ClusterCentroids` where + `voting="hard"` could have lead to select a sample from any class instead of + the targeted class. + :pr:`769` by :user:`Guillaume Lemaitre `. + Enhancements ............ diff --git a/imblearn/under_sampling/_prototype_generation/_cluster_centroids.py b/imblearn/under_sampling/_prototype_generation/_cluster_centroids.py index ebdbed27b..90df4d471 100644 --- a/imblearn/under_sampling/_prototype_generation/_cluster_centroids.py +++ b/imblearn/under_sampling/_prototype_generation/_cluster_centroids.py @@ -166,17 +166,20 @@ def _fit_resample(self, X, y): X_resampled, y_resampled = [], [] for target_class in np.unique(y): + target_class_indices = np.flatnonzero(y == target_class) if target_class in self.sampling_strategy_.keys(): n_samples = self.sampling_strategy_[target_class] self.estimator_.set_params(**{"n_clusters": n_samples}) - self.estimator_.fit(X[y == target_class]) + self.estimator_.fit(_safe_indexing(X, target_class_indices)) X_new, y_new = self._generate_sample( - X, y, self.estimator_.cluster_centers_, target_class + _safe_indexing(X, target_class_indices), + _safe_indexing(y, target_class_indices), + self.estimator_.cluster_centers_, + target_class, ) X_resampled.append(X_new) y_resampled.append(y_new) else: - target_class_indices = np.flatnonzero(y == target_class) X_resampled.append(_safe_indexing(X, target_class_indices)) y_resampled.append(_safe_indexing(y, target_class_indices)) diff --git a/imblearn/under_sampling/_prototype_generation/tests/test_cluster_centroids.py b/imblearn/under_sampling/_prototype_generation/tests/test_cluster_centroids.py index 6163afa64..aaffea261 100644 --- a/imblearn/under_sampling/_prototype_generation/tests/test_cluster_centroids.py +++ b/imblearn/under_sampling/_prototype_generation/tests/test_cluster_centroids.py @@ -6,6 +6,7 @@ from scipy import sparse from sklearn.cluster import KMeans +from sklearn.datasets import make_classification from imblearn.under_sampling import ClusterCentroids @@ -121,3 +122,37 @@ def test_cluster_centroids_n_jobs(): cc.fit_resample(X, Y) assert len(record) == 1 assert "'n_jobs' was deprecated" in record[0].message.args[0] + + +def test_cluster_centroids_hard_target_class(): + # check that the samples selecting by the hard voting corresponds to the + # targeted class + # non-regression test for: + # https://github.com/scikit-learn-contrib/imbalanced-learn/issues/738 + X, y = make_classification( + n_samples=1000, + n_features=2, + n_informative=1, + n_redundant=0, + n_repeated=0, + n_clusters_per_class=1, + weights=[0.3, 0.7], + class_sep=0.01, + random_state=0, + ) + + cc = ClusterCentroids(voting="hard", random_state=0) + X_res, y_res = cc.fit_resample(X, y) + + minority_class_indices = np.flatnonzero(y == 0) + X_minority_class = X[minority_class_indices] + + resampled_majority_class_indices = np.flatnonzero(y_res == 1) + X_res_majority = X_res[resampled_majority_class_indices] + + sample_from_minority_in_majority = [ + np.all(np.isclose(selected_sample, minority_sample)) + for selected_sample in X_res_majority + for minority_sample in X_minority_class + ] + assert sum(sample_from_minority_in_majority) == 0