Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions imblearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ def _check_X_y(self, X, y, accept_sparse=None):
X, y = self._validate_data(X, y, reset=True, accept_sparse=accept_sparse)
return X, y, binarize_y

def _more_tags(self):
return {"X_types": ["2darray", "sparse", "dataframe"]}


def _identity(X, y):
return X, y
Expand Down
2 changes: 1 addition & 1 deletion imblearn/over_sampling/_random_over_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def _fit_resample(self, X, y):

def _more_tags(self):
return {
"X_types": ["2darray", "string"],
"X_types": ["2darray", "string", "sparse", "dataframe"],
"sample_indices": True,
"allow_nan": True,
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ def _fit_resample(self, X, y):
index_target_class = slice(None)

idx_under = np.concatenate(
(idx_under, np.flatnonzero(y == target_class)[index_target_class],),
(
idx_under,
np.flatnonzero(y == target_class)[index_target_class],
),
axis=0,
)

Expand All @@ -118,7 +121,7 @@ def _fit_resample(self, X, y):

def _more_tags(self):
return {
"X_types": ["2darray", "string"],
"X_types": ["2darray", "string", "sparse", "dataframe"],
"sample_indices": True,
"allow_nan": True,
}
34 changes: 23 additions & 11 deletions imblearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,22 @@ def _set_checking_parameters(estimator):


def _yield_sampler_checks(sampler):
tags = sampler._get_tags()
yield check_target_type
yield check_samplers_one_label
yield check_samplers_fit
yield check_samplers_fit_resample
yield check_samplers_sampling_strategy_fit_resample
yield check_samplers_sparse
yield check_samplers_pandas
if "sparse" in tags["X_types"]:
yield check_samplers_sparse
if "dataframe" in tags["X_types"]:
yield check_samplers_pandas
yield check_samplers_list
yield check_samplers_multiclass_ova
yield check_samplers_preserve_dtype
# we don't filter samplers based on their tag here because we want to make
# sure that the fitted attribute does not exist if the tag is not
# stipulated
yield check_samplers_sample_indices
yield check_samplers_2d_target

Expand All @@ -75,7 +81,8 @@ def _yield_all_checks(estimator):
tags = estimator._get_tags()
if tags["_skip_test"]:
warnings.warn(
f"Explicit SKIP via _skip_test tag for estimator {name}.", SkipTestWarning,
f"Explicit SKIP via _skip_test tag for estimator {name}.",
SkipTestWarning,
)
return
# trigger our checks if this is a SamplerMixin
Expand Down Expand Up @@ -116,6 +123,7 @@ def parametrize_with_checks(estimators):
... def test_sklearn_compatible_estimator(estimator, check):
... check(estimator)
"""

def checks_generator():
for estimator in estimators:
name = type(estimator).__name__
Expand All @@ -124,9 +132,7 @@ def checks_generator():
yield _maybe_mark_xfail(estimator, check, pytest)

return pytest.mark.parametrize(
"estimator, check",
checks_generator(),
ids=_get_check_estimator_ids
"estimator, check", checks_generator(), ids=_get_check_estimator_ids
)


Expand All @@ -137,14 +143,22 @@ def check_target_type(name, estimator_orig):
y = np.linspace(0, 1, 20)
msg = "Unknown label type: 'continuous'"
assert_raises_regex(
ValueError, msg, estimator.fit_resample, X, y,
ValueError,
msg,
estimator.fit_resample,
X,
y,
)
# if the target is multilabel then we should raise an error
rng = np.random.RandomState(42)
y = rng.randint(2, size=(20, 3))
msg = "Multilabel and multioutput targets are not supported."
assert_raises_regex(
ValueError, msg, estimator.fit_resample, X, y,
ValueError,
msg,
estimator.fit_resample,
X,
y,
)


Expand Down Expand Up @@ -385,9 +399,7 @@ def check_samplers_sample_indices(name, sampler_orig):
assert not hasattr(sampler, "sample_indices_")


def check_classifier_on_multilabel_or_multioutput_targets(
name, estimator_orig
):
def check_classifier_on_multilabel_or_multioutput_targets(name, estimator_orig):
estimator = clone(estimator_orig)
X, y = make_multilabel_classification(n_samples=30)
msg = "Multilabel and multioutput targets are not supported."
Expand Down