Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats_new/v0.7.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ Bug fixes
:class:`imblearn.over_sampling.SMOTENC`.
:pr:`675` by :user:`bganglia <bganglia>`.

- Fix a bug in :class:`imblearn.under_sampling.ClusterCentroids` where
`voting="hard"` could have lead to select a sample from any class instead of
the targeted class.
:pr:`769` by :user:`Guillaume Lemaitre <glemaitre>`.

Enhancements
............

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,17 +166,20 @@ def _fit_resample(self, X, y):

X_resampled, y_resampled = [], []
for target_class in np.unique(y):
target_class_indices = np.flatnonzero(y == target_class)
if target_class in self.sampling_strategy_.keys():
n_samples = self.sampling_strategy_[target_class]
self.estimator_.set_params(**{"n_clusters": n_samples})
self.estimator_.fit(X[y == target_class])
self.estimator_.fit(_safe_indexing(X, target_class_indices))
X_new, y_new = self._generate_sample(
X, y, self.estimator_.cluster_centers_, target_class
_safe_indexing(X, target_class_indices),
_safe_indexing(y, target_class_indices),
self.estimator_.cluster_centers_,
target_class,
)
X_resampled.append(X_new)
y_resampled.append(y_new)
else:
target_class_indices = np.flatnonzero(y == target_class)
X_resampled.append(_safe_indexing(X, target_class_indices))
y_resampled.append(_safe_indexing(y, target_class_indices))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from scipy import sparse

from sklearn.cluster import KMeans
from sklearn.datasets import make_classification

from imblearn.under_sampling import ClusterCentroids

Expand Down Expand Up @@ -121,3 +122,37 @@ def test_cluster_centroids_n_jobs():
cc.fit_resample(X, Y)
assert len(record) == 1
assert "'n_jobs' was deprecated" in record[0].message.args[0]


def test_cluster_centroids_hard_target_class():
# check that the samples selecting by the hard voting corresponds to the
# targeted class
# non-regression test for:
# https://github.com/scikit-learn-contrib/imbalanced-learn/issues/738
X, y = make_classification(
n_samples=1000,
n_features=2,
n_informative=1,
n_redundant=0,
n_repeated=0,
n_clusters_per_class=1,
weights=[0.3, 0.7],
class_sep=0.01,
random_state=0,
)

cc = ClusterCentroids(voting="hard", random_state=0)
X_res, y_res = cc.fit_resample(X, y)

minority_class_indices = np.flatnonzero(y == 0)
X_minority_class = X[minority_class_indices]

resampled_majority_class_indices = np.flatnonzero(y_res == 1)
X_res_majority = X_res[resampled_majority_class_indices]

sample_from_minority_in_majority = [
np.all(np.isclose(selected_sample, minority_sample))
for selected_sample in X_res_majority
for minority_sample in X_minority_class
]
assert sum(sample_from_minority_in_majority) == 0