diff --git a/imblearn/base.py b/imblearn/base.py index 650288169..cf7cbacb3 100644 --- a/imblearn/base.py +++ b/imblearn/base.py @@ -130,6 +130,9 @@ def _check_X_y(self, X, y, accept_sparse=None): X, y = self._validate_data(X, y, reset=True, accept_sparse=accept_sparse) return X, y, binarize_y + def _more_tags(self): + return {"X_types": ["2darray", "sparse", "dataframe"]} + def _identity(X, y): return X, y diff --git a/imblearn/over_sampling/_random_over_sampler.py b/imblearn/over_sampling/_random_over_sampler.py index 1801e258f..4d2795d7e 100644 --- a/imblearn/over_sampling/_random_over_sampler.py +++ b/imblearn/over_sampling/_random_over_sampler.py @@ -241,7 +241,7 @@ def _fit_resample(self, X, y): def _more_tags(self): return { - "X_types": ["2darray", "string"], + "X_types": ["2darray", "string", "sparse", "dataframe"], "sample_indices": True, "allow_nan": True, } diff --git a/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py b/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py index 858dc8c21..6a57659fb 100644 --- a/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py +++ b/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py @@ -108,7 +108,10 @@ def _fit_resample(self, X, y): index_target_class = slice(None) idx_under = np.concatenate( - (idx_under, np.flatnonzero(y == target_class)[index_target_class],), + ( + idx_under, + np.flatnonzero(y == target_class)[index_target_class], + ), axis=0, ) @@ -118,7 +121,7 @@ def _fit_resample(self, X, y): def _more_tags(self): return { - "X_types": ["2darray", "string"], + "X_types": ["2darray", "string", "sparse", "dataframe"], "sample_indices": True, "allow_nan": True, } diff --git a/imblearn/utils/estimator_checks.py b/imblearn/utils/estimator_checks.py index e8a72db62..d693e24ed 100644 --- a/imblearn/utils/estimator_checks.py +++ b/imblearn/utils/estimator_checks.py @@ -51,16 +51,22 @@ def _set_checking_parameters(estimator): def _yield_sampler_checks(sampler): + tags = sampler._get_tags() yield check_target_type yield check_samplers_one_label yield check_samplers_fit yield check_samplers_fit_resample yield check_samplers_sampling_strategy_fit_resample - yield check_samplers_sparse - yield check_samplers_pandas + if "sparse" in tags["X_types"]: + yield check_samplers_sparse + if "dataframe" in tags["X_types"]: + yield check_samplers_pandas yield check_samplers_list yield check_samplers_multiclass_ova yield check_samplers_preserve_dtype + # we don't filter samplers based on their tag here because we want to make + # sure that the fitted attribute does not exist if the tag is not + # stipulated yield check_samplers_sample_indices yield check_samplers_2d_target @@ -75,7 +81,8 @@ def _yield_all_checks(estimator): tags = estimator._get_tags() if tags["_skip_test"]: warnings.warn( - f"Explicit SKIP via _skip_test tag for estimator {name}.", SkipTestWarning, + f"Explicit SKIP via _skip_test tag for estimator {name}.", + SkipTestWarning, ) return # trigger our checks if this is a SamplerMixin @@ -116,6 +123,7 @@ def parametrize_with_checks(estimators): ... def test_sklearn_compatible_estimator(estimator, check): ... check(estimator) """ + def checks_generator(): for estimator in estimators: name = type(estimator).__name__ @@ -124,9 +132,7 @@ def checks_generator(): yield _maybe_mark_xfail(estimator, check, pytest) return pytest.mark.parametrize( - "estimator, check", - checks_generator(), - ids=_get_check_estimator_ids + "estimator, check", checks_generator(), ids=_get_check_estimator_ids ) @@ -137,14 +143,22 @@ def check_target_type(name, estimator_orig): y = np.linspace(0, 1, 20) msg = "Unknown label type: 'continuous'" assert_raises_regex( - ValueError, msg, estimator.fit_resample, X, y, + ValueError, + msg, + estimator.fit_resample, + X, + y, ) # if the target is multilabel then we should raise an error rng = np.random.RandomState(42) y = rng.randint(2, size=(20, 3)) msg = "Multilabel and multioutput targets are not supported." assert_raises_regex( - ValueError, msg, estimator.fit_resample, X, y, + ValueError, + msg, + estimator.fit_resample, + X, + y, ) @@ -385,9 +399,7 @@ def check_samplers_sample_indices(name, sampler_orig): assert not hasattr(sampler, "sample_indices_") -def check_classifier_on_multilabel_or_multioutput_targets( - name, estimator_orig -): +def check_classifier_on_multilabel_or_multioutput_targets(name, estimator_orig): estimator = clone(estimator_orig) X, y = make_multilabel_classification(n_samples=30) msg = "Multilabel and multioutput targets are not supported."