Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats_new/v0.7.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ Bug fixes
the targeted class.
:pr:`769` by :user:`Guillaume Lemaitre <glemaitre>`.

- Fix a bug in :class:`imblearn.FunctionSampler` where validation was performed
even with `validate=False` when calling `fit`.
:pr:`790` by :user:`Guillaume Lemaitre <glemaitre>`.

Enhancements
............

Expand Down
32 changes: 32 additions & 0 deletions imblearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,38 @@ def __init__(self, *, func=None, accept_sparse=True, kw_args=None,
self.kw_args = kw_args
self.validate = validate

def fit(self, X, y):
"""Check inputs and statistics of the sampler.

You should use ``fit_resample`` in all cases.

Parameters
----------
X : {array-like, dataframe, sparse matrix} of shape \
(n_samples, n_features)
Data array.

y : array-like of shape (n_samples,)
Target array.

Returns
-------
self : object
Return the instance itself.
"""
# we need to overwrite SamplerMixin.fit to bypass the validation
if self.validate:
check_classification_targets(y)
X, y, _ = self._check_X_y(
X, y, accept_sparse=self.accept_sparse
)

self.sampling_strategy_ = check_sampling_strategy(
self.sampling_strategy, y, self._sampling_type
)

return self

def fit_resample(self, X, y):
"""Resample the dataset.

Expand Down
15 changes: 15 additions & 0 deletions imblearn/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,18 @@ def dummy_sampler(X, y):
y_pred = pipeline.fit(X, y).predict(X)

assert type_of_target(y_pred) == 'continuous'


def test_function_resampler_fit():
# Check that the validation is bypass when calling `fit`
# Non-regression test for:
# https://github.com/scikit-learn-contrib/imbalanced-learn/issues/782
X = np.array([[1, np.nan], [2, 3], [np.inf, 4]])
y = np.array([0, 1, 1])

def func(X, y):
return X[:1], y[:1]

sampler = FunctionSampler(func=func, validate=False)
sampler.fit(X, y)
sampler.fit_resample(X, y)