diff --git a/doc/whats_new/v0.7.rst b/doc/whats_new/v0.7.rst index bca4a680d..4691c112c 100644 --- a/doc/whats_new/v0.7.rst +++ b/doc/whats_new/v0.7.rst @@ -54,6 +54,10 @@ Bug fixes the targeted class. :pr:`769` by :user:`Guillaume Lemaitre `. +- Fix a bug in :class:`imblearn.FunctionSampler` where validation was performed + even with `validate=False` when calling `fit`. + :pr:`790` by :user:`Guillaume Lemaitre `. + Enhancements ............ diff --git a/imblearn/base.py b/imblearn/base.py index 8933b22e6..92c3d02b6 100644 --- a/imblearn/base.py +++ b/imblearn/base.py @@ -220,6 +220,38 @@ def __init__(self, *, func=None, accept_sparse=True, kw_args=None, self.kw_args = kw_args self.validate = validate + def fit(self, X, y): + """Check inputs and statistics of the sampler. + + You should use ``fit_resample`` in all cases. + + Parameters + ---------- + X : {array-like, dataframe, sparse matrix} of shape \ + (n_samples, n_features) + Data array. + + y : array-like of shape (n_samples,) + Target array. + + Returns + ------- + self : object + Return the instance itself. + """ + # we need to overwrite SamplerMixin.fit to bypass the validation + if self.validate: + check_classification_targets(y) + X, y, _ = self._check_X_y( + X, y, accept_sparse=self.accept_sparse + ) + + self.sampling_strategy_ = check_sampling_strategy( + self.sampling_strategy, y, self._sampling_type + ) + + return self + def fit_resample(self, X, y): """Resample the dataset. diff --git a/imblearn/tests/test_base.py b/imblearn/tests/test_base.py index f10910564..9e84e1e72 100644 --- a/imblearn/tests/test_base.py +++ b/imblearn/tests/test_base.py @@ -94,3 +94,18 @@ def dummy_sampler(X, y): y_pred = pipeline.fit(X, y).predict(X) assert type_of_target(y_pred) == 'continuous' + + +def test_function_resampler_fit(): + # Check that the validation is bypass when calling `fit` + # Non-regression test for: + # https://github.com/scikit-learn-contrib/imbalanced-learn/issues/782 + X = np.array([[1, np.nan], [2, 3], [np.inf, 4]]) + y = np.array([0, 1, 1]) + + def func(X, y): + return X[:1], y[:1] + + sampler = FunctionSampler(func=func, validate=False) + sampler.fit(X, y) + sampler.fit_resample(X, y)