Skip to content

Commit 836fabd

Browse files
committed
FEA implement SMOTEN
1 parent 6155658 commit 836fabd

File tree

2 files changed

+69
-1
lines changed

2 files changed

+69
-1
lines changed

imblearn/over_sampling/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ._smote import KMeansSMOTE
1111
from ._smote import SVMSMOTE
1212
from ._smote import SMOTENC
13+
from ._smote import SMOTEN
1314

1415
__all__ = [
1516
"ADASYN",
@@ -19,4 +20,5 @@
1920
"BorderlineSMOTE",
2021
"SVMSMOTE",
2122
"SMOTENC",
23+
"SMOTEN",
2224
]

imblearn/over_sampling/_smote.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111

1212
import numpy as np
1313
from scipy import sparse
14+
from scipy import stats
1415

1516
from sklearn.base import clone
1617
from sklearn.cluster import MiniBatchKMeans
1718
from sklearn.metrics import pairwise_distances
18-
from sklearn.preprocessing import OneHotEncoder
19+
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
1920
from sklearn.svm import SVC
2021
from sklearn.utils import check_random_state
2122
from sklearn.utils import _safe_indexing
@@ -25,6 +26,7 @@
2526

2627
from .base import BaseOverSampler
2728
from ..exceptions import raise_isinstance_error
29+
from ..metrics.pairwise import ValueDifferenceMetric
2830
from ..utils import check_neighbors_object
2931
from ..utils import check_target_type
3032
from ..utils import Substitution
@@ -1293,3 +1295,67 @@ def _fit_resample(self, X, y):
12931295
y_resampled = np.hstack((y_resampled, y_new))
12941296

12951297
return X_resampled, y_resampled
1298+
1299+
1300+
class SMOTEN(SMOTE):
1301+
def _check_X_y(self, X, y):
1302+
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
1303+
X, y = self._validate_data(
1304+
X, y, reset=True, dtype=None, accept_sparse=["csr", "csc"]
1305+
)
1306+
return X, y, binarize_y
1307+
1308+
def _validate_estimator(self):
1309+
super()._validate_estimator()
1310+
self.nn_k_.set_params(metric="precomputed")
1311+
1312+
def _make_samples(self, X_class, klass, y_dtype, nn_indices, n_samples):
1313+
random_state = check_random_state(self.random_state)
1314+
# generate sample indices that will be used to generate new samples
1315+
samples_indices = random_state.choice(
1316+
np.arange(X_class.shape[0]), size=n_samples, replace=True
1317+
)
1318+
X_new = np.empty(shape=(n_samples, X_class.shape[1]), dtype=X_class.dtype)
1319+
for idx, sample_idx in enumerate(samples_indices):
1320+
X_new[idx, :] = stats.mode(X_class[nn_indices[sample_idx]], axis=0).mode
1321+
y_new = np.full(n_samples, fill_value=klass, dtype=y_dtype)
1322+
return X_new, y_new
1323+
1324+
def _fit_resample(self, X, y):
1325+
self._validate_estimator()
1326+
1327+
X_resampled = [X.copy()]
1328+
y_resampled = [y.copy()]
1329+
1330+
encoder = OrdinalEncoder(dtype=np.int32)
1331+
X_encoded = encoder.fit_transform(X)
1332+
1333+
vdm = ValueDifferenceMetric(
1334+
n_categories=[len(cat) for cat in encoder.categories_]
1335+
).fit(X_encoded, y)
1336+
1337+
for class_sample, n_samples in self.sampling_strategy_.items():
1338+
if n_samples == 0:
1339+
continue
1340+
target_class_indices = np.flatnonzero(y == class_sample)
1341+
X_class = _safe_indexing(X_encoded, target_class_indices)
1342+
1343+
X_class_dist = vdm.pairwise(X_class)
1344+
self.nn_k_.fit(X_class_dist)
1345+
# should countain the point itself
1346+
nn_indices = self.nn_k_.kneighbors(X_class_dist, return_distance=False)
1347+
X_new, y_new = self._make_samples(
1348+
X_class, class_sample, y.dtype, nn_indices, n_samples
1349+
)
1350+
1351+
X_new = encoder.inverse_transform(X_new)
1352+
X_resampled.append(X_new)
1353+
y_resampled.append(y_new)
1354+
1355+
if sparse.issparse(X):
1356+
X_resampled = sparse.vstack(X_resampled, format=X.format)
1357+
else:
1358+
X_resampled = np.vstack(X_resampled)
1359+
y_resampled = np.hstack(y_resampled)
1360+
1361+
return X_resampled, y_resampled

0 commit comments

Comments
 (0)