Skip to content

Commit 964d082

Browse files
committed
add test no dependent on cupy
1 parent d815e2d commit 964d082

File tree

4 files changed

+123
-58
lines changed

4 files changed

+123
-58
lines changed

imblearn/over_sampling/_smote/filter.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,6 @@ def _validate_estimator(self):
154154
self.nn_m_ = check_neighbors_object(
155155
"m_neighbors", self.m_neighbors, additional_neighbor=1
156156
)
157-
self.nn_m_.set_params(**{"n_jobs": self.n_jobs})
158157
if self.kind not in ("borderline-1", "borderline-2"):
159158
raise ValueError(
160159
f'The possible "kind" of algorithm are '
@@ -382,7 +381,6 @@ def _validate_estimator(self):
382381
self.nn_m_ = check_neighbors_object(
383382
"m_neighbors", self.m_neighbors, additional_neighbor=1
384383
)
385-
self.nn_m_.set_params(**{"n_jobs": self.n_jobs})
386384

387385
if self.svm_estimator is None:
388386
self.svm_estimator_ = SVC(gamma="scale", random_state=self.random_state)
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from collections import Counter
2+
3+
import pytest
4+
import numpy as np
5+
6+
from imblearn.over_sampling import (
7+
BorderlineSMOTE,
8+
KMeansSMOTE,
9+
SMOTE,
10+
SMOTEN,
11+
SMOTENC,
12+
SVMSMOTE,
13+
)
14+
from imblearn.utils.testing import CustomNearestNeighbors
15+
16+
17+
@pytest.fixture
18+
def numerical_data():
19+
rng = np.random.RandomState(0)
20+
X = rng.randn(100, 2)
21+
y = np.repeat([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0], 5)
22+
23+
return X, y
24+
25+
26+
@pytest.fixture
27+
def categorical_data():
28+
rng = np.random.RandomState(0)
29+
30+
feature_1 = ["A"] * 10 + ["B"] * 20 + ["C"] * 30
31+
feature_2 = ["A"] * 40 + ["B"] * 20
32+
feature_3 = ["A"] * 20 + ["B"] * 20 + ["C"] * 10 + ["D"] * 10
33+
X = np.array([feature_1, feature_2, feature_3], dtype=object).T
34+
rng.shuffle(X)
35+
y = np.array([0] * 20 + [1] * 40, dtype=np.int32)
36+
y_labels = np.array(["not apple", "apple"], dtype=object)
37+
y = y_labels[y]
38+
return X, y
39+
40+
41+
@pytest.fixture
42+
def heterogeneous_data():
43+
rng = np.random.RandomState(42)
44+
X = np.empty((30, 4), dtype=object)
45+
X[:, :2] = rng.randn(30, 2)
46+
X[:, 2] = rng.choice(["a", "b", "c"], size=30).astype(object)
47+
X[:, 3] = rng.randint(3, size=30)
48+
y = np.array([0] * 10 + [1] * 20)
49+
return X, y, [2, 3]
50+
51+
52+
@pytest.mark.parametrize(
53+
"smote", [BorderlineSMOTE(), SVMSMOTE()], ids=["borderline", "svm"]
54+
)
55+
def test_smote_m_neighbors(numerical_data, smote):
56+
# check that m_neighbors is properly set. Regression test for:
57+
# https://github.com/scikit-learn-contrib/imbalanced-learn/issues/568
58+
X, y = numerical_data
59+
_ = smote.fit_resample(X, y)
60+
assert smote.nn_k_.n_neighbors == 6
61+
assert smote.nn_m_.n_neighbors == 11
62+
63+
64+
@pytest.mark.parametrize(
65+
"smote",
66+
[
67+
BorderlineSMOTE(random_state=0),
68+
KMeansSMOTE(random_state=1),
69+
SMOTE(random_state=0),
70+
SVMSMOTE(random_state=0),
71+
],
72+
ids=["borderline", "kmeans", "smote", "svm"],
73+
)
74+
def test_numerical_smote_k_custom_nn(numerical_data, smote):
75+
X, y = numerical_data
76+
smote.set_params(k_neighbors=CustomNearestNeighbors(n_neighbors=5))
77+
X_res, y_res = smote.fit_resample(X, y)
78+
79+
assert X_res.shape == (120, 2)
80+
assert Counter(y_res) == {0: 60, 1: 60}
81+
82+
83+
def test_categorical_smote_k_custom_nn(categorical_data):
84+
X, y = categorical_data
85+
smote = SMOTEN(k_neighbors=CustomNearestNeighbors(n_neighbors=5))
86+
X_res, y_res = smote.fit_resample(X, y)
87+
88+
assert X_res.shape == (80, 3)
89+
assert Counter(y_res) == {"apple": 40, "not apple": 40}
90+
91+
92+
def test_heterogeneous_smote_k_custom_nn(heterogeneous_data):
93+
X, y, categorical_features = heterogeneous_data
94+
smote = SMOTENC(
95+
categorical_features, k_neighbors=CustomNearestNeighbors(n_neighbors=5)
96+
)
97+
X_res, y_res = smote.fit_resample(X, y)
98+
99+
assert X_res.shape == (40, 4)
100+
assert Counter(y_res) == {0: 20, 1: 20}
101+
102+
103+
@pytest.mark.parametrize(
104+
"smote",
105+
[BorderlineSMOTE(random_state=0), SVMSMOTE(random_state=0)],
106+
ids=["borderline", "svm"],
107+
)
108+
def test_numerical_smote_extra_custom_nn(numerical_data, smote):
109+
X, y = numerical_data
110+
smote.set_params(m_neighbors=CustomNearestNeighbors(n_neighbors=5))
111+
X_res, y_res = smote.fit_resample(X, y)
112+
113+
assert X_res.shape == (120, 2)
114+
assert Counter(y_res) == {0: 60, 1: 60}

imblearn/over_sampling/_smote/tests/test_smote.py

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,12 @@
44
# License: MIT
55

66
import numpy as np
7-
import pytest
87

98
from sklearn.utils._testing import assert_allclose
109
from sklearn.utils._testing import assert_array_equal
1110
from sklearn.neighbors import NearestNeighbors
1211

1312
from imblearn.over_sampling import SMOTE
14-
from imblearn.over_sampling import SVMSMOTE
15-
from imblearn.over_sampling import BorderlineSMOTE
1613

1714

1815
RND_SEED = 0
@@ -153,54 +150,3 @@ def test_sample_regular_with_nn():
153150
)
154151
assert_allclose(X_resampled, X_gt, rtol=R_TOL)
155152
assert_array_equal(y_resampled, y_gt)
156-
157-
158-
@pytest.mark.parametrize(
159-
"smote", [BorderlineSMOTE(), SVMSMOTE()], ids=["borderline", "svm"]
160-
)
161-
def test_smote_m_neighbors(smote):
162-
# check that m_neighbors is properly set. Regression test for:
163-
# https://github.com/scikit-learn-contrib/imbalanced-learn/issues/568
164-
_ = smote.fit_resample(X, Y)
165-
assert smote.nn_k_.n_neighbors == 6
166-
assert smote.nn_m_.n_neighbors == 11
167-
168-
169-
def test_sample_cuml_with_nn():
170-
cuml = pytest.importorskip("cuml")
171-
nn_k = cuml.neighbors.NearestNeighbors(n_neighbors=2)
172-
smote = SMOTE(random_state=RND_SEED, k_neighbors=nn_k)
173-
X_resampled, y_resampled = smote.fit_resample(X, Y)
174-
X_gt = np.array(
175-
[
176-
[0.11622591, -0.0317206],
177-
[0.77481731, 0.60935141],
178-
[1.25192108, -0.22367336],
179-
[0.53366841, -0.30312976],
180-
[1.52091956, -0.49283504],
181-
[-0.28162401, -2.10400981],
182-
[0.83680821, 1.72827342],
183-
[0.3084254, 0.33299982],
184-
[0.70472253, -0.73309052],
185-
[0.28893132, -0.38761769],
186-
[1.15514042, 0.0129463],
187-
[0.88407872, 0.35454207],
188-
[1.31301027, -0.92648734],
189-
[-1.11515198, -0.93689695],
190-
[-0.18410027, -0.45194484],
191-
[0.9281014, 0.53085498],
192-
[-0.14374509, 0.27370049],
193-
[-0.41635887, -0.38299653],
194-
[0.08711622, 0.93259929],
195-
[1.70580611, -0.11219234],
196-
[1.10580062, 0.00601499],
197-
[1.60506454, -0.31959815],
198-
[1.40109204, -0.74276846],
199-
[0.38584956, -0.20702218],
200-
]
201-
)
202-
y_gt = np.array(
203-
[0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0]
204-
)
205-
assert_allclose(X_resampled, X_gt, rtol=R_TOL)
206-
assert_array_equal(y_resampled, y_gt)

imblearn/utils/testing.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pathlib import Path
1414
from re import compile
1515

16+
from scipy import sparse
1617
from pytest import warns as _warns
1718

1819
from sklearn.base import BaseEstimator
@@ -168,17 +169,23 @@ def warns(expected_warning, match=None):
168169

169170

170171
class CustomNearestNeighbors(BaseEstimator):
171-
"""Basic implementation of nearest neighbors not relying on scikit-learn."""
172+
"""Basic implementation of nearest neighbors not relying on scikit-learn.
172173
173-
def __init__(self, n_neighbors=1):
174+
`kneighbors_graph` is ignored and `metric` does not have any impact.
175+
"""
176+
177+
def __init__(self, n_neighbors=1, metric="euclidean"):
174178
self.n_neighbors = n_neighbors
179+
self.metric = metric
175180

176181
def fit(self, X, y=None):
182+
X = X.toarray() if sparse.issparse(X) else X
177183
self._kd_tree = KDTree(X)
178184
return self
179185

180186
def kneighbors(self, X, n_neighbors=None, return_distance=True):
181187
n_neighbors = n_neighbors if n_neighbors is not None else self.n_neighbors
188+
X = X.toarray() if sparse.issparse(X) else X
182189
distances, indices = self._kd_tree.query(X, k=n_neighbors)
183190
if return_distance:
184191
return distances, indices

0 commit comments

Comments
 (0)