Skip to content

Commit 994065d

Browse files
viiryajkbradley
authored andcommitted
[SPARK-13030][ML] Create OneHotEncoderEstimator for OneHotEncoder as Estimator
## What changes were proposed in this pull request? This patch adds a new class `OneHotEncoderEstimator` which extends `Estimator`. The `fit` method returns `OneHotEncoderModel`. Common methods between existing `OneHotEncoder` and new `OneHotEncoderEstimator`, such as transforming schema, are extracted and put into `OneHotEncoderCommon` to reduce code duplication. ### Multi-column support `OneHotEncoderEstimator` adds simpler multi-column support because it is new API and can be free from backward compatibility. ### handleInvalid Param support `OneHotEncoderEstimator` supports `handleInvalid` Param. It supports `error` and `keep`. ## How was this patch tested? Added new test suite `OneHotEncoderEstimatorSuite`. Author: Liang-Chi Hsieh <[email protected]> Closes #19527 from viirya/SPARK-13030.
1 parent 5955a2d commit 994065d

File tree

3 files changed

+960
-66
lines changed

3 files changed

+960
-66
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala

Lines changed: 17 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,12 @@ import org.apache.spark.sql.types.{DoubleType, NumericType, StructType}
4141
* The output vectors are sparse.
4242
*
4343
* @see `StringIndexer` for converting categorical values into category indices
44+
* @deprecated `OneHotEncoderEstimator` will be renamed `OneHotEncoder` and this `OneHotEncoder`
45+
* will be removed in 3.0.0.
4446
*/
4547
@Since("1.4.0")
48+
@deprecated("`OneHotEncoderEstimator` will be renamed `OneHotEncoder` and this `OneHotEncoder`" +
49+
" will be removed in 3.0.0.", "2.3.0")
4650
class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Transformer
4751
with HasInputCol with HasOutputCol with DefaultParamsWritable {
4852

@@ -78,56 +82,16 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e
7882
override def transformSchema(schema: StructType): StructType = {
7983
val inputColName = $(inputCol)
8084
val outputColName = $(outputCol)
85+
val inputFields = schema.fields
8186

8287
require(schema(inputColName).dataType.isInstanceOf[NumericType],
8388
s"Input column must be of type NumericType but got ${schema(inputColName).dataType}")
84-
val inputFields = schema.fields
8589
require(!inputFields.exists(_.name == outputColName),
8690
s"Output column $outputColName already exists.")
8791

88-
val inputAttr = Attribute.fromStructField(schema(inputColName))
89-
val outputAttrNames: Option[Array[String]] = inputAttr match {
90-
case nominal: NominalAttribute =>
91-
if (nominal.values.isDefined) {
92-
nominal.values
93-
} else if (nominal.numValues.isDefined) {
94-
nominal.numValues.map(n => Array.tabulate(n)(_.toString))
95-
} else {
96-
None
97-
}
98-
case binary: BinaryAttribute =>
99-
if (binary.values.isDefined) {
100-
binary.values
101-
} else {
102-
Some(Array.tabulate(2)(_.toString))
103-
}
104-
case _: NumericAttribute =>
105-
throw new RuntimeException(
106-
s"The input column $inputColName cannot be numeric.")
107-
case _ =>
108-
None // optimistic about unknown attributes
109-
}
110-
111-
val filteredOutputAttrNames = outputAttrNames.map { names =>
112-
if ($(dropLast)) {
113-
require(names.length > 1,
114-
s"The input column $inputColName should have at least two distinct values.")
115-
names.dropRight(1)
116-
} else {
117-
names
118-
}
119-
}
120-
121-
val outputAttrGroup = if (filteredOutputAttrNames.isDefined) {
122-
val attrs: Array[Attribute] = filteredOutputAttrNames.get.map { name =>
123-
BinaryAttribute.defaultAttr.withName(name)
124-
}
125-
new AttributeGroup($(outputCol), attrs)
126-
} else {
127-
new AttributeGroup($(outputCol))
128-
}
129-
130-
val outputFields = inputFields :+ outputAttrGroup.toStructField()
92+
val outputField = OneHotEncoderCommon.transformOutputColumnSchema(
93+
schema(inputColName), outputColName, $(dropLast))
94+
val outputFields = inputFields :+ outputField
13195
StructType(outputFields)
13296
}
13397

@@ -136,30 +100,17 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e
136100
// schema transformation
137101
val inputColName = $(inputCol)
138102
val outputColName = $(outputCol)
139-
val shouldDropLast = $(dropLast)
140-
var outputAttrGroup = AttributeGroup.fromStructField(
103+
104+
val outputAttrGroupFromSchema = AttributeGroup.fromStructField(
141105
transformSchema(dataset.schema)(outputColName))
142-
if (outputAttrGroup.size < 0) {
143-
// If the number of attributes is unknown, we check the values from the input column.
144-
val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).rdd.map(_.getDouble(0))
145-
.treeAggregate(0.0)(
146-
(m, x) => {
147-
assert(x <= Int.MaxValue,
148-
s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x")
149-
assert(x >= 0.0 && x == x.toInt,
150-
s"Values from column $inputColName must be indices, but got $x.")
151-
math.max(m, x)
152-
},
153-
(m0, m1) => {
154-
math.max(m0, m1)
155-
}
156-
).toInt + 1
157-
val outputAttrNames = Array.tabulate(numAttrs)(_.toString)
158-
val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames
159-
val outputAttrs: Array[Attribute] =
160-
filtered.map(name => BinaryAttribute.defaultAttr.withName(name))
161-
outputAttrGroup = new AttributeGroup(outputColName, outputAttrs)
106+
107+
val outputAttrGroup = if (outputAttrGroupFromSchema.size < 0) {
108+
OneHotEncoderCommon.getOutputAttrGroupFromData(
109+
dataset, Seq(inputColName), Seq(outputColName), $(dropLast))(0)
110+
} else {
111+
outputAttrGroupFromSchema
162112
}
113+
163114
val metadata = outputAttrGroup.toMetadata()
164115

165116
// data transformation

0 commit comments

Comments
 (0)