Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/whats_new/v0.12.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ Bug fixes
the number of samples in the minority class.
:pr:`1012` by :user:`Guillaume Lemaitre <glemaitre>`.

Compatibility
.............

- :class:`~imblearn.ensemble.BalancedRandomForestClassifier` now support missing values
and monotonic constraints if scikit-learn >= 1.4 is installed.
- :class:`~imblearn.pipeline.Pipeline` support metadata routing if scikit-learn >= 1.4
is installed.

Deprecations
............

Expand Down
1 change: 1 addition & 0 deletions imblearn/ensemble/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,5 @@ def check(self):
list,
None,
],
"monotonic_cst": ["array-like", None],
}
163 changes: 115 additions & 48 deletions imblearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _local_parallel_build_trees(
class_weight=None,
n_samples_bootstrap=None,
forest=None,
missing_values_in_feature_mask=None,
):
# resample before to fit the tree
X_resampled, y_resampled = sampler.fit_resample(X, y)
Expand All @@ -68,33 +69,34 @@ def _local_parallel_build_trees(
if _get_n_samples_bootstrap is not None:
n_samples_bootstrap = min(n_samples_bootstrap, X_resampled.shape[0])

if sklearn_version >= parse_version("1.1"):
tree = _parallel_build_trees(
tree,
bootstrap,
X_resampled,
y_resampled,
sample_weight,
tree_idx,
n_trees,
verbose=verbose,
class_weight=class_weight,
n_samples_bootstrap=n_samples_bootstrap,
)
params_parallel_build_trees = {
"tree": tree,
"X": X_resampled,
"y": y_resampled,
"sample_weight": sample_weight,
"tree_idx": tree_idx,
"n_trees": n_trees,
"verbose": verbose,
"class_weight": class_weight,
"n_samples_bootstrap": n_samples_bootstrap,
}

if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
# TODO: remove when the minimum supported version of scikit-learn will be 1.4
# support for missing values
params_parallel_build_trees[
"missing_values_in_feature_mask"
] = missing_values_in_feature_mask

# TODO: remove when the minimum supported version of scikit-learn will be 1.1
# change of signature in scikit-learn 1.1
if parse_version(sklearn_version.base_version) >= parse_version("1.1"):
params_parallel_build_trees["bootstrap"] = bootstrap
else:
# TODO: remove when the minimum version of scikit-learn supported is 1.1
tree = _parallel_build_trees(
tree,
forest,
X_resampled,
y_resampled,
sample_weight,
tree_idx,
n_trees,
verbose=verbose,
class_weight=class_weight,
n_samples_bootstrap=n_samples_bootstrap,
)
params_parallel_build_trees["forest"] = forest

tree = _parallel_build_trees(**params_parallel_build_trees)

return sampler, tree


Expand Down Expand Up @@ -305,6 +307,25 @@ class BalancedRandomForestClassifier(_ParamsValidationMixin, RandomForestClassif
.. versionadded:: 0.6
Added in `scikit-learn` in 0.22

monotonic_cst : array-like of int of shape (n_features), default=None
Indicates the monotonicity constraint to enforce on each feature.
- 1: monotonic increase
- 0: no constraint
- -1: monotonic decrease

If monotonic_cst is None, no constraints are applied.

Monotonicity constraints are not supported for:
- multiclass classifications (i.e. when `n_classes > 2`),
- multioutput classifications (i.e. when `n_outputs_ > 1`),
- classifications trained on data with missing values.

The constraints hold over the probability of the positive class.

.. versionadded:: 0.12
Only supported when scikit-learn >= 1.4 is installed. Otherwise, a
`ValueError` is raised.

Attributes
----------
estimator_ : :class:`~sklearn.tree.DecisionTreeClassifier` instance
Expand Down Expand Up @@ -415,7 +436,7 @@ class labels (multi-output problem).
"""

# make a deepcopy to not modify the original dictionary
if sklearn_version >= parse_version("1.3"):
if sklearn_version >= parse_version("1.4"):
_parameter_constraints = deepcopy(RandomForestClassifier._parameter_constraints)
else:
_parameter_constraints = deepcopy(
Expand Down Expand Up @@ -459,27 +480,42 @@ def __init__(
class_weight=None,
ccp_alpha=0.0,
max_samples=None,
monotonic_cst=None,
):
super().__init__(
criterion=criterion,
max_depth=max_depth,
n_estimators=n_estimators,
bootstrap=bootstrap,
oob_score=oob_score,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
warm_start=warm_start,
class_weight=class_weight,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
min_weight_fraction_leaf=min_weight_fraction_leaf,
max_features=max_features,
max_leaf_nodes=max_leaf_nodes,
min_impurity_decrease=min_impurity_decrease,
ccp_alpha=ccp_alpha,
max_samples=max_samples,
)
params_random_forest = {
"criterion": criterion,
"max_depth": max_depth,
"n_estimators": n_estimators,
"bootstrap": bootstrap,
"oob_score": oob_score,
"n_jobs": n_jobs,
"random_state": random_state,
"verbose": verbose,
"warm_start": warm_start,
"class_weight": class_weight,
"min_samples_split": min_samples_split,
"min_samples_leaf": min_samples_leaf,
"min_weight_fraction_leaf": min_weight_fraction_leaf,
"max_features": max_features,
"max_leaf_nodes": max_leaf_nodes,
"min_impurity_decrease": min_impurity_decrease,
"ccp_alpha": ccp_alpha,
"max_samples": max_samples,
}
# TODO: remove when the minimum supported version of scikit-learn will be 1.4
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
# use scikit-learn support for monotonic constraints
params_random_forest["monotonic_cst"] = monotonic_cst
else:
if monotonic_cst is not None:
raise ValueError(
"Monotonic constraints are not supported for scikit-learn "
"version < 1.4."
)
# create an attribute for compatibility with other scikit-learn tools such
# as HTML representation.
self.monotonic_cst = monotonic_cst
super().__init__(**params_random_forest)

self.sampling_strategy = sampling_strategy
self.replacement = replacement
Expand Down Expand Up @@ -591,11 +627,41 @@ def fit(self, X, y, sample_weight=None):
# Validate or convert input data
if issparse(y):
raise ValueError("sparse multilabel-indicator for y is not supported.")

# TODO: remove when the minimum supported version of scipy will be 1.4
# Support for missing values
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
force_all_finite = False
else:
force_all_finite = True

X, y = self._validate_data(
X, y, multi_output=True, accept_sparse="csc", dtype=DTYPE
X,
y,
multi_output=True,
accept_sparse="csc",
dtype=DTYPE,
force_all_finite=force_all_finite,
)

# TODO: remove when the minimum supported version of scikit-learn will be 1.4
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
# _compute_missing_values_in_feature_mask checks if X has missing values and
# will raise an error if the underlying tree base estimator can't handle
# missing values. Only the criterion is required to determine if the tree
# supports missing values.
estimator = type(self.estimator)(criterion=self.criterion)
missing_values_in_feature_mask = (
estimator._compute_missing_values_in_feature_mask(
X, estimator_name=self.__class__.__name__
)
)
else:
missing_values_in_feature_mask = None

if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X)

self._n_features = X.shape[1]

if issparse(X):
Expand Down Expand Up @@ -713,6 +779,7 @@ def fit(self, X, y, sample_weight=None):
class_weight=self.class_weight,
n_samples_bootstrap=n_samples_bootstrap,
forest=self,
missing_values_in_feature_mask=missing_values_in_feature_mask,
)
for i, (s, t) in enumerate(zip(samplers, trees))
)
Expand Down
97 changes: 97 additions & 0 deletions imblearn/ensemble/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,100 @@ def test_balanced_random_forest_change_behaviour(imbalanced_dataset):
)
with pytest.warns(FutureWarning, match="The default of `bootstrap`"):
estimator.fit(*imbalanced_dataset)


@pytest.mark.skipif(
parse_version(sklearn_version.base_version) < parse_version("1.4"),
reason="scikit-learn should be >= 1.4",
)
def test_missing_values_is_resilient():
"""Check that forest can deal with missing values and has decent performance."""

rng = np.random.RandomState(0)
n_samples, n_features = 1000, 10
X, y = make_classification(
n_samples=n_samples, n_features=n_features, random_state=rng
)

# Create dataset with missing values
X_missing = X.copy()
X_missing[rng.choice([False, True], size=X.shape, p=[0.95, 0.05])] = np.nan
assert np.isnan(X_missing).any()

X_missing_train, X_missing_test, y_train, y_test = train_test_split(
X_missing, y, random_state=0
)

# Train forest with missing values
forest_with_missing = BalancedRandomForestClassifier(
sampling_strategy="all",
replacement=True,
bootstrap=False,
random_state=rng,
n_estimators=50,
)
forest_with_missing.fit(X_missing_train, y_train)
score_with_missing = forest_with_missing.score(X_missing_test, y_test)

# Train forest without missing values
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
forest = BalancedRandomForestClassifier(
sampling_strategy="all",
replacement=True,
bootstrap=False,
random_state=rng,
n_estimators=50,
)
forest.fit(X_train, y_train)
score_without_missing = forest.score(X_test, y_test)

# Score is still 80 percent of the forest's score that had no missing values
assert score_with_missing >= 0.80 * score_without_missing


@pytest.mark.skipif(
parse_version(sklearn_version.base_version) < parse_version("1.4"),
reason="scikit-learn should be >= 1.4",
)
def test_missing_value_is_predictive():
"""Check that the forest learns when missing values are only present for
a predictive feature."""
rng = np.random.RandomState(0)
n_samples = 300

X_non_predictive = rng.standard_normal(size=(n_samples, 10))
y = rng.randint(0, high=2, size=n_samples)

# Create a predictive feature using `y` and with some noise
X_random_mask = rng.choice([False, True], size=n_samples, p=[0.95, 0.05])
y_mask = y.astype(bool)
y_mask[X_random_mask] = ~y_mask[X_random_mask]

predictive_feature = rng.standard_normal(size=n_samples)
predictive_feature[y_mask] = np.nan
assert np.isnan(predictive_feature).any()

X_predictive = X_non_predictive.copy()
X_predictive[:, 5] = predictive_feature

(
X_predictive_train,
X_predictive_test,
X_non_predictive_train,
X_non_predictive_test,
y_train,
y_test,
) = train_test_split(X_predictive, X_non_predictive, y, random_state=0)
forest_predictive = BalancedRandomForestClassifier(
sampling_strategy="all", replacement=True, bootstrap=False, random_state=0
).fit(X_predictive_train, y_train)
forest_non_predictive = BalancedRandomForestClassifier(
sampling_strategy="all", replacement=True, bootstrap=False, random_state=0
).fit(X_non_predictive_train, y_train)

predictive_test_score = forest_predictive.score(X_predictive_test, y_test)

assert predictive_test_score >= 0.75
assert predictive_test_score >= forest_non_predictive.score(
X_non_predictive_test, y_test
)
Loading