1313
1414import numpy as np
1515from scipy import sparse
16+ from sklearn .base import clone
1617from sklearn .preprocessing import OneHotEncoder , OrdinalEncoder
1718from sklearn .utils import _safe_indexing , check_array , check_random_state
1819from sklearn .utils .sparsefuncs_fast import (
@@ -393,6 +394,11 @@ class SMOTENC(SMOTE):
393394 - mask array of shape (n_features, ) and ``bool`` dtype for which
394395 ``True`` indicates the categorical features.
395396
397+ categorical_encoder : estimator, default=None
398+ One-hot encoder used to encode the categorical features. If `None`, a
399+ :class:`~sklearn.preprocessing.OneHotEncoder` is used with default parameters
400+ apart from `handle_unknown` which is set to 'ignore'.
401+
396402 {sampling_strategy}
397403
398404 {random_state}
@@ -431,6 +437,13 @@ class SMOTENC(SMOTE):
431437 ohe_ : :class:`~sklearn.preprocessing.OneHotEncoder`
432438 The one-hot encoder used to encode the categorical features.
433439
440+ .. deprecated:: 0.11
441+ `ohe_` is deprecated in 0.11 and will be removed in 0.13. Use
442+ `categorical_encoder_` instead.
443+
444+ categorical_encoder_ : estimator
445+ The encoder used to encode the categorical features.
446+
434447 categorical_features_ : ndarray of shape (n_cat_features,), dtype=np.int64
435448 Indices of the categorical features.
436449
@@ -514,12 +527,17 @@ class SMOTENC(SMOTE):
514527 _parameter_constraints : dict = {
515528 ** SMOTE ._parameter_constraints ,
516529 "categorical_features" : ["array-like" ],
530+ "categorical_encoder" : [
531+ HasMethods (["fit_transform" , "inverse_transform" ]),
532+ None ,
533+ ],
517534 }
518535
519536 def __init__ (
520537 self ,
521538 categorical_features ,
522539 * ,
540+ categorical_encoder = None ,
523541 sampling_strategy = "auto" ,
524542 random_state = None ,
525543 k_neighbors = 5 ,
@@ -532,6 +550,7 @@ def __init__(
532550 n_jobs = n_jobs ,
533551 )
534552 self .categorical_features = categorical_features
553+ self .categorical_encoder = categorical_encoder
535554
536555 def _check_X_y (self , X , y ):
537556 """Overwrite the checking to let pass some string for categorical
@@ -603,17 +622,19 @@ def _fit_resample(self, X, y):
603622 else :
604623 dtype_ohe = np .float64
605624
606- self .ohe_ = OneHotEncoder ( handle_unknown = "ignore" , dtype = dtype_ohe )
607- if hasattr ( self .ohe_ , "sparse_output" ):
608- # scikit-learn >= 1.2
609- self . ohe_ . set_params ( sparse_output = True )
625+ if self .categorical_encoder is None :
626+ self .categorical_encoder_ = OneHotEncoder (
627+ handle_unknown = "ignore" , dtype = dtype_ohe
628+ )
610629 else :
611- self .ohe_ . set_params ( sparse = True )
630+ self .categorical_encoder_ = clone ( self . categorical_encoder )
612631
613632 # the input of the OneHotEncoder needs to be dense
614- X_ohe = self .ohe_ .fit_transform (
633+ X_ohe = self .categorical_encoder_ .fit_transform (
615634 X_categorical .toarray () if sparse .issparse (X_categorical ) else X_categorical
616635 )
636+ if not sparse .issparse (X_ohe ):
637+ X_ohe = sparse .csr_matrix (X_ohe , dtype = dtype_ohe )
617638
618639 # we can replace the 1 entries of the categorical features with the
619640 # median of the standard deviation. It will ensure that whenever
@@ -636,7 +657,7 @@ def _fit_resample(self, X, y):
636657 # reverse the encoding of the categorical features
637658 X_res_cat = X_resampled [:, self .continuous_features_ .size :]
638659 X_res_cat .data = np .ones_like (X_res_cat .data )
639- X_res_cat_dec = self .ohe_ .inverse_transform (X_res_cat )
660+ X_res_cat_dec = self .categorical_encoder_ .inverse_transform (X_res_cat )
640661
641662 if sparse .issparse (X ):
642663 X_resampled = sparse .hstack (
@@ -695,7 +716,7 @@ def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps):
695716 all_neighbors = nn_data [nn_num [rows ]]
696717
697718 categories_size = [self .continuous_features_ .size ] + [
698- cat .size for cat in self .ohe_ .categories_
719+ cat .size for cat in self .categorical_encoder_ .categories_
699720 ]
700721
701722 for start_idx , end_idx in zip (
@@ -714,6 +735,16 @@ def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps):
714735
715736 return X_new
716737
738+ @property
739+ def ohe_ (self ):
740+ """One-hot encoder used to encode the categorical features."""
741+ warnings .warn (
742+ "'ohe_' attribute has been deprecated in 0.11 and will be removed "
743+ "in 0.13. Use 'categorical_encoder_' instead." ,
744+ FutureWarning ,
745+ )
746+ return self .categorical_encoder_
747+
717748
718749@Substitution (
719750 sampling_strategy = BaseOverSampler ._sampling_strategy_docstring ,
0 commit comments