@@ -32,8 +32,8 @@ import org.apache.spark.sql.expressions.UserDefinedFunction
3232import org .apache .spark .sql .functions .{col , lit , udf }
3333import org .apache .spark .sql .types .{DoubleType , NumericType , StructField , StructType }
3434
35- /** Private trait for params for OneHotEncoderEstimator and OneHotEncoderModel */
36- private [ml] trait OneHotEncoderParams extends Params with HasHandleInvalid
35+ /** Private trait for params and common methods for OneHotEncoderEstimator and OneHotEncoderModel */
36+ private [ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
3737 with HasInputCols with HasOutputCols {
3838
3939 /**
@@ -62,6 +62,35 @@ private[ml] trait OneHotEncoderParams extends Params with HasHandleInvalid
6262 /** @group getParam */
6363 @ Since (" 2.3.0" )
6464 def getDropLast : Boolean = $(dropLast)
65+
66+ protected def checkParamsValidity (schema : StructType ): Unit = {
67+ val inputColNames = $(inputCols)
68+ val outputColNames = $(outputCols)
69+ val existingFields = schema.fields
70+
71+ require(inputColNames.length == outputColNames.length,
72+ s " The number of input columns ${inputColNames.length} must be the same as the number of " +
73+ s " output columns ${outputColNames.length}. " )
74+
75+ inputColNames.zip(outputColNames).map { case (inputColName, outputColName) =>
76+ require(schema(inputColName).dataType.isInstanceOf [NumericType ],
77+ s " Input column must be of type NumericType but got ${schema(inputColName).dataType}" )
78+ require(! existingFields.exists(_.name == outputColName),
79+ s " Output column $outputColName already exists. " )
80+ }
81+ }
82+
83+ /** Prepares output columns with proper attributes by examining input columns. */
84+ protected def prepareSchemaWithOutputField (schema : StructType ): StructType = {
85+ val inputFields = $(inputCols).map(schema(_))
86+ val outputColNames = $(outputCols)
87+
88+ val outputFields = inputFields.zip(outputColNames).map { case (inputField, outputColName) =>
89+ OneHotEncoderCommon .transformOutputColumnSchema(
90+ inputField, $(dropLast), outputColName)
91+ }
92+ StructType (schema.fields ++ outputFields)
93+ }
6594}
6695
6796/**
@@ -80,7 +109,7 @@ private[ml] trait OneHotEncoderParams extends Params with HasHandleInvalid
80109 */
81110@ Since (" 2.3.0" )
82111class OneHotEncoderEstimator @ Since (" 2.3.0" ) (@ Since (" 2.3.0" ) override val uid : String )
83- extends Estimator [OneHotEncoderModel ] with OneHotEncoderParams with DefaultParamsWritable {
112+ extends Estimator [OneHotEncoderModel ] with OneHotEncoderBase with DefaultParamsWritable {
84113
85114 @ Since (" 2.3.0" )
86115 def this () = this (Identifiable .randomUID(" oneHotEncoder" ))
@@ -103,14 +132,8 @@ class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid:
103132
104133 @ Since (" 2.3.0" )
105134 override def transformSchema (schema : StructType ): StructType = {
106- val inputColNames = $(inputCols)
107- val outputColNames = $(outputCols)
108-
109- OneHotEncoderEstimator .checkParamsValidity(inputColNames, outputColNames, schema)
110-
111- val outputFields = OneHotEncoderEstimator .prepareOutputFields(
112- inputColNames.map(schema(_)), outputColNames, $(dropLast))
113- StructType (schema.fields ++ outputFields)
135+ checkParamsValidity(schema)
136+ prepareSchemaWithOutputField(schema)
114137 }
115138
116139 @ Since (" 2.3.0" )
@@ -158,42 +181,13 @@ object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimat
158181
159182 @ Since (" 2.3.0" )
160183 override def load (path : String ): OneHotEncoderEstimator = super .load(path)
161-
162- private [feature] def checkParamsValidity (
163- inputColNames : Seq [String ],
164- outputColNames : Seq [String ],
165- schema : StructType ): Unit = {
166-
167- val inputFields = schema.fields
168-
169- require(inputColNames.length == outputColNames.length,
170- s " The number of input columns ${inputColNames.length} must be the same as the number of " +
171- s " output columns ${outputColNames.length}. " )
172-
173- inputColNames.zip(outputColNames).map { case (inputColName, outputColName) =>
174- require(schema(inputColName).dataType.isInstanceOf [NumericType ],
175- s " Input column must be of type NumericType but got ${schema(inputColName).dataType}" )
176- require(! inputFields.exists(_.name == outputColName),
177- s " Output column $outputColName already exists. " )
178- }
179- }
180-
181- private [feature] def prepareOutputFields (
182- inputCols : Seq [StructField ],
183- outputColNames : Seq [String ],
184- dropLast : Boolean ): Seq [StructField ] = {
185- inputCols.zip(outputColNames).map { case (inputCol, outputColName) =>
186- OneHotEncoderCommon .transformOutputColumnSchema(
187- inputCol, dropLast, outputColName)
188- }
189- }
190184}
191185
192186@ Since (" 2.3.0" )
193187class OneHotEncoderModel private [ml] (
194188 @ Since (" 2.3.0" ) override val uid : String ,
195189 @ Since (" 2.3.0" ) val categorySizes : Array [Int ])
196- extends Model [OneHotEncoderModel ] with OneHotEncoderParams with MLWritable {
190+ extends Model [OneHotEncoderModel ] with OneHotEncoderBase with MLWritable {
197191
198192 import OneHotEncoderModel ._
199193
@@ -241,17 +235,21 @@ class OneHotEncoderModel private[ml] (
241235 val inputColNames = $(inputCols)
242236 val outputColNames = $(outputCols)
243237
244- OneHotEncoderEstimator . checkParamsValidity(inputColNames, outputColNames, schema)
238+ checkParamsValidity(schema)
245239
246240 require(inputColNames.length == categorySizes.length,
247241 s " The number of input columns ${inputColNames.length} must be the same as the number of " +
248242 s " features ${categorySizes.length} during fitting. " )
249243
250- val outputFields = OneHotEncoderEstimator .prepareOutputFields(
251- inputColNames.map(schema(_)), outputColNames, $(dropLast))
252- verifyNumOfValues(StructType (schema.fields ++ outputFields))
244+ val transformedSchema = prepareSchemaWithOutputField(schema)
245+ verifyNumOfValues(transformedSchema)
253246 }
254247
248+ /**
249+ * If the metadata of input columns also specifies the number of categories, we need to
250+ * compare with expected category number obtained during fitting. Mismatched numbers will
251+ * cause exception.
252+ */
255253 private def verifyNumOfValues (schema : StructType ): StructType = {
256254 $(outputCols).zipWithIndex.foreach { case (outputColName, idx) =>
257255 val inputColName = $(inputCols)(idx)
0 commit comments