From 8fd4677fd0e729d99d8777010e78bb5cfea3cf86 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 18 Oct 2017 07:31:32 +0000 Subject: [PATCH 01/13] Add OneHotEncoderEstimator and related tests. --- .../spark/ml/feature/OneHotEncoder.scala | 79 +--- .../ml/feature/OneHotEncoderEstimator.scala | 439 ++++++++++++++++++ .../ml/param/shared/SharedParamsCodeGen.scala | 1 + .../spark/ml/param/shared/sharedParams.scala | 15 + .../feature/OneHotEncoderEstimatorSuite.scala | 293 ++++++++++++ 5 files changed, 761 insertions(+), 66 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index a669da183e2c8..d4813dcf2663d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -78,56 +78,16 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e override def transformSchema(schema: StructType): StructType = { val inputColName = $(inputCol) val outputColName = $(outputCol) + val inputFields = schema.fields require(schema(inputColName).dataType.isInstanceOf[NumericType], s"Input column must be of type NumericType but got ${schema(inputColName).dataType}") - val inputFields = schema.fields require(!inputFields.exists(_.name == outputColName), s"Output column $outputColName already exists.") - val inputAttr = Attribute.fromStructField(schema(inputColName)) - val outputAttrNames: Option[Array[String]] = inputAttr match { - case nominal: NominalAttribute => - if (nominal.values.isDefined) { - nominal.values - } else if (nominal.numValues.isDefined) { - nominal.numValues.map(n => Array.tabulate(n)(_.toString)) - } else { - None - } - case binary: BinaryAttribute => - if (binary.values.isDefined) { - binary.values - } else { - Some(Array.tabulate(2)(_.toString)) - } - case _: NumericAttribute => - throw new RuntimeException( - s"The input column $inputColName cannot be numeric.") - case _ => - None // optimistic about unknown attributes - } - - val filteredOutputAttrNames = outputAttrNames.map { names => - if ($(dropLast)) { - require(names.length > 1, - s"The input column $inputColName should have at least two distinct values.") - names.dropRight(1) - } else { - names - } - } - - val outputAttrGroup = if (filteredOutputAttrNames.isDefined) { - val attrs: Array[Attribute] = filteredOutputAttrNames.get.map { name => - BinaryAttribute.defaultAttr.withName(name) - } - new AttributeGroup($(outputCol), attrs) - } else { - new AttributeGroup($(outputCol)) - } - - val outputFields = inputFields :+ outputAttrGroup.toStructField() + val outputField = OneHotEncoderCommon.transformOutputColumnSchema( + schema(inputColName), $(dropLast), outputColName) + val outputFields = inputFields :+ outputField StructType(outputFields) } @@ -136,30 +96,17 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e // schema transformation val inputColName = $(inputCol) val outputColName = $(outputCol) - val shouldDropLast = $(dropLast) - var outputAttrGroup = AttributeGroup.fromStructField( + + val outputAttrGroupFromSchema = AttributeGroup.fromStructField( transformSchema(dataset.schema)(outputColName)) - if (outputAttrGroup.size < 0) { - // If the number of attributes is unknown, we check the values from the input column. - val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).rdd.map(_.getDouble(0)) - .treeAggregate(0.0)( - (m, x) => { - assert(x <= Int.MaxValue, - s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x") - assert(x >= 0.0 && x == x.toInt, - s"Values from column $inputColName must be indices, but got $x.") - math.max(m, x) - }, - (m0, m1) => { - math.max(m0, m1) - } - ).toInt + 1 - val outputAttrNames = Array.tabulate(numAttrs)(_.toString) - val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames - val outputAttrs: Array[Attribute] = - filtered.map(name => BinaryAttribute.defaultAttr.withName(name)) - outputAttrGroup = new AttributeGroup(outputColName, outputAttrs) + + val outputAttrGroup = if (outputAttrGroupFromSchema.size < 0) { + OneHotEncoderCommon.getOutputAttrGroupFromData( + dataset, $(dropLast), inputColName, outputColName) + } else { + outputAttrGroupFromSchema } + val metadata = outputAttrGroup.toMetadata() // data transformation diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala new file mode 100644 index 0000000000000..017bc6be9aeb8 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -0,0 +1,439 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkException +import org.apache.spark.annotation.Since +import org.apache.spark.ml.{Estimator, Model, Transformer} +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType} + +/** Private trait for params for OneHotEncoderEstimator and OneHotEncoderModel */ +private[ml] trait OneHotEncoderParams extends Params with HasHandleInvalid + with HasInputCols with HasOutputCols { + + /** + * Param for how to handle invalid data. + * Options are 'skip' (filter out rows with invalid data) or 'error' (throw an error). + * Default: "error" + * @group param + */ + @Since("2.3.0") + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + "How to handle invalid data " + + "Options are 'skip' (filter out rows with invalid data) or error (throw an error).", + ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids)) + + setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID) + + /** + * Whether to drop the last category in the encoded vector (default: true) + * @group param + */ + @Since("2.3.0") + final val dropLast: BooleanParam = + new BooleanParam(this, "dropLast", "whether to drop the last category") + setDefault(dropLast -> true) + + /** @group getParam */ + @Since("2.3.0") + def getDropLast: Boolean = $(dropLast) +} + +/** + * A one-hot encoder that maps a column of category indices to a column of binary vectors, with + * at most a single one-value per row that indicates the input category index. + * For example with 5 categories, an input value of 2.0 would map to an output vector of + * `[0.0, 0.0, 1.0, 0.0]`. + * The last category is not included by default (configurable via `dropLast`), + * because it makes the vector entries sum up to one, and hence linearly dependent. + * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. + * + * @note This is different from scikit-learn's OneHotEncoder, which keeps all categories. + * The output vectors are sparse. + * + * @see `StringIndexer` for converting categorical values into category indices + */ +@Since("2.3.0") +class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid: String) + extends Estimator[OneHotEncoderModel] with OneHotEncoderParams with DefaultParamsWritable { + + @Since("2.3.0") + def this() = this(Identifiable.randomUID("oneHotEncoder")) + + /** @group setParam */ + @Since("2.3.0") + def setInputCols(values: Array[String]): this.type = set(inputCols, values) + + /** @group setParam */ + @Since("2.3.0") + def setOutputCols(values: Array[String]): this.type = set(outputCols, values) + + /** @group setParam */ + @Since("2.3.0") + def setDropLast(value: Boolean): this.type = set(dropLast, value) + + /** @group setParam */ + @Since("2.3.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + @Since("2.3.0") + override def transformSchema(schema: StructType): StructType = { + val inputColNames = $(inputCols) + val outputColNames = $(outputCols) + val inputFields = schema.fields + + require(inputColNames.length == outputColNames.length, + s"The number of input columns ${inputColNames.length} must be the same as the number of " + + s"output columns ${outputColNames.length}.") + + val outputFields = inputColNames.zip(outputColNames).map { case (inputColName, outputColName) => + + require(schema(inputColName).dataType.isInstanceOf[NumericType], + s"Input column must be of type NumericType but got ${schema(inputColName).dataType}") + require(!inputFields.exists(_.name == outputColName), + s"Output column $outputColName already exists.") + + OneHotEncoderCommon.transformOutputColumnSchema( + schema(inputColName), $(dropLast), outputColName) + } + StructType(inputFields ++ outputFields) + } + + @Since("2.3.0") + override def fit(dataset: Dataset[_]): OneHotEncoderModel = { + val transformedSchema = transformSchema(dataset.schema) + + val categorySizes = $(outputCols).zipWithIndex.map { case (outputColName, idx) => + val outputAttrGroupFromSchema = AttributeGroup.fromStructField( + transformedSchema(outputColName)) + + val outputAttrGroup = if (outputAttrGroupFromSchema.size < 0) { + OneHotEncoderCommon.getOutputAttrGroupFromData( + dataset, $(dropLast), $(inputCols)(idx), outputColName) + } else { + outputAttrGroupFromSchema + } + + outputAttrGroup.size + } + + val model = new OneHotEncoderModel(uid, categorySizes).setParent(this) + copyValues(model) + } + + @Since("2.3.0") + override def copy(extra: ParamMap): OneHotEncoderEstimator = defaultCopy(extra) +} + +@Since("2.3.0") +object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimator] { + + private[feature] val SKIP_INVALID: String = "skip" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val supportedHandleInvalids: Array[String] = Array(SKIP_INVALID, ERROR_INVALID) + + @Since("2.3.0") + override def load(path: String): OneHotEncoderEstimator = super.load(path) +} + +@Since("2.3.0") +class OneHotEncoderModel private[ml] ( + @Since("2.3.0") override val uid: String, + @Since("2.3.0") val categorySizes: Array[Int]) + extends Model[OneHotEncoderModel] with OneHotEncoderParams with MLWritable { + + import OneHotEncoderModel._ + + private def encoders: Array[UserDefinedFunction] = { + val oneValue = Array(1.0) + val emptyValues = Array.empty[Double] + val emptyIndices = Array.empty[Int] + val dropLast = getDropLast + val handleInvalid = getHandleInvalid + + categorySizes.map { size => + udf { label: Double => + if (label < size) { + Vectors.sparse(size, Array(label.toInt), oneValue) + } else if (label == size && dropLast) { + Vectors.sparse(size, emptyIndices, emptyValues) + } else { + if (handleInvalid == OneHotEncoderEstimator.ERROR_INVALID) { + throw new SparkException(s"Unseen value: $label. To handle unseen values, " + + s"set Param handleInvalid to ${OneHotEncoderEstimator.SKIP_INVALID}.") + } else { + Vectors.sparse(size, emptyIndices, emptyValues) + } + } + } + } + } + + /** @group setParam */ + @Since("2.3.0") + def setInputCols(values: Array[String]): this.type = set(inputCols, values) + + /** @group setParam */ + @Since("2.3.0") + def setOutputCols(values: Array[String]): this.type = set(outputCols, values) + + /** @group setParam */ + @Since("2.3.0") + def setDropLast(value: Boolean): this.type = set(dropLast, value) + + /** @group setParam */ + @Since("2.3.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + @Since("2.3.0") + override def transformSchema(schema: StructType): StructType = { + val inputColNames = $(inputCols) + val outputColNames = $(outputCols) + val inputFields = schema.fields + + require(inputColNames.length == outputColNames.length, + s"The number of input columns ${inputColNames.length} must be the same as the number of " + + s"output columns ${outputColNames.length}.") + + require(inputColNames.length == categorySizes.length, + s"The number of input columns ${inputColNames.length} must be the same as the number of " + + s"features ${categorySizes.length} during fitting.") + + val inputOutputPairs = inputColNames.zip(outputColNames) + val outputFields = inputOutputPairs.map { case (inputColName, outputColName) => + + require(schema(inputColName).dataType.isInstanceOf[NumericType], + s"Input column must be of type NumericType but got ${schema(inputColName).dataType}") + require(!inputFields.exists(_.name == outputColName), + s"Output column $outputColName already exists.") + + OneHotEncoderCommon.transformOutputColumnSchema( + schema(inputColName), $(dropLast), outputColName) + } + verifyNumOfValues(StructType(inputFields ++ outputFields)) + } + + private def verifyNumOfValues(schema: StructType): StructType = { + $(outputCols).zipWithIndex.foreach { case (outputColName, idx) => + val inputColName = $(inputCols)(idx) + val attrGroup = AttributeGroup.fromStructField(schema(outputColName)) + + // If the input metadata specifies number of category, + // compare with expected category number. + if (attrGroup.attributes.nonEmpty) { + require(attrGroup.size == categorySizes(idx), "OneHotEncoderModel expected " + + s"${categorySizes(idx)} categorical values for input column ${inputColName}, but " + + s"the input column had metadata specifying ${attrGroup.size} values.") + } + } + schema + } + + @Since("2.3.0") + override def transform(dataset: Dataset[_]): DataFrame = { + if (getDropLast && getHandleInvalid == OneHotEncoderEstimator.SKIP_INVALID) { + throw new IllegalArgumentException("When Param handleInvalid is set to " + + s"${OneHotEncoderEstimator.SKIP_INVALID}, Param dropLast can't be true, " + + "because last category and invalid values will conflict in encoded vector.") + } + + val transformedSchema = transformSchema(dataset.schema, logging = true) + + val encodedColumns = encoders.zipWithIndex.map { case (encoder, idx) => + val inputColName = $(inputCols)(idx) + val outputColName = $(outputCols)(idx) + + val outputAttrGroupFromSchema = + AttributeGroup.fromStructField(transformedSchema(outputColName)) + + val metadata = if (outputAttrGroupFromSchema.size < 0) { + OneHotEncoderCommon.createAttrGroupForAttrNames(outputColName, false, + categorySizes(idx)).toMetadata() + } else { + outputAttrGroupFromSchema.toMetadata() + } + + encoder(col(inputColName).cast(DoubleType)).as(outputColName, metadata) + } + val allCols = Seq(col("*")) ++ encodedColumns + dataset.select(allCols: _*) + } + + @Since("2.3.0") + override def copy(extra: ParamMap): OneHotEncoderModel = { + val copied = new OneHotEncoderModel(uid, categorySizes) + copyValues(copied, extra).setParent(parent) + } + + @Since("2.3.0") + override def write: MLWriter = new OneHotEncoderModelWriter(this) +} + +@Since("2.3.0") +object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] { + + private[OneHotEncoderModel] + class OneHotEncoderModelWriter(instance: OneHotEncoderModel) extends MLWriter { + + private case class Data(categorySizes: Array[Int]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.categorySizes) + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class OneHotEncoderModelReader extends MLReader[OneHotEncoderModel] { + + private val className = classOf[OneHotEncoderModel].getName + + override def load(path: String): OneHotEncoderModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sparkSession.read.parquet(dataPath) + .select("categorySizes") + .head() + val categorySizes = data.getAs[Seq[Int]](0).toArray + val model = new OneHotEncoderModel(metadata.uid, categorySizes) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("2.3.0") + override def read: MLReader[OneHotEncoderModel] = new OneHotEncoderModelReader + + @Since("2.3.0") + override def load(path: String): OneHotEncoderModel = super.load(path) +} + +/** + * Provides some helper methods used by both `OneHotEncoder` and `OneHotEncoderEstimator`. + */ +private[feature] object OneHotEncoderCommon { + + private def genOutputAttrNames( + inputCol: StructField, + outputColName: String): Option[Array[String]] = { + val inputAttr = Attribute.fromStructField(inputCol) + inputAttr match { + case nominal: NominalAttribute => + if (nominal.values.isDefined) { + nominal.values + } else if (nominal.numValues.isDefined) { + nominal.numValues.map(n => Array.tabulate(n)(_.toString)) + } else { + None + } + case binary: BinaryAttribute => + if (binary.values.isDefined) { + binary.values + } else { + Some(Array.tabulate(2)(_.toString)) + } + case _: NumericAttribute => + throw new RuntimeException( + s"The input column ${inputCol.name} cannot be numeric.") + case _ => + None // optimistic about unknown attributes + } + } + + /** Creates an `AttributeGroup` filled by the `BinaryAttribute` named as required. */ + private def genOutputAttrGroup( + outputAttrNames: Option[Array[String]], + outputColName: String): AttributeGroup = { + outputAttrNames.map { attrNames => + val attrs: Array[Attribute] = attrNames.map { name => + BinaryAttribute.defaultAttr.withName(name) + } + new AttributeGroup(outputColName, attrs) + }.getOrElse{ + new AttributeGroup(outputColName) + } + } + + /** + * Prepares the `StructField` with proper metadata for `OneHotEncoder`'s output column. + */ + def transformOutputColumnSchema( + inputCol: StructField, + dropLast: Boolean, + outputColName: String): StructField = { + val outputAttrNames = genOutputAttrNames(inputCol, outputColName) + val filteredOutputAttrNames = outputAttrNames.map { names => + if (dropLast) { + require(names.length > 1, + s"The input column ${inputCol.name} should have at least two distinct values.") + names.dropRight(1) + } else { + names + } + } + + genOutputAttrGroup(filteredOutputAttrNames, outputColName).toStructField() + } + + /** + * This method is called when we want to generate `AttributeGroup` from actual data for + * one-hot encoder. + */ + def getOutputAttrGroupFromData( + dataset: Dataset[_], + dropLast: Boolean, + inputColName: String, + outputColName: String): AttributeGroup = { + val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).rdd.map(_.getDouble(0)) + .treeAggregate(0.0)( + (m, x) => { + assert(x <= Int.MaxValue, + s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x") + assert(x >= 0.0 && x == x.toInt, + s"Values from column $inputColName must be indices, but got $x.") + math.max(m, x) + }, + (m0, m1) => { + math.max(m0, m1) + } + ).toInt + 1 + + createAttrGroupForAttrNames(outputColName, dropLast, numAttrs) + } + + /** Creates an `AttributeGroup` with the required number of `BinaryAttribute`. */ + def createAttrGroupForAttrNames( + outputColName: String, + dropLast: Boolean, + numAttrs: Int): AttributeGroup = { + val outputAttrNames = Array.tabulate(numAttrs)(_.toString) + val filtered = if (dropLast) outputAttrNames.dropRight(1) else outputAttrNames + genOutputAttrGroup(Some(filtered), outputColName) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 1860fe8361749..64163d3c16039 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -60,6 +60,7 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), + ParamDesc[Array[String]]("outputCols", "output column names"), ParamDesc[Int]("checkpointInterval", "set checkpoint interval (>= 1) or " + "disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed " + "every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 6061d9ca0a084..b9f4791b37de1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -230,6 +230,21 @@ private[ml] trait HasOutputCol extends Params { final def getOutputCol: String = $(outputCol) } +/** + * Trait for shared param outputCols. + */ +private[ml] trait HasOutputCols extends Params { + + /** + * Param for output column names. + * @group param + */ + final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols", "output column names") + + /** @group getParam */ + final def getOutputCols: Array[String] = $(outputCols) +} + /** * Trait for shared param checkpointInterval. */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala new file mode 100644 index 0000000000000..621a67c744912 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types._ + +class OneHotEncoderEstimatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + + def stringIndexed(): DataFrame = stringIndexedMultipleCols().select("id", "label", "labelIndex") + + def stringIndexedMultipleCols(): DataFrame = { + val data = Seq( + (0, "a", "A"), + (1, "b", "B"), + (2, "c", "D"), + (3, "a", "A"), + (4, "a", "B"), + (5, "c", "C")) + val df = data.toDF("id", "label", "label2") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + val df2 = indexer.transform(df) + val indexer2 = new StringIndexer() + .setInputCol("label2") + .setOutputCol("labelIndex2") + .fit(df2) + indexer2.transform(df2) + } + + test("params") { + ParamsSuite.checkParams(new OneHotEncoderEstimator) + } + + test("OneHotEncoderEstimator dropLast = false") { + val transformed = stringIndexed() + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("labelIndex")) + .setOutputCols(Array("labelVec")) + assert(encoder.getDropLast === true) + encoder.setDropLast(false) + assert(encoder.getDropLast === false) + + val model = encoder.fit(transformed) + val encoded = model.transform(transformed) + + val output = encoded.select("id", "labelVec").rdd.map { r => + val vec = r.getAs[Vector](1) + (r.getInt(0), vec(0), vec(1), vec(2)) + }.collect().toSet + // a -> 0, b -> 2, c -> 1 + val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0), + (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0)) + assert(output === expected) + } + + test("OneHotEncoderEstimator dropLast = true") { + val transformed = stringIndexed() + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("labelIndex")) + .setOutputCols(Array("labelVec")) + + val model = encoder.fit(transformed) + val encoded = model.transform(transformed) + + val output = encoded.select("id", "labelVec").rdd.map { r => + val vec = r.getAs[Vector](1) + (r.getInt(0), vec(0), vec(1)) + }.collect().toSet + // a -> 0, b -> 2, c -> 1 + val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0), + (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0)) + assert(output === expected) + } + + test("input column with ML attribute") { + val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") + val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("size") + .select(col("size").as("size", attr.toMetadata())) + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("size")) + .setOutputCols(Array("encoded")) + val model = encoder.fit(df) + val output = model.transform(df) + val group = AttributeGroup.fromStructField(output.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + } + + test("input column without ML attribute") { + val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("index") + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("index")) + .setOutputCols(Array("encoded")) + val model = encoder.fit(df) + val output = model.transform(df) + val group = AttributeGroup.fromStructField(output.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) + } + + test("read/write") { + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("index")) + .setOutputCols(Array("encoded")) + testDefaultReadWrite(encoder) + } + + test("OneHotEncoderModel read/write") { + val instance = new OneHotEncoderModel("myOneHotEncoderModel", Array(1, 2, 3)) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.categorySizes === instance.categorySizes) + } + + test("OneHotEncoderEstimator with varying types") { + val df = stringIndexed() + val dfWithTypes = df + .withColumn("shortLabel", df("labelIndex").cast(ShortType)) + .withColumn("longLabel", df("labelIndex").cast(LongType)) + .withColumn("intLabel", df("labelIndex").cast(IntegerType)) + .withColumn("floatLabel", df("labelIndex").cast(FloatType)) + .withColumn("decimalLabel", df("labelIndex").cast(DecimalType(10, 0))) + val cols = Array("labelIndex", "shortLabel", "longLabel", "intLabel", + "floatLabel", "decimalLabel") + for (col <- cols) { + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array(col)) + .setOutputCols(Array("labelVec")) + .setDropLast(false) + val model = encoder.fit(dfWithTypes) + val encoded = model.transform(dfWithTypes) + + val output = encoded.select("id", "labelVec").rdd.map { r => + val vec = r.getAs[Vector](1) + (r.getInt(0), vec(0), vec(1), vec(2)) + }.collect().toSet + // a -> 0, b -> 2, c -> 1 + val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0), + (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0)) + assert(output === expected) + } + } + + test("OneHotEncoderEstimator: encoding multiple columns and dropLast = false") { + val transformed = stringIndexedMultipleCols() + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("labelIndex", "labelIndex2")) + .setOutputCols(Array("labelVec", "labelVec2")) + assert(encoder.getDropLast === true) + encoder.setDropLast(false) + assert(encoder.getDropLast === false) + + val model = encoder.fit(transformed) + val encoded = model.transform(transformed) + + // Verify 1st column. + val output = encoded.select("id", "labelVec").rdd.map { r => + val vec = r.getAs[Vector](1) + (r.getInt(0), vec(0), vec(1), vec(2)) + }.collect().toSet + // a -> 0, b -> 2, c -> 1 + val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0), + (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0)) + assert(output === expected) + + // Verify 2nd column. + val output2 = encoded.select("id", "labelVec2").rdd.map { r => + val vec = r.getAs[Vector](1) + (r.getInt(0), vec(0), vec(1), vec(2), vec(3)) + }.collect().toSet + // A -> 1, B -> 0, C -> 3, D -> 2 + val expected2 = Set((0, 0.0, 1.0, 0.0, 0.0), (1, 1.0, 0.0, 0.0, 0.0), (2, 0.0, 0.0, 1.0, 0.0), + (3, 0.0, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0, 0.0), (5, 0.0, 0.0, 0.0, 1.0)) + assert(output2 === expected2) + } + + test("OneHotEncoderEstimator: encoding multiple columns and dropLast = true") { + val transformed = stringIndexedMultipleCols() + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("labelIndex", "labelIndex2")) + .setOutputCols(Array("labelVec", "labelVec2")) + + val model = encoder.fit(transformed) + val encoded = model.transform(transformed) + + // Verify 1st column. + val output = encoded.select("id", "labelVec").rdd.map { r => + val vec = r.getAs[Vector](1) + (r.getInt(0), vec(0), vec(1)) + }.collect().toSet + // a -> 0, b -> 2, c -> 1 + val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0), + (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0)) + assert(output === expected) + + // Verify 2nd column. + val output2 = encoded.select("id", "labelVec2").rdd.map { r => + val vec = r.getAs[Vector](1) + (r.getInt(0), vec(0), vec(1), vec(2)) + }.collect().toSet + // A -> 1, B -> 0, C -> 3, D -> 2 + val expected2 = Set((0, 0.0, 1.0, 0.0), (1, 1.0, 0.0, 0.0), (2, 0.0, 0.0, 1.0), + (3, 0.0, 1.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 0.0, 0.0)) + assert(output2 === expected2) + } + + test("Throw error on invalid values") { + val trainingData = Seq((0, 0), (1, 1), (2, 2)) + val trainingDF = trainingData.toDF("id", "a") + val testData = Seq((0, 0), (1, 2), (1, 3)) + val testDF = testData.toDF("id", "a") + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("a")) + .setOutputCols(Array("encoded")) + + val model = encoder.fit(trainingDF) + val err = intercept[SparkException] { + model.transform(testDF).show + } + err.getMessage.contains("Unseen value: 3.0. To handle unseen values") + } + + test("Skip on invalid values") { + val trainingData = Seq((0, 0), (1, 1)) + val trainingDF = trainingData.toDF("id", "a") + val testData = Seq((0, 0), (1, 2)) + val testDF = testData.toDF("id", "a") + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("a")) + .setOutputCols(Array("encoded")) + .setHandleInvalid("skip") + .setDropLast(false) + + val model = encoder.fit(trainingDF) + val encoded = model.transform(testDF) + + val output = encoded.select("id", "encoded").rdd.map { r => + val vec = r.getAs[Vector](1) + (r.getInt(0), vec(0), vec(1)) + }.collect().toSet + val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0)) + assert(output === expected) + } + + test("Can't set dropLast as true and skip on invalid values") { + val trainingData = Seq((0, 0), (1, 1)) + val trainingDF = trainingData.toDF("id", "a") + val testData = Seq((0, 0), (1, 2)) + val testDF = testData.toDF("id", "a") + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("a")) + .setOutputCols(Array("encoded")) + .setHandleInvalid("skip") + + val model = encoder.fit(trainingDF) + val err = intercept[IllegalArgumentException] { + model.transform(testDF) + } + err.getMessage.contains("When Param handleInvalid is set to skip, Param dropLast can't be true") + } +} From 48e650890f9c7638e539c3bda5cf74ae6796a1ce Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 19 Oct 2017 03:36:28 +0000 Subject: [PATCH 02/13] Scan multi-column at once to obtain the numbers of values for columns. --- .../spark/ml/feature/OneHotEncoder.scala | 2 +- .../ml/feature/OneHotEncoderEstimator.scala | 71 +++++++++++++------ 2 files changed, 49 insertions(+), 24 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index d4813dcf2663d..23fc209bcc542 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -102,7 +102,7 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e val outputAttrGroup = if (outputAttrGroupFromSchema.size < 0) { OneHotEncoderCommon.getOutputAttrGroupFromData( - dataset, $(dropLast), inputColName, outputColName) + dataset, $(dropLast), Seq(inputColName), Seq(outputColName))(0) } else { outputAttrGroupFromSchema } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala index 017bc6be9aeb8..9b26888e727e8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -127,19 +127,29 @@ class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid: @Since("2.3.0") override def fit(dataset: Dataset[_]): OneHotEncoderModel = { val transformedSchema = transformSchema(dataset.schema) + val categorySizes = new Array[Int]($(outputCols).length) - val categorySizes = $(outputCols).zipWithIndex.map { case (outputColName, idx) => - val outputAttrGroupFromSchema = AttributeGroup.fromStructField( - transformedSchema(outputColName)) - - val outputAttrGroup = if (outputAttrGroupFromSchema.size < 0) { - OneHotEncoderCommon.getOutputAttrGroupFromData( - dataset, $(dropLast), $(inputCols)(idx), outputColName) + val columnToScanIndices = $(outputCols).zipWithIndex.flatMap { case (outputColName, idx) => + val numOfAttrs = AttributeGroup.fromStructField( + transformedSchema(outputColName)).size + if (numOfAttrs < 0) { + Some(idx) } else { - outputAttrGroupFromSchema + categorySizes(idx) = numOfAttrs + None } + } - outputAttrGroup.size + // Some input columns don't have attributes or their attributes don't have necessary info. + // We need to scan the data to get the number of values for each column. + if (columnToScanIndices.length > 0) { + val inputColNames = columnToScanIndices.map($(inputCols)(_)) + val outputColNames = columnToScanIndices.map($(outputCols)(_)) + val attrGroups = OneHotEncoderCommon.getOutputAttrGroupFromData( + dataset, $(dropLast), inputColNames, outputColNames) + attrGroups.zip(columnToScanIndices).foreach { case (attrGroup, idx) => + categorySizes(idx) = attrGroup.size + } } val model = new OneHotEncoderModel(uid, categorySizes).setParent(this) @@ -408,23 +418,38 @@ private[feature] object OneHotEncoderCommon { def getOutputAttrGroupFromData( dataset: Dataset[_], dropLast: Boolean, - inputColName: String, - outputColName: String): AttributeGroup = { - val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).rdd.map(_.getDouble(0)) - .treeAggregate(0.0)( - (m, x) => { + inputColNames: Seq[String], + outputColNames: Seq[String]): Seq[AttributeGroup] = { + // The RDD approach has advantage of early-stop if any values are invalid. It seems that + // DataFrame ops don't have equivalent functions. + val columns = inputColNames.map { inputColName => + col(inputColName).cast(DoubleType) + } + val numOfColumns = columns.length + + val numAttrsArray = dataset.select(columns: _*).rdd.map { row => + val array = new Array[Double](numOfColumns) + (0 until numOfColumns).foreach(idx => array(idx) = row.getDouble(idx)) + array + }.treeAggregate(new Array[Double](numOfColumns))( + (maxValues, curValues) => { + (0 until numOfColumns).map { idx => + val x = curValues(idx) assert(x <= Int.MaxValue, - s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x") + s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x.") assert(x >= 0.0 && x == x.toInt, - s"Values from column $inputColName must be indices, but got $x.") - math.max(m, x) - }, - (m0, m1) => { - math.max(m0, m1) - } - ).toInt + 1 + s"Values from column ${inputColNames(idx)} must be indices, but got $x.") + math.max(maxValues(idx), x) + }.toArray + }, + (m0, m1) => { + (0 until numOfColumns).map(idx => math.max(m0(idx), m1(idx))).toArray + } + ).map(_.toInt + 1) - createAttrGroupForAttrNames(outputColName, dropLast, numAttrs) + outputColNames.zip(numAttrsArray).map { case (outputColName, numAttrs) => + createAttrGroupForAttrNames(outputColName, dropLast, numAttrs) + } } /** Creates an `AttributeGroup` with the required number of `BinaryAttribute`. */ From b42d175ddc4928ec36718177702059ccf0bfbfea Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 20 Oct 2017 00:32:09 +0000 Subject: [PATCH 03/13] Remove unused import. --- .../org/apache/spark/ml/feature/OneHotEncoderEstimator.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala index 9b26888e727e8..498989b95914c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -21,11 +21,11 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.Since -import org.apache.spark.ml.{Estimator, Model, Transformer} +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} +import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCols, HasOutputCols} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.expressions.UserDefinedFunction From 66d46acaa58ca0e6304878504e94117bbee59d24 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 20 Oct 2017 10:20:29 +0000 Subject: [PATCH 04/13] Rename "skip" to "keep". Reduce encoder array to one encoder. Use withColumns. --- .../ml/feature/OneHotEncoderEstimator.scala | 89 +++++++++---------- .../feature/OneHotEncoderEstimatorSuite.scala | 10 +-- 2 files changed, 49 insertions(+), 50 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala index 498989b95914c..fe5c7011a93cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -29,7 +29,7 @@ import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCols, HasOutp import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.expressions.UserDefinedFunction -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.{col, lit, udf} import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType} /** Private trait for params for OneHotEncoderEstimator and OneHotEncoderModel */ @@ -38,14 +38,14 @@ private[ml] trait OneHotEncoderParams extends Params with HasHandleInvalid /** * Param for how to handle invalid data. - * Options are 'skip' (filter out rows with invalid data) or 'error' (throw an error). + * Options are 'keep' (invalid data are ignored) or 'error' (throw an error). * Default: "error" * @group param */ @Since("2.3.0") override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle invalid data " + - "Options are 'skip' (filter out rows with invalid data) or error (throw an error).", + "Options are 'keep' (invalid data are ignored) or error (throw an error).", ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids)) setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID) @@ -107,17 +107,9 @@ class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid: val outputColNames = $(outputCols) val inputFields = schema.fields - require(inputColNames.length == outputColNames.length, - s"The number of input columns ${inputColNames.length} must be the same as the number of " + - s"output columns ${outputColNames.length}.") + OneHotEncoderEstimator.checkParamsValidity(inputColNames, outputColNames, schema) val outputFields = inputColNames.zip(outputColNames).map { case (inputColName, outputColName) => - - require(schema(inputColName).dataType.isInstanceOf[NumericType], - s"Input column must be of type NumericType but got ${schema(inputColName).dataType}") - require(!inputFields.exists(_.name == outputColName), - s"Output column $outputColName already exists.") - OneHotEncoderCommon.transformOutputColumnSchema( schema(inputColName), $(dropLast), outputColName) } @@ -163,12 +155,31 @@ class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid: @Since("2.3.0") object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimator] { - private[feature] val SKIP_INVALID: String = "skip" + private[feature] val KEEP_INVALID: String = "keep" private[feature] val ERROR_INVALID: String = "error" - private[feature] val supportedHandleInvalids: Array[String] = Array(SKIP_INVALID, ERROR_INVALID) + private[feature] val supportedHandleInvalids: Array[String] = Array(KEEP_INVALID, ERROR_INVALID) @Since("2.3.0") override def load(path: String): OneHotEncoderEstimator = super.load(path) + + private[feature] def checkParamsValidity( + inputColNames: Seq[String], + outputColNames: Seq[String], + schema: StructType): Unit = { + + val inputFields = schema.fields + + require(inputColNames.length == outputColNames.length, + s"The number of input columns ${inputColNames.length} must be the same as the number of " + + s"output columns ${outputColNames.length}.") + + inputColNames.zip(outputColNames).map { case (inputColName, outputColName) => + require(schema(inputColName).dataType.isInstanceOf[NumericType], + s"Input column must be of type NumericType but got ${schema(inputColName).dataType}") + require(!inputFields.exists(_.name == outputColName), + s"Output column $outputColName already exists.") + } + } } @Since("2.3.0") @@ -179,26 +190,24 @@ class OneHotEncoderModel private[ml] ( import OneHotEncoderModel._ - private def encoders: Array[UserDefinedFunction] = { + private def encoder: UserDefinedFunction = { val oneValue = Array(1.0) val emptyValues = Array.empty[Double] val emptyIndices = Array.empty[Int] val dropLast = getDropLast val handleInvalid = getHandleInvalid - categorySizes.map { size => - udf { label: Double => - if (label < size) { - Vectors.sparse(size, Array(label.toInt), oneValue) - } else if (label == size && dropLast) { - Vectors.sparse(size, emptyIndices, emptyValues) + udf { (label: Double, size: Int) => + if (label < size) { + Vectors.sparse(size, Array(label.toInt), oneValue) + } else if (label == size && dropLast) { + Vectors.sparse(size, emptyIndices, emptyValues) + } else { + if (handleInvalid == OneHotEncoderEstimator.ERROR_INVALID) { + throw new SparkException(s"Unseen value: $label. To handle unseen values, " + + s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.") } else { - if (handleInvalid == OneHotEncoderEstimator.ERROR_INVALID) { - throw new SparkException(s"Unseen value: $label. To handle unseen values, " + - s"set Param handleInvalid to ${OneHotEncoderEstimator.SKIP_INVALID}.") - } else { - Vectors.sparse(size, emptyIndices, emptyValues) - } + Vectors.sparse(size, emptyIndices, emptyValues) } } } @@ -226,9 +235,7 @@ class OneHotEncoderModel private[ml] ( val outputColNames = $(outputCols) val inputFields = schema.fields - require(inputColNames.length == outputColNames.length, - s"The number of input columns ${inputColNames.length} must be the same as the number of " + - s"output columns ${outputColNames.length}.") + OneHotEncoderEstimator.checkParamsValidity(inputColNames, outputColNames, schema) require(inputColNames.length == categorySizes.length, s"The number of input columns ${inputColNames.length} must be the same as the number of " + @@ -236,12 +243,6 @@ class OneHotEncoderModel private[ml] ( val inputOutputPairs = inputColNames.zip(outputColNames) val outputFields = inputOutputPairs.map { case (inputColName, outputColName) => - - require(schema(inputColName).dataType.isInstanceOf[NumericType], - s"Input column must be of type NumericType but got ${schema(inputColName).dataType}") - require(!inputFields.exists(_.name == outputColName), - s"Output column $outputColName already exists.") - OneHotEncoderCommon.transformOutputColumnSchema( schema(inputColName), $(dropLast), outputColName) } @@ -266,15 +267,15 @@ class OneHotEncoderModel private[ml] ( @Since("2.3.0") override def transform(dataset: Dataset[_]): DataFrame = { - if (getDropLast && getHandleInvalid == OneHotEncoderEstimator.SKIP_INVALID) { + if (getDropLast && getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID) { throw new IllegalArgumentException("When Param handleInvalid is set to " + - s"${OneHotEncoderEstimator.SKIP_INVALID}, Param dropLast can't be true, " + + s"${OneHotEncoderEstimator.KEEP_INVALID}, Param dropLast can't be true, " + "because last category and invalid values will conflict in encoded vector.") } val transformedSchema = transformSchema(dataset.schema, logging = true) - val encodedColumns = encoders.zipWithIndex.map { case (encoder, idx) => + val encodedColumns = (0 until $(inputCols).length).map { idx => val inputColName = $(inputCols)(idx) val outputColName = $(outputCols)(idx) @@ -288,10 +289,10 @@ class OneHotEncoderModel private[ml] ( outputAttrGroupFromSchema.toMetadata() } - encoder(col(inputColName).cast(DoubleType)).as(outputColName, metadata) + encoder(col(inputColName).cast(DoubleType), lit(categorySizes(idx))) + .as(outputColName, metadata) } - val allCols = Seq(col("*")) ++ encodedColumns - dataset.select(allCols: _*) + dataset.withColumns($(outputCols), encodedColumns) } @Since("2.3.0") @@ -428,9 +429,7 @@ private[feature] object OneHotEncoderCommon { val numOfColumns = columns.length val numAttrsArray = dataset.select(columns: _*).rdd.map { row => - val array = new Array[Double](numOfColumns) - (0 until numOfColumns).foreach(idx => array(idx) = row.getDouble(idx)) - array + (0 until numOfColumns).map(idx => row.getDouble(idx)).toArray }.treeAggregate(new Array[Double](numOfColumns))( (maxValues, curValues) => { (0 until numOfColumns).map { idx => diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala index 621a67c744912..01490cfc0a86a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala @@ -250,7 +250,7 @@ class OneHotEncoderEstimatorSuite err.getMessage.contains("Unseen value: 3.0. To handle unseen values") } - test("Skip on invalid values") { + test("Keep on invalid values") { val trainingData = Seq((0, 0), (1, 1)) val trainingDF = trainingData.toDF("id", "a") val testData = Seq((0, 0), (1, 2)) @@ -259,7 +259,7 @@ class OneHotEncoderEstimatorSuite val encoder = new OneHotEncoderEstimator() .setInputCols(Array("a")) .setOutputCols(Array("encoded")) - .setHandleInvalid("skip") + .setHandleInvalid("keep") .setDropLast(false) val model = encoder.fit(trainingDF) @@ -273,7 +273,7 @@ class OneHotEncoderEstimatorSuite assert(output === expected) } - test("Can't set dropLast as true and skip on invalid values") { + test("Can't set dropLast as true and keep on invalid values") { val trainingData = Seq((0, 0), (1, 1)) val trainingDF = trainingData.toDF("id", "a") val testData = Seq((0, 0), (1, 2)) @@ -282,12 +282,12 @@ class OneHotEncoderEstimatorSuite val encoder = new OneHotEncoderEstimator() .setInputCols(Array("a")) .setOutputCols(Array("encoded")) - .setHandleInvalid("skip") + .setHandleInvalid("keep") val model = encoder.fit(trainingDF) val err = intercept[IllegalArgumentException] { model.transform(testDF) } - err.getMessage.contains("When Param handleInvalid is set to skip, Param dropLast can't be true") + err.getMessage.contains("When Param handleInvalid is set to keep, Param dropLast can't be true") } } From a9e9262c2a05174f019cddb8a1ae14d48a92ffba Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 20 Oct 2017 12:51:04 +0000 Subject: [PATCH 05/13] Extract common method for preparing output fields. --- .../ml/feature/OneHotEncoderEstimator.scala | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala index fe5c7011a93cc..70861815d7381 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -105,15 +105,12 @@ class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid: override def transformSchema(schema: StructType): StructType = { val inputColNames = $(inputCols) val outputColNames = $(outputCols) - val inputFields = schema.fields OneHotEncoderEstimator.checkParamsValidity(inputColNames, outputColNames, schema) - val outputFields = inputColNames.zip(outputColNames).map { case (inputColName, outputColName) => - OneHotEncoderCommon.transformOutputColumnSchema( - schema(inputColName), $(dropLast), outputColName) - } - StructType(inputFields ++ outputFields) + val outputFields = OneHotEncoderEstimator.prepareOutputFields( + inputColNames.map(schema(_)), outputColNames, $(dropLast)) + StructType(schema.fields ++ outputFields) } @Since("2.3.0") @@ -180,6 +177,16 @@ object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimat s"Output column $outputColName already exists.") } } + + private[feature] def prepareOutputFields( + inputCols: Seq[StructField], + outputColNames: Seq[String], + dropLast: Boolean): Seq[StructField] = { + inputCols.zip(outputColNames).map { case (inputCol, outputColName) => + OneHotEncoderCommon.transformOutputColumnSchema( + inputCol, dropLast, outputColName) + } + } } @Since("2.3.0") @@ -233,7 +240,6 @@ class OneHotEncoderModel private[ml] ( override def transformSchema(schema: StructType): StructType = { val inputColNames = $(inputCols) val outputColNames = $(outputCols) - val inputFields = schema.fields OneHotEncoderEstimator.checkParamsValidity(inputColNames, outputColNames, schema) @@ -241,12 +247,9 @@ class OneHotEncoderModel private[ml] ( s"The number of input columns ${inputColNames.length} must be the same as the number of " + s"features ${categorySizes.length} during fitting.") - val inputOutputPairs = inputColNames.zip(outputColNames) - val outputFields = inputOutputPairs.map { case (inputColName, outputColName) => - OneHotEncoderCommon.transformOutputColumnSchema( - schema(inputColName), $(dropLast), outputColName) - } - verifyNumOfValues(StructType(inputFields ++ outputFields)) + val outputFields = OneHotEncoderEstimator.prepareOutputFields( + inputColNames.map(schema(_)), outputColNames, $(dropLast)) + verifyNumOfValues(StructType(schema.fields ++ outputFields)) } private def verifyNumOfValues(schema: StructType): StructType = { From e0241200c58a5ec201a0f1abdebc1660878ed49f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 20 Oct 2017 13:11:30 +0000 Subject: [PATCH 06/13] Move common methods to reduce method parameters. --- .../ml/feature/OneHotEncoderEstimator.scala | 88 +++++++++---------- 1 file changed, 43 insertions(+), 45 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala index 70861815d7381..b8309feec21d3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -32,8 +32,8 @@ import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions.{col, lit, udf} import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType} -/** Private trait for params for OneHotEncoderEstimator and OneHotEncoderModel */ -private[ml] trait OneHotEncoderParams extends Params with HasHandleInvalid +/** Private trait for params and common methods for OneHotEncoderEstimator and OneHotEncoderModel */ +private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid with HasInputCols with HasOutputCols { /** @@ -62,6 +62,35 @@ private[ml] trait OneHotEncoderParams extends Params with HasHandleInvalid /** @group getParam */ @Since("2.3.0") def getDropLast: Boolean = $(dropLast) + + protected def checkParamsValidity(schema: StructType): Unit = { + val inputColNames = $(inputCols) + val outputColNames = $(outputCols) + val existingFields = schema.fields + + require(inputColNames.length == outputColNames.length, + s"The number of input columns ${inputColNames.length} must be the same as the number of " + + s"output columns ${outputColNames.length}.") + + inputColNames.zip(outputColNames).map { case (inputColName, outputColName) => + require(schema(inputColName).dataType.isInstanceOf[NumericType], + s"Input column must be of type NumericType but got ${schema(inputColName).dataType}") + require(!existingFields.exists(_.name == outputColName), + s"Output column $outputColName already exists.") + } + } + + /** Prepares output columns with proper attributes by examining input columns. */ + protected def prepareSchemaWithOutputField(schema: StructType): StructType = { + val inputFields = $(inputCols).map(schema(_)) + val outputColNames = $(outputCols) + + val outputFields = inputFields.zip(outputColNames).map { case (inputField, outputColName) => + OneHotEncoderCommon.transformOutputColumnSchema( + inputField, $(dropLast), outputColName) + } + StructType(schema.fields ++ outputFields) + } } /** @@ -80,7 +109,7 @@ private[ml] trait OneHotEncoderParams extends Params with HasHandleInvalid */ @Since("2.3.0") class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid: String) - extends Estimator[OneHotEncoderModel] with OneHotEncoderParams with DefaultParamsWritable { + extends Estimator[OneHotEncoderModel] with OneHotEncoderBase with DefaultParamsWritable { @Since("2.3.0") def this() = this(Identifiable.randomUID("oneHotEncoder")) @@ -103,14 +132,8 @@ class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid: @Since("2.3.0") override def transformSchema(schema: StructType): StructType = { - val inputColNames = $(inputCols) - val outputColNames = $(outputCols) - - OneHotEncoderEstimator.checkParamsValidity(inputColNames, outputColNames, schema) - - val outputFields = OneHotEncoderEstimator.prepareOutputFields( - inputColNames.map(schema(_)), outputColNames, $(dropLast)) - StructType(schema.fields ++ outputFields) + checkParamsValidity(schema) + prepareSchemaWithOutputField(schema) } @Since("2.3.0") @@ -158,42 +181,13 @@ object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimat @Since("2.3.0") override def load(path: String): OneHotEncoderEstimator = super.load(path) - - private[feature] def checkParamsValidity( - inputColNames: Seq[String], - outputColNames: Seq[String], - schema: StructType): Unit = { - - val inputFields = schema.fields - - require(inputColNames.length == outputColNames.length, - s"The number of input columns ${inputColNames.length} must be the same as the number of " + - s"output columns ${outputColNames.length}.") - - inputColNames.zip(outputColNames).map { case (inputColName, outputColName) => - require(schema(inputColName).dataType.isInstanceOf[NumericType], - s"Input column must be of type NumericType but got ${schema(inputColName).dataType}") - require(!inputFields.exists(_.name == outputColName), - s"Output column $outputColName already exists.") - } - } - - private[feature] def prepareOutputFields( - inputCols: Seq[StructField], - outputColNames: Seq[String], - dropLast: Boolean): Seq[StructField] = { - inputCols.zip(outputColNames).map { case (inputCol, outputColName) => - OneHotEncoderCommon.transformOutputColumnSchema( - inputCol, dropLast, outputColName) - } - } } @Since("2.3.0") class OneHotEncoderModel private[ml] ( @Since("2.3.0") override val uid: String, @Since("2.3.0") val categorySizes: Array[Int]) - extends Model[OneHotEncoderModel] with OneHotEncoderParams with MLWritable { + extends Model[OneHotEncoderModel] with OneHotEncoderBase with MLWritable { import OneHotEncoderModel._ @@ -241,17 +235,21 @@ class OneHotEncoderModel private[ml] ( val inputColNames = $(inputCols) val outputColNames = $(outputCols) - OneHotEncoderEstimator.checkParamsValidity(inputColNames, outputColNames, schema) + checkParamsValidity(schema) require(inputColNames.length == categorySizes.length, s"The number of input columns ${inputColNames.length} must be the same as the number of " + s"features ${categorySizes.length} during fitting.") - val outputFields = OneHotEncoderEstimator.prepareOutputFields( - inputColNames.map(schema(_)), outputColNames, $(dropLast)) - verifyNumOfValues(StructType(schema.fields ++ outputFields)) + val transformedSchema = prepareSchemaWithOutputField(schema) + verifyNumOfValues(transformedSchema) } + /** + * If the metadata of input columns also specifies the number of categories, we need to + * compare with expected category number obtained during fitting. Mismatched numbers will + * cause exception. + */ private def verifyNumOfValues(schema: StructType): StructType = { $(outputCols).zipWithIndex.foreach { case (outputColName, idx) => val inputColName = $(inputCols)(idx) From adc410770528c6c95a3c35de64548362c1b46643 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 24 Oct 2017 00:31:14 +0000 Subject: [PATCH 07/13] Address comments. --- .../org/apache/spark/ml/feature/OneHotEncoderEstimator.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala index b8309feec21d3..6e9c23edbcef4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -38,14 +38,14 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid /** * Param for how to handle invalid data. - * Options are 'keep' (invalid data are ignored) or 'error' (throw an error). + * Options are 'keep' (invalid data produces a vector of zeros) or 'error' (throw an error). * Default: "error" * @group param */ @Since("2.3.0") override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle invalid data " + - "Options are 'keep' (invalid data are ignored) or error (throw an error).", + "Options are 'keep' (invalid data produces a vector of zeros) or error (throw an error).", ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids)) setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID) From ae2ac82b10e457b8beede9dc4a33ce0a578f007d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 26 Oct 2017 01:43:02 +0000 Subject: [PATCH 08/13] Remove unused method parameter. --- .../apache/spark/ml/feature/OneHotEncoderEstimator.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala index 6e9c23edbcef4..b46323d3af8a4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -351,9 +351,7 @@ object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] { */ private[feature] object OneHotEncoderCommon { - private def genOutputAttrNames( - inputCol: StructField, - outputColName: String): Option[Array[String]] = { + private def genOutputAttrNames(inputCol: StructField): Option[Array[String]] = { val inputAttr = Attribute.fromStructField(inputCol) inputAttr match { case nominal: NominalAttribute => @@ -399,7 +397,7 @@ private[feature] object OneHotEncoderCommon { inputCol: StructField, dropLast: Boolean, outputColName: String): StructField = { - val outputAttrNames = genOutputAttrNames(inputCol, outputColName) + val outputAttrNames = genOutputAttrNames(inputCol) val filteredOutputAttrNames = outputAttrNames.map { names => if (dropLast) { require(names.length > 1, From 4c6cc57136a60c577dd0ba2ae90521f7c18ae5d1 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 31 Oct 2017 13:37:08 +0000 Subject: [PATCH 09/13] Address comments. --- .../apache/spark/ml/feature/OneHotEncoder.scala | 4 ++++ .../spark/ml/feature/OneHotEncoderEstimator.scala | 14 ++++---------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 23fc209bcc542..9137d5fbb10bf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -41,8 +41,12 @@ import org.apache.spark.sql.types.{DoubleType, NumericType, StructType} * The output vectors are sparse. * * @see `StringIndexer` for converting categorical values into category indices + * @deprecated `OneHotEncoderEstimator` will be renamed `OneHotEncoder` and this `OneHotEncoder` + * will be removed in 3.0.0. */ @Since("1.4.0") +@deprecated("`OneHotEncoderEstimator` will be renamed `OneHotEncoder` and this `OneHotEncoder`" + + " will be removed in 3.0.0.", "2.3.0") class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala index b46323d3af8a4..e415071388265 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -63,7 +63,7 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid @Since("2.3.0") def getDropLast: Boolean = $(dropLast) - protected def checkParamsValidity(schema: StructType): Unit = { + protected def validateAndTransformSchema(schema: StructType): StructType = { val inputColNames = $(inputCols) val outputColNames = $(outputCols) val existingFields = schema.fields @@ -78,12 +78,9 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid require(!existingFields.exists(_.name == outputColName), s"Output column $outputColName already exists.") } - } - /** Prepares output columns with proper attributes by examining input columns. */ - protected def prepareSchemaWithOutputField(schema: StructType): StructType = { + // Prepares output columns with proper attributes by examining input columns. val inputFields = $(inputCols).map(schema(_)) - val outputColNames = $(outputCols) val outputFields = inputFields.zip(outputColNames).map { case (inputField, outputColName) => OneHotEncoderCommon.transformOutputColumnSchema( @@ -132,8 +129,7 @@ class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid: @Since("2.3.0") override def transformSchema(schema: StructType): StructType = { - checkParamsValidity(schema) - prepareSchemaWithOutputField(schema) + validateAndTransformSchema(schema) } @Since("2.3.0") @@ -235,13 +231,11 @@ class OneHotEncoderModel private[ml] ( val inputColNames = $(inputCols) val outputColNames = $(outputCols) - checkParamsValidity(schema) - require(inputColNames.length == categorySizes.length, s"The number of input columns ${inputColNames.length} must be the same as the number of " + s"features ${categorySizes.length} during fitting.") - val transformedSchema = prepareSchemaWithOutputField(schema) + val transformedSchema = validateAndTransformSchema(schema) verifyNumOfValues(transformedSchema) } From 32318faebd118509bdd0c0100e84c4755182ea27 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 10 Dec 2017 01:34:06 +0000 Subject: [PATCH 10/13] Address comment. --- .../ml/feature/OneHotEncoderEstimator.scala | 71 ++++++++++++------- .../feature/OneHotEncoderEstimatorSuite.scala | 50 +++++-------- 2 files changed, 66 insertions(+), 55 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala index e415071388265..b915fbb2f49e3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -38,14 +38,16 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid /** * Param for how to handle invalid data. - * Options are 'keep' (invalid data produces a vector of zeros) or 'error' (throw an error). + * Options are 'keep' (invalid data presented as an extra categorical feature) or + * 'error' (throw an error). * Default: "error" * @group param */ @Since("2.3.0") override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle invalid data " + - "Options are 'keep' (invalid data produces a vector of zeros) or error (throw an error).", + "Options are 'keep' (invalid data presented as an extra categorical feature) " + + "or error (throw an error).", ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids)) setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID) @@ -81,10 +83,11 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid // Prepares output columns with proper attributes by examining input columns. val inputFields = $(inputCols).map(schema(_)) + val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID val outputFields = inputFields.zip(outputColNames).map { case (inputField, outputColName) => OneHotEncoderCommon.transformOutputColumnSchema( - inputField, $(dropLast), outputColName) + inputField, $(dropLast), outputColName, keepInvalid) } StructType(schema.fields ++ outputFields) } @@ -102,6 +105,10 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid * @note This is different from scikit-learn's OneHotEncoder, which keeps all categories. * The output vectors are sparse. * + * When `handleInvalid` is configured to 'keep', an extra "category" indicating invalid values is + * added as last category. So when `dropLast` is true, invalid values are encoded as all-zeros + * vector. + * * @see `StringIndexer` for converting categorical values into category indices */ @Since("2.3.0") @@ -153,8 +160,9 @@ class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid: if (columnToScanIndices.length > 0) { val inputColNames = columnToScanIndices.map($(inputCols)(_)) val outputColNames = columnToScanIndices.map($(outputCols)(_)) + val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID val attrGroups = OneHotEncoderCommon.getOutputAttrGroupFromData( - dataset, $(dropLast), inputColNames, outputColNames) + dataset, $(dropLast), inputColNames, outputColNames, keepInvalid) attrGroups.zip(columnToScanIndices).foreach { case (attrGroup, idx) => categorySizes(idx) = attrGroup.size } @@ -193,19 +201,29 @@ class OneHotEncoderModel private[ml] ( val emptyIndices = Array.empty[Int] val dropLast = getDropLast val handleInvalid = getHandleInvalid + val keepInvalid = handleInvalid == OneHotEncoderEstimator.KEEP_INVALID udf { (label: Double, size: Int) => - if (label < size) { + val numCategory = if (!dropLast && keepInvalid) { + // When `handleInvalid` is 'keep' and `dropLast` is false, the last category is + // for invalid data. + size - 1 + } else { + size + } + + if (label < numCategory) { Vectors.sparse(size, Array(label.toInt), oneValue) - } else if (label == size && dropLast) { + } else if (label == numCategory && dropLast && !keepInvalid) { + Vectors.sparse(size, emptyIndices, emptyValues) + } else if (dropLast && keepInvalid) { Vectors.sparse(size, emptyIndices, emptyValues) + } else if (keepInvalid) { + Vectors.sparse(size, Array(size - 1), oneValue) } else { - if (handleInvalid == OneHotEncoderEstimator.ERROR_INVALID) { - throw new SparkException(s"Unseen value: $label. To handle unseen values, " + - s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.") - } else { - Vectors.sparse(size, emptyIndices, emptyValues) - } + assert(handleInvalid == OneHotEncoderEstimator.ERROR_INVALID) + throw new SparkException(s"Unseen value: $label. To handle unseen values, " + + s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.") } } } @@ -262,12 +280,6 @@ class OneHotEncoderModel private[ml] ( @Since("2.3.0") override def transform(dataset: Dataset[_]): DataFrame = { - if (getDropLast && getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID) { - throw new IllegalArgumentException("When Param handleInvalid is set to " + - s"${OneHotEncoderEstimator.KEEP_INVALID}, Param dropLast can't be true, " + - "because last category and invalid values will conflict in encoded vector.") - } - val transformedSchema = transformSchema(dataset.schema, logging = true) val encodedColumns = (0 until $(inputCols).length).map { idx => @@ -390,13 +402,16 @@ private[feature] object OneHotEncoderCommon { def transformOutputColumnSchema( inputCol: StructField, dropLast: Boolean, - outputColName: String): StructField = { + outputColName: String, + keepInvalid: Boolean = false): StructField = { val outputAttrNames = genOutputAttrNames(inputCol) val filteredOutputAttrNames = outputAttrNames.map { names => - if (dropLast) { + if (dropLast && !keepInvalid) { require(names.length > 1, s"The input column ${inputCol.name} should have at least two distinct values.") names.dropRight(1) + } else if (!dropLast && keepInvalid) { + names ++ Seq("invalidValues") } else { names } @@ -413,7 +428,8 @@ private[feature] object OneHotEncoderCommon { dataset: Dataset[_], dropLast: Boolean, inputColNames: Seq[String], - outputColNames: Seq[String]): Seq[AttributeGroup] = { + outputColNames: Seq[String], + handleInvalid: Boolean = false): Seq[AttributeGroup] = { // The RDD approach has advantage of early-stop if any values are invalid. It seems that // DataFrame ops don't have equivalent functions. val columns = inputColNames.map { inputColName => @@ -440,7 +456,7 @@ private[feature] object OneHotEncoderCommon { ).map(_.toInt + 1) outputColNames.zip(numAttrsArray).map { case (outputColName, numAttrs) => - createAttrGroupForAttrNames(outputColName, dropLast, numAttrs) + createAttrGroupForAttrNames(outputColName, dropLast, numAttrs, handleInvalid) } } @@ -448,9 +464,16 @@ private[feature] object OneHotEncoderCommon { def createAttrGroupForAttrNames( outputColName: String, dropLast: Boolean, - numAttrs: Int): AttributeGroup = { + numAttrs: Int, + keepInvalid: Boolean = false): AttributeGroup = { val outputAttrNames = Array.tabulate(numAttrs)(_.toString) - val filtered = if (dropLast) outputAttrNames.dropRight(1) else outputAttrNames + val filtered = if (dropLast && !keepInvalid) { + outputAttrNames.dropRight(1) + } else if (!dropLast && keepInvalid) { + outputAttrNames ++ Seq("invalidValues") + } else { + outputAttrNames + } genOutputAttrGroup(Some(filtered), outputColName) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala index 01490cfc0a86a..56be70c9941b5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala @@ -251,43 +251,31 @@ class OneHotEncoderEstimatorSuite } test("Keep on invalid values") { - val trainingData = Seq((0, 0), (1, 1)) + val trainingData = Seq((0, 0), (1, 1), (2, 2)) val trainingDF = trainingData.toDF("id", "a") - val testData = Seq((0, 0), (1, 2)) + val testData = Seq((0, 0), (1, 1), (2, 3)) val testDF = testData.toDF("id", "a") - val encoder = new OneHotEncoderEstimator() - .setInputCols(Array("a")) - .setOutputCols(Array("encoded")) - .setHandleInvalid("keep") - .setDropLast(false) - - val model = encoder.fit(trainingDF) - val encoded = model.transform(testDF) + val dropLasts = Seq(false, true) + val expectedOutput = Seq( + Set((0, Seq(1.0, 0.0, 0.0, 0.0)), (1, Seq(0.0, 1.0, 0.0, 0.0)), (2, Seq(0.0, 0.0, 0.0, 1.0))), + Set((0, Seq(1.0, 0.0, 0.0)), (1, Seq(0.0, 1.0, 0.0)), (2, Seq(0.0, 0.0, 0.0)))) - val output = encoded.select("id", "encoded").rdd.map { r => - val vec = r.getAs[Vector](1) - (r.getInt(0), vec(0), vec(1)) - }.collect().toSet - val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0)) - assert(output === expected) - } - - test("Can't set dropLast as true and keep on invalid values") { - val trainingData = Seq((0, 0), (1, 1)) - val trainingDF = trainingData.toDF("id", "a") - val testData = Seq((0, 0), (1, 2)) - val testDF = testData.toDF("id", "a") + dropLasts.zipWithIndex.foreach { case (dropLast, idx) => + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("a")) + .setOutputCols(Array("encoded")) + .setHandleInvalid("keep") + .setDropLast(dropLast) - val encoder = new OneHotEncoderEstimator() - .setInputCols(Array("a")) - .setOutputCols(Array("encoded")) - .setHandleInvalid("keep") + val model = encoder.fit(trainingDF) + val encoded = model.transform(testDF) - val model = encoder.fit(trainingDF) - val err = intercept[IllegalArgumentException] { - model.transform(testDF) + val output = encoded.select("id", "encoded").rdd.map { r => + val vec = r.getAs[Vector](1) + (r.getInt(0), vec.toArray.toSeq) + }.collect().toSet + assert(output === expectedOutput(idx)) } - err.getMessage.contains("When Param handleInvalid is set to keep, Param dropLast can't be true") } } From 144f07d5e92bf5cbc10cb2dc990fc32f15405977 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 26 Dec 2017 01:51:23 +0000 Subject: [PATCH 11/13] Address comments. --- .../spark/ml/feature/OneHotEncoder.scala | 4 +- .../ml/feature/OneHotEncoderEstimator.scala | 116 +++-- .../feature/OneHotEncoderEstimatorSuite.scala | 414 ++++++++++++------ 3 files changed, 348 insertions(+), 186 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 9137d5fbb10bf..5ab6c2dde667a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -90,7 +90,7 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e s"Output column $outputColName already exists.") val outputField = OneHotEncoderCommon.transformOutputColumnSchema( - schema(inputColName), $(dropLast), outputColName) + schema(inputColName), outputColName, $(dropLast)) val outputFields = inputFields :+ outputField StructType(outputFields) } @@ -106,7 +106,7 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e val outputAttrGroup = if (outputAttrGroupFromSchema.size < 0) { OneHotEncoderCommon.getOutputAttrGroupFromData( - dataset, $(dropLast), Seq(inputColName), Seq(outputColName))(0) + dataset, Seq(inputColName), Seq(outputColName), $(dropLast))(0) } else { outputAttrGroupFromSchema } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala index b915fbb2f49e3..d489f0a12f96e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -65,7 +65,8 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid @Since("2.3.0") def getDropLast: Boolean = $(dropLast) - protected def validateAndTransformSchema(schema: StructType): StructType = { + protected def validateAndTransformSchema( + schema: StructType, dropLast: Boolean, keepInvalid: Boolean): StructType = { val inputColNames = $(inputCols) val outputColNames = $(outputCols) val existingFields = schema.fields @@ -74,22 +75,19 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid s"The number of input columns ${inputColNames.length} must be the same as the number of " + s"output columns ${outputColNames.length}.") - inputColNames.zip(outputColNames).map { case (inputColName, outputColName) => - require(schema(inputColName).dataType.isInstanceOf[NumericType], - s"Input column must be of type NumericType but got ${schema(inputColName).dataType}") - require(!existingFields.exists(_.name == outputColName), - s"Output column $outputColName already exists.") - } + // Input columns must be NumericType. + inputColNames.foreach(SchemaUtils.checkNumericType(schema, _)) // Prepares output columns with proper attributes by examining input columns. val inputFields = $(inputCols).map(schema(_)) - val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID val outputFields = inputFields.zip(outputColNames).map { case (inputField, outputColName) => OneHotEncoderCommon.transformOutputColumnSchema( - inputField, $(dropLast), outputColName, keepInvalid) + inputField, outputColName, dropLast, keepInvalid) + } + outputFields.foldLeft(schema) { case (newSchema, outputField) => + SchemaUtils.appendColumn(newSchema, outputField) } - StructType(schema.fields ++ outputFields) } } @@ -109,6 +107,9 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid * added as last category. So when `dropLast` is true, invalid values are encoded as all-zeros * vector. * + * @note When encoding multi-column by using `inputCols` and `outputCols` params, input/output cols + * come in pairs, specified by the order in the arrays, and each pair is treated independently. + * * @see `StringIndexer` for converting categorical values into category indices */ @Since("2.3.0") @@ -136,7 +137,9 @@ class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid: @Since("2.3.0") override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + // When fitting data, we want the the plain number of categories without `handleInvalid` and + // `dropLast` taken into account. + validateAndTransformSchema(schema, dropLast = false, keepInvalid = false) } @Since("2.3.0") @@ -160,9 +163,11 @@ class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid: if (columnToScanIndices.length > 0) { val inputColNames = columnToScanIndices.map($(inputCols)(_)) val outputColNames = columnToScanIndices.map($(outputCols)(_)) - val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID + + // When fitting data, we want the plain number of categories without `handleInvalid` and + // `dropLast` taken into account. val attrGroups = OneHotEncoderCommon.getOutputAttrGroupFromData( - dataset, $(dropLast), inputColNames, outputColNames, keepInvalid) + dataset, inputColNames, outputColNames, dropLast = false) attrGroups.zip(columnToScanIndices).foreach { case (attrGroup, idx) => categorySizes(idx) = attrGroup.size } @@ -195,6 +200,26 @@ class OneHotEncoderModel private[ml] ( import OneHotEncoderModel._ + // The actual number of categories varies due to different setting of `dropLast` and + // `handleInvalid`. + private def configedCategorySizes: Array[Int] = { + val dropLast = getDropLast + val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID + + if (!dropLast && keepInvalid) { + // When `handleInvalid` is "keep", an extra category is added as last category + // for invalid data. + categorySizes.map(_ + 1) + } else if (dropLast && !keepInvalid) { + // When `dropLast` is true, the last category is removed. + categorySizes.map(_ - 1) + } else { + // When `dropLast` is true and `handleInvalid` is "keep", the extra category for invalid + // data is removed. Thus, it is the same as the plain number of categories. + categorySizes + } + } + private def encoder: UserDefinedFunction = { val oneValue = Array(1.0) val emptyValues = Array.empty[Double] @@ -205,21 +230,29 @@ class OneHotEncoderModel private[ml] ( udf { (label: Double, size: Int) => val numCategory = if (!dropLast && keepInvalid) { - // When `handleInvalid` is 'keep' and `dropLast` is false, the last category is + // When `dropLast` is false and `handleInvalid` is "keep", the last category is // for invalid data. size - 1 } else { size } - if (label < numCategory) { + if (label < 0) { + throw new SparkException(s"Negative value: $label. Input can't be negative.") + } else if (label < numCategory) { Vectors.sparse(size, Array(label.toInt), oneValue) } else if (label == numCategory && dropLast && !keepInvalid) { + // When `dropLast` is true and `handleInvalid` is not "keep", + // the last category is removed. Vectors.sparse(size, emptyIndices, emptyValues) } else if (dropLast && keepInvalid) { + // When `dropLast` is true and `handleInvalid` is "keep", + // invalid data is encoded to the removed last category. Vectors.sparse(size, emptyIndices, emptyValues) } else if (keepInvalid) { - Vectors.sparse(size, Array(size - 1), oneValue) + // When `dropLast` is false and `handleInvalid` is "keep", + // invalid data is encoded to the last category. + Vectors.sparse(size, Array(numCategory), oneValue) } else { assert(handleInvalid == OneHotEncoderEstimator.ERROR_INVALID) throw new SparkException(s"Unseen value: $label. To handle unseen values, " + @@ -253,26 +286,29 @@ class OneHotEncoderModel private[ml] ( s"The number of input columns ${inputColNames.length} must be the same as the number of " + s"features ${categorySizes.length} during fitting.") - val transformedSchema = validateAndTransformSchema(schema) + val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID + val transformedSchema = validateAndTransformSchema(schema, dropLast = $(dropLast), + keepInvalid = keepInvalid) verifyNumOfValues(transformedSchema) } /** * If the metadata of input columns also specifies the number of categories, we need to - * compare with expected category number obtained during fitting. Mismatched numbers will - * cause exception. + * compare with expected category number with `handleInvalid` and `dropLast` taken into + * account. Mismatched numbers will cause exception. */ private def verifyNumOfValues(schema: StructType): StructType = { $(outputCols).zipWithIndex.foreach { case (outputColName, idx) => val inputColName = $(inputCols)(idx) val attrGroup = AttributeGroup.fromStructField(schema(outputColName)) - // If the input metadata specifies number of category, - // compare with expected category number. + // If the input metadata specifies number of category for output column, + // comparing with expected category number with `handleInvalid` and + // `dropLast` taken into account. if (attrGroup.attributes.nonEmpty) { - require(attrGroup.size == categorySizes(idx), "OneHotEncoderModel expected " + - s"${categorySizes(idx)} categorical values for input column ${inputColName}, but " + - s"the input column had metadata specifying ${attrGroup.size} values.") + require(attrGroup.size == configedCategorySizes(idx), "OneHotEncoderModel expected " + + s"${configedCategorySizes(idx)} categorical values for input column ${inputColName}, " + + s"but the input column had metadata specifying ${attrGroup.size} values.") } } schema @@ -281,6 +317,7 @@ class OneHotEncoderModel private[ml] ( @Since("2.3.0") override def transform(dataset: Dataset[_]): DataFrame = { val transformedSchema = transformSchema(dataset.schema, logging = true) + val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID val encodedColumns = (0 until $(inputCols).length).map { idx => val inputColName = $(inputCols)(idx) @@ -290,13 +327,13 @@ class OneHotEncoderModel private[ml] ( AttributeGroup.fromStructField(transformedSchema(outputColName)) val metadata = if (outputAttrGroupFromSchema.size < 0) { - OneHotEncoderCommon.createAttrGroupForAttrNames(outputColName, false, - categorySizes(idx)).toMetadata() + OneHotEncoderCommon.createAttrGroupForAttrNames(outputColName, + categorySizes(idx), $(dropLast), keepInvalid).toMetadata() } else { outputAttrGroupFromSchema.toMetadata() } - encoder(col(inputColName).cast(DoubleType), lit(categorySizes(idx))) + encoder(col(inputColName).cast(DoubleType), lit(configedCategorySizes(idx))) .as(outputColName, metadata) } dataset.withColumns($(outputCols), encodedColumns) @@ -376,7 +413,7 @@ private[feature] object OneHotEncoderCommon { } case _: NumericAttribute => throw new RuntimeException( - s"The input column ${inputCol.name} cannot be numeric.") + s"The input column ${inputCol.name} cannot be continuous-value.") case _ => None // optimistic about unknown attributes } @@ -401,8 +438,8 @@ private[feature] object OneHotEncoderCommon { */ def transformOutputColumnSchema( inputCol: StructField, - dropLast: Boolean, outputColName: String, + dropLast: Boolean, keepInvalid: Boolean = false): StructField = { val outputAttrNames = genOutputAttrNames(inputCol) val filteredOutputAttrNames = outputAttrNames.map { names => @@ -426,10 +463,9 @@ private[feature] object OneHotEncoderCommon { */ def getOutputAttrGroupFromData( dataset: Dataset[_], - dropLast: Boolean, inputColNames: Seq[String], outputColNames: Seq[String], - handleInvalid: Boolean = false): Seq[AttributeGroup] = { + dropLast: Boolean): Seq[AttributeGroup] = { // The RDD approach has advantage of early-stop if any values are invalid. It seems that // DataFrame ops don't have equivalent functions. val columns = inputColNames.map { inputColName => @@ -441,31 +477,35 @@ private[feature] object OneHotEncoderCommon { (0 until numOfColumns).map(idx => row.getDouble(idx)).toArray }.treeAggregate(new Array[Double](numOfColumns))( (maxValues, curValues) => { - (0 until numOfColumns).map { idx => + (0 until numOfColumns).foreach { idx => val x = curValues(idx) assert(x <= Int.MaxValue, s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x.") assert(x >= 0.0 && x == x.toInt, s"Values from column ${inputColNames(idx)} must be indices, but got $x.") - math.max(maxValues(idx), x) - }.toArray + maxValues(idx) = math.max(maxValues(idx), x) + } + maxValues }, (m0, m1) => { - (0 until numOfColumns).map(idx => math.max(m0(idx), m1(idx))).toArray + (0 until numOfColumns).foreach { idx => + m0(idx) = math.max(m0(idx), m1(idx)) + } + m0 } ).map(_.toInt + 1) outputColNames.zip(numAttrsArray).map { case (outputColName, numAttrs) => - createAttrGroupForAttrNames(outputColName, dropLast, numAttrs, handleInvalid) + createAttrGroupForAttrNames(outputColName, numAttrs, dropLast, keepInvalid = false) } } /** Creates an `AttributeGroup` with the required number of `BinaryAttribute`. */ def createAttrGroupForAttrNames( outputColName: String, - dropLast: Boolean, numAttrs: Int, - keepInvalid: Boolean = false): AttributeGroup = { + dropLast: Boolean, + keepInvalid: Boolean): AttributeGroup = { val outputAttrNames = Array.tabulate(numAttrs)(_.toString) val filtered = if (dropLast && !keepInvalid) { outputAttrNames.dropRight(1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala index 56be70c9941b5..9b9dc435b54f8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.ml.feature import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} -import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ @@ -32,72 +32,67 @@ class OneHotEncoderEstimatorSuite import testImplicits._ - def stringIndexed(): DataFrame = stringIndexedMultipleCols().select("id", "label", "labelIndex") - - def stringIndexedMultipleCols(): DataFrame = { - val data = Seq( - (0, "a", "A"), - (1, "b", "B"), - (2, "c", "D"), - (3, "a", "A"), - (4, "a", "B"), - (5, "c", "C")) - val df = data.toDF("id", "label", "label2") - val indexer = new StringIndexer() - .setInputCol("label") - .setOutputCol("labelIndex") - .fit(df) - val df2 = indexer.transform(df) - val indexer2 = new StringIndexer() - .setInputCol("label2") - .setOutputCol("labelIndex2") - .fit(df2) - indexer2.transform(df2) - } - test("params") { ParamsSuite.checkParams(new OneHotEncoderEstimator) } test("OneHotEncoderEstimator dropLast = false") { - val transformed = stringIndexed() + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + val encoder = new OneHotEncoderEstimator() - .setInputCols(Array("labelIndex")) - .setOutputCols(Array("labelVec")) + .setInputCols(Array("input")) + .setOutputCols(Array("output")) assert(encoder.getDropLast === true) encoder.setDropLast(false) assert(encoder.getDropLast === false) - val model = encoder.fit(transformed) - val encoded = model.transform(transformed) - - val output = encoded.select("id", "labelVec").rdd.map { r => - val vec = r.getAs[Vector](1) - (r.getInt(0), vec(0), vec(1), vec(2)) - }.collect().toSet - // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0), - (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0)) - assert(output === expected) + val model = encoder.fit(df) + val encoded = model.transform(df) + encoded.select("output", "expected").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } } test("OneHotEncoderEstimator dropLast = true") { - val transformed = stringIndexed() + val data = Seq( + Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + Row(2.0, Vectors.sparse(2, Seq())), + Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(2, Seq()))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + val encoder = new OneHotEncoderEstimator() - .setInputCols(Array("labelIndex")) - .setOutputCols(Array("labelVec")) - - val model = encoder.fit(transformed) - val encoded = model.transform(transformed) - - val output = encoded.select("id", "labelVec").rdd.map { r => - val vec = r.getAs[Vector](1) - (r.getInt(0), vec(0), vec(1)) - }.collect().toSet - // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0), - (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0)) - assert(output === expected) + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + + val model = encoder.fit(df) + val encoded = model.transform(df) + encoded.select("output", "expected").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } } test("input column with ML attribute") { @@ -142,95 +137,109 @@ class OneHotEncoderEstimatorSuite } test("OneHotEncoderEstimator with varying types") { - val df = stringIndexed() + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + val dfWithTypes = df - .withColumn("shortLabel", df("labelIndex").cast(ShortType)) - .withColumn("longLabel", df("labelIndex").cast(LongType)) - .withColumn("intLabel", df("labelIndex").cast(IntegerType)) - .withColumn("floatLabel", df("labelIndex").cast(FloatType)) - .withColumn("decimalLabel", df("labelIndex").cast(DecimalType(10, 0))) - val cols = Array("labelIndex", "shortLabel", "longLabel", "intLabel", - "floatLabel", "decimalLabel") + .withColumn("shortInput", df("input").cast(ShortType)) + .withColumn("longInput", df("input").cast(LongType)) + .withColumn("intInput", df("input").cast(IntegerType)) + .withColumn("floatInput", df("input").cast(FloatType)) + .withColumn("decimalInput", df("input").cast(DecimalType(10, 0))) + + val cols = Array("input", "shortInput", "longInput", "intInput", + "floatInput", "decimalInput") for (col <- cols) { val encoder = new OneHotEncoderEstimator() .setInputCols(Array(col)) - .setOutputCols(Array("labelVec")) + .setOutputCols(Array("output")) .setDropLast(false) + val model = encoder.fit(dfWithTypes) val encoded = model.transform(dfWithTypes) - val output = encoded.select("id", "labelVec").rdd.map { r => - val vec = r.getAs[Vector](1) - (r.getInt(0), vec(0), vec(1), vec(2)) - }.collect().toSet - // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0), - (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0)) - assert(output === expected) + encoded.select("output", "expected").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } } } test("OneHotEncoderEstimator: encoding multiple columns and dropLast = false") { - val transformed = stringIndexedMultipleCols() + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0))), 3.0, Vectors.sparse(4, Seq((3, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), 0.0, Vectors.sparse(4, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 1.0, Vectors.sparse(4, Seq((1, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 0.0, Vectors.sparse(4, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input1", DoubleType), + StructField("expected1", new VectorUDT), + StructField("input2", DoubleType), + StructField("expected2", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + val encoder = new OneHotEncoderEstimator() - .setInputCols(Array("labelIndex", "labelIndex2")) - .setOutputCols(Array("labelVec", "labelVec2")) + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("output1", "output2")) assert(encoder.getDropLast === true) encoder.setDropLast(false) assert(encoder.getDropLast === false) - val model = encoder.fit(transformed) - val encoded = model.transform(transformed) - - // Verify 1st column. - val output = encoded.select("id", "labelVec").rdd.map { r => - val vec = r.getAs[Vector](1) - (r.getInt(0), vec(0), vec(1), vec(2)) - }.collect().toSet - // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0), - (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0)) - assert(output === expected) - - // Verify 2nd column. - val output2 = encoded.select("id", "labelVec2").rdd.map { r => - val vec = r.getAs[Vector](1) - (r.getInt(0), vec(0), vec(1), vec(2), vec(3)) - }.collect().toSet - // A -> 1, B -> 0, C -> 3, D -> 2 - val expected2 = Set((0, 0.0, 1.0, 0.0, 0.0), (1, 1.0, 0.0, 0.0, 0.0), (2, 0.0, 0.0, 1.0, 0.0), - (3, 0.0, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0, 0.0), (5, 0.0, 0.0, 0.0, 1.0)) - assert(output2 === expected2) + val model = encoder.fit(df) + val encoded = model.transform(df) + encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3)) + }.collect().foreach { case (vec1, vec2, vec3, vec4) => + assert(vec1 === vec2) + assert(vec3 === vec4) + } } test("OneHotEncoderEstimator: encoding multiple columns and dropLast = true") { - val transformed = stringIndexedMultipleCols() + val data = Seq( + Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 2.0, Vectors.sparse(3, Seq((2, 1.0)))), + Row(1.0, Vectors.sparse(2, Seq((1, 1.0))), 3.0, Vectors.sparse(3, Seq())), + Row(2.0, Vectors.sparse(2, Seq()), 0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 1.0, Vectors.sparse(3, Seq((1, 1.0)))), + Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(2, Seq()), 2.0, Vectors.sparse(3, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input1", DoubleType), + StructField("expected1", new VectorUDT), + StructField("input2", DoubleType), + StructField("expected2", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + val encoder = new OneHotEncoderEstimator() - .setInputCols(Array("labelIndex", "labelIndex2")) - .setOutputCols(Array("labelVec", "labelVec2")) - - val model = encoder.fit(transformed) - val encoded = model.transform(transformed) - - // Verify 1st column. - val output = encoded.select("id", "labelVec").rdd.map { r => - val vec = r.getAs[Vector](1) - (r.getInt(0), vec(0), vec(1)) - }.collect().toSet - // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0), - (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0)) - assert(output === expected) - - // Verify 2nd column. - val output2 = encoded.select("id", "labelVec2").rdd.map { r => - val vec = r.getAs[Vector](1) - (r.getInt(0), vec(0), vec(1), vec(2)) - }.collect().toSet - // A -> 1, B -> 0, C -> 3, D -> 2 - val expected2 = Set((0, 0.0, 1.0, 0.0), (1, 1.0, 0.0, 0.0), (2, 0.0, 0.0, 1.0), - (3, 0.0, 1.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 0.0, 0.0)) - assert(output2 === expected2) + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("output1", "output2")) + + val model = encoder.fit(df) + val encoded = model.transform(df) + encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3)) + }.collect().foreach { case (vec1, vec2, vec3, vec4) => + assert(vec1 === vec2) + assert(vec3 === vec4) + } } test("Throw error on invalid values") { @@ -250,32 +259,145 @@ class OneHotEncoderEstimatorSuite err.getMessage.contains("Unseen value: 3.0. To handle unseen values") } - test("Keep on invalid values") { - val trainingData = Seq((0, 0), (1, 1), (2, 2)) - val trainingDF = trainingData.toDF("id", "a") - val testData = Seq((0, 0), (1, 1), (2, 3)) - val testDF = testData.toDF("id", "a") + test("Can't transform on negative input") { + val trainingDF = Seq((0, 0), (1, 1), (2, 2)).toDF("a", "b") + val testDF = Seq((0, 0), (-1, 2), (1, 3)).toDF("a", "b") - val dropLasts = Seq(false, true) - val expectedOutput = Seq( - Set((0, Seq(1.0, 0.0, 0.0, 0.0)), (1, Seq(0.0, 1.0, 0.0, 0.0)), (2, Seq(0.0, 0.0, 0.0, 1.0))), - Set((0, Seq(1.0, 0.0, 0.0)), (1, Seq(0.0, 1.0, 0.0)), (2, Seq(0.0, 0.0, 0.0)))) + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("a")) + .setOutputCols(Array("encoded")) - dropLasts.zipWithIndex.foreach { case (dropLast, idx) => - val encoder = new OneHotEncoderEstimator() - .setInputCols(Array("a")) - .setOutputCols(Array("encoded")) - .setHandleInvalid("keep") - .setDropLast(dropLast) - - val model = encoder.fit(trainingDF) - val encoded = model.transform(testDF) - - val output = encoded.select("id", "encoded").rdd.map { r => - val vec = r.getAs[Vector](1) - (r.getInt(0), vec.toArray.toSeq) - }.collect().toSet - assert(output === expectedOutput(idx)) + val model = encoder.fit(trainingDF) + val err = intercept[SparkException] { + model.transform(testDF).collect() + } + err.getMessage.contains("Negative value: -1.0. Input can't be negative") + } + + test("Keep on invalid values: dropLast = false") { + val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input") + + val testData = Seq( + Row(0.0, Vectors.sparse(4, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(4, Seq((1, 1.0)))), + Row(3.0, Vectors.sparse(4, Seq((3, 1.0))))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val testDF = spark.createDataFrame(sc.parallelize(testData), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + .setHandleInvalid("keep") + .setDropLast(false) + + val model = encoder.fit(trainingDF) + val encoded = model.transform(testDF) + encoded.select("output", "expected").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } + } + + test("Keep on invalid values: dropLast = true") { + val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input") + + val testData = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))), + Row(3.0, Vectors.sparse(3, Seq()))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val testDF = spark.createDataFrame(sc.parallelize(testData), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + .setHandleInvalid("keep") + .setDropLast(true) + + val model = encoder.fit(trainingDF) + val encoded = model.transform(testDF) + encoded.select("output", "expected").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } + } + + test("OneHotEncoderModel changes dropLast") { + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0))), Vectors.sparse(2, Seq((1, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), Vectors.sparse(2, Seq())), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), Vectors.sparse(2, Seq()))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected1", new VectorUDT), + StructField("expected2", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + + val model = encoder.fit(df) + + model.setDropLast(false) + val encoded1 = model.transform(df) + encoded1.select("output", "expected1").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) } + + model.setDropLast(true) + val encoded2 = model.transform(df) + encoded2.select("output", "expected2").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } + } + + test("OneHotEncoderModel changes handleInvalid") { + val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input") + + val testData = Seq( + Row(0.0, Vectors.sparse(4, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(4, Seq((1, 1.0)))), + Row(3.0, Vectors.sparse(4, Seq((3, 1.0))))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val testDF = spark.createDataFrame(sc.parallelize(testData), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + + val model = encoder.fit(trainingDF) + model.setHandleInvalid("error") + + val err = intercept[SparkException] { + model.transform(testDF).show + } + err.getMessage.contains("Unseen value: 3.0. To handle unseen values") + + model.setHandleInvalid("keep") + model.transform(testDF).collect() } } From 587ad427a6682e98e1fefe592ecf278c674767f3 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 26 Dec 2017 02:14:33 +0000 Subject: [PATCH 12/13] Add one more test. --- .../feature/OneHotEncoderEstimatorSuite.scala | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala index 9b9dc435b54f8..1d3f845586426 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala @@ -393,11 +393,29 @@ class OneHotEncoderEstimatorSuite model.setHandleInvalid("error") val err = intercept[SparkException] { - model.transform(testDF).show + model.transform(testDF).collect() } err.getMessage.contains("Unseen value: 3.0. To handle unseen values") model.setHandleInvalid("keep") model.transform(testDF).collect() } + + test("Transforming on mismatched attributes") { + val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") + val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("size") + .select(col("size").as("size", attr.toMetadata())) + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("size")) + .setOutputCols(Array("encoded")) + val model = encoder.fit(df) + + val testAttr = NominalAttribute.defaultAttr.withValues("tiny", "small", "medium", "large") + val testDF = Seq(0.0, 1.0, 2.0, 3.0).map(Tuple1.apply).toDF("size") + .select(col("size").as("size", testAttr.toMetadata())) + val err = intercept[Exception] { + model.transform(testDF).collect() + } + err.getMessage.contains("OneHotEncoderModel expected 2 categorical values") + } } From e94496a5c8b08fdc437e9623dfba2b0d80998263 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 31 Dec 2017 04:45:21 +0000 Subject: [PATCH 13/13] Address comments. --- .../ml/feature/OneHotEncoderEstimator.scala | 67 ++++++++++--------- 1 file changed, 35 insertions(+), 32 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala index d489f0a12f96e..074622d41e28d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -137,14 +137,19 @@ class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid: @Since("2.3.0") override def transformSchema(schema: StructType): StructType = { - // When fitting data, we want the the plain number of categories without `handleInvalid` and - // `dropLast` taken into account. - validateAndTransformSchema(schema, dropLast = false, keepInvalid = false) + val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID + validateAndTransformSchema(schema, dropLast = $(dropLast), + keepInvalid = keepInvalid) } @Since("2.3.0") override def fit(dataset: Dataset[_]): OneHotEncoderModel = { - val transformedSchema = transformSchema(dataset.schema) + transformSchema(dataset.schema) + + // Compute the plain number of categories without `handleInvalid` and + // `dropLast` taken into account. + val transformedSchema = validateAndTransformSchema(dataset.schema, dropLast = false, + keepInvalid = false) val categorySizes = new Array[Int]($(outputCols).length) val columnToScanIndices = $(outputCols).zipWithIndex.flatMap { case (outputColName, idx) => @@ -200,23 +205,23 @@ class OneHotEncoderModel private[ml] ( import OneHotEncoderModel._ - // The actual number of categories varies due to different setting of `dropLast` and - // `handleInvalid`. - private def configedCategorySizes: Array[Int] = { + // Returns the category size for a given index with `dropLast` and `handleInvalid` + // taken into account. + private def configedCategorySize(orgCategorySize: Int, idx: Int): Int = { val dropLast = getDropLast val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID if (!dropLast && keepInvalid) { // When `handleInvalid` is "keep", an extra category is added as last category // for invalid data. - categorySizes.map(_ + 1) + orgCategorySize + 1 } else if (dropLast && !keepInvalid) { // When `dropLast` is true, the last category is removed. - categorySizes.map(_ - 1) + orgCategorySize - 1 } else { // When `dropLast` is true and `handleInvalid` is "keep", the extra category for invalid // data is removed. Thus, it is the same as the plain number of categories. - categorySizes + orgCategorySize } } @@ -228,31 +233,28 @@ class OneHotEncoderModel private[ml] ( val handleInvalid = getHandleInvalid val keepInvalid = handleInvalid == OneHotEncoderEstimator.KEEP_INVALID - udf { (label: Double, size: Int) => - val numCategory = if (!dropLast && keepInvalid) { - // When `dropLast` is false and `handleInvalid` is "keep", the last category is - // for invalid data. - size - 1 - } else { - size - } + // The udf performed on input data. The first parameter is the input value. The second + // parameter is the index of input. + udf { (label: Double, idx: Int) => + val plainNumCategories = categorySizes(idx) + val size = configedCategorySize(plainNumCategories, idx) if (label < 0) { throw new SparkException(s"Negative value: $label. Input can't be negative.") - } else if (label < numCategory) { - Vectors.sparse(size, Array(label.toInt), oneValue) - } else if (label == numCategory && dropLast && !keepInvalid) { + } else if (label == size && dropLast && !keepInvalid) { // When `dropLast` is true and `handleInvalid` is not "keep", // the last category is removed. Vectors.sparse(size, emptyIndices, emptyValues) - } else if (dropLast && keepInvalid) { - // When `dropLast` is true and `handleInvalid` is "keep", - // invalid data is encoded to the removed last category. - Vectors.sparse(size, emptyIndices, emptyValues) - } else if (keepInvalid) { - // When `dropLast` is false and `handleInvalid` is "keep", - // invalid data is encoded to the last category. - Vectors.sparse(size, Array(numCategory), oneValue) + } else if (label >= plainNumCategories && keepInvalid) { + // When `handleInvalid` is "keep", encodes invalid data to last category (and removed + // if `dropLast` is true) + if (dropLast) { + Vectors.sparse(size, emptyIndices, emptyValues) + } else { + Vectors.sparse(size, Array(size - 1), oneValue) + } + } else if (label < plainNumCategories) { + Vectors.sparse(size, Array(label.toInt), oneValue) } else { assert(handleInvalid == OneHotEncoderEstimator.ERROR_INVALID) throw new SparkException(s"Unseen value: $label. To handle unseen values, " + @@ -306,8 +308,9 @@ class OneHotEncoderModel private[ml] ( // comparing with expected category number with `handleInvalid` and // `dropLast` taken into account. if (attrGroup.attributes.nonEmpty) { - require(attrGroup.size == configedCategorySizes(idx), "OneHotEncoderModel expected " + - s"${configedCategorySizes(idx)} categorical values for input column ${inputColName}, " + + val numCategories = configedCategorySize(categorySizes(idx), idx) + require(attrGroup.size == numCategories, "OneHotEncoderModel expected " + + s"$numCategories categorical values for input column ${inputColName}, " + s"but the input column had metadata specifying ${attrGroup.size} values.") } } @@ -333,7 +336,7 @@ class OneHotEncoderModel private[ml] ( outputAttrGroupFromSchema.toMetadata() } - encoder(col(inputColName).cast(DoubleType), lit(configedCategorySizes(idx))) + encoder(col(inputColName).cast(DoubleType), lit(idx)) .as(outputColName, metadata) } dataset.withColumns($(outputCols), encodedColumns)