Skip to content

Commit 7914e0c

Browse files
Refactor SMOTEN and SMOTENC to be more unified
1 parent c2fe8c2 commit 7914e0c

File tree

1 file changed

+25
-64
lines changed

1 file changed

+25
-64
lines changed

imblearn/over_sampling/_smote.py

Lines changed: 25 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,15 @@ def _fit_resample(self, X, y):
986986
self.n_features_ = X.shape[1]
987987
self._validate_estimator()
988988

989+
X_encoded = self._encode(X, y)
990+
991+
X_resampled, y_resampled = super(SMOTENC, self)._fit_resample(
992+
X_encoded, y)
993+
X_resampled = self._decode(X, X_resampled)
994+
995+
return X_resampled, y_resampled
996+
997+
def _encode(self, X, y):
989998
# compute the median of the standard deviation of the minority class
990999
target_stats = Counter(y)
9911000
class_minority = min(target_stats, key=target_stats.get)
@@ -1015,18 +1024,15 @@ def _fit_resample(self, X, y):
10151024
X_ohe = self.ohe_.fit_transform(
10161025
X_categorical.toarray() if sparse.issparse(X_categorical)
10171026
else X_categorical)
1018-
10191027
# we can replace the 1 entries of the categorical features with the
10201028
# median of the standard deviation. It will ensure that whenever
10211029
# distance is computed between 2 samples, the difference will be equal
10221030
# to the median of the standard deviation as in the original paper.
10231031
X_ohe.data = (np.ones_like(X_ohe.data, dtype=X_ohe.dtype) *
10241032
self.median_std_ / 2)
1025-
X_encoded = sparse.hstack((X_continuous, X_ohe), format='csr')
1026-
1027-
X_resampled, y_resampled = super(SMOTENC, self)._fit_resample(
1028-
X_encoded, y)
1033+
return sparse.hstack((X_continuous, X_ohe), format='csr')
10291034

1035+
def _decode(self, X, X_resampled):
10301036
# reverse the encoding of the categorical features
10311037
X_res_cat = X_resampled[:, self.continuous_features_.size:]
10321038
X_res_cat.data = np.ones_like(X_res_cat.data)
@@ -1055,8 +1061,7 @@ def _fit_resample(self, X, y):
10551061
X_resampled.indices = col_indices
10561062
else:
10571063
X_resampled = X_resampled[:, indices_reordered]
1058-
1059-
return X_resampled, y_resampled
1064+
return X_resampled
10601065

10611066
def _generate_sample(self, X, nn_data, nn_num, row, col, step):
10621067
"""Generate a synthetic sample with an additional steps for the
@@ -1095,7 +1100,7 @@ def _generate_sample(self, X, nn_data, nn_num, row, col, step):
10951100
# @Substitution(
10961101
# sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
10971102
# random_state=_random_state_docstring)
1098-
class SMOTEN(SMOTE):
1103+
class SMOTEN(SMOTENC):
10991104
"""Synthetic Minority Over-sampling Technique for Nominal data
11001105
(SMOTE-N).
11011106
@@ -1212,73 +1217,29 @@ class SMOTEN(SMOTE):
12121217

12131218
def __init__(self, sampling_strategy='auto', kind='regular',
12141219
random_state=None, k_neighbors=5, n_jobs=1):
1215-
super(SMOTEN, self).__init__(sampling_strategy=sampling_strategy,
1220+
super(SMOTEN, self).__init__(categorical_features=[],
1221+
sampling_strategy=sampling_strategy,
12161222
random_state=random_state,
12171223
k_neighbors=k_neighbors,
1218-
ratio=None,
1219-
n_jobs=n_jobs,
1220-
kind=kind)
1221-
1222-
@staticmethod
1223-
def _check_X_y(X, y):
1224-
"""Overwrite the checking to let pass some string for categorical
1225-
features.
1226-
"""
1227-
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
1228-
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'], dtype=None)
1229-
return X, y, binarize_y
1224+
n_jobs=n_jobs)
12301225

12311226
def _validate_estimator(self):
1227+
self.categorical_features = np.asarray(range(self.n_features_))
1228+
self.continuous_features_ = np.asarray([])
12321229
super(SMOTEN, self)._validate_estimator()
12331230

1234-
def _fit_resample(self, X, y):
1235-
self.n_features_ = X.shape[1]
1236-
self._validate_estimator()
1231+
def _decode(self, X, X_resampled):
1232+
X_unstacked = self.ohe_.inverse_transform(X_resampled)
1233+
if sparse.issparse(X):
1234+
X_unstacked = sparse.csr_matrix(X_unstacked)
1235+
return X_unstacked
12371236

1237+
def _encode(self, X, y):
12381238
self.ohe_ = OneHotEncoder(sparse=True, handle_unknown='ignore',
12391239
dtype=np.float64)
12401240
# the input of the OneHotEncoder needs to be dense
1241-
X_ohe = self.ohe_.fit_transform(
1241+
return self.ohe_.fit_transform(
12421242
X.toarray() if sparse.issparse(X)
12431243
else X)
12441244

1245-
X_resampled, y_resampled = super(SMOTEN, self)._fit_resample(
1246-
X_ohe, y)
1247-
X_resampled = self.ohe_.inverse_transform(X_resampled)
1248-
1249-
if sparse.issparse(X):
1250-
X_resampled = sparse.csr_matrix(X_resampled)
1251-
1252-
return X_resampled, y_resampled
1253-
1254-
def _generate_sample(self, X, nn_data, nn_num, row, col, step):
1255-
"""Generate a synthetic sample with an additional steps for the
1256-
categorical features.
12571245

1258-
Each new sample is generated the same way than in SMOTE. However, the
1259-
categorical features are mapped to the most frequent nearest neighbors
1260-
of the majority class.
1261-
"""
1262-
rng = check_random_state(self.random_state)
1263-
sample = super(SMOTEN, self)._generate_sample(X, nn_data, nn_num,
1264-
row, col, step)
1265-
# To avoid conversion and since there is only few samples used, we
1266-
# convert those samples to dense array.
1267-
sample = (sample.toarray().squeeze()
1268-
if sparse.issparse(sample) else sample)
1269-
all_neighbors = nn_data[nn_num[row]]
1270-
all_neighbors = (all_neighbors.toarray()
1271-
if sparse.issparse(all_neighbors) else all_neighbors)
1272-
1273-
categories_size = [cat.size for cat in self.ohe_.categories_]
1274-
1275-
for start_idx, end_idx in zip(np.cumsum(categories_size)[:-1],
1276-
np.cumsum(categories_size)[1:]):
1277-
col_max = all_neighbors[:, start_idx:end_idx].sum(axis=0)
1278-
# tie breaking argmax
1279-
col_sel = rng.choice(np.flatnonzero(
1280-
np.isclose(col_max, col_max.max())))
1281-
sample[start_idx:end_idx] = 0
1282-
sample[start_idx + col_sel] = 1
1283-
1284-
return sparse.csr_matrix(sample) if sparse.issparse(X) else sample

0 commit comments

Comments
 (0)