2525from sklearn .cluster import KMeans
2626from sklearn .exceptions import SkipTestWarning
2727from sklearn .preprocessing import label_binarize
28- from sklearn .utils .estimator_checks import _mark_xfail_checks
29- from sklearn .utils .estimator_checks import _set_check_estimator_ids
28+ from sklearn .utils .estimator_checks import _maybe_mark_xfail
29+ from sklearn .utils .estimator_checks import _get_check_estimator_ids
3030from sklearn .utils ._testing import assert_allclose
3131from sklearn .utils ._testing import assert_raises_regex
3232from sklearn .utils .multiclass import type_of_target
@@ -44,7 +44,7 @@ def _set_checking_parameters(estimator):
4444 if name == "ClusterCentroids" :
4545 estimator .set_params (
4646 voting = "soft" ,
47- estimator = KMeans (random_state = 0 , algorithm = "full" ),
47+ estimator = KMeans (random_state = 0 , algorithm = "full" , n_init = 1 ),
4848 )
4949 if name == "KMeansSMOTE" :
5050 estimator .set_params (kmeans_estimator = 12 )
@@ -117,21 +117,19 @@ def parametrize_with_checks(estimators):
117117 ... def test_sklearn_compatible_estimator(estimator, check):
118118 ... check(estimator)
119119 """
120- names = (type (estimator ).__name__ for estimator in estimators )
120+ def checks_generator ():
121+ for estimator in estimators :
122+ name = type (estimator ).__name__
123+ for check in _yield_all_checks (estimator ):
124+ check = partial (check , name )
125+ yield _maybe_mark_xfail (estimator , check , pytest )
121126
122- checks_generator = ((clone (estimator ), partial (check , name ))
123- for name , estimator in zip (names , estimators )
124- for check in _yield_all_checks (estimator ))
127+ return pytest .mark .parametrize ("estimator, check" , checks_generator (),
128+ ids = _get_check_estimator_ids )
125129
126- checks_with_marks = (
127- _mark_xfail_checks (estimator , check , pytest )
128- for estimator , check in checks_generator )
129130
130- return pytest .mark .parametrize ("estimator, check" , checks_with_marks ,
131- ids = _set_check_estimator_ids )
132-
133-
134- def check_target_type (name , estimator ):
131+ def check_target_type (name , estimator_orig ):
132+ estimator = clone (estimator_orig )
135133 # should raise warning if the target is continuous (we cannot raise error)
136134 X = np .random .random ((20 , 2 ))
137135 y = np .linspace (0 , 1 , 20 )
@@ -148,7 +146,8 @@ def check_target_type(name, estimator):
148146 )
149147
150148
151- def check_samplers_one_label (name , sampler ):
149+ def check_samplers_one_label (name , sampler_orig ):
150+ sampler = clone (sampler_orig )
152151 error_string_fit = "Sampler can't balance when only one class is present."
153152 X = np .random .random ((20 , 2 ))
154153 y = np .zeros (20 )
@@ -168,7 +167,8 @@ def check_samplers_one_label(name, sampler):
168167 raise AssertionError (error_string_fit )
169168
170169
171- def check_samplers_fit (name , sampler ):
170+ def check_samplers_fit (name , sampler_orig ):
171+ sampler = clone (sampler_orig )
172172 np .random .seed (42 ) # Make this test reproducible
173173 X = np .random .random ((30 , 2 ))
174174 y = np .array ([1 ] * 20 + [0 ] * 10 )
@@ -178,7 +178,8 @@ def check_samplers_fit(name, sampler):
178178 ), "No fitted attribute sampling_strategy_"
179179
180180
181- def check_samplers_fit_resample (name , sampler ):
181+ def check_samplers_fit_resample (name , sampler_orig ):
182+ sampler = clone (sampler_orig )
182183 X , y = make_classification (
183184 n_samples = 1000 ,
184185 n_classes = 3 ,
@@ -213,7 +214,8 @@ def check_samplers_fit_resample(name, sampler):
213214 )
214215
215216
216- def check_samplers_sampling_strategy_fit_resample (name , sampler ):
217+ def check_samplers_sampling_strategy_fit_resample (name , sampler_orig ):
218+ sampler = clone (sampler_orig )
217219 # in this test we will force all samplers to not change the class 1
218220 X , y = make_classification (
219221 n_samples = 1000 ,
@@ -240,7 +242,8 @@ def check_samplers_sampling_strategy_fit_resample(name, sampler):
240242 assert Counter (y_res )[1 ] == expected_stat
241243
242244
243- def check_samplers_sparse (name , sampler ):
245+ def check_samplers_sparse (name , sampler_orig ):
246+ sampler = clone (sampler_orig )
244247 # check that sparse matrices can be passed through the sampler leading to
245248 # the same results than dense
246249 X , y = make_classification (
@@ -252,14 +255,16 @@ def check_samplers_sparse(name, sampler):
252255 )
253256 X_sparse = sparse .csr_matrix (X )
254257 X_res_sparse , y_res_sparse = sampler .fit_resample (X_sparse , y )
258+ sampler = clone (sampler )
255259 X_res , y_res = sampler .fit_resample (X , y )
256260 assert sparse .issparse (X_res_sparse )
257- assert_allclose (X_res_sparse .A , X_res )
261+ assert_allclose (X_res_sparse .A , X_res , rtol = 1e-5 )
258262 assert_allclose (y_res_sparse , y_res )
259263
260264
261- def check_samplers_pandas (name , sampler ):
265+ def check_samplers_pandas (name , sampler_orig ):
262266 pd = pytest .importorskip ("pandas" )
267+ sampler = clone (sampler_orig )
263268 # Check that the samplers handle pandas dataframe and pandas series
264269 X , y = make_classification (
265270 n_samples = 1000 ,
@@ -290,7 +295,8 @@ def check_samplers_pandas(name, sampler):
290295 assert_allclose (y_res_s .to_numpy (), y_res )
291296
292297
293- def check_samplers_list (name , sampler ):
298+ def check_samplers_list (name , sampler_orig ):
299+ sampler = clone (sampler_orig )
294300 # Check that the can samplers handle simple lists
295301 X , y = make_classification (
296302 n_samples = 1000 ,
@@ -312,7 +318,8 @@ def check_samplers_list(name, sampler):
312318 assert_allclose (y_res , y_res_list )
313319
314320
315- def check_samplers_multiclass_ova (name , sampler ):
321+ def check_samplers_multiclass_ova (name , sampler_orig ):
322+ sampler = clone (sampler_orig )
316323 # Check that multiclass target lead to the same results than OVA encoding
317324 X , y = make_classification (
318325 n_samples = 1000 ,
@@ -329,7 +336,8 @@ def check_samplers_multiclass_ova(name, sampler):
329336 assert_allclose (y_res , y_res_ova .argmax (axis = 1 ))
330337
331338
332- def check_samplers_2d_target (name , sampler ):
339+ def check_samplers_2d_target (name , sampler_orig ):
340+ sampler = clone (sampler_orig )
333341 X , y = make_classification (
334342 n_samples = 100 ,
335343 n_classes = 3 ,
@@ -342,7 +350,8 @@ def check_samplers_2d_target(name, sampler):
342350 sampler .fit_resample (X , y )
343351
344352
345- def check_samplers_preserve_dtype (name , sampler ):
353+ def check_samplers_preserve_dtype (name , sampler_orig ):
354+ sampler = clone (sampler_orig )
346355 X , y = make_classification (
347356 n_samples = 1000 ,
348357 n_classes = 3 ,
@@ -358,7 +367,8 @@ def check_samplers_preserve_dtype(name, sampler):
358367 assert y .dtype == y_res .dtype , "y dtype is not preserved"
359368
360369
361- def check_samplers_sample_indices (name , sampler ):
370+ def check_samplers_sample_indices (name , sampler_orig ):
371+ sampler = clone (sampler_orig )
362372 X , y = make_classification (
363373 n_samples = 1000 ,
364374 n_classes = 3 ,
@@ -374,17 +384,21 @@ def check_samplers_sample_indices(name, sampler):
374384 assert not hasattr (sampler , "sample_indices_" )
375385
376386
377- def check_classifier_on_multilabel_or_multioutput_targets (name , estimator ):
387+ def check_classifier_on_multilabel_or_multioutput_targets (
388+ name , estimator_orig
389+ ):
390+ estimator = clone (estimator_orig )
378391 X , y = make_multilabel_classification (n_samples = 30 )
379392 msg = "Multilabel and multioutput targets are not supported."
380393 with pytest .raises (ValueError , match = msg ):
381394 estimator .fit (X , y )
382395
383396
384- def check_classifiers_with_encoded_labels (name , classifier ):
397+ def check_classifiers_with_encoded_labels (name , classifier_orig ):
385398 # Non-regression test for #709
386399 # https://github.com/scikit-learn-contrib/imbalanced-learn/issues/709
387400 pytest .importorskip ("pandas" )
401+ classifier = clone (classifier_orig )
388402 df , y = fetch_openml ("iris" , version = 1 , as_frame = True , return_X_y = True )
389403 df , y = make_imbalance (
390404 df , y , sampling_strategy = {
0 commit comments