4848def test_easy_ensemble_classifier (n_estimators , base_estimator ):
4949 # Check classification for various parameter settings.
5050 X , y = make_imbalance (
51- iris .data , iris .target , sampling_strategy = {0 : 20 , 1 : 25 , 2 : 50 }, random_state = 0 ,
51+ iris .data ,
52+ iris .target ,
53+ sampling_strategy = {0 : 20 , 1 : 25 , 2 : 50 },
54+ random_state = 0 ,
5255 )
5356 X_train , X_test , y_train , y_test = train_test_split (X , y , random_state = 0 )
5457
@@ -72,7 +75,10 @@ def test_easy_ensemble_classifier(n_estimators, base_estimator):
7275def test_base_estimator ():
7376 # Check base_estimator and its default values.
7477 X , y = make_imbalance (
75- iris .data , iris .target , sampling_strategy = {0 : 20 , 1 : 25 , 2 : 50 }, random_state = 0 ,
78+ iris .data ,
79+ iris .target ,
80+ sampling_strategy = {0 : 20 , 1 : 25 , 2 : 50 },
81+ random_state = 0 ,
7682 )
7783 X_train , X_test , y_train , y_test = train_test_split (X , y , random_state = 0 )
7884
@@ -91,7 +97,10 @@ def test_base_estimator():
9197
9298def test_bagging_with_pipeline ():
9399 X , y = make_imbalance (
94- iris .data , iris .target , sampling_strategy = {0 : 20 , 1 : 25 , 2 : 50 }, random_state = 0 ,
100+ iris .data ,
101+ iris .target ,
102+ sampling_strategy = {0 : 20 , 1 : 25 , 2 : 50 },
103+ random_state = 0 ,
95104 )
96105 estimator = EasyEnsembleClassifier (
97106 n_estimators = 2 ,
@@ -109,7 +118,9 @@ def test_warm_start(random_state=42):
109118 for n_estimators in [5 , 10 ]:
110119 if clf_ws is None :
111120 clf_ws = EasyEnsembleClassifier (
112- n_estimators = n_estimators , random_state = random_state , warm_start = True ,
121+ n_estimators = n_estimators ,
122+ random_state = random_state ,
123+ warm_start = True ,
113124 )
114125 else :
115126 clf_ws .set_params (n_estimators = n_estimators )
@@ -182,7 +193,10 @@ def test_warm_start_equivalence():
182193)
183194def test_easy_ensemble_classifier_error (n_estimators , msg_error ):
184195 X , y = make_imbalance (
185- iris .data , iris .target , sampling_strategy = {0 : 20 , 1 : 25 , 2 : 50 }, random_state = 0 ,
196+ iris .data ,
197+ iris .target ,
198+ sampling_strategy = {0 : 20 , 1 : 25 , 2 : 50 },
199+ random_state = 0 ,
186200 )
187201 with pytest .raises (ValueError , match = msg_error ):
188202 eec = EasyEnsembleClassifier (n_estimators = n_estimators )
@@ -191,7 +205,10 @@ def test_easy_ensemble_classifier_error(n_estimators, msg_error):
191205
192206def test_easy_ensemble_classifier_single_estimator ():
193207 X , y = make_imbalance (
194- iris .data , iris .target , sampling_strategy = {0 : 20 , 1 : 25 , 2 : 50 }, random_state = 0 ,
208+ iris .data ,
209+ iris .target ,
210+ sampling_strategy = {0 : 20 , 1 : 25 , 2 : 50 },
211+ random_state = 0 ,
195212 )
196213 X_train , X_test , y_train , y_test = train_test_split (X , y , random_state = 0 )
197214
@@ -205,14 +222,19 @@ def test_easy_ensemble_classifier_single_estimator():
205222
206223def test_easy_ensemble_classifier_grid_search ():
207224 X , y = make_imbalance (
208- iris .data , iris .target , sampling_strategy = {0 : 20 , 1 : 25 , 2 : 50 }, random_state = 0 ,
225+ iris .data ,
226+ iris .target ,
227+ sampling_strategy = {0 : 20 , 1 : 25 , 2 : 50 },
228+ random_state = 0 ,
209229 )
210230
211231 parameters = {
212232 "n_estimators" : [1 , 2 ],
213233 "base_estimator__n_estimators" : [3 , 4 ],
214234 }
215235 grid_search = GridSearchCV (
216- EasyEnsembleClassifier (base_estimator = AdaBoostClassifier ()), parameters , cv = 5 ,
236+ EasyEnsembleClassifier (base_estimator = AdaBoostClassifier ()),
237+ parameters ,
238+ cv = 5 ,
217239 )
218240 grid_search .fit (X , y )
0 commit comments