Skip to content
11 changes: 6 additions & 5 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,12 @@ Below is a list of the methods currently implemented in this module.
1. Random minority over-sampling with replacement
2. SMOTE - Synthetic Minority Over-sampling Technique [8]_
3. SMOTENC - SMOTE for Nominal Continuous [8]_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SMOTENC - SMOTE for Nominal and Continuous
or
SMOTENC - SMOTE Nominal Continuous

4. bSMOTE(1 & 2) - Borderline SMOTE of types 1 and 2 [9]_
5. SVM SMOTE - Support Vectors SMOTE [10]_
6. ADASYN - Adaptive synthetic sampling approach for imbalanced learning [15]_
7. KMeans-SMOTE [17]_
8. ROSE - Random OverSampling Examples [19]_
4. SMOTEN - SMMOTE for Nominal only [8]_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SMMOTE->SMOTE

5. bSMOTE(1 & 2) - Borderline SMOTE of types 1 and 2 [9]_
6. SVM SMOTE - Support Vectors SMOTE [10]_
7. ADASYN - Adaptive synthetic sampling approach for imbalanced learning [15]_
8. KMeans-SMOTE [17]_
9. ROSE - Random OverSampling Examples [19]_

* Over-sampling followed by under-sampling
1. SMOTE + Tomek links [12]_
Expand Down
38 changes: 38 additions & 0 deletions doc/over_sampling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,44 @@ Therefore, it can be seen that the samples generated in the first and last
columns are belonging to the same categories originally presented without any
other extra interpolation.

However, :class:`SMOTENC` is working with data composed of categorical data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, :class:SMOTENC is working with datasets composed of continuous and categorical features.

only. WHen data are made of only nominal categorical data, one can use the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When

:class:`SMOTEN` variant :cite:`chawla2002smote`. The algorithm changes in
two ways:

* the nearest neighbors search does not rely on the Euclidean distance. Indeed,
the value difference metric (VDM) also implemented in the class
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Value Difference Metric

:class:`~imblearn.metrics.ValueDifferenceMetric` is used.
* the new sample generation is based on majority vote per feature to generate
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on the majority?

the most common category seen in the neighbors samples.

Let's take the following example::

>>> import numpy as np
>>> X = np.array(["green"] * 5 + ["red"] * 10 + ["blue"] * 7,
... dtype=object).reshape(-1, 1)
>>> y = np.array(["apple"] * 5 + ["not apple"] * 3 + ["apple"] * 7 +
... ["not apple"] * 5 + ["apple"] * 2, dtype=object)

We generate a dataset associating a color to being an apple or not an apple.
We strongly associated "green" and "red" to being an apple. The minority class
being "not apple", we expect new data generated belonging to the category
"blue"::

>>> from imblearn.over_sampling import SMOTEN
>>> sampler = SMOTEN(random_state=0)
>>> X_res, y_res = sampler.fit_resample(X, y)
>>> X_res[y.size:]
array([['blue'],
['blue'],
['blue'],
['blue'],
['blue'],
['blue']], dtype=object)
>>> y_res[y.size:]
array(['not apple', 'not apple', 'not apple', 'not apple', 'not apple',
'not apple'], dtype=object)

Mathematical formulation
========================

Expand Down
1 change: 1 addition & 0 deletions doc/references/over_sampling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ SMOTE algorithms

SMOTE
SMOTENC
SMOTEN
ADASYN
BorderlineSMOTE
KMeansSMOTE
Expand Down
4 changes: 4 additions & 0 deletions doc/whats_new/v0.8.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ New features
compute pairwise distances between samples containing only nominal values.
:pr:`796` by :user:`Guillaume Lemaitre <glemaitre>`.

- Add the class :class:`imblearn.over_sampling.SMOTEN` to over-sample data
only containing nominal categorical features.
:pr:`802` by :user:`Guillaume Lemaitre <glemaitre>`.

Enhancements
............

Expand Down
2 changes: 2 additions & 0 deletions imblearn/over_sampling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ._smote import KMeansSMOTE
from ._smote import SVMSMOTE
from ._smote import SMOTENC
from ._smote import SMOTEN

__all__ = [
"ADASYN",
Expand All @@ -19,4 +20,5 @@
"BorderlineSMOTE",
"SVMSMOTE",
"SMOTENC",
"SMOTEN",
]
14 changes: 14 additions & 0 deletions imblearn/over_sampling/_adasyn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ class ADASYN(BaseOverSampler):
--------
SMOTE : Over-sample using SMOTE.

SMOTENC : Over-sample using SMOTE for continuous and categorical features.

SMOTEN : Over-sample using the SMOTE variable specifically for categorical
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SMOTEN : Over-sample using the SMOTE variant specifically for nominal features only.

features only.

SVMSMOTE : Over-sample using SVM-SMOTE variant.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SVMSMOTE : Over-sample using the SVM-SMOTE variant.


BorderlineSMOTE : Over-sample using Borderline-SMOTE variant.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BorderlineSMOTE : Over-sample using the Borderline-SMOTE variant.


Notes
-----
The implementation is based on [1]_.
Expand Down Expand Up @@ -169,3 +178,8 @@ def _fit_resample(self, X, y):
y_resampled = np.hstack(y_resampled)

return X_resampled, y_resampled

def _more_tags(self):
return {
"X_types": ["2darray"],
}
3 changes: 3 additions & 0 deletions imblearn/over_sampling/_random_over_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ class RandomOverSampler(BaseOverSampler):

SMOTENC : Over-sample using SMOTE for continuous and categorical features.

SMOTEN : Over-sample using the SMOTE variable specifically for categorical
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SMOTEN : Over-sample using the SMOTE variant specifically for nominal features only.

features only.

SVMSMOTE : Over-sample using SVM-SMOTE variant.

ADASYN : Over-sample using ADASYN.
Expand Down
160 changes: 159 additions & 1 deletion imblearn/over_sampling/_smote.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@

import numpy as np
from scipy import sparse
from scipy import stats

from sklearn.base import clone
from sklearn.cluster import MiniBatchKMeans
from sklearn.metrics import pairwise_distances
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
from sklearn.svm import SVC
from sklearn.utils import check_random_state
from sklearn.utils import _safe_indexing
Expand All @@ -25,6 +26,7 @@

from .base import BaseOverSampler
from ..exceptions import raise_isinstance_error
from ..metrics.pairwise import ValueDifferenceMetric
from ..utils import check_neighbors_object
from ..utils import check_target_type
from ..utils import Substitution
Expand Down Expand Up @@ -448,6 +450,9 @@ class SVMSMOTE(BaseSMOTE):

SMOTENC : Over-sample using SMOTE for continuous and categorical features.

SMOTEN : Over-sample using the SMOTE variable specifically for categorical
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SMOTEN : Over-sample using the SMOTE variant specifically for nominal features only.

features only.

BorderlineSMOTE : Over-sample using Borderline-SMOTE.

ADASYN : Over-sample using ADASYN.
Expand Down Expand Up @@ -643,6 +648,9 @@ class SMOTE(BaseSMOTE):
--------
SMOTENC : Over-sample using SMOTE for continuous and categorical features.

SMOTEN : Over-sample using the SMOTE variable specifically for categorical
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SMOTEN : Over-sample using the SMOTE variant specifically for nominal features only.

features only.

BorderlineSMOTE : Over-sample using the borderline-SMOTE variant.

SVMSMOTE : Over-sample using the SVM-SMOTE variant.
Expand Down Expand Up @@ -766,6 +774,9 @@ class SMOTENC(SMOTE):
--------
SMOTE : Over-sample using SMOTE.

SMOTEN : Over-sample using the SMOTE variable specifically for categorical
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SMOTEN : Over-sample using the SMOTE variant specifically for nominal features only.

features only.

SVMSMOTE : Over-sample using SVM-SMOTE variant.

BorderlineSMOTE : Over-sample using Borderline-SMOTE variant.
Expand Down Expand Up @@ -1055,6 +1066,11 @@ class KMeansSMOTE(BaseSMOTE):
--------
SMOTE : Over-sample using SMOTE.

SMOTENC : Over-sample using SMOTE for continuous and categorical features.

SMOTEN : Over-sample using the SMOTE variable specifically for categorical
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SMOTEN : Over-sample using the SMOTE variant specifically for nominal features only.

features only.

SVMSMOTE : Over-sample using SVM-SMOTE variant.

BorderlineSMOTE : Over-sample using Borderline-SMOTE variant.
Expand Down Expand Up @@ -1248,3 +1264,145 @@ def _fit_resample(self, X, y):
y_resampled = np.hstack((y_resampled, y_new))

return X_resampled, y_resampled


@Substitution(
sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
n_jobs=_n_jobs_docstring,
random_state=_random_state_docstring,
)
class SMOTEN(SMOTE):
"""Perform SMOTE over-sampling for nominal categorical features only.

This method is refered as SMOTEN in [1]_.

Read more in the :ref:`User Guide <smote_adasyn>`.

Parameters
----------
{sampling_strategy}

{random_state}

k_neighbors : int or object, default=5
If ``int``, number of nearest neighbours to used to construct synthetic
samples. If object, an estimator that inherits from
:class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to
find the k_neighbors.

{n_jobs}

See Also
--------
SMOTE : Over-sample using SMOTE.

SMOTENC : Over-sample using SMOTE for continuous and categorical features.

BorderlineSMOTE : Over-sample using the borderline-SMOTE variant.

SVMSMOTE : Over-sample using the SVM-SMOTE variant.

ADASYN : Over-sample using ADASYN.

KMeansSMOTE : Over-sample applying a clustering before to oversample using
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KMeansSMOTE : Over-sample by applying a clustering before to oversample using
SMOTE.

SMOTE.

Notes
-----
See the original papers: [1]_ for more details.

Supports multi-class resampling. A one-vs.-rest scheme is used as
originally proposed in [1]_.

References
----------
.. [1] N. V. Chawla, K. W. Bowyer, L. O.Hall, W. P. Kegelmeyer, "SMOTE:
synthetic minority over-sampling technique," Journal of artificial
intelligence research, 321-357, 2002.

Examples
--------
>>> import numpy as np
>>> X = np.array(["A"] * 10 + ["B"] * 20 + ["C"] * 30, dtype=object).reshape(-1, 1)
>>> y = np.array([0] * 20 + [1] * 40, dtype=np.int32)
>>> from collections import Counter
>>> print(f"Original class counts: {{Counter(y)}}")
Original class counts: Counter({{1: 40, 0: 20}})
>>> from imblearn.over_sampling import SMOTEN
>>> sampler = SMOTEN(random_state=0)
>>> X_res, y_res = sampler.fit_resample(X, y)
>>> print(f"Class counts after resampling {{Counter(y_res)}}")
Class counts after resampling Counter({{0: 40, 1: 40}})
"""

def _check_X_y(self, X, y):
"""Check should accept strings and not sparse matrices."""
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
X, y = self._validate_data(
X,
y,
reset=True,
dtype=None,
accept_sparse=False,
)
return X, y, binarize_y

def _validate_estimator(self):
"""Force to use precomputed distance matrix."""
super()._validate_estimator()
self.nn_k_.set_params(metric="precomputed")

def _make_samples(self, X_class, klass, y_dtype, nn_indices, n_samples):
random_state = check_random_state(self.random_state)
# generate sample indices that will be used to generate new samples
samples_indices = random_state.choice(
np.arange(X_class.shape[0]), size=n_samples, replace=True
)
# for each drawn samples, select its k-neighbors and generate a sample
# where for each feature individually, each category generated is the
# most common category
X_new = np.squeeze(
stats.mode(X_class[nn_indices[samples_indices]], axis=1).mode, axis=1
)
y_new = np.full(n_samples, fill_value=klass, dtype=y_dtype)
return X_new, y_new

def _fit_resample(self, X, y):
self._validate_estimator()

X_resampled = [X.copy()]
y_resampled = [y.copy()]

encoder = OrdinalEncoder(dtype=np.int32)
X_encoded = encoder.fit_transform(X)

vdm = ValueDifferenceMetric(
n_categories=[len(cat) for cat in encoder.categories_]
).fit(X_encoded, y)

for class_sample, n_samples in self.sampling_strategy_.items():
if n_samples == 0:
continue
target_class_indices = np.flatnonzero(y == class_sample)
X_class = _safe_indexing(X_encoded, target_class_indices)

X_class_dist = vdm.pairwise(X_class)
self.nn_k_.fit(X_class_dist)
# the kneigbors search will include the sample itself which is
# expected from the original algorithm
nn_indices = self.nn_k_.kneighbors(X_class_dist, return_distance=False)
X_new, y_new = self._make_samples(
X_class, class_sample, y.dtype, nn_indices, n_samples
)

X_new = encoder.inverse_transform(X_new)
X_resampled.append(X_new)
y_resampled.append(y_new)

X_resampled = np.vstack(X_resampled)
y_resampled = np.hstack(y_resampled)

return X_resampled, y_resampled

def _more_tags(self):
return {"X_types": ["2darray", "dataframe", "string"]}
54 changes: 54 additions & 0 deletions imblearn/over_sampling/tests/test_smoten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np
import pytest

from imblearn.over_sampling import SMOTEN


@pytest.fixture
def data():
rng = np.random.RandomState(0)

feature_1 = ["A"] * 10 + ["B"] * 20 + ["C"] * 30
feature_2 = ["A"] * 40 + ["B"] * 20
feature_3 = ["A"] * 20 + ["B"] * 20 + ["C"] * 10 + ["D"] * 10
X = np.array([feature_1, feature_2, feature_3], dtype=object).T
rng.shuffle(X)
y = np.array([0] * 20 + [1] * 40, dtype=np.int32)
y_labels = np.array(["not apple", "apple"], dtype=object)
y = y_labels[y]
return X, y


def test_smoten(data):
# overall check for SMOTEN
X, y = data
sampler = SMOTEN(random_state=0)
X_res, y_res = sampler.fit_resample(X, y)

assert X_res.shape == (80, 3)
assert y_res.shape == (80,)


def test_smoten_resampling():
# check if the SMOTEN resample data as expected
# we generate data such that "not apple" will be the minority class and
# samples from this class will be generated. We will force the "blue"
# category to be associated with this class. Therefore, the new generated
# samples should as well be from the "blue" category.
X = np.array(["green"] * 5 + ["red"] * 10 + ["blue"] * 7, dtype=object).reshape(
-1, 1
)
y = np.array(
["apple"] * 5
+ ["not apple"] * 3
+ ["apple"] * 7
+ ["not apple"] * 5
+ ["apple"] * 2,
dtype=object,
)
sampler = SMOTEN(random_state=0)
X_res, y_res = sampler.fit_resample(X, y)

X_generated, y_generated = X_res[X.shape[0] :], y_res[X.shape[0] :]
np.testing.assert_array_equal(X_generated, "blue")
np.testing.assert_array_equal(y_generated, "not apple")