|
| 1 | +from collections import Counter |
| 2 | + |
| 3 | +import pytest |
| 4 | +import numpy as np |
| 5 | + |
| 6 | +from imblearn.over_sampling import ( |
| 7 | + ADASYN, |
| 8 | + BorderlineSMOTE, |
| 9 | + KMeansSMOTE, |
| 10 | + SMOTE, |
| 11 | + SMOTEN, |
| 12 | + SMOTENC, |
| 13 | + SVMSMOTE, |
| 14 | +) |
| 15 | +from imblearn.utils.testing import CustomNearestNeighbors |
| 16 | + |
| 17 | + |
| 18 | +@pytest.fixture |
| 19 | +def numerical_data(): |
| 20 | + rng = np.random.RandomState(0) |
| 21 | + X = rng.randn(100, 2) |
| 22 | + y = np.repeat([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0], 5) |
| 23 | + |
| 24 | + return X, y |
| 25 | + |
| 26 | + |
| 27 | +@pytest.fixture |
| 28 | +def categorical_data(): |
| 29 | + rng = np.random.RandomState(0) |
| 30 | + |
| 31 | + feature_1 = ["A"] * 10 + ["B"] * 20 + ["C"] * 30 |
| 32 | + feature_2 = ["A"] * 40 + ["B"] * 20 |
| 33 | + feature_3 = ["A"] * 20 + ["B"] * 20 + ["C"] * 10 + ["D"] * 10 |
| 34 | + X = np.array([feature_1, feature_2, feature_3], dtype=object).T |
| 35 | + rng.shuffle(X) |
| 36 | + y = np.array([0] * 20 + [1] * 40, dtype=np.int32) |
| 37 | + y_labels = np.array(["not apple", "apple"], dtype=object) |
| 38 | + y = y_labels[y] |
| 39 | + return X, y |
| 40 | + |
| 41 | + |
| 42 | +@pytest.fixture |
| 43 | +def heterogeneous_data(): |
| 44 | + rng = np.random.RandomState(42) |
| 45 | + X = np.empty((30, 4), dtype=object) |
| 46 | + X[:, :2] = rng.randn(30, 2) |
| 47 | + X[:, 2] = rng.choice(["a", "b", "c"], size=30).astype(object) |
| 48 | + X[:, 3] = rng.randint(3, size=30) |
| 49 | + y = np.array([0] * 10 + [1] * 20) |
| 50 | + return X, y, [2, 3] |
| 51 | + |
| 52 | + |
| 53 | +@pytest.mark.parametrize( |
| 54 | + "smote", [BorderlineSMOTE(), SVMSMOTE()], ids=["borderline", "svm"] |
| 55 | +) |
| 56 | +def test_smote_m_neighbors(numerical_data, smote): |
| 57 | + # check that m_neighbors is properly set. Regression test for: |
| 58 | + # https://github.com/scikit-learn-contrib/imbalanced-learn/issues/568 |
| 59 | + X, y = numerical_data |
| 60 | + _ = smote.fit_resample(X, y) |
| 61 | + assert smote.nn_k_.n_neighbors == 6 |
| 62 | + assert smote.nn_m_.n_neighbors == 11 |
| 63 | + |
| 64 | + |
| 65 | +@pytest.mark.parametrize( |
| 66 | + "smote, neighbor_estimator_name", |
| 67 | + [ |
| 68 | + (ADASYN(random_state=0), "n_neighbors"), |
| 69 | + (BorderlineSMOTE(random_state=0), "k_neighbors"), |
| 70 | + (KMeansSMOTE(random_state=1), "k_neighbors"), |
| 71 | + (SMOTE(random_state=0), "k_neighbors"), |
| 72 | + (SVMSMOTE(random_state=0), "k_neighbors"), |
| 73 | + ], |
| 74 | + ids=["adasyn", "borderline", "kmeans", "smote", "svm"], |
| 75 | +) |
| 76 | +def test_numerical_smote_custom_nn(numerical_data, smote, neighbor_estimator_name): |
| 77 | + X, y = numerical_data |
| 78 | + params = { |
| 79 | + neighbor_estimator_name: CustomNearestNeighbors(n_neighbors=5), |
| 80 | + } |
| 81 | + smote.set_params(**params) |
| 82 | + X_res, _ = smote.fit_resample(X, y) |
| 83 | + |
| 84 | + assert X_res.shape[0] >= 120 |
| 85 | + |
| 86 | + |
| 87 | +def test_categorical_smote_k_custom_nn(categorical_data): |
| 88 | + X, y = categorical_data |
| 89 | + smote = SMOTEN(k_neighbors=CustomNearestNeighbors(n_neighbors=5)) |
| 90 | + X_res, y_res = smote.fit_resample(X, y) |
| 91 | + |
| 92 | + assert X_res.shape == (80, 3) |
| 93 | + assert Counter(y_res) == {"apple": 40, "not apple": 40} |
| 94 | + |
| 95 | + |
| 96 | +def test_heterogeneous_smote_k_custom_nn(heterogeneous_data): |
| 97 | + X, y, categorical_features = heterogeneous_data |
| 98 | + smote = SMOTENC( |
| 99 | + categorical_features, k_neighbors=CustomNearestNeighbors(n_neighbors=5) |
| 100 | + ) |
| 101 | + X_res, y_res = smote.fit_resample(X, y) |
| 102 | + |
| 103 | + assert X_res.shape == (40, 4) |
| 104 | + assert Counter(y_res) == {0: 20, 1: 20} |
| 105 | + |
| 106 | + |
| 107 | +@pytest.mark.parametrize( |
| 108 | + "smote", |
| 109 | + [BorderlineSMOTE(random_state=0), SVMSMOTE(random_state=0)], |
| 110 | + ids=["borderline", "svm"], |
| 111 | +) |
| 112 | +def test_numerical_smote_extra_custom_nn(numerical_data, smote): |
| 113 | + X, y = numerical_data |
| 114 | + smote.set_params(m_neighbors=CustomNearestNeighbors(n_neighbors=5)) |
| 115 | + X_res, y_res = smote.fit_resample(X, y) |
| 116 | + |
| 117 | + assert X_res.shape == (120, 2) |
| 118 | + assert Counter(y_res) == {0: 60, 1: 60} |
0 commit comments