3636from ..base import _ParamsValidationMixin
3737from ..pipeline import make_pipeline
3838from ..under_sampling import RandomUnderSampler
39- from ..under_sampling .base import BaseUnderSampler
4039from ..utils import Substitution
4140from ..utils ._docstring import _n_jobs_docstring , _random_state_docstring
42- from ..utils ._param_validation import Interval , StrOptions
41+ from ..utils ._param_validation import Hidden , Interval , StrOptions
4342from ..utils ._validation import check_sampling_strategy
4443from ..utils .fixes import _fit_context
4544from ._common import _random_forest_classifier_parameter_constraints
@@ -100,7 +99,6 @@ def _local_parallel_build_trees(
10099
101100
102101@Substitution (
103- sampling_strategy = BaseUnderSampler ._sampling_strategy_docstring ,
104102 n_jobs = _n_jobs_docstring ,
105103 random_state = _random_state_docstring ,
106104)
@@ -193,11 +191,56 @@ class BalancedRandomForestClassifier(_ParamsValidationMixin, RandomForestClassif
193191 Whether to use out-of-bag samples to estimate
194192 the generalization accuracy.
195193
196- {sampling_strategy}
194+ sampling_strategy : float, str, dict, callable, default="auto"
195+ Sampling information to sample the data set.
196+
197+ - When ``float``, it corresponds to the desired ratio of the number of
198+ samples in the minority class over the number of samples in the
199+ majority class after resampling. Therefore, the ratio is expressed as
200+ :math:`\\ alpha_{{us}} = N_{{m}} / N_{{rM}}` where :math:`N_{{m}}` is the
201+ number of samples in the minority class and
202+ :math:`N_{{rM}}` is the number of samples in the majority class
203+ after resampling.
204+
205+ .. warning::
206+ ``float`` is only available for **binary** classification. An
207+ error is raised for multi-class classification.
208+
209+ - When ``str``, specify the class targeted by the resampling. The
210+ number of samples in the different classes will be equalized.
211+ Possible choices are:
212+
213+ ``'majority'``: resample only the majority class;
214+
215+ ``'not minority'``: resample all classes but the minority class;
216+
217+ ``'not majority'``: resample all classes but the majority class;
218+
219+ ``'all'``: resample all classes;
220+
221+ ``'auto'``: equivalent to ``'not minority'``.
222+
223+ - When ``dict``, the keys correspond to the targeted classes. The
224+ values correspond to the desired number of samples for each targeted
225+ class.
226+
227+ - When callable, function taking ``y`` and returns a ``dict``. The keys
228+ correspond to the targeted classes. The values correspond to the
229+ desired number of samples for each class.
230+
231+ .. versionchanged:: 0.11
232+ The default of `sampling_strategy` will change from `"auto"` to
233+ `"all"` in version 0.13. This forces to use a bootstrap of the
234+ minority class as proposed in [1]_.
197235
198236 replacement : bool, default=False
199237 Whether or not to sample randomly with replacement or not.
200238
239+ .. versionchanged:: 0.11
240+ The default of `replacement` will change from `False` to `True` in
241+ version 0.13. This forces to use a bootstrap of the
242+ minority class and draw with replacement as proposed in [1]_.
243+
201244 {n_jobs}
202245
203246 {random_state}
@@ -351,7 +394,8 @@ class labels (multi-output problem).
351394 >>> X, y = make_classification(n_samples=1000, n_classes=3,
352395 ... n_informative=4, weights=[0.2, 0.3, 0.5],
353396 ... random_state=0)
354- >>> clf = BalancedRandomForestClassifier(max_depth=2, random_state=0)
397+ >>> clf = BalancedRandomForestClassifier(
398+ ... sampling_strategy="all", replacement=True, max_depth=2, random_state=0)
355399 >>> clf.fit(X, y)
356400 BalancedRandomForestClassifier(...)
357401 >>> print(clf.feature_importances_)
@@ -376,8 +420,9 @@ class labels (multi-output problem).
376420 StrOptions ({"auto" , "majority" , "not minority" , "not majority" , "all" }),
377421 dict ,
378422 callable ,
423+ Hidden (StrOptions ({"warn" })),
379424 ],
380- "replacement" : ["boolean" ],
425+ "replacement" : ["boolean" , Hidden ( StrOptions ({ "warn" })) ],
381426 }
382427 )
383428
@@ -395,8 +440,8 @@ def __init__(
395440 min_impurity_decrease = 0.0 ,
396441 bootstrap = True ,
397442 oob_score = False ,
398- sampling_strategy = "auto " ,
399- replacement = False ,
443+ sampling_strategy = "warn " ,
444+ replacement = "warn" ,
400445 n_jobs = None ,
401446 random_state = None ,
402447 verbose = 0 ,
@@ -450,7 +495,7 @@ def _validate_estimator(self, default=DecisionTreeClassifier()):
450495
451496 self .base_sampler_ = RandomUnderSampler (
452497 sampling_strategy = self ._sampling_strategy ,
453- replacement = self .replacement ,
498+ replacement = self ._replacement ,
454499 )
455500
456501 def _make_sampler_estimator (self , random_state = None ):
@@ -496,6 +541,31 @@ def fit(self, X, y, sample_weight=None):
496541 The fitted instance.
497542 """
498543 self ._validate_params ()
544+ # TODO: remove in 0.13
545+ if self .sampling_strategy == "warn" :
546+ warn (
547+ "The default of `sampling_strategy` will change from `'auto'` to "
548+ "`'all'` in version 0.13. This change will follow the implementation "
549+ "proposed in the original paper. Set to `'all'` to silence this "
550+ "warning and adopt the future behaviour." ,
551+ FutureWarning ,
552+ )
553+ self ._sampling_strategy = "auto"
554+ else :
555+ self ._sampling_strategy = self .sampling_strategy
556+
557+ if self .replacement == "warn" :
558+ warn (
559+ "The default of `replacement` will change from `False` to "
560+ "`True` in version 0.13. This change will follow the implementation "
561+ "proposed in the original paper. Set to `True` to silence this "
562+ "warning and adopt the future behaviour." ,
563+ FutureWarning ,
564+ )
565+ self ._replacement = False
566+ else :
567+ self ._replacement = self .replacement
568+
499569 # Validate or convert input data
500570 if issparse (y ):
501571 raise ValueError ("sparse multilabel-indicator for y is not supported." )
@@ -533,7 +603,7 @@ def fit(self, X, y, sample_weight=None):
533603 if getattr (y , "dtype" , None ) != DOUBLE or not y .flags .contiguous :
534604 y_encoded = np .ascontiguousarray (y_encoded , dtype = DOUBLE )
535605
536- if isinstance (self .sampling_strategy , dict ):
606+ if isinstance (self ._sampling_strategy , dict ):
537607 self ._sampling_strategy = {
538608 np .where (self .classes_ [0 ] == key )[0 ][0 ]: value
539609 for key , value in check_sampling_strategy (
@@ -543,7 +613,7 @@ def fit(self, X, y, sample_weight=None):
543613 ).items ()
544614 }
545615 else :
546- self ._sampling_strategy = self .sampling_strategy
616+ self ._sampling_strategy = self ._sampling_strategy
547617
548618 if expanded_class_weight is not None :
549619 if sample_weight is not None :
0 commit comments