55# License: MIT
66
77import copy
8- import inspect
98import numbers
109import warnings
1110
1514from sklearn .ensemble import BaggingClassifier
1615from sklearn .ensemble ._bagging import _parallel_decision_function
1716from sklearn .ensemble ._base import _partition_estimators
17+ from sklearn .exceptions import NotFittedError
1818from sklearn .tree import DecisionTreeClassifier
1919from sklearn .utils import parse_version
2020from sklearn .utils .validation import check_is_fitted
@@ -121,30 +121,13 @@ class BalancedBaggingClassifier(_ParamsValidationMixin, BaggingClassifier):
121121
122122 .. versionadded:: 0.8
123123
124- base_estimator : estimator object, default=None
125- The base estimator to fit on random subsets of the dataset.
126- If None, then the base estimator is a decision tree.
127-
128- .. deprecated:: 0.10
129- `base_estimator` was renamed to `estimator` in version 0.10 and
130- will be removed in 0.12.
131-
132124 Attributes
133125 ----------
134126 estimator_ : estimator
135127 The base estimator from which the ensemble is grown.
136128
137129 .. versionadded:: 0.10
138130
139- base_estimator_ : estimator
140- The base estimator from which the ensemble is grown.
141-
142- .. deprecated:: 1.2
143- `base_estimator_` is deprecated in `scikit-learn` 1.2 and will be
144- removed in 1.4. Use `estimator_` instead. When the minimum version
145- of `scikit-learn` supported by `imbalanced-learn` will reach 1.4,
146- this attribute will be removed.
147-
148131 n_features_ : int
149132 The number of features when `fit` is performed.
150133
@@ -266,7 +249,7 @@ class BalancedBaggingClassifier(_ParamsValidationMixin, BaggingClassifier):
266249 """
267250
268251 # make a deepcopy to not modify the original dictionary
269- if sklearn_version >= parse_version ("1.3 " ):
252+ if sklearn_version >= parse_version ("1.4 " ):
270253 _parameter_constraints = copy .deepcopy (BaggingClassifier ._parameter_constraints )
271254 else :
272255 _parameter_constraints = copy .deepcopy (_bagging_parameter_constraints )
@@ -283,6 +266,9 @@ class BalancedBaggingClassifier(_ParamsValidationMixin, BaggingClassifier):
283266 "sampler" : [HasMethods (["fit_resample" ]), None ],
284267 }
285268 )
269+ # TODO: remove when minimum supported version of scikit-learn is 1.4
270+ if "base_estimator" in _parameter_constraints :
271+ del _parameter_constraints ["base_estimator" ]
286272
287273 def __init__ (
288274 self ,
@@ -301,18 +287,8 @@ def __init__(
301287 random_state = None ,
302288 verbose = 0 ,
303289 sampler = None ,
304- base_estimator = "deprecated" ,
305290 ):
306- # TODO: remove when supporting scikit-learn>=1.2
307- bagging_classifier_signature = inspect .signature (super ().__init__ )
308- estimator_params = {"base_estimator" : base_estimator }
309- if "estimator" in bagging_classifier_signature .parameters :
310- estimator_params ["estimator" ] = estimator
311- else :
312- self .estimator = estimator
313-
314291 super ().__init__ (
315- ** estimator_params ,
316292 n_estimators = n_estimators ,
317293 max_samples = max_samples ,
318294 max_features = max_features ,
@@ -324,6 +300,7 @@ def __init__(
324300 random_state = random_state ,
325301 verbose = verbose ,
326302 )
303+ self .estimator = estimator
327304 self .sampling_strategy = sampling_strategy
328305 self .replacement = replacement
329306 self .sampler = sampler
@@ -349,42 +326,17 @@ def _validate_y(self, y):
349326 def _validate_estimator (self , default = DecisionTreeClassifier ()):
350327 """Check the estimator and the n_estimator attribute, set the
351328 `estimator_` attribute."""
352- if self .estimator is not None and (
353- self .base_estimator not in [None , "deprecated" ]
354- ):
355- raise ValueError (
356- "Both `estimator` and `base_estimator` were set. Only set `estimator`."
357- )
358-
359329 if self .estimator is not None :
360- base_estimator = clone (self .estimator )
361- elif self .base_estimator not in [None , "deprecated" ]:
362- warnings .warn (
363- "`base_estimator` was renamed to `estimator` in version 0.10 and "
364- "will be removed in 0.12." ,
365- FutureWarning ,
366- )
367- base_estimator = clone (self .base_estimator )
330+ estimator = clone (self .estimator )
368331 else :
369- base_estimator = clone (default )
332+ estimator = clone (default )
370333
371334 if self .sampler_ ._sampling_type != "bypass" :
372335 self .sampler_ .set_params (sampling_strategy = self ._sampling_strategy )
373336
374- self ._estimator = Pipeline (
375- [("sampler" , self .sampler_ ), ("classifier" , base_estimator )]
337+ self .estimator_ = Pipeline (
338+ [("sampler" , self .sampler_ ), ("classifier" , estimator )]
376339 )
377- try :
378- # scikit-learn < 1.2
379- self .base_estimator_ = self ._estimator
380- except AttributeError :
381- pass
382-
383- # TODO: remove when supporting scikit-learn>=1.4
384- @property
385- def estimator_ (self ):
386- """Estimator used to grow the ensemble."""
387- return self ._estimator
388340
389341 # TODO: remove when supporting scikit-learn>=1.2
390342 @property
@@ -483,6 +435,22 @@ def decision_function(self, X):
483435
484436 return decisions
485437
438+ @property
439+ def base_estimator_ (self ):
440+ """Attribute for older sklearn version compatibility."""
441+ error = AttributeError (
442+ f"{ self .__class__ .__name__ } object has no attribute 'base_estimator_'."
443+ )
444+ if sklearn_version < parse_version ("1.2" ):
445+ # The base class require to have the attribute defined. For scikit-learn
446+ # > 1.2, we are going to raise an error.
447+ try :
448+ check_is_fitted (self )
449+ return self .estimator_
450+ except NotFittedError :
451+ raise error
452+ raise error
453+
486454 def _more_tags (self ):
487455 tags = super ()._more_tags ()
488456 tags_key = "_xfail_checks"
0 commit comments