|
3 | 3 | # Christos Aridas |
4 | 4 | # License: MIT |
5 | 5 |
|
| 6 | +from collections import OrderedDict |
| 7 | + |
| 8 | +import numpy as np |
6 | 9 | import pytest |
7 | 10 | from sklearn.base import clone |
8 | 11 | from sklearn.exceptions import ConvergenceWarning |
|
12 | 15 | parametrize_with_checks as parametrize_with_checks_sklearn, |
13 | 16 | ) |
14 | 17 |
|
15 | | -from imblearn.under_sampling import NearMiss |
| 18 | +from imblearn.over_sampling import RandomOverSampler |
| 19 | +from imblearn.under_sampling import NearMiss, RandomUnderSampler |
16 | 20 | from imblearn.utils.estimator_checks import ( |
17 | 21 | _set_checking_parameters, |
18 | 22 | check_param_validation, |
@@ -73,3 +77,19 @@ def test_check_param_validation(estimator): |
73 | 77 | print(name) |
74 | 78 | _set_checking_parameters(estimator) |
75 | 79 | check_param_validation(name, estimator) |
| 80 | + |
| 81 | + |
| 82 | +@pytest.mark.parametrize("Sampler", [RandomOverSampler, RandomUnderSampler]) |
| 83 | +def test_strategy_as_ordered_dict(Sampler): |
| 84 | + """Check that it is possible to pass an `OrderedDict` as strategy.""" |
| 85 | + rng = np.random.RandomState(42) |
| 86 | + X, y = rng.randn(30, 2), np.array([0] * 10 + [1] * 20) |
| 87 | + sampler = Sampler(random_state=42) |
| 88 | + if isinstance(sampler, RandomOverSampler): |
| 89 | + strategy = OrderedDict({0: 20, 1: 20}) |
| 90 | + else: |
| 91 | + strategy = OrderedDict({0: 10, 1: 10}) |
| 92 | + sampler.set_params(sampling_strategy=strategy) |
| 93 | + X_res, y_res = sampler.fit_resample(X, y) |
| 94 | + assert X_res.shape[0] == sum(strategy.values()) |
| 95 | + assert y_res.shape[0] == sum(strategy.values()) |
0 commit comments