From b227e3b2908b8d236f3739b30c6ba369d2fdf867 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 31 Oct 2017 22:45:34 +0800 Subject: [PATCH 1/8] init pr --- .../spark/ml/feature/StringIndexer.scala | 244 ++++++++++++------ .../ml/param/shared/SharedParamsCodeGen.scala | 1 + .../spark/ml/param/shared/sharedParams.scala | 15 ++ .../spark/ml/feature/StringIndexerSuite.scala | 6 +- 4 files changed, 185 insertions(+), 81 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 2679ec310c47..2196f5137194 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -26,18 +26,20 @@ import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model, Transformer} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, + HasOutputCol, HasOutputCols} import org.apache.spark.ml.util._ -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.{Column, DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.util.VersionUtils.majorMinorVersion import org.apache.spark.util.collection.OpenHashMap /** * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. */ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid with HasInputCol - with HasOutputCol { + with HasOutputCol with HasInputCols with HasOutputCols { /** * Param for how to handle invalid data (unseen labels or NULL values). @@ -79,19 +81,39 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi @Since("2.3.0") def getStringOrderType: String = $(stringOrderType) + private[feature] def getInOutCols: (Array[String], Array[String]) = { + + require((isSet(inputCol) && isSet(outputCol) && !isSet(inputCols) && !isSet(outputCols)) || + (!isSet(inputCol) && !isSet(outputCol) && isSet(inputCols) && isSet(outputCols)), + "Only allow to set either inputCol/outputCol, or inputCols/outputCols" + ) + + if (isSet(inputCol)) { + (Array($(inputCol)), Array($(outputCol))) + } else { + require($(inputCols).length == $(outputCols).length, + "inputCols number do not match outputCols") + ($(inputCols), $(outputCols)) + } + } + /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - val inputColName = $(inputCol) - val inputDataType = schema(inputColName).dataType - require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], - s"The input column $inputColName must be either string type or numeric type, " + - s"but got $inputDataType.") - val inputFields = schema.fields - val outputColName = $(outputCol) - require(inputFields.forall(_.name != outputColName), - s"Output column $outputColName already exists.") - val attr = NominalAttribute.defaultAttr.withName($(outputCol)) - val outputFields = inputFields :+ attr.toStructField() + val (inputColNames, outputColNames) = getInOutCols + + val outputFields = for (i <- 0 until inputColNames.length) yield { + val inputColName = inputColNames(i) + val inputDataType = schema(inputColName).dataType + require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], + s"The input column $inputColName must be either string type or numeric type, " + + s"but got $inputDataType.") + val inputFields = schema.fields + val outputColName = outputColNames(i) + require(inputFields.forall(_.name != outputColName), + s"Output column $outputColName already exists.") + val attr = NominalAttribute.defaultAttr.withName($(outputCol)) + attr.toStructField() + } StructType(outputFields) } } @@ -130,21 +152,33 @@ class StringIndexer @Since("1.4.0") ( @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + @Since("2.3.0") + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + @Since("2.3.0") + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): StringIndexerModel = { transformSchema(dataset.schema, logging = true) - val values = dataset.na.drop(Array($(inputCol))) - .select(col($(inputCol)).cast(StringType)) - .rdd.map(_.getString(0)) - val labels = $(stringOrderType) match { - case StringIndexer.frequencyDesc => values.countByValue().toSeq.sortBy(-_._2) - .map(_._1).toArray - case StringIndexer.frequencyAsc => values.countByValue().toSeq.sortBy(_._2) - .map(_._1).toArray - case StringIndexer.alphabetDesc => values.distinct.collect.sortWith(_ > _) - case StringIndexer.alphabetAsc => values.distinct.collect.sortWith(_ < _) + + val labelsArray = for (inputCol <- getInOutCols._1) yield { + val values = dataset.na.drop(Array(inputCol)) + .select(col(inputCol).cast(StringType)) + .rdd.map(_.getString(0)) + $(stringOrderType) match { + case StringIndexer.frequencyDesc => values.countByValue().toSeq.sortBy(-_._2) + .map(_._1).toArray + case StringIndexer.frequencyAsc => values.countByValue().toSeq.sortBy(_._2) + .map(_._1).toArray + case StringIndexer.alphabetDesc => values.distinct.collect.sortWith(_ > _) + case StringIndexer.alphabetAsc => values.distinct.collect.sortWith(_ < _) + } } - copyValues(new StringIndexerModel(uid, labels).setParent(this)) + + copyValues(new StringIndexerModel(uid, labelsArray).setParent(this)) } @Since("1.4.0") @@ -177,7 +211,8 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { /** * Model fitted by [[StringIndexer]]. * - * @param labels Ordered list of labels, corresponding to indices to be assigned. + * @param labelsArray Array of Ordered list of labels, corresponding to indices to be assigned + * for each input column. * * @note During transformation, if the input column does not exist, * `StringIndexerModel.transform` would return the input dataset unmodified. @@ -186,23 +221,36 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { @Since("1.4.0") class StringIndexerModel ( @Since("1.4.0") override val uid: String, - @Since("1.5.0") val labels: Array[String]) + @Since("2.3.0") val labelsArray: Array[Array[String]]) extends Model[StringIndexerModel] with StringIndexerBase with MLWritable { import StringIndexerModel._ @Since("1.5.0") - def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels) - - private val labelToIndex: OpenHashMap[String, Double] = { - val n = labels.length - val map = new OpenHashMap[String, Double](n) - var i = 0 - while (i < n) { - map.update(labels(i), i) - i += 1 + def this(labels: Array[String]) = + this(Identifiable.randomUID("strIdx"), Array(labels)) + + @Since("1.5.0") + def labels: Array[String] = { + require(labelsArray.length == 1) + labelsArray(0) + } + + @Since("2.3.0") + def this(labelsArray: Array[Array[String]]) = + this(Identifiable.randomUID("strIdx"), labelsArray) + + private val labelToIndexArray: Array[OpenHashMap[String, Double]] = { + for (labels <- labelsArray) yield { + val n = labels.length + val map = new OpenHashMap[String, Double](n) + var i = 0 + while (i < n) { + map.update(labels(i), i) + i += 1 + } + map } - map } /** @group setParam */ @@ -217,54 +265,83 @@ class StringIndexerModel ( @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + @Since("2.3.0") + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + @Since("2.3.0") + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - if (!dataset.schema.fieldNames.contains($(inputCol))) { - logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " + - "Skip StringIndexerModel.") - return dataset.toDF - } transformSchema(dataset.schema, logging = true) - val filteredLabels = getHandleInvalid match { - case StringIndexer.KEEP_INVALID => labels :+ "__unknown" - case _ => labels - } + val (inputColNames, outputColNames) = getInOutCols + + val outputColumns = new Array[Column](outputColNames.length) - val metadata = NominalAttribute.defaultAttr - .withName($(outputCol)).withValues(filteredLabels).toMetadata() + var filteredDataset = dataset // If we are skipping invalid records, filter them out. - val (filteredDataset, keepInvalid) = getHandleInvalid match { - case StringIndexer.SKIP_INVALID => + if (getHandleInvalid == StringIndexer.SKIP_INVALID) { + filteredDataset = dataset.na.drop(inputColNames) + for (i <- 0 until inputColNames.length) { + val inputColName = inputColNames(i) + val labelToIndex = labelToIndexArray(i) val filterer = udf { label: String => labelToIndex.contains(label) } - (dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), false) - case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID) + filteredDataset = filteredDataset.where(filterer(dataset(inputColName))) + } } - val indexer = udf { label: String => - if (label == null) { - if (keepInvalid) { - labels.length - } else { - throw new SparkException("StringIndexer encountered NULL value. To handle or skip " + - "NULLS, try setting StringIndexer.handleInvalid.") - } + for (i <- 0 until outputColNames.length) { + val inputColName = inputColNames(i) + val outputColName = outputColNames(i) + val labelToIndex = labelToIndexArray(i) + val labels = labelsArray(i) + + if (!dataset.schema.fieldNames.contains(inputColName)) { + logInfo(s"Input column ${inputColName} does not exist during transformation. " + + "Skip this column StringIndexerModel transform.") + outputColNames(i) = null } else { - if (labelToIndex.contains(label)) { - labelToIndex(label) - } else if (keepInvalid) { - labels.length - } else { - throw new SparkException(s"Unseen label: $label. To handle unseen labels, " + - s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.") + val filteredLabels = getHandleInvalid match { + case StringIndexer.KEEP_INVALID => labelsArray(i) :+ "__unknown" + case _ => labelsArray(i) + } + + val metadata = NominalAttribute.defaultAttr + .withName(outputColName).withValues(filteredLabels).toMetadata() + + val keepInvalid = (getHandleInvalid == StringIndexer.KEEP_INVALID) + + val indexer = udf { label: String => + if (label == null) { + if (keepInvalid) { + labels.length + } else { + throw new SparkException("StringIndexer encountered NULL value. To handle or skip " + + "NULLS, try setting StringIndexer.handleInvalid.") + } + } else { + if (labelToIndex.contains(label)) { + labelToIndex(label) + } else if (keepInvalid) { + labels.length + } else { + throw new SparkException(s"Unseen label: $label. To handle unseen labels, " + + s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.") + } + } } + + outputColumns(i) = indexer(dataset(inputColName).cast(StringType)) + .as(outputColName, metadata) } } - - filteredDataset.select(col("*"), - indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata)) + filteredDataset.withColumns(outputColNames.filter(_ != null), + outputColumns.filter(_ != null)) } @Since("1.4.0") @@ -279,7 +356,7 @@ class StringIndexerModel ( @Since("1.4.1") override def copy(extra: ParamMap): StringIndexerModel = { - val copied = new StringIndexerModel(uid, labels) + val copied = new StringIndexerModel(uid, labelsArray) copyValues(copied, extra).setParent(parent) } @@ -293,11 +370,11 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { private[StringIndexerModel] class StringIndexModelWriter(instance: StringIndexerModel) extends MLWriter { - private case class Data(labels: Array[String]) + private case class Data(labelsArray: Array[Array[String]]) override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sc) - val data = Data(instance.labels) + val data = Data(instance.labelsArray) val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } @@ -310,11 +387,22 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { override def load(path: String): StringIndexerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sparkSession.read.parquet(dataPath) - .select("labels") - .head() - val labels = data.getAs[Seq[String]](0).toArray - val model = new StringIndexerModel(metadata.uid, labels) + + val (majorVersion, minorVersion) = majorMinorVersion(metadata.sparkVersion) + val labelsArray = if (majorVersion < 2 || (majorVersion == 2 && minorVersion <= 2)) { + // Spark 2.2 and before + val data = sparkSession.read.parquet(dataPath) + .select("labels") + .head() + val labels = data.getAs[Seq[String]](0).toArray + Array(labels) + } else { + val data = sparkSession.read.parquet(dataPath) + .select("labelsArray") + .head() + data.getAs[Seq[Seq[String]]](0).map(_.toArray).toArray + } + val model = new StringIndexerModel(metadata.uid, labelsArray) DefaultParamsReader.getAndSetParams(model, metadata) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 1860fe836174..64163d3c1603 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -60,6 +60,7 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), + ParamDesc[Array[String]]("outputCols", "output column names"), ParamDesc[Int]("checkpointInterval", "set checkpoint interval (>= 1) or " + "disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed " + "every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 6061d9ca0a08..b9f4791b37de 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -230,6 +230,21 @@ private[ml] trait HasOutputCol extends Params { final def getOutputCol: String = $(outputCol) } +/** + * Trait for shared param outputCols. + */ +private[ml] trait HasOutputCols extends Params { + + /** + * Param for output column names. + * @group param + */ + final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols", "output column names") + + /** @group getParam */ + final def getOutputCols: Array[String] = $(outputCols) +} + /** * Trait for shared param checkpointInterval. */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 027b1fbc6657..c8b83e1783e6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -33,7 +33,7 @@ class StringIndexerSuite test("params") { ParamsSuite.checkParams(new StringIndexer) - val model = new StringIndexerModel("indexer", Array("a", "b")) + val model = new StringIndexerModel("indexer", Array(Array("a", "b"))) val modelWithoutUid = new StringIndexerModel(Array("a", "b")) ParamsSuite.checkParams(model) ParamsSuite.checkParams(modelWithoutUid) @@ -167,7 +167,7 @@ class StringIndexerSuite } test("StringIndexerModel should keep silent if the input column does not exist.") { - val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c")) + val indexerModel = new StringIndexerModel("indexer", Array(Array("a", "b", "c"))) .setInputCol("label") .setOutputCol("labelIndex") val df = spark.range(0L, 10L).toDF() @@ -202,7 +202,7 @@ class StringIndexerSuite } test("StringIndexerModel read/write") { - val instance = new StringIndexerModel("myStringIndexerModel", Array("a", "b", "c")) + val instance = new StringIndexerModel("myStringIndexerModel", Array(Array("a", "b", "c"))) .setInputCol("myInputCol") .setOutputCol("myOutputCol") .setHandleInvalid("skip") From 6a176178504a8c7171c7a209e258f4a2998f0933 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 2 Nov 2017 12:55:45 +0800 Subject: [PATCH 2/8] optimize fit & add UT --- .../spark/ml/feature/StringIndexer.scala | 92 +++++++++++-------- .../spark/ml/feature/StringIndexerSuite.scala | 28 ++++++ 2 files changed, 84 insertions(+), 36 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 2196f5137194..6e80a8be7c36 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -20,16 +20,15 @@ package org.apache.spark.ml.feature import scala.language.existentials import org.apache.hadoop.fs.Path - import org.apache.spark.SparkException + import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model, Transformer} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, - HasOutputCol, HasOutputCols} +import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} import org.apache.spark.ml.util._ -import org.apache.spark.sql.{Column, DataFrame, Dataset} +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.util.VersionUtils.majorMinorVersion @@ -98,23 +97,32 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi } /** Validates and transforms the input schema. */ - protected def validateAndTransformSchema(schema: StructType): StructType = { - val (inputColNames, outputColNames) = getInOutCols + protected def validateAndTransformSchema(schema: StructType, + skipNonExistsCol: Boolean = false): StructType = { + val (inputColNames, outputColNames) = getInOutCols + val inputFields = schema.fields val outputFields = for (i <- 0 until inputColNames.length) yield { val inputColName = inputColNames(i) - val inputDataType = schema(inputColName).dataType - require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], - s"The input column $inputColName must be either string type or numeric type, " + - s"but got $inputDataType.") - val inputFields = schema.fields - val outputColName = outputColNames(i) - require(inputFields.forall(_.name != outputColName), - s"Output column $outputColName already exists.") - val attr = NominalAttribute.defaultAttr.withName($(outputCol)) - attr.toStructField() + if (schema.fieldNames.contains(inputColName)) { + val inputDataType = schema(inputColName).dataType + require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], + s"The input column $inputColName must be either string type or numeric type, " + + s"but got $inputDataType.") + val outputColName = outputColNames(i) + require(inputFields.forall(_.name != outputColName), + s"Output column $outputColName already exists.") + val attr = NominalAttribute.defaultAttr.withName($(outputCol)) + attr.toStructField() + } else { + if (skipNonExistsCol) { + null + } else { + throw new SparkException(s"Input column ${inputColName} do not exist.") + } + } } - StructType(outputFields) + StructType(inputFields ++ outputFields.filter(_ != null)) } } @@ -164,20 +172,36 @@ class StringIndexer @Since("1.4.0") ( override def fit(dataset: Dataset[_]): StringIndexerModel = { transformSchema(dataset.schema, logging = true) - val labelsArray = for (inputCol <- getInOutCols._1) yield { - val values = dataset.na.drop(Array(inputCol)) - .select(col(inputCol).cast(StringType)) - .rdd.map(_.getString(0)) + val inputCols = getInOutCols._1 + + val zeroState = Array.fill(inputCols.length)(new OpenHashMap[String, Long]()) + + val countByValueArray = dataset.na.drop(inputCols) + .select(inputCols.map(col(_).cast(StringType)): _*) + .rdd.aggregate(zeroState)( + (state: Array[OpenHashMap[String, Long]], row: Row) => { + for (i <- 0 until inputCols.length) { + state(i).changeValue(row.getString(i), 1L, _ + 1) + } + state + }, + (state1: Array[OpenHashMap[String, Long]], state2: Array[OpenHashMap[String, Long]]) => { + for (i <- 0 until inputCols.length) { + state2(i).foreach { case (key: String, count: Long) => + state1(i).changeValue(key, count, _ + count) + } + } + state1 + } + ) + val labelsArray = countByValueArray.map { countByValue => $(stringOrderType) match { - case StringIndexer.frequencyDesc => values.countByValue().toSeq.sortBy(-_._2) - .map(_._1).toArray - case StringIndexer.frequencyAsc => values.countByValue().toSeq.sortBy(_._2) - .map(_._1).toArray - case StringIndexer.alphabetDesc => values.distinct.collect.sortWith(_ > _) - case StringIndexer.alphabetAsc => values.distinct.collect.sortWith(_ < _) + case StringIndexer.frequencyDesc => countByValue.toSeq.sortBy(-_._2).map(_._1).toArray + case StringIndexer.frequencyAsc => countByValue.toSeq.sortBy(_._2).map(_._1).toArray + case StringIndexer.alphabetDesc => countByValue.toSeq.map(_._1).sortWith(_ > _).toArray + case StringIndexer.alphabetAsc => countByValue.toSeq.map(_._1).sortWith(_ < _).toArray } } - copyValues(new StringIndexerModel(uid, labelsArray).setParent(this)) } @@ -277,14 +301,15 @@ class StringIndexerModel ( override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val (inputColNames, outputColNames) = getInOutCols + var (inputColNames, outputColNames) = getInOutCols val outputColumns = new Array[Column](outputColNames.length) var filteredDataset = dataset // If we are skipping invalid records, filter them out. if (getHandleInvalid == StringIndexer.SKIP_INVALID) { - filteredDataset = dataset.na.drop(inputColNames) + filteredDataset = dataset.na.drop(inputColNames.filter( + dataset.schema.fieldNames.contains(_))) for (i <- 0 until inputColNames.length) { val inputColName = inputColNames(i) val labelToIndex = labelToIndexArray(i) @@ -346,12 +371,7 @@ class StringIndexerModel ( @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { - if (schema.fieldNames.contains($(inputCol))) { - validateAndTransformSchema(schema) - } else { - // If the input column does not exist during transformation, we skip StringIndexerModel. - schema - } + validateAndTransformSchema(schema, skipNonExistsCol = true) } @Since("1.4.1") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index c8b83e1783e6..1ef780cc8aef 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -61,6 +61,34 @@ class StringIndexerSuite assert(output === expected) } + test("StringIndexer multiple input columns") { + val data = Seq((0, "a", "e"), (1, "b", "f"), (2, "c", "e"), + (3, "a", "f"), (4, "a", "f"), (5, "c", "f")) + val df = data.toDF("id", "label1", "label2") + val indexer = new StringIndexer() + .setInputCols(Array("label1", "label2")) + .setOutputCols(Array("labelIndex1", "labelIndex2")) + val indexerModel = indexer.fit(df) + + MLTestingUtils.checkCopyAndUids(indexer, indexerModel) + + val transformed = indexerModel.transform(df) + val attr1 = Attribute.fromStructField(transformed.schema("labelIndex1")) + .asInstanceOf[NominalAttribute] + assert(attr1.values.get === Array("a", "c", "b")) + val attr2 = Attribute.fromStructField(transformed.schema("labelIndex2")) + .asInstanceOf[NominalAttribute] + assert(attr2.values.get === Array("f", "e")) + val output = transformed.select("id", "labelIndex1", "labelIndex2").rdd.map { r => + (r.getInt(0), r.getDouble(1), r.getDouble(2)) + }.collect().toSet + // a -> 0, b -> 2, c -> 1 + // e -> 1, f -> 0 + val expected = Set((0, 0.0, 1.0), (1, 2.0, 0.0), (2, 1.0, 1.0), + (3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.0, 0.0)) + assert(output === expected) + } + test("StringIndexerUnseen") { val data = Seq((0, "a"), (1, "b"), (4, "b")) val data2 = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d")) From 8e71b45ad4dbdbb7638febfd927fc76db25869d4 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 2 Nov 2017 15:25:09 +0800 Subject: [PATCH 3/8] fix style --- .../main/scala/org/apache/spark/ml/feature/StringIndexer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 6e80a8be7c36..a17ff9f2ddf4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -20,8 +20,8 @@ package org.apache.spark.ml.feature import scala.language.existentials import org.apache.hadoop.fs.Path -import org.apache.spark.SparkException +import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model, Transformer} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} From 77bea32984b167894be79736f56601a44baaaa99 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 15 Nov 2017 11:06:29 +0800 Subject: [PATCH 4/8] fix_mima --- project/MimaExcludes.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 915c7e2e2fda..dc94549cc0a5 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -82,7 +82,15 @@ object MimaExcludes { // [SPARK-21087] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelWriter"), - ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter") + ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter"), + + // [SPARK-11215][ML] Add multiple columns support to StringIndexer + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.StringIndexer.validateAndTransformSchema"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.StringIndexerModel.validateAndTransformSchema"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StringIndexerModel.this"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasOutputCols.outputCols"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasOutputCols.getOutputCols"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasOutputCols.org$apache$spark$ml$param$shared$HasOutputCols$_setter_$outputCols_=") ) // Exclude rules for 2.2.x From e5db190f9b76e22ea8f665456cba60fd31cc9cf0 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 21 Nov 2017 18:52:41 +0800 Subject: [PATCH 5/8] address failed RFormula tests --- .../spark/ml/feature/StringIndexer.scala | 12 ++++++--- .../spark/ml/feature/RFormulaSuite.scala | 25 +++++++++++++------ 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 39762482cf1f..7e90ef3b479b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -235,7 +235,7 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { /** * Model fitted by [[StringIndexer]]. * - * @param labelsArray Array of Ordered list of labels, corresponding to indices to be assigned + * @param labelsArray Array of ordered list of labels, corresponding to indices to be assigned * for each input column. * * @note During transformation, if the input column does not exist, @@ -365,8 +365,14 @@ class StringIndexerModel ( .as(outputColName, metadata) } } - filteredDataset.withColumns(outputColNames.filter(_ != null), - outputColumns.filter(_ != null)) + val filteredOutputColNames = outputColNames.filter(_ != null) + val filteredOutputColumns = outputColumns.filter(_ != null) + + if (filteredOutputColNames.length > 0) { + filteredDataset.withColumns(filteredOutputColNames, filteredOutputColumns) + } else { + filteredDataset.toDF() + } } @Since("1.4.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 5d09c90ec6df..d56eb50d905d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -114,7 +114,8 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("encodes string terms") { val formula = new RFormula().setFormula("id ~ a + b") - val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) + val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), + (5, "bar", 6), (6, "foo", 6)) .toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) @@ -123,7 +124,9 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0), (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0), - (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0) + (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0), + (5, "bar", 6, Vectors.dense(1.0, 0.0, 6.0), 5.0), + (6, "foo", 6, Vectors.dense(0.0, 1.0, 6.0), 6.0) ).toDF("id", "a", "b", "features", "label") assert(result.schema.toString == resultSchema.toString) assert(result.collect() === expected.collect()) @@ -299,7 +302,8 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("index string label") { val formula = new RFormula().setFormula("id ~ a + b") val original = - Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5)) + Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5), + ("female", "bar", 6), ("female", "foo", 6)) .toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) @@ -307,7 +311,9 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), ("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), ("female", "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 0.0), - ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0) + ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0), + ("female", "bar", 6, Vectors.dense(1.0, 0.0, 6.0), 0.0), + ("female", "foo", 6, Vectors.dense(0.0, 1.0, 6.0), 0.0) ).toDF("id", "a", "b", "features", "label") // assert(result.schema.toString == resultSchema.toString) assert(result.collect() === expected.collect()) @@ -316,7 +322,8 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("force to index label even it is numeric type") { val formula = new RFormula().setFormula("id ~ a + b").setForceIndexLabel(true) val original = spark.createDataFrame( - Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5)) + Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5), + (1.0, "bar", 6), (0.0, "foo", 6)) ).toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) @@ -325,14 +332,18 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul (1.0, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 0.0), (1.0, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), (0.0, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 1.0), - (1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0)) + (1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0), + (1.0, "bar", 6, Vectors.dense(1.0, 0.0, 6.0), 0.0), + (0.0, "foo", 6, Vectors.dense(0.0, 1.0, 6.0), 1.0) + ) ).toDF("id", "a", "b", "features", "label") assert(result.collect() === expected.collect()) } test("attribute generation") { val formula = new RFormula().setFormula("id ~ a + b") - val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) + val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), + (1, "bar", 6), (0, "foo", 6)) .toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) From 031f53fbd1c112d8f0b37bb29e847cd3184498c6 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 22 Nov 2017 12:38:59 +0800 Subject: [PATCH 6/8] fix pyspark tests --- python/pyspark/ml/classification.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 27ad1e80aa0d..0d2712196200 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -914,7 +914,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> df = spark.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", + ... stringOrderType="alphabetAsc") >>> si_model = stringIndexer.fit(df) >>> td = si_model.transform(df) >>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed") @@ -1050,7 +1051,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> df = spark.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", + ... stringOrderType="alphabetAsc") >>> si_model = stringIndexer.fit(df) >>> td = si_model.transform(df) >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42) @@ -1188,7 +1190,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> df = spark.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", + ... stringOrderType="alphabetAsc") >>> si_model = stringIndexer.fit(df) >>> td = si_model.transform(df) >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42) From 66d054a7daca8a82fd1022fe05e766e7f7285028 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 23 Nov 2017 16:48:05 +0800 Subject: [PATCH 7/8] make frequency order result stable --- .../scala/org/apache/spark/ml/feature/StringIndexer.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 7e90ef3b479b..7b69d411b77e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -178,7 +178,7 @@ class StringIndexer @Since("1.4.0") ( val countByValueArray = dataset.na.drop(inputCols) .select(inputCols.map(col(_).cast(StringType)): _*) - .rdd.aggregate(zeroState)( + .rdd.treeAggregate(zeroState)( (state: Array[OpenHashMap[String, Long]], row: Row) => { for (i <- 0 until inputCols.length) { state(i).changeValue(row.getString(i), 1L, _ + 1) @@ -196,8 +196,10 @@ class StringIndexer @Since("1.4.0") ( ) val labelsArray = countByValueArray.map { countByValue => $(stringOrderType) match { - case StringIndexer.frequencyDesc => countByValue.toSeq.sortBy(-_._2).map(_._1).toArray - case StringIndexer.frequencyAsc => countByValue.toSeq.sortBy(_._2).map(_._1).toArray + case StringIndexer.frequencyDesc => + countByValue.toSeq.sortBy(_._1).sortBy(-_._2).map(_._1).toArray + case StringIndexer.frequencyAsc => + countByValue.toSeq.sortBy(_._1).sortBy(_._2).map(_._1).toArray case StringIndexer.alphabetDesc => countByValue.toSeq.map(_._1).sortWith(_ > _).toArray case StringIndexer.alphabetAsc => countByValue.toSeq.map(_._1).sortWith(_ < _).toArray } From bb209c80395cce84466a8fb8e0c58ca151b791ab Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 15 Dec 2017 18:35:08 +0800 Subject: [PATCH 8/8] address comments --- .../spark/ml/feature/StringIndexer.scala | 2 +- .../spark/ml/feature/StringIndexerSuite.scala | 23 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 7b69d411b77e..405a6ee9fd9f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -84,7 +84,7 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi require((isSet(inputCol) && isSet(outputCol) && !isSet(inputCols) && !isSet(outputCols)) || (!isSet(inputCol) && !isSet(outputCol) && isSet(inputCols) && isSet(outputCols)), - "Only allow to set either inputCol/outputCol, or inputCols/outputCols" + "StringIndexer only supports setting either inputCol/outputCol or inputCols/outputCols." ) if (isSet(inputCol)) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 2cf7bd06d889..317d69677bcf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -37,6 +37,29 @@ class StringIndexerSuite val modelWithoutUid = new StringIndexerModel(Array("a", "b")) ParamsSuite.checkParams(model) ParamsSuite.checkParams(modelWithoutUid) + + val stringIndexerSingleCol = new StringIndexer() + .setInputCol("in").setOutputCol("out") + val inOutCols1 = stringIndexerSingleCol.getInOutCols + assert(inOutCols1._1 === Array("in")) + assert(inOutCols1._2 === Array("out")) + + val stringIndexerMultiCol = new StringIndexer() + .setInputCols(Array("in1", "in2")).setOutputCols(Array("out1", "out2")) + val inOutCols2 = stringIndexerMultiCol.getInOutCols + assert(inOutCols2._1 === Array("in1", "in2")) + assert(inOutCols2._2 === Array("out1", "out2")) + + intercept[IllegalArgumentException] { + new StringIndexer().setInputCol("in").setOutputCols(Array("out1", "out2")).getInOutCols + } + intercept[IllegalArgumentException] { + new StringIndexer().setInputCols(Array("in1", "in2")).setOutputCol("out1").getInOutCols + } + intercept[IllegalArgumentException] { + new StringIndexer().setInputCols(Array("in1", "in2")) + .setOutputCols(Array("out1", "out2", "out3")).getInOutCols + } } test("StringIndexer") {