@@ -986,6 +986,15 @@ def _fit_resample(self, X, y):
986986 self .n_features_ = X .shape [1 ]
987987 self ._validate_estimator ()
988988
989+ X_encoded = self ._encode (X , y )
990+
991+ X_resampled , y_resampled = super (SMOTENC , self )._fit_resample (
992+ X_encoded , y )
993+ X_resampled = self ._decode (X , X_resampled )
994+
995+ return X_resampled , y_resampled
996+
997+ def _encode (self , X , y ):
989998 # compute the median of the standard deviation of the minority class
990999 target_stats = Counter (y )
9911000 class_minority = min (target_stats , key = target_stats .get )
@@ -1015,18 +1024,15 @@ def _fit_resample(self, X, y):
10151024 X_ohe = self .ohe_ .fit_transform (
10161025 X_categorical .toarray () if sparse .issparse (X_categorical )
10171026 else X_categorical )
1018-
10191027 # we can replace the 1 entries of the categorical features with the
10201028 # median of the standard deviation. It will ensure that whenever
10211029 # distance is computed between 2 samples, the difference will be equal
10221030 # to the median of the standard deviation as in the original paper.
10231031 X_ohe .data = (np .ones_like (X_ohe .data , dtype = X_ohe .dtype ) *
10241032 self .median_std_ / 2 )
1025- X_encoded = sparse .hstack ((X_continuous , X_ohe ), format = 'csr' )
1026-
1027- X_resampled , y_resampled = super (SMOTENC , self )._fit_resample (
1028- X_encoded , y )
1033+ return sparse .hstack ((X_continuous , X_ohe ), format = 'csr' )
10291034
1035+ def _decode (self , X , X_resampled ):
10301036 # reverse the encoding of the categorical features
10311037 X_res_cat = X_resampled [:, self .continuous_features_ .size :]
10321038 X_res_cat .data = np .ones_like (X_res_cat .data )
@@ -1055,8 +1061,7 @@ def _fit_resample(self, X, y):
10551061 X_resampled .indices = col_indices
10561062 else :
10571063 X_resampled = X_resampled [:, indices_reordered ]
1058-
1059- return X_resampled , y_resampled
1064+ return X_resampled
10601065
10611066 def _generate_sample (self , X , nn_data , nn_num , row , col , step ):
10621067 """Generate a synthetic sample with an additional steps for the
@@ -1095,7 +1100,7 @@ def _generate_sample(self, X, nn_data, nn_num, row, col, step):
10951100# @Substitution(
10961101# sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
10971102# random_state=_random_state_docstring)
1098- class SMOTEN (SMOTE ):
1103+ class SMOTEN (SMOTENC ):
10991104 """Synthetic Minority Over-sampling Technique for Nominal data
11001105 (SMOTE-N).
11011106
@@ -1212,73 +1217,29 @@ class SMOTEN(SMOTE):
12121217
12131218 def __init__ (self , sampling_strategy = 'auto' , kind = 'regular' ,
12141219 random_state = None , k_neighbors = 5 , n_jobs = 1 ):
1215- super (SMOTEN , self ).__init__ (sampling_strategy = sampling_strategy ,
1220+ super (SMOTEN , self ).__init__ (categorical_features = [],
1221+ sampling_strategy = sampling_strategy ,
12161222 random_state = random_state ,
12171223 k_neighbors = k_neighbors ,
1218- ratio = None ,
1219- n_jobs = n_jobs ,
1220- kind = kind )
1221-
1222- @staticmethod
1223- def _check_X_y (X , y ):
1224- """Overwrite the checking to let pass some string for categorical
1225- features.
1226- """
1227- y , binarize_y = check_target_type (y , indicate_one_vs_all = True )
1228- X , y = check_X_y (X , y , accept_sparse = ['csr' , 'csc' ], dtype = None )
1229- return X , y , binarize_y
1224+ n_jobs = n_jobs )
12301225
12311226 def _validate_estimator (self ):
1227+ self .categorical_features = np .asarray (range (self .n_features_ ))
1228+ self .continuous_features_ = np .asarray ([])
12321229 super (SMOTEN , self )._validate_estimator ()
12331230
1234- def _fit_resample (self , X , y ):
1235- self .n_features_ = X .shape [1 ]
1236- self ._validate_estimator ()
1231+ def _decode (self , X , X_resampled ):
1232+ X_unstacked = self .ohe_ .inverse_transform (X_resampled )
1233+ if sparse .issparse (X ):
1234+ X_unstacked = sparse .csr_matrix (X_unstacked )
1235+ return X_unstacked
12371236
1237+ def _encode (self , X , y ):
12381238 self .ohe_ = OneHotEncoder (sparse = True , handle_unknown = 'ignore' ,
12391239 dtype = np .float64 )
12401240 # the input of the OneHotEncoder needs to be dense
1241- X_ohe = self .ohe_ .fit_transform (
1241+ return self .ohe_ .fit_transform (
12421242 X .toarray () if sparse .issparse (X )
12431243 else X )
12441244
1245- X_resampled , y_resampled = super (SMOTEN , self )._fit_resample (
1246- X_ohe , y )
1247- X_resampled = self .ohe_ .inverse_transform (X_resampled )
1248-
1249- if sparse .issparse (X ):
1250- X_resampled = sparse .csr_matrix (X_resampled )
1251-
1252- return X_resampled , y_resampled
1253-
1254- def _generate_sample (self , X , nn_data , nn_num , row , col , step ):
1255- """Generate a synthetic sample with an additional steps for the
1256- categorical features.
12571245
1258- Each new sample is generated the same way than in SMOTE. However, the
1259- categorical features are mapped to the most frequent nearest neighbors
1260- of the majority class.
1261- """
1262- rng = check_random_state (self .random_state )
1263- sample = super (SMOTEN , self )._generate_sample (X , nn_data , nn_num ,
1264- row , col , step )
1265- # To avoid conversion and since there is only few samples used, we
1266- # convert those samples to dense array.
1267- sample = (sample .toarray ().squeeze ()
1268- if sparse .issparse (sample ) else sample )
1269- all_neighbors = nn_data [nn_num [row ]]
1270- all_neighbors = (all_neighbors .toarray ()
1271- if sparse .issparse (all_neighbors ) else all_neighbors )
1272-
1273- categories_size = [cat .size for cat in self .ohe_ .categories_ ]
1274-
1275- for start_idx , end_idx in zip (np .cumsum (categories_size )[:- 1 ],
1276- np .cumsum (categories_size )[1 :]):
1277- col_max = all_neighbors [:, start_idx :end_idx ].sum (axis = 0 )
1278- # tie breaking argmax
1279- col_sel = rng .choice (np .flatnonzero (
1280- np .isclose (col_max , col_max .max ())))
1281- sample [start_idx :end_idx ] = 0
1282- sample [start_idx + col_sel ] = 1
1283-
1284- return sparse .csr_matrix (sample ) if sparse .issparse (X ) else sample
0 commit comments