Skip to content

Commit bc96223

Browse files
Refactor SMOTEN and SMOTENC to be more unified
1 parent c2fe8c2 commit bc96223

File tree

1 file changed

+27
-63
lines changed

1 file changed

+27
-63
lines changed

imblearn/over_sampling/_smote.py

Lines changed: 27 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -950,11 +950,13 @@ class SMOTENC(SMOTE):
950950
951951
"""
952952

953-
def __init__(self, categorical_features, sampling_strategy='auto',
953+
def __init__(self, categorical_features, sampling_strategy='auto', kind='regular',
954954
random_state=None, k_neighbors=5, n_jobs=1):
955955
super(SMOTENC, self).__init__(sampling_strategy=sampling_strategy,
956956
random_state=random_state,
957957
k_neighbors=k_neighbors,
958+
n_jobs=n_jobs,
959+
kind=kind,
958960
ratio=None)
959961
self.categorical_features = categorical_features
960962

@@ -986,6 +988,15 @@ def _fit_resample(self, X, y):
986988
self.n_features_ = X.shape[1]
987989
self._validate_estimator()
988990

991+
X_encoded = self._encode(X, y)
992+
993+
X_resampled, y_resampled = super(SMOTENC, self)._fit_resample(
994+
X_encoded, y)
995+
X_resampled = self._decode(X, X_resampled)
996+
997+
return X_resampled, y_resampled
998+
999+
def _encode(self, X, y):
9891000
# compute the median of the standard deviation of the minority class
9901001
target_stats = Counter(y)
9911002
class_minority = min(target_stats, key=target_stats.get)
@@ -1015,18 +1026,15 @@ def _fit_resample(self, X, y):
10151026
X_ohe = self.ohe_.fit_transform(
10161027
X_categorical.toarray() if sparse.issparse(X_categorical)
10171028
else X_categorical)
1018-
10191029
# we can replace the 1 entries of the categorical features with the
10201030
# median of the standard deviation. It will ensure that whenever
10211031
# distance is computed between 2 samples, the difference will be equal
10221032
# to the median of the standard deviation as in the original paper.
10231033
X_ohe.data = (np.ones_like(X_ohe.data, dtype=X_ohe.dtype) *
10241034
self.median_std_ / 2)
1025-
X_encoded = sparse.hstack((X_continuous, X_ohe), format='csr')
1026-
1027-
X_resampled, y_resampled = super(SMOTENC, self)._fit_resample(
1028-
X_encoded, y)
1035+
return sparse.hstack((X_continuous, X_ohe), format='csr')
10291036

1037+
def _decode(self, X, X_resampled):
10301038
# reverse the encoding of the categorical features
10311039
X_res_cat = X_resampled[:, self.continuous_features_.size:]
10321040
X_res_cat.data = np.ones_like(X_res_cat.data)
@@ -1055,8 +1063,7 @@ def _fit_resample(self, X, y):
10551063
X_resampled.indices = col_indices
10561064
else:
10571065
X_resampled = X_resampled[:, indices_reordered]
1058-
1059-
return X_resampled, y_resampled
1066+
return X_resampled
10601067

10611068
def _generate_sample(self, X, nn_data, nn_num, row, col, step):
10621069
"""Generate a synthetic sample with an additional steps for the
@@ -1095,7 +1102,7 @@ def _generate_sample(self, X, nn_data, nn_num, row, col, step):
10951102
# @Substitution(
10961103
# sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
10971104
# random_state=_random_state_docstring)
1098-
class SMOTEN(SMOTE):
1105+
class SMOTEN(SMOTENC):
10991106
"""Synthetic Minority Over-sampling Technique for Nominal data
11001107
(SMOTE-N).
11011108
@@ -1212,73 +1219,30 @@ class SMOTEN(SMOTE):
12121219

12131220
def __init__(self, sampling_strategy='auto', kind='regular',
12141221
random_state=None, k_neighbors=5, n_jobs=1):
1215-
super(SMOTEN, self).__init__(sampling_strategy=sampling_strategy,
1222+
super(SMOTEN, self).__init__(categorical_features=[],
1223+
sampling_strategy=sampling_strategy,
12161224
random_state=random_state,
12171225
k_neighbors=k_neighbors,
1218-
ratio=None,
12191226
n_jobs=n_jobs,
12201227
kind=kind)
12211228

1222-
@staticmethod
1223-
def _check_X_y(X, y):
1224-
"""Overwrite the checking to let pass some string for categorical
1225-
features.
1226-
"""
1227-
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
1228-
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'], dtype=None)
1229-
return X, y, binarize_y
1230-
12311229
def _validate_estimator(self):
1230+
self.categorical_features = np.asarray(range(self.n_features_))
1231+
self.continuous_features_ = np.asarray([])
12321232
super(SMOTEN, self)._validate_estimator()
12331233

1234-
def _fit_resample(self, X, y):
1235-
self.n_features_ = X.shape[1]
1236-
self._validate_estimator()
1234+
def _decode(self, X, X_resampled):
1235+
X_unstacked = self.ohe_.inverse_transform(X_resampled)
1236+
if sparse.issparse(X):
1237+
X_unstacked = sparse.csr_matrix(X_unstacked)
1238+
return X_unstacked
12371239

1240+
def _encode(self, X, y):
12381241
self.ohe_ = OneHotEncoder(sparse=True, handle_unknown='ignore',
12391242
dtype=np.float64)
12401243
# the input of the OneHotEncoder needs to be dense
1241-
X_ohe = self.ohe_.fit_transform(
1244+
return self.ohe_.fit_transform(
12421245
X.toarray() if sparse.issparse(X)
12431246
else X)
12441247

1245-
X_resampled, y_resampled = super(SMOTEN, self)._fit_resample(
1246-
X_ohe, y)
1247-
X_resampled = self.ohe_.inverse_transform(X_resampled)
1248-
1249-
if sparse.issparse(X):
1250-
X_resampled = sparse.csr_matrix(X_resampled)
1251-
1252-
return X_resampled, y_resampled
1253-
1254-
def _generate_sample(self, X, nn_data, nn_num, row, col, step):
1255-
"""Generate a synthetic sample with an additional steps for the
1256-
categorical features.
12571248

1258-
Each new sample is generated the same way than in SMOTE. However, the
1259-
categorical features are mapped to the most frequent nearest neighbors
1260-
of the majority class.
1261-
"""
1262-
rng = check_random_state(self.random_state)
1263-
sample = super(SMOTEN, self)._generate_sample(X, nn_data, nn_num,
1264-
row, col, step)
1265-
# To avoid conversion and since there is only few samples used, we
1266-
# convert those samples to dense array.
1267-
sample = (sample.toarray().squeeze()
1268-
if sparse.issparse(sample) else sample)
1269-
all_neighbors = nn_data[nn_num[row]]
1270-
all_neighbors = (all_neighbors.toarray()
1271-
if sparse.issparse(all_neighbors) else all_neighbors)
1272-
1273-
categories_size = [cat.size for cat in self.ohe_.categories_]
1274-
1275-
for start_idx, end_idx in zip(np.cumsum(categories_size)[:-1],
1276-
np.cumsum(categories_size)[1:]):
1277-
col_max = all_neighbors[:, start_idx:end_idx].sum(axis=0)
1278-
# tie breaking argmax
1279-
col_sel = rng.choice(np.flatnonzero(
1280-
np.isclose(col_max, col_max.max())))
1281-
sample[start_idx:end_idx] = 0
1282-
sample[start_idx + col_sel] = 1
1283-
1284-
return sparse.csr_matrix(sample) if sparse.issparse(X) else sample

0 commit comments

Comments
 (0)