|
11 | 11 |
|
12 | 12 | import numpy as np |
13 | 13 | from scipy import sparse |
| 14 | +from scipy import stats |
14 | 15 |
|
15 | 16 | from sklearn.base import clone |
16 | 17 | from sklearn.cluster import MiniBatchKMeans |
17 | 18 | from sklearn.metrics import pairwise_distances |
18 | | -from sklearn.preprocessing import OneHotEncoder |
| 19 | +from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder |
19 | 20 | from sklearn.svm import SVC |
20 | 21 | from sklearn.utils import check_random_state |
21 | 22 | from sklearn.utils import _safe_indexing |
|
25 | 26 |
|
26 | 27 | from .base import BaseOverSampler |
27 | 28 | from ..exceptions import raise_isinstance_error |
| 29 | +from ..metrics.pairwise import ValueDifferenceMetric |
28 | 30 | from ..utils import check_neighbors_object |
29 | 31 | from ..utils import check_target_type |
30 | 32 | from ..utils import Substitution |
@@ -1293,3 +1295,67 @@ def _fit_resample(self, X, y): |
1293 | 1295 | y_resampled = np.hstack((y_resampled, y_new)) |
1294 | 1296 |
|
1295 | 1297 | return X_resampled, y_resampled |
| 1298 | + |
| 1299 | + |
| 1300 | +class SMOTEN(SMOTE): |
| 1301 | + def _check_X_y(self, X, y): |
| 1302 | + y, binarize_y = check_target_type(y, indicate_one_vs_all=True) |
| 1303 | + X, y = self._validate_data( |
| 1304 | + X, y, reset=True, dtype=None, accept_sparse=["csr", "csc"] |
| 1305 | + ) |
| 1306 | + return X, y, binarize_y |
| 1307 | + |
| 1308 | + def _validate_estimator(self): |
| 1309 | + super()._validate_estimator() |
| 1310 | + self.nn_k_.set_params(metric="precomputed") |
| 1311 | + |
| 1312 | + def _make_samples(self, X_class, klass, y_dtype, nn_indices, n_samples): |
| 1313 | + random_state = check_random_state(self.random_state) |
| 1314 | + # generate sample indices that will be used to generate new samples |
| 1315 | + samples_indices = random_state.choice( |
| 1316 | + np.arange(X_class.shape[0]), size=n_samples, replace=True |
| 1317 | + ) |
| 1318 | + X_new = np.empty(shape=(n_samples, X_class.shape[1]), dtype=X_class.dtype) |
| 1319 | + for idx, sample_idx in enumerate(samples_indices): |
| 1320 | + X_new[idx, :] = stats.mode(X_class[nn_indices[sample_idx]], axis=0).mode |
| 1321 | + y_new = np.full(n_samples, fill_value=klass, dtype=y_dtype) |
| 1322 | + return X_new, y_new |
| 1323 | + |
| 1324 | + def _fit_resample(self, X, y): |
| 1325 | + self._validate_estimator() |
| 1326 | + |
| 1327 | + X_resampled = [X.copy()] |
| 1328 | + y_resampled = [y.copy()] |
| 1329 | + |
| 1330 | + encoder = OrdinalEncoder(dtype=np.int32) |
| 1331 | + X_encoded = encoder.fit_transform(X) |
| 1332 | + |
| 1333 | + vdm = ValueDifferenceMetric( |
| 1334 | + n_categories=[len(cat) for cat in encoder.categories_] |
| 1335 | + ).fit(X_encoded, y) |
| 1336 | + |
| 1337 | + for class_sample, n_samples in self.sampling_strategy_.items(): |
| 1338 | + if n_samples == 0: |
| 1339 | + continue |
| 1340 | + target_class_indices = np.flatnonzero(y == class_sample) |
| 1341 | + X_class = _safe_indexing(X_encoded, target_class_indices) |
| 1342 | + |
| 1343 | + X_class_dist = vdm.pairwise(X_class) |
| 1344 | + self.nn_k_.fit(X_class_dist) |
| 1345 | + # should countain the point itself |
| 1346 | + nn_indices = self.nn_k_.kneighbors(X_class_dist, return_distance=False) |
| 1347 | + X_new, y_new = self._make_samples( |
| 1348 | + X_class, class_sample, y.dtype, nn_indices, n_samples |
| 1349 | + ) |
| 1350 | + |
| 1351 | + X_new = encoder.inverse_transform(X_new) |
| 1352 | + X_resampled.append(X_new) |
| 1353 | + y_resampled.append(y_new) |
| 1354 | + |
| 1355 | + if sparse.issparse(X): |
| 1356 | + X_resampled = sparse.vstack(X_resampled, format=X.format) |
| 1357 | + else: |
| 1358 | + X_resampled = np.vstack(X_resampled) |
| 1359 | + y_resampled = np.hstack(y_resampled) |
| 1360 | + |
| 1361 | + return X_resampled, y_resampled |
0 commit comments