From 738d2ec860c39b2aeaa0107528e78cf5db589636 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 5 Dec 2022 13:59:14 +0100 Subject: [PATCH 1/2] MAINT be more inclusive regarding dict --- imblearn/over_sampling/base.py | 3 ++- imblearn/under_sampling/base.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/imblearn/over_sampling/base.py b/imblearn/over_sampling/base.py index 4bc08e91a..d4e4a4541 100644 --- a/imblearn/over_sampling/base.py +++ b/imblearn/over_sampling/base.py @@ -6,6 +6,7 @@ # License: MIT import numbers +from collections.abc import Mapping from ..base import BaseSampler from ..utils._param_validation import Interval, StrOptions @@ -61,7 +62,7 @@ class BaseOverSampler(BaseSampler): "sampling_strategy": [ Interval(numbers.Real, 0, 1, closed="right"), StrOptions({"auto", "majority", "not minority", "not majority", "all"}), - dict, + Mapping, callable, ], "random_state": ["random_state"], diff --git a/imblearn/under_sampling/base.py b/imblearn/under_sampling/base.py index e36d8c31f..92da45723 100644 --- a/imblearn/under_sampling/base.py +++ b/imblearn/under_sampling/base.py @@ -5,6 +5,7 @@ # License: MIT import numbers +from collections.abc import Mapping from ..base import BaseSampler from ..utils._param_validation import Interval, StrOptions @@ -61,7 +62,7 @@ class BaseUnderSampler(BaseSampler): "sampling_strategy": [ Interval(numbers.Real, 0, 1, closed="right"), StrOptions({"auto", "majority", "not minority", "not majority", "all"}), - dict, + Mapping, callable, ], } From 15592a35850bd647913b64c1d9f4e9e8b0566bc3 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 5 Dec 2022 14:32:13 +0100 Subject: [PATCH 2/2] TST non-regression test --- imblearn/tests/test_common.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/imblearn/tests/test_common.py b/imblearn/tests/test_common.py index 9ec5764d3..0f0b06494 100644 --- a/imblearn/tests/test_common.py +++ b/imblearn/tests/test_common.py @@ -3,6 +3,9 @@ # Christos Aridas # License: MIT +from collections import OrderedDict + +import numpy as np import pytest from sklearn.base import clone from sklearn.exceptions import ConvergenceWarning @@ -12,7 +15,8 @@ parametrize_with_checks as parametrize_with_checks_sklearn, ) -from imblearn.under_sampling import NearMiss +from imblearn.over_sampling import RandomOverSampler +from imblearn.under_sampling import NearMiss, RandomUnderSampler from imblearn.utils.estimator_checks import ( _set_checking_parameters, check_param_validation, @@ -73,3 +77,19 @@ def test_check_param_validation(estimator): print(name) _set_checking_parameters(estimator) check_param_validation(name, estimator) + + +@pytest.mark.parametrize("Sampler", [RandomOverSampler, RandomUnderSampler]) +def test_strategy_as_ordered_dict(Sampler): + """Check that it is possible to pass an `OrderedDict` as strategy.""" + rng = np.random.RandomState(42) + X, y = rng.randn(30, 2), np.array([0] * 10 + [1] * 20) + sampler = Sampler(random_state=42) + if isinstance(sampler, RandomOverSampler): + strategy = OrderedDict({0: 20, 1: 20}) + else: + strategy = OrderedDict({0: 10, 1: 10}) + sampler.set_params(sampling_strategy=strategy) + X_res, y_res = sampler.fit_resample(X, y) + assert X_res.shape[0] == sum(strategy.values()) + assert y_res.shape[0] == sum(strategy.values())