99import pytest
1010import numpy as np
1111
12- from sklearn .base import BaseEstimator
1312from sklearn .neighbors ._base import KNeighborsMixin
1413from sklearn .neighbors import NearestNeighbors
1514from sklearn .utils ._testing import assert_array_equal
1615
17- from imblearn .utils .testing import warns
1816from imblearn .utils import check_neighbors_object
1917from imblearn .utils import check_sampling_strategy
2018from imblearn .utils import check_target_type
19+ from imblearn .utils .testing import warns , _CustomNearestNeighbors
2120from imblearn .utils ._validation import ArraysTransformer
2221from imblearn .utils ._validation import _deprecate_positional_args
2322
2423multiclass_target = np .array ([1 ] * 50 + [2 ] * 100 + [3 ] * 25 )
2524binary_target = np .array ([1 ] * 25 + [0 ] * 100 )
2625
2726
28- class KNNLikeEstimator (BaseEstimator ):
29- """A class exposing the same KNeighborsMixin API than KNeighborsClassifier."""
30-
31- def kneighbors (self , X ):
32- return np .ones ((len (X ), 1 ))
33-
34- def kneighbors_graph (self , X ):
35- return np .ones ((len (X ), 1 ))
36-
37-
3827def test_check_neighbors_object ():
3928 name = "n_neighbors"
4029 n_neighbors = 1
@@ -47,9 +36,9 @@ def test_check_neighbors_object():
4736 estimator = NearestNeighbors (n_neighbors = n_neighbors )
4837 estimator_cloned = check_neighbors_object (name , estimator )
4938 assert estimator .n_neighbors == estimator_cloned .n_neighbors
50- estimator = KNNLikeEstimator ()
39+ estimator = _CustomNearestNeighbors ()
5140 estimator_cloned = check_neighbors_object (name , estimator )
52- assert isinstance (estimator_cloned , KNNLikeEstimator )
41+ assert isinstance (estimator_cloned , _CustomNearestNeighbors )
5342 n_neighbors = "rnd"
5443 err_msg = (
5544 "n_neighbors must be an interger or an object compatible with the "
0 commit comments