@@ -29,7 +29,7 @@ def imbalanced_dataset():
2929
3030def test_balanced_random_forest_error_warning_warm_start (imbalanced_dataset ):
3131 brf = BalancedRandomForestClassifier (
32- n_estimators = 5 , sampling_strategy = "all" , replacement = True
32+ n_estimators = 5 , sampling_strategy = "all" , replacement = True , bootstrap = False
3333 )
3434 brf .fit (* imbalanced_dataset )
3535
@@ -51,6 +51,7 @@ def test_balanced_random_forest(imbalanced_dataset):
5151 random_state = 0 ,
5252 sampling_strategy = "all" ,
5353 replacement = True ,
54+ bootstrap = False ,
5455 )
5556 brf .fit (* imbalanced_dataset )
5657
@@ -68,6 +69,7 @@ def test_balanced_random_forest_attributes(imbalanced_dataset):
6869 random_state = 0 ,
6970 sampling_strategy = "all" ,
7071 replacement = True ,
72+ bootstrap = False ,
7173 )
7274 brf .fit (X , y )
7375
@@ -93,7 +95,11 @@ def test_balanced_random_forest_sample_weight(imbalanced_dataset):
9395 X , y = imbalanced_dataset
9496 sample_weight = rng .rand (y .shape [0 ])
9597 brf = BalancedRandomForestClassifier (
96- n_estimators = 5 , random_state = 0 , sampling_strategy = "all" , replacement = True
98+ n_estimators = 5 ,
99+ random_state = 0 ,
100+ sampling_strategy = "all" ,
101+ replacement = True ,
102+ bootstrap = False ,
97103 )
98104 brf .fit (X , y , sample_weight )
99105
@@ -111,6 +117,7 @@ def test_balanced_random_forest_oob(imbalanced_dataset):
111117 min_samples_leaf = 2 ,
112118 sampling_strategy = "all" ,
113119 replacement = True ,
120+ bootstrap = True ,
114121 )
115122
116123 est .fit (X_train , y_train )
@@ -132,7 +139,9 @@ def test_balanced_random_forest_oob(imbalanced_dataset):
132139
133140
134141def test_balanced_random_forest_grid_search (imbalanced_dataset ):
135- brf = BalancedRandomForestClassifier (sampling_strategy = "all" , replacement = True )
142+ brf = BalancedRandomForestClassifier (
143+ sampling_strategy = "all" , replacement = True , bootstrap = False
144+ )
136145 grid = GridSearchCV (brf , {"n_estimators" : (1 , 2 ), "max_depth" : (1 , 2 )}, cv = 3 )
137146 grid .fit (* imbalanced_dataset )
138147
@@ -150,6 +159,7 @@ def test_little_tree_with_small_max_samples():
150159 max_samples = None ,
151160 sampling_strategy = "all" ,
152161 replacement = True ,
162+ bootstrap = True ,
153163 )
154164
155165 # Second fit with max samples restricted to just 2
@@ -159,6 +169,7 @@ def test_little_tree_with_small_max_samples():
159169 max_samples = 2 ,
160170 sampling_strategy = "all" ,
161171 replacement = True ,
172+ bootstrap = True ,
162173 )
163174
164175 est1 .fit (X , y )
@@ -172,12 +183,14 @@ def test_little_tree_with_small_max_samples():
172183
173184
174185def test_balanced_random_forest_pruning (imbalanced_dataset ):
175- brf = BalancedRandomForestClassifier (sampling_strategy = "all" , replacement = True )
186+ brf = BalancedRandomForestClassifier (
187+ sampling_strategy = "all" , replacement = True , bootstrap = False
188+ )
176189 brf .fit (* imbalanced_dataset )
177190 n_nodes_no_pruning = brf .estimators_ [0 ].tree_ .node_count
178191
179192 brf_pruned = BalancedRandomForestClassifier (
180- ccp_alpha = 0.015 , sampling_strategy = "all" , replacement = True
193+ ccp_alpha = 0.015 , sampling_strategy = "all" , replacement = True , bootstrap = False
181194 )
182195 brf_pruned .fit (* imbalanced_dataset )
183196 n_nodes_pruning = brf_pruned .estimators_ [0 ].tree_ .node_count
@@ -200,6 +213,7 @@ def test_balanced_random_forest_oob_binomial(ratio):
200213 random_state = 42 ,
201214 sampling_strategy = "not minority" ,
202215 replacement = False ,
216+ bootstrap = True ,
203217 )
204218 erf .fit (X , y )
205219 assert np .abs (erf .oob_score_ - 0.5 ) < 0.1
@@ -209,7 +223,7 @@ def test_balanced_bagging_classifier_n_features():
209223 """Check that we raise a FutureWarning when accessing `n_features_`."""
210224 X , y = load_iris (return_X_y = True )
211225 estimator = BalancedRandomForestClassifier (
212- sampling_strategy = "all" , replacement = True
226+ sampling_strategy = "all" , replacement = True , bootstrap = False
213227 ).fit (X , y )
214228 with pytest .warns (FutureWarning , match = "`n_features_` was deprecated" ):
215229 estimator .n_features_
@@ -222,7 +236,7 @@ def test_balanced_random_forest_classifier_base_estimator():
222236 """Check that we raise a FutureWarning when accessing `base_estimator_`."""
223237 X , y = load_iris (return_X_y = True )
224238 estimator = BalancedRandomForestClassifier (
225- sampling_strategy = "all" , replacement = True
239+ sampling_strategy = "all" , replacement = True , bootstrap = False
226240 ).fit (X , y )
227241 with pytest .warns (FutureWarning , match = "`base_estimator_` was deprecated" ):
228242 estimator .base_estimator_
@@ -233,9 +247,14 @@ def test_balanced_random_forest_change_behaviour(imbalanced_dataset):
233247 """Check that we raise a change of behaviour for the parameters `sampling_strategy`
234248 and `replacement`.
235249 """
236- estimator = BalancedRandomForestClassifier (sampling_strategy = "all" )
250+ estimator = BalancedRandomForestClassifier (sampling_strategy = "all" , bootstrap = False )
237251 with pytest .warns (FutureWarning , match = "The default of `replacement`" ):
238252 estimator .fit (* imbalanced_dataset )
239- estimator = BalancedRandomForestClassifier (replacement = True )
253+ estimator = BalancedRandomForestClassifier (replacement = True , bootstrap = False )
240254 with pytest .warns (FutureWarning , match = "The default of `sampling_strategy`" ):
241255 estimator .fit (* imbalanced_dataset )
256+ estimator = BalancedRandomForestClassifier (
257+ sampling_strategy = "all" , replacement = True
258+ )
259+ with pytest .warns (FutureWarning , match = "The default of `bootstrap`" ):
260+ estimator .fit (* imbalanced_dataset )
0 commit comments