Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -1648,30 +1648,30 @@ class OneHotEncoder(JavaEstimator, HasInputCols, HasOutputCols, HasHandleInvalid
at most a single one-value per row that indicates the input category index.
For example with 5 categories, an input value of 2.0 would map to an output vector of
`[0.0, 0.0, 1.0, 0.0]`.
The last category is not included by default (configurable via `dropLast`),
The last category is not included by default (configurable via :py:attr:`dropLast`),
because it makes the vector entries sum up to one, and hence linearly dependent.
So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`.

Note: This is different from scikit-learn's OneHotEncoder, which keeps all categories.
The output vectors are sparse.
.. note:: This is different from scikit-learn's OneHotEncoder, which keeps all categories.
The output vectors are sparse.

When `handleInvalid` is configured to 'keep', an extra "category" indicating invalid values is
added as last category. So when `dropLast` is true, invalid values are encoded as all-zeros
vector.
When :py:attr:`handleInvalid` is configured to 'keep', an extra "category" indicating invalid
values is added as last category. So when :py:attr:`dropLast` is true, invalid values are
encoded as all-zeros vector.

Note: When encoding multi-column by using `inputCols` and `outputCols` params, input/output
cols come in pairs, specified by the order in the arrays, and each pair is treated
independently.
.. note:: When encoding multi-column by using :py:attr:`inputCols` and
:py:attr:`outputCols` params, input/output cols come in pairs, specified by the order in
the arrays, and each pair is treated independently.

See `StringIndexer` for converting categorical values into category indices
.. seealso:: :py:class:`StringIndexer` for converting categorical values into category indices

>>> from pyspark.ml.linalg import Vectors
>>> df = spark.createDataFrame([(0.0,), (1.0,), (2.0,)], ["input"])
>>> ohe = OneHotEncoder(inputCols=["input"], outputCols=["output"])
>>> model = ohe.fit(df)
>>> model.transform(df).head().output
SparseVector(2, {0: 1.0})
>>> ohePath = temp_path + "/oheEstimator"
>>> ohePath = temp_path + "/ohe"
>>> ohe.save(ohePath)
>>> loadedOHE = OneHotEncoder.load(ohePath)
>>> loadedOHE.getInputCols() == ohe.getInputCols()
Expand Down