@@ -60,6 +60,7 @@ def _local_parallel_build_trees(
6060 class_weight = None ,
6161 n_samples_bootstrap = None ,
6262 forest = None ,
63+ missing_values_in_feature_mask = None ,
6364):
6465 # resample before to fit the tree
6566 X_resampled , y_resampled = sampler .fit_resample (X , y )
@@ -68,33 +69,34 @@ def _local_parallel_build_trees(
6869 if _get_n_samples_bootstrap is not None :
6970 n_samples_bootstrap = min (n_samples_bootstrap , X_resampled .shape [0 ])
7071
71- if sklearn_version >= parse_version ("1.1" ):
72- tree = _parallel_build_trees (
73- tree ,
74- bootstrap ,
75- X_resampled ,
76- y_resampled ,
77- sample_weight ,
78- tree_idx ,
79- n_trees ,
80- verbose = verbose ,
81- class_weight = class_weight ,
82- n_samples_bootstrap = n_samples_bootstrap ,
83- )
72+ params_parallel_build_trees = {
73+ "tree" : tree ,
74+ "X" : X_resampled ,
75+ "y" : y_resampled ,
76+ "sample_weight" : sample_weight ,
77+ "tree_idx" : tree_idx ,
78+ "n_trees" : n_trees ,
79+ "verbose" : verbose ,
80+ "class_weight" : class_weight ,
81+ "n_samples_bootstrap" : n_samples_bootstrap ,
82+ }
83+
84+ if parse_version (sklearn_version .base_version ) >= parse_version ("1.4" ):
85+ # TODO: remove when the minimum supported version of scikit-learn will be 1.4
86+ # support for missing values
87+ params_parallel_build_trees [
88+ "missing_values_in_feature_mask"
89+ ] = missing_values_in_feature_mask
90+
91+ # TODO: remove when the minimum supported version of scikit-learn will be 1.1
92+ # change of signature in scikit-learn 1.1
93+ if parse_version (sklearn_version .base_version ) >= parse_version ("1.1" ):
94+ params_parallel_build_trees ["bootstrap" ] = bootstrap
8495 else :
85- # TODO: remove when the minimum version of scikit-learn supported is 1.1
86- tree = _parallel_build_trees (
87- tree ,
88- forest ,
89- X_resampled ,
90- y_resampled ,
91- sample_weight ,
92- tree_idx ,
93- n_trees ,
94- verbose = verbose ,
95- class_weight = class_weight ,
96- n_samples_bootstrap = n_samples_bootstrap ,
97- )
96+ params_parallel_build_trees ["forest" ] = forest
97+
98+ tree = _parallel_build_trees (** params_parallel_build_trees )
99+
98100 return sampler , tree
99101
100102
@@ -305,6 +307,25 @@ class BalancedRandomForestClassifier(_ParamsValidationMixin, RandomForestClassif
305307 .. versionadded:: 0.6
306308 Added in `scikit-learn` in 0.22
307309
310+ monotonic_cst : array-like of int of shape (n_features), default=None
311+ Indicates the monotonicity constraint to enforce on each feature.
312+ - 1: monotonic increase
313+ - 0: no constraint
314+ - -1: monotonic decrease
315+
316+ If monotonic_cst is None, no constraints are applied.
317+
318+ Monotonicity constraints are not supported for:
319+ - multiclass classifications (i.e. when `n_classes > 2`),
320+ - multioutput classifications (i.e. when `n_outputs_ > 1`),
321+ - classifications trained on data with missing values.
322+
323+ The constraints hold over the probability of the positive class.
324+
325+ .. versionadded:: 0.12
326+ Only supported when scikit-learn >= 1.4 is installed. Otherwise, a
327+ `ValueError` is raised.
328+
308329 Attributes
309330 ----------
310331 estimator_ : :class:`~sklearn.tree.DecisionTreeClassifier` instance
@@ -415,7 +436,7 @@ class labels (multi-output problem).
415436 """
416437
417438 # make a deepcopy to not modify the original dictionary
418- if sklearn_version >= parse_version ("1.3 " ):
439+ if sklearn_version >= parse_version ("1.4 " ):
419440 _parameter_constraints = deepcopy (RandomForestClassifier ._parameter_constraints )
420441 else :
421442 _parameter_constraints = deepcopy (
@@ -459,27 +480,42 @@ def __init__(
459480 class_weight = None ,
460481 ccp_alpha = 0.0 ,
461482 max_samples = None ,
483+ monotonic_cst = None ,
462484 ):
463- super ().__init__ (
464- criterion = criterion ,
465- max_depth = max_depth ,
466- n_estimators = n_estimators ,
467- bootstrap = bootstrap ,
468- oob_score = oob_score ,
469- n_jobs = n_jobs ,
470- random_state = random_state ,
471- verbose = verbose ,
472- warm_start = warm_start ,
473- class_weight = class_weight ,
474- min_samples_split = min_samples_split ,
475- min_samples_leaf = min_samples_leaf ,
476- min_weight_fraction_leaf = min_weight_fraction_leaf ,
477- max_features = max_features ,
478- max_leaf_nodes = max_leaf_nodes ,
479- min_impurity_decrease = min_impurity_decrease ,
480- ccp_alpha = ccp_alpha ,
481- max_samples = max_samples ,
482- )
485+ params_random_forest = {
486+ "criterion" : criterion ,
487+ "max_depth" : max_depth ,
488+ "n_estimators" : n_estimators ,
489+ "bootstrap" : bootstrap ,
490+ "oob_score" : oob_score ,
491+ "n_jobs" : n_jobs ,
492+ "random_state" : random_state ,
493+ "verbose" : verbose ,
494+ "warm_start" : warm_start ,
495+ "class_weight" : class_weight ,
496+ "min_samples_split" : min_samples_split ,
497+ "min_samples_leaf" : min_samples_leaf ,
498+ "min_weight_fraction_leaf" : min_weight_fraction_leaf ,
499+ "max_features" : max_features ,
500+ "max_leaf_nodes" : max_leaf_nodes ,
501+ "min_impurity_decrease" : min_impurity_decrease ,
502+ "ccp_alpha" : ccp_alpha ,
503+ "max_samples" : max_samples ,
504+ }
505+ # TODO: remove when the minimum supported version of scikit-learn will be 1.4
506+ if parse_version (sklearn_version .base_version ) >= parse_version ("1.4" ):
507+ # use scikit-learn support for monotonic constraints
508+ params_random_forest ["monotonic_cst" ] = monotonic_cst
509+ else :
510+ if monotonic_cst is not None :
511+ raise ValueError (
512+ "Monotonic constraints are not supported for scikit-learn "
513+ "version < 1.4."
514+ )
515+ # create an attribute for compatibility with other scikit-learn tools such
516+ # as HTML representation.
517+ self .monotonic_cst = monotonic_cst
518+ super ().__init__ (** params_random_forest )
483519
484520 self .sampling_strategy = sampling_strategy
485521 self .replacement = replacement
@@ -591,11 +627,41 @@ def fit(self, X, y, sample_weight=None):
591627 # Validate or convert input data
592628 if issparse (y ):
593629 raise ValueError ("sparse multilabel-indicator for y is not supported." )
630+
631+ # TODO: remove when the minimum supported version of scipy will be 1.4
632+ # Support for missing values
633+ if parse_version (sklearn_version .base_version ) >= parse_version ("1.4" ):
634+ force_all_finite = False
635+ else :
636+ force_all_finite = True
637+
594638 X , y = self ._validate_data (
595- X , y , multi_output = True , accept_sparse = "csc" , dtype = DTYPE
639+ X ,
640+ y ,
641+ multi_output = True ,
642+ accept_sparse = "csc" ,
643+ dtype = DTYPE ,
644+ force_all_finite = force_all_finite ,
596645 )
646+
647+ # TODO: remove when the minimum supported version of scikit-learn will be 1.4
648+ if parse_version (sklearn_version .base_version ) >= parse_version ("1.4" ):
649+ # _compute_missing_values_in_feature_mask checks if X has missing values and
650+ # will raise an error if the underlying tree base estimator can't handle
651+ # missing values. Only the criterion is required to determine if the tree
652+ # supports missing values.
653+ estimator = type (self .estimator )(criterion = self .criterion )
654+ missing_values_in_feature_mask = (
655+ estimator ._compute_missing_values_in_feature_mask (
656+ X , estimator_name = self .__class__ .__name__
657+ )
658+ )
659+ else :
660+ missing_values_in_feature_mask = None
661+
597662 if sample_weight is not None :
598663 sample_weight = _check_sample_weight (sample_weight , X )
664+
599665 self ._n_features = X .shape [1 ]
600666
601667 if issparse (X ):
@@ -713,6 +779,7 @@ def fit(self, X, y, sample_weight=None):
713779 class_weight = self .class_weight ,
714780 n_samples_bootstrap = n_samples_bootstrap ,
715781 forest = self ,
782+ missing_values_in_feature_mask = missing_values_in_feature_mask ,
716783 )
717784 for i , (s , t ) in enumerate (zip (samplers , trees ))
718785 )
0 commit comments