Skip to content

Commit e024120

Browse files
committed
Move common methods to reduce method parameters.
1 parent a9e9262 commit e024120

File tree

1 file changed

+43
-45
lines changed

1 file changed

+43
-45
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala

Lines changed: 43 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ import org.apache.spark.sql.expressions.UserDefinedFunction
3232
import org.apache.spark.sql.functions.{col, lit, udf}
3333
import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType}
3434

35-
/** Private trait for params for OneHotEncoderEstimator and OneHotEncoderModel */
36-
private[ml] trait OneHotEncoderParams extends Params with HasHandleInvalid
35+
/** Private trait for params and common methods for OneHotEncoderEstimator and OneHotEncoderModel */
36+
private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
3737
with HasInputCols with HasOutputCols {
3838

3939
/**
@@ -62,6 +62,35 @@ private[ml] trait OneHotEncoderParams extends Params with HasHandleInvalid
6262
/** @group getParam */
6363
@Since("2.3.0")
6464
def getDropLast: Boolean = $(dropLast)
65+
66+
protected def checkParamsValidity(schema: StructType): Unit = {
67+
val inputColNames = $(inputCols)
68+
val outputColNames = $(outputCols)
69+
val existingFields = schema.fields
70+
71+
require(inputColNames.length == outputColNames.length,
72+
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
73+
s"output columns ${outputColNames.length}.")
74+
75+
inputColNames.zip(outputColNames).map { case (inputColName, outputColName) =>
76+
require(schema(inputColName).dataType.isInstanceOf[NumericType],
77+
s"Input column must be of type NumericType but got ${schema(inputColName).dataType}")
78+
require(!existingFields.exists(_.name == outputColName),
79+
s"Output column $outputColName already exists.")
80+
}
81+
}
82+
83+
/** Prepares output columns with proper attributes by examining input columns. */
84+
protected def prepareSchemaWithOutputField(schema: StructType): StructType = {
85+
val inputFields = $(inputCols).map(schema(_))
86+
val outputColNames = $(outputCols)
87+
88+
val outputFields = inputFields.zip(outputColNames).map { case (inputField, outputColName) =>
89+
OneHotEncoderCommon.transformOutputColumnSchema(
90+
inputField, $(dropLast), outputColName)
91+
}
92+
StructType(schema.fields ++ outputFields)
93+
}
6594
}
6695

6796
/**
@@ -80,7 +109,7 @@ private[ml] trait OneHotEncoderParams extends Params with HasHandleInvalid
80109
*/
81110
@Since("2.3.0")
82111
class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid: String)
83-
extends Estimator[OneHotEncoderModel] with OneHotEncoderParams with DefaultParamsWritable {
112+
extends Estimator[OneHotEncoderModel] with OneHotEncoderBase with DefaultParamsWritable {
84113

85114
@Since("2.3.0")
86115
def this() = this(Identifiable.randomUID("oneHotEncoder"))
@@ -103,14 +132,8 @@ class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid:
103132

104133
@Since("2.3.0")
105134
override def transformSchema(schema: StructType): StructType = {
106-
val inputColNames = $(inputCols)
107-
val outputColNames = $(outputCols)
108-
109-
OneHotEncoderEstimator.checkParamsValidity(inputColNames, outputColNames, schema)
110-
111-
val outputFields = OneHotEncoderEstimator.prepareOutputFields(
112-
inputColNames.map(schema(_)), outputColNames, $(dropLast))
113-
StructType(schema.fields ++ outputFields)
135+
checkParamsValidity(schema)
136+
prepareSchemaWithOutputField(schema)
114137
}
115138

116139
@Since("2.3.0")
@@ -158,42 +181,13 @@ object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimat
158181

159182
@Since("2.3.0")
160183
override def load(path: String): OneHotEncoderEstimator = super.load(path)
161-
162-
private[feature] def checkParamsValidity(
163-
inputColNames: Seq[String],
164-
outputColNames: Seq[String],
165-
schema: StructType): Unit = {
166-
167-
val inputFields = schema.fields
168-
169-
require(inputColNames.length == outputColNames.length,
170-
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
171-
s"output columns ${outputColNames.length}.")
172-
173-
inputColNames.zip(outputColNames).map { case (inputColName, outputColName) =>
174-
require(schema(inputColName).dataType.isInstanceOf[NumericType],
175-
s"Input column must be of type NumericType but got ${schema(inputColName).dataType}")
176-
require(!inputFields.exists(_.name == outputColName),
177-
s"Output column $outputColName already exists.")
178-
}
179-
}
180-
181-
private[feature] def prepareOutputFields(
182-
inputCols: Seq[StructField],
183-
outputColNames: Seq[String],
184-
dropLast: Boolean): Seq[StructField] = {
185-
inputCols.zip(outputColNames).map { case (inputCol, outputColName) =>
186-
OneHotEncoderCommon.transformOutputColumnSchema(
187-
inputCol, dropLast, outputColName)
188-
}
189-
}
190184
}
191185

192186
@Since("2.3.0")
193187
class OneHotEncoderModel private[ml] (
194188
@Since("2.3.0") override val uid: String,
195189
@Since("2.3.0") val categorySizes: Array[Int])
196-
extends Model[OneHotEncoderModel] with OneHotEncoderParams with MLWritable {
190+
extends Model[OneHotEncoderModel] with OneHotEncoderBase with MLWritable {
197191

198192
import OneHotEncoderModel._
199193

@@ -241,17 +235,21 @@ class OneHotEncoderModel private[ml] (
241235
val inputColNames = $(inputCols)
242236
val outputColNames = $(outputCols)
243237

244-
OneHotEncoderEstimator.checkParamsValidity(inputColNames, outputColNames, schema)
238+
checkParamsValidity(schema)
245239

246240
require(inputColNames.length == categorySizes.length,
247241
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
248242
s"features ${categorySizes.length} during fitting.")
249243

250-
val outputFields = OneHotEncoderEstimator.prepareOutputFields(
251-
inputColNames.map(schema(_)), outputColNames, $(dropLast))
252-
verifyNumOfValues(StructType(schema.fields ++ outputFields))
244+
val transformedSchema = prepareSchemaWithOutputField(schema)
245+
verifyNumOfValues(transformedSchema)
253246
}
254247

248+
/**
249+
* If the metadata of input columns also specifies the number of categories, we need to
250+
* compare with expected category number obtained during fitting. Mismatched numbers will
251+
* cause exception.
252+
*/
255253
private def verifyNumOfValues(schema: StructType): StructType = {
256254
$(outputCols).zipWithIndex.foreach { case (outputColName, idx) =>
257255
val inputColName = $(inputCols)(idx)

0 commit comments

Comments
 (0)