@@ -43,10 +43,7 @@ def test_ros_init():
4343 assert ros .random_state == RND_SEED
4444
4545
46- @pytest .mark .parametrize (
47- "params" ,
48- [{"smoothed_bootstrap" : False }, {"smoothed_bootstrap" : True , "shrinkage" : 0 }]
49- )
46+ @pytest .mark .parametrize ("params" , [{"shrinkage" : None }, {"shrinkage" : 0 }])
5047@pytest .mark .parametrize ("X_type" , ["array" , "dataframe" ])
5148def test_ros_fit_resample (X_type , data , params ):
5249 X , Y = data
@@ -80,16 +77,13 @@ def test_ros_fit_resample(X_type, data, params):
8077 assert_allclose (X_resampled , X_gt )
8178 assert_array_equal (y_resampled , y_gt )
8279
83- if not params ["smoothed_bootstrap" ] :
80+ if params ["shrinkage" ] is None :
8481 assert ros .shrinkage_ is None
8582 else :
8683 assert ros .shrinkage_ == {0 : 0 }
8784
8885
89- @pytest .mark .parametrize (
90- "params" ,
91- [{"smoothed_bootstrap" : False }, {"smoothed_bootstrap" : True , "shrinkage" : 0 }]
92- )
86+ @pytest .mark .parametrize ("params" , [{"shrinkage" : None }, {"shrinkage" : 0 }])
9387def test_ros_fit_resample_half (data , params ):
9488 X , Y = data
9589 sampling_strategy = {0 : 3 , 1 : 7 }
@@ -115,16 +109,13 @@ def test_ros_fit_resample_half(data, params):
115109 assert_allclose (X_resampled , X_gt )
116110 assert_array_equal (y_resampled , y_gt )
117111
118- if not params ["smoothed_bootstrap" ] :
112+ if params ["shrinkage" ] is None :
119113 assert ros .shrinkage_ is None
120114 else :
121115 assert ros .shrinkage_ == {0 : 0 , 1 : 0 }
122116
123117
124- @pytest .mark .parametrize (
125- "params" ,
126- [{"smoothed_bootstrap" : False }, {"smoothed_bootstrap" : True , "shrinkage" : 0 }]
127- )
118+ @pytest .mark .parametrize ("params" , [{"shrinkage" : None }, {"shrinkage" : 0 }])
128119def test_multiclass_fit_resample (data , params ):
129120 # check the random over-sampling with a multiclass problem
130121 X , Y = data
@@ -138,7 +129,7 @@ def test_multiclass_fit_resample(data, params):
138129 assert count_y_res [1 ] == 5
139130 assert count_y_res [2 ] == 5
140131
141- if not params ["smoothed_bootstrap" ] :
132+ if params ["shrinkage" ] is None :
142133 assert ros .shrinkage_ is None
143134 else :
144135 assert ros .shrinkage_ == {0 : 0 , 2 : 0 }
@@ -188,11 +179,8 @@ def test_random_over_sampling_heterogeneous_data_smoothed_bootstrap():
188179 [["xxx" , 1 , 1.0 ], ["yyy" , 2 , 2.0 ], ["zzz" , 3 , 3.0 ]], dtype = object
189180 )
190181 y = np .array ([0 , 0 , 1 ])
191- ros = RandomOverSampler (
192- smoothed_bootstrap = True ,
193- random_state = RND_SEED ,
194- )
195- err_msg = "When smoothed_bootstrap=True, X needs to contain only numerical"
182+ ros = RandomOverSampler (shrinkage = 1 , random_state = RND_SEED )
183+ err_msg = "When shrinkage is not None, X needs to contain only numerical"
196184 with pytest .raises (ValueError , match = err_msg ):
197185 ros .fit_resample (X_hetero , y )
198186
@@ -201,7 +189,7 @@ def test_random_over_sampling_heterogeneous_data_smoothed_bootstrap():
201189def test_random_over_sampler_smoothed_bootstrap (X_type , data ):
202190 # check that smoothed bootstrap is working for numerical array
203191 X , y = data
204- sampler = RandomOverSampler (smoothed_bootstrap = True , shrinkage = 1 )
192+ sampler = RandomOverSampler (shrinkage = 1 )
205193 X = _convert_container (X , X_type )
206194 X_res , y_res = sampler .fit_resample (X , y )
207195
@@ -217,10 +205,8 @@ def test_random_over_sampler_equivalence_shrinkage(data):
217205 # bootstrap
218206 X , y = data
219207
220- ros_not_shrink = RandomOverSampler (
221- smoothed_bootstrap = True , shrinkage = 0 , random_state = 0
222- )
223- ros_hard_bootstrap = RandomOverSampler (smoothed_bootstrap = False , random_state = 0 )
208+ ros_not_shrink = RandomOverSampler (shrinkage = 0 , random_state = 0 )
209+ ros_hard_bootstrap = RandomOverSampler (shrinkage = None , random_state = 0 )
224210
225211 X_res_not_shrink , y_res_not_shrink = ros_not_shrink .fit_resample (X , y )
226212 X_res , y_res = ros_hard_bootstrap .fit_resample (X , y )
@@ -240,7 +226,7 @@ def test_random_over_sampler_shrinkage_behaviour(data):
240226 # should also be larger.
241227 X , y = data
242228
243- ros = RandomOverSampler (smoothed_bootstrap = True , shrinkage = 1 , random_state = 0 )
229+ ros = RandomOverSampler (shrinkage = 1 , random_state = 0 )
244230 X_res_shink_1 , y_res_shrink_1 = ros .fit_resample (X , y )
245231
246232 ros .set_params (shrinkage = 5 )
@@ -252,12 +238,18 @@ def test_random_over_sampler_shrinkage_behaviour(data):
252238 assert disperstion_shrink_1 < disperstion_shrink_5
253239
254240
255- def test_random_over_sampler_shrinkage_error (data ):
256- # check that we raise proper error when shrinkage do not contain the
257- # necessary information
241+ @pytest .mark .parametrize (
242+ "shrinkage, err_msg" ,
243+ [
244+ ({}, "`shrinkage` should contain a shrinkage factor for each class" ),
245+ (- 1 , "The shrinkage factor needs to be >= 0" ),
246+ ({0 : - 1 }, "The shrinkage factor needs to be >= 0" ),
247+ ([1 , ], "`shrinkage` should either be a positive floating number or" )
248+ ]
249+ )
250+ def test_random_over_sampler_shrinkage_error (data , shrinkage , err_msg ):
251+ # check the validation of the shrinkage parameter
258252 X , y = data
259- shrinkage = {}
260- ros = RandomOverSampler (smoothed_bootstrap = True , shrinkage = shrinkage )
261- err_msg = "`shrinkage` should contain a shrinkage factor for each class"
253+ ros = RandomOverSampler (shrinkage = shrinkage )
262254 with pytest .raises (ValueError , match = err_msg ):
263255 ros .fit_resample (X , y )
0 commit comments