Skip to content

Commit 18b6057

Browse files
committed
iter
1 parent 48d1fd5 commit 18b6057

File tree

1 file changed

+118
-0
lines changed

1 file changed

+118
-0
lines changed
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from collections import Counter
2+
3+
import pytest
4+
import numpy as np
5+
6+
from imblearn.over_sampling import (
7+
ADASYN,
8+
BorderlineSMOTE,
9+
KMeansSMOTE,
10+
SMOTE,
11+
SMOTEN,
12+
SMOTENC,
13+
SVMSMOTE,
14+
)
15+
from imblearn.utils.testing import CustomNearestNeighbors
16+
17+
18+
@pytest.fixture
19+
def numerical_data():
20+
rng = np.random.RandomState(0)
21+
X = rng.randn(100, 2)
22+
y = np.repeat([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0], 5)
23+
24+
return X, y
25+
26+
27+
@pytest.fixture
28+
def categorical_data():
29+
rng = np.random.RandomState(0)
30+
31+
feature_1 = ["A"] * 10 + ["B"] * 20 + ["C"] * 30
32+
feature_2 = ["A"] * 40 + ["B"] * 20
33+
feature_3 = ["A"] * 20 + ["B"] * 20 + ["C"] * 10 + ["D"] * 10
34+
X = np.array([feature_1, feature_2, feature_3], dtype=object).T
35+
rng.shuffle(X)
36+
y = np.array([0] * 20 + [1] * 40, dtype=np.int32)
37+
y_labels = np.array(["not apple", "apple"], dtype=object)
38+
y = y_labels[y]
39+
return X, y
40+
41+
42+
@pytest.fixture
43+
def heterogeneous_data():
44+
rng = np.random.RandomState(42)
45+
X = np.empty((30, 4), dtype=object)
46+
X[:, :2] = rng.randn(30, 2)
47+
X[:, 2] = rng.choice(["a", "b", "c"], size=30).astype(object)
48+
X[:, 3] = rng.randint(3, size=30)
49+
y = np.array([0] * 10 + [1] * 20)
50+
return X, y, [2, 3]
51+
52+
53+
@pytest.mark.parametrize(
54+
"smote", [BorderlineSMOTE(), SVMSMOTE()], ids=["borderline", "svm"]
55+
)
56+
def test_smote_m_neighbors(numerical_data, smote):
57+
# check that m_neighbors is properly set. Regression test for:
58+
# https://github.com/scikit-learn-contrib/imbalanced-learn/issues/568
59+
X, y = numerical_data
60+
_ = smote.fit_resample(X, y)
61+
assert smote.nn_k_.n_neighbors == 6
62+
assert smote.nn_m_.n_neighbors == 11
63+
64+
65+
@pytest.mark.parametrize(
66+
"smote, neighbor_estimator_name",
67+
[
68+
(ADASYN(random_state=0), "n_neighbors"),
69+
(BorderlineSMOTE(random_state=0), "k_neighbors"),
70+
(KMeansSMOTE(random_state=1), "k_neighbors"),
71+
(SMOTE(random_state=0), "k_neighbors"),
72+
(SVMSMOTE(random_state=0), "k_neighbors"),
73+
],
74+
ids=["adasyn", "borderline", "kmeans", "smote", "svm"],
75+
)
76+
def test_numerical_smote_custom_nn(numerical_data, smote, neighbor_estimator_name):
77+
X, y = numerical_data
78+
params = {
79+
neighbor_estimator_name: CustomNearestNeighbors(n_neighbors=5),
80+
}
81+
smote.set_params(**params)
82+
X_res, _ = smote.fit_resample(X, y)
83+
84+
assert X_res.shape[0] >= 120
85+
86+
87+
def test_categorical_smote_k_custom_nn(categorical_data):
88+
X, y = categorical_data
89+
smote = SMOTEN(k_neighbors=CustomNearestNeighbors(n_neighbors=5))
90+
X_res, y_res = smote.fit_resample(X, y)
91+
92+
assert X_res.shape == (80, 3)
93+
assert Counter(y_res) == {"apple": 40, "not apple": 40}
94+
95+
96+
def test_heterogeneous_smote_k_custom_nn(heterogeneous_data):
97+
X, y, categorical_features = heterogeneous_data
98+
smote = SMOTENC(
99+
categorical_features, k_neighbors=CustomNearestNeighbors(n_neighbors=5)
100+
)
101+
X_res, y_res = smote.fit_resample(X, y)
102+
103+
assert X_res.shape == (40, 4)
104+
assert Counter(y_res) == {0: 20, 1: 20}
105+
106+
107+
@pytest.mark.parametrize(
108+
"smote",
109+
[BorderlineSMOTE(random_state=0), SVMSMOTE(random_state=0)],
110+
ids=["borderline", "svm"],
111+
)
112+
def test_numerical_smote_extra_custom_nn(numerical_data, smote):
113+
X, y = numerical_data
114+
smote.set_params(m_neighbors=CustomNearestNeighbors(n_neighbors=5))
115+
X_res, y_res = smote.fit_resample(X, y)
116+
117+
assert X_res.shape == (120, 2)
118+
assert Counter(y_res) == {0: 60, 1: 60}

0 commit comments

Comments
 (0)