1616from sklearn .base import clone
1717from sklearn .exceptions import DataConversionWarning
1818from sklearn .preprocessing import OneHotEncoder , OrdinalEncoder
19- from sklearn .utils import _safe_indexing , check_array , check_random_state
19+ from sklearn .utils import (
20+ _get_column_indices ,
21+ _safe_indexing ,
22+ check_array ,
23+ check_random_state ,
24+ )
2025from sklearn .utils .sparsefuncs_fast import (
2126 csc_mean_variance_axis0 ,
2227 csr_mean_variance_axis0 ,
@@ -390,10 +395,14 @@ class SMOTENC(SMOTE):
390395
391396 Parameters
392397 ----------
393- categorical_features : array-like of shape (n_cat_features,) or (n_features,)
398+ categorical_features : array-like of shape (n_cat_features,) or (n_features,), \
399+ dtype={{bool, int, str}}
394400 Specified which features are categorical. Can either be:
395401
396- - array of indices specifying the categorical features;
402+ - array of `int` corresponding to the indices specifying the categorical
403+ features;
404+ - array of `str` corresponding to the feature names. `X` should be a pandas
405+ :class:`pandas.DataFrame` in this case.
397406 - mask array of shape (n_features, ) and ``bool`` dtype for which
398407 ``True`` indicates the categorical features.
399408
@@ -565,24 +574,16 @@ def _check_X_y(self, X, y):
565574 self ._check_feature_names (X , reset = True )
566575 return X , y , binarize_y
567576
568- def _validate_estimator (self ):
569- super ()._validate_estimator ()
570- categorical_features = np .asarray (self .categorical_features )
571- if categorical_features .dtype .name == "bool" :
572- self .categorical_features_ = np .flatnonzero (categorical_features )
573- else :
574- if any (
575- [cat not in np .arange (self .n_features_ ) for cat in categorical_features ]
576- ):
577- raise ValueError (
578- f"Some of the categorical indices are out of range. Indices"
579- f" should be between 0 and { self .n_features_ - 1 } "
580- )
581- self .categorical_features_ = categorical_features
577+ def _validate_column_types (self , X ):
578+ self .categorical_features_ = np .array (
579+ _get_column_indices (X , self .categorical_features )
580+ )
582581 self .continuous_features_ = np .setdiff1d (
583582 np .arange (self .n_features_ ), self .categorical_features_
584583 )
585584
585+ def _validate_estimator (self ):
586+ super ()._validate_estimator ()
586587 if self .categorical_features_ .size == self .n_features_in_ :
587588 raise ValueError (
588589 "SMOTE-NC is not designed to work only with categorical "
@@ -600,6 +601,7 @@ def _fit_resample(self, X, y):
600601 )
601602
602603 self .n_features_ = _num_features (X )
604+ self ._validate_column_types (X )
603605 self ._validate_estimator ()
604606
605607 # compute the median of the standard deviation of the minority class
0 commit comments