@@ -51,16 +51,22 @@ def _set_checking_parameters(estimator):
5151
5252
5353def _yield_sampler_checks (sampler ):
54+ tags = sampler ._get_tags ()
5455 yield check_target_type
5556 yield check_samplers_one_label
5657 yield check_samplers_fit
5758 yield check_samplers_fit_resample
5859 yield check_samplers_sampling_strategy_fit_resample
59- yield check_samplers_sparse
60- yield check_samplers_pandas
60+ if "sparse" in tags ["X_types" ]:
61+ yield check_samplers_sparse
62+ if "dataframe" in tags ["X_types" ]:
63+ yield check_samplers_pandas
6164 yield check_samplers_list
6265 yield check_samplers_multiclass_ova
6366 yield check_samplers_preserve_dtype
67+ # we don't filter samplers based on their tag here because we want to make
68+ # sure that the fitted attribute does not exist if the tag is not
69+ # stipulated
6470 yield check_samplers_sample_indices
6571 yield check_samplers_2d_target
6672
@@ -75,7 +81,8 @@ def _yield_all_checks(estimator):
7581 tags = estimator ._get_tags ()
7682 if tags ["_skip_test" ]:
7783 warnings .warn (
78- f"Explicit SKIP via _skip_test tag for estimator { name } ." , SkipTestWarning ,
84+ f"Explicit SKIP via _skip_test tag for estimator { name } ." ,
85+ SkipTestWarning ,
7986 )
8087 return
8188 # trigger our checks if this is a SamplerMixin
@@ -116,6 +123,7 @@ def parametrize_with_checks(estimators):
116123 ... def test_sklearn_compatible_estimator(estimator, check):
117124 ... check(estimator)
118125 """
126+
119127 def checks_generator ():
120128 for estimator in estimators :
121129 name = type (estimator ).__name__
@@ -124,9 +132,7 @@ def checks_generator():
124132 yield _maybe_mark_xfail (estimator , check , pytest )
125133
126134 return pytest .mark .parametrize (
127- "estimator, check" ,
128- checks_generator (),
129- ids = _get_check_estimator_ids
135+ "estimator, check" , checks_generator (), ids = _get_check_estimator_ids
130136 )
131137
132138
@@ -137,14 +143,22 @@ def check_target_type(name, estimator_orig):
137143 y = np .linspace (0 , 1 , 20 )
138144 msg = "Unknown label type: 'continuous'"
139145 assert_raises_regex (
140- ValueError , msg , estimator .fit_resample , X , y ,
146+ ValueError ,
147+ msg ,
148+ estimator .fit_resample ,
149+ X ,
150+ y ,
141151 )
142152 # if the target is multilabel then we should raise an error
143153 rng = np .random .RandomState (42 )
144154 y = rng .randint (2 , size = (20 , 3 ))
145155 msg = "Multilabel and multioutput targets are not supported."
146156 assert_raises_regex (
147- ValueError , msg , estimator .fit_resample , X , y ,
157+ ValueError ,
158+ msg ,
159+ estimator .fit_resample ,
160+ X ,
161+ y ,
148162 )
149163
150164
@@ -385,9 +399,7 @@ def check_samplers_sample_indices(name, sampler_orig):
385399 assert not hasattr (sampler , "sample_indices_" )
386400
387401
388- def check_classifier_on_multilabel_or_multioutput_targets (
389- name , estimator_orig
390- ):
402+ def check_classifier_on_multilabel_or_multioutput_targets (name , estimator_orig ):
391403 estimator = clone (estimator_orig )
392404 X , y = make_multilabel_classification (n_samples = 30 )
393405 msg = "Multilabel and multioutput targets are not supported."
0 commit comments