-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-11215][ML] Add multiple columns support to StringIndexer #19621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
WeichenXu123
wants to merge
10
commits into
apache:master
from
WeichenXu123:multi-col-string-indexer
Closed
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
b227e3b
init pr
WeichenXu123 6a17617
optimize fit & add UT
WeichenXu123 8e71b45
fix style
WeichenXu123 b0b14b0
merge 'master' and resolve conflicts
WeichenXu123 77bea32
fix_mima
WeichenXu123 e5db190
address failed RFormula tests
WeichenXu123 031f53f
fix pyspark tests
WeichenXu123 66d054a
make frequency order result stable
WeichenXu123 0bd9f66
Merge branch 'master' into multi-col-string-indexer
WeichenXu123 bb209c8
address comments
WeichenXu123 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,18 +26,19 @@ import org.apache.spark.annotation.Since | |
| import org.apache.spark.ml.{Estimator, Model, Transformer} | ||
| import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} | ||
| import org.apache.spark.ml.param._ | ||
| import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasOutputCol} | ||
| import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.sql.{DataFrame, Dataset} | ||
| import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.util.VersionUtils.majorMinorVersion | ||
| import org.apache.spark.util.collection.OpenHashMap | ||
|
|
||
| /** | ||
| * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. | ||
| */ | ||
| private[feature] trait StringIndexerBase extends Params with HasHandleInvalid with HasInputCol | ||
| with HasOutputCol { | ||
| with HasOutputCol with HasInputCols with HasOutputCols { | ||
|
|
||
| /** | ||
| * Param for how to handle invalid data (unseen labels or NULL values). | ||
|
|
@@ -79,20 +80,49 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi | |
| @Since("2.3.0") | ||
| def getStringOrderType: String = $(stringOrderType) | ||
|
|
||
| private[feature] def getInOutCols: (Array[String], Array[String]) = { | ||
|
|
||
| require((isSet(inputCol) && isSet(outputCol) && !isSet(inputCols) && !isSet(outputCols)) || | ||
| (!isSet(inputCol) && !isSet(outputCol) && isSet(inputCols) && isSet(outputCols)), | ||
| "StringIndexer only supports setting either inputCol/outputCol or inputCols/outputCols." | ||
| ) | ||
|
|
||
| if (isSet(inputCol)) { | ||
| (Array($(inputCol)), Array($(outputCol))) | ||
| } else { | ||
| require($(inputCols).length == $(outputCols).length, | ||
| "inputCols number do not match outputCols") | ||
| ($(inputCols), $(outputCols)) | ||
| } | ||
| } | ||
|
|
||
| /** Validates and transforms the input schema. */ | ||
| protected def validateAndTransformSchema(schema: StructType): StructType = { | ||
| val inputColName = $(inputCol) | ||
| val inputDataType = schema(inputColName).dataType | ||
| require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], | ||
| s"The input column $inputColName must be either string type or numeric type, " + | ||
| s"but got $inputDataType.") | ||
| protected def validateAndTransformSchema(schema: StructType, | ||
| skipNonExistsCol: Boolean = false): StructType = { | ||
|
|
||
| val (inputColNames, outputColNames) = getInOutCols | ||
| val inputFields = schema.fields | ||
| val outputColName = $(outputCol) | ||
| require(inputFields.forall(_.name != outputColName), | ||
| s"Output column $outputColName already exists.") | ||
| val attr = NominalAttribute.defaultAttr.withName($(outputCol)) | ||
| val outputFields = inputFields :+ attr.toStructField() | ||
| StructType(outputFields) | ||
| val outputFields = for (i <- 0 until inputColNames.length) yield { | ||
| val inputColName = inputColNames(i) | ||
| if (schema.fieldNames.contains(inputColName)) { | ||
| val inputDataType = schema(inputColName).dataType | ||
| require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], | ||
| s"The input column $inputColName must be either string type or numeric type, " + | ||
| s"but got $inputDataType.") | ||
| val outputColName = outputColNames(i) | ||
| require(inputFields.forall(_.name != outputColName), | ||
| s"Output column $outputColName already exists.") | ||
| val attr = NominalAttribute.defaultAttr.withName($(outputCol)) | ||
| attr.toStructField() | ||
| } else { | ||
| if (skipNonExistsCol) { | ||
| null | ||
| } else { | ||
| throw new SparkException(s"Input column ${inputColName} do not exist.") | ||
| } | ||
| } | ||
| } | ||
| StructType(inputFields ++ outputFields.filter(_ != null)) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -130,21 +160,51 @@ class StringIndexer @Since("1.4.0") ( | |
| @Since("1.4.0") | ||
| def setOutputCol(value: String): this.type = set(outputCol, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setInputCols(value: Array[String]): this.type = set(inputCols, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setOutputCols(value: Array[String]): this.type = set(outputCols, value) | ||
|
|
||
| @Since("2.0.0") | ||
| override def fit(dataset: Dataset[_]): StringIndexerModel = { | ||
| transformSchema(dataset.schema, logging = true) | ||
| val values = dataset.na.drop(Array($(inputCol))) | ||
| .select(col($(inputCol)).cast(StringType)) | ||
| .rdd.map(_.getString(0)) | ||
| val labels = $(stringOrderType) match { | ||
| case StringIndexer.frequencyDesc => values.countByValue().toSeq.sortBy(-_._2) | ||
| .map(_._1).toArray | ||
| case StringIndexer.frequencyAsc => values.countByValue().toSeq.sortBy(_._2) | ||
| .map(_._1).toArray | ||
| case StringIndexer.alphabetDesc => values.distinct.collect.sortWith(_ > _) | ||
| case StringIndexer.alphabetAsc => values.distinct.collect.sortWith(_ < _) | ||
|
|
||
| val inputCols = getInOutCols._1 | ||
|
|
||
| val zeroState = Array.fill(inputCols.length)(new OpenHashMap[String, Long]()) | ||
|
|
||
| val countByValueArray = dataset.na.drop(inputCols) | ||
| .select(inputCols.map(col(_).cast(StringType)): _*) | ||
| .rdd.treeAggregate(zeroState)( | ||
| (state: Array[OpenHashMap[String, Long]], row: Row) => { | ||
| for (i <- 0 until inputCols.length) { | ||
| state(i).changeValue(row.getString(i), 1L, _ + 1) | ||
| } | ||
| state | ||
| }, | ||
| (state1: Array[OpenHashMap[String, Long]], state2: Array[OpenHashMap[String, Long]]) => { | ||
| for (i <- 0 until inputCols.length) { | ||
| state2(i).foreach { case (key: String, count: Long) => | ||
| state1(i).changeValue(key, count, _ + count) | ||
| } | ||
| } | ||
| state1 | ||
| } | ||
| ) | ||
| val labelsArray = countByValueArray.map { countByValue => | ||
| $(stringOrderType) match { | ||
| case StringIndexer.frequencyDesc => | ||
| countByValue.toSeq.sortBy(_._1).sortBy(-_._2).map(_._1).toArray | ||
| case StringIndexer.frequencyAsc => | ||
| countByValue.toSeq.sortBy(_._1).sortBy(_._2).map(_._1).toArray | ||
| case StringIndexer.alphabetDesc => countByValue.toSeq.map(_._1).sortWith(_ > _).toArray | ||
| case StringIndexer.alphabetAsc => countByValue.toSeq.map(_._1).sortWith(_ < _).toArray | ||
| } | ||
| } | ||
| copyValues(new StringIndexerModel(uid, labels).setParent(this)) | ||
| copyValues(new StringIndexerModel(uid, labelsArray).setParent(this)) | ||
| } | ||
|
|
||
| @Since("1.4.0") | ||
|
|
@@ -177,7 +237,8 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { | |
| /** | ||
| * Model fitted by [[StringIndexer]]. | ||
| * | ||
| * @param labels Ordered list of labels, corresponding to indices to be assigned. | ||
| * @param labelsArray Array of ordered list of labels, corresponding to indices to be assigned | ||
| * for each input column. | ||
| * | ||
| * @note During transformation, if the input column does not exist, | ||
| * `StringIndexerModel.transform` would return the input dataset unmodified. | ||
|
|
@@ -186,23 +247,36 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { | |
| @Since("1.4.0") | ||
| class StringIndexerModel ( | ||
| @Since("1.4.0") override val uid: String, | ||
| @Since("1.5.0") val labels: Array[String]) | ||
| @Since("2.3.0") val labelsArray: Array[Array[String]]) | ||
| extends Model[StringIndexerModel] with StringIndexerBase with MLWritable { | ||
|
|
||
| import StringIndexerModel._ | ||
|
|
||
| @Since("1.5.0") | ||
| def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels) | ||
|
|
||
| private val labelToIndex: OpenHashMap[String, Double] = { | ||
| val n = labels.length | ||
| val map = new OpenHashMap[String, Double](n) | ||
| var i = 0 | ||
| while (i < n) { | ||
| map.update(labels(i), i) | ||
| i += 1 | ||
| def this(labels: Array[String]) = | ||
| this(Identifiable.randomUID("strIdx"), Array(labels)) | ||
|
|
||
| @Since("1.5.0") | ||
| def labels: Array[String] = { | ||
| require(labelsArray.length == 1) | ||
| labelsArray(0) | ||
| } | ||
|
|
||
| @Since("2.3.0") | ||
| def this(labelsArray: Array[Array[String]]) = | ||
| this(Identifiable.randomUID("strIdx"), labelsArray) | ||
|
|
||
| private val labelToIndexArray: Array[OpenHashMap[String, Double]] = { | ||
| for (labels <- labelsArray) yield { | ||
| val n = labels.length | ||
| val map = new OpenHashMap[String, Double](n) | ||
| var i = 0 | ||
| while (i < n) { | ||
| map.update(labels(i), i) | ||
| i += 1 | ||
| } | ||
| map | ||
| } | ||
| map | ||
| } | ||
|
|
||
| /** @group setParam */ | ||
|
|
@@ -217,69 +291,100 @@ class StringIndexerModel ( | |
| @Since("1.4.0") | ||
| def setOutputCol(value: String): this.type = set(outputCol, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setInputCols(value: Array[String]): this.type = set(inputCols, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setOutputCols(value: Array[String]): this.type = set(outputCols, value) | ||
|
|
||
| @Since("2.0.0") | ||
| override def transform(dataset: Dataset[_]): DataFrame = { | ||
| if (!dataset.schema.fieldNames.contains($(inputCol))) { | ||
| logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " + | ||
| "Skip StringIndexerModel.") | ||
| return dataset.toDF | ||
| } | ||
| transformSchema(dataset.schema, logging = true) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can skip
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated. |
||
|
|
||
| val filteredLabels = getHandleInvalid match { | ||
| case StringIndexer.KEEP_INVALID => labels :+ "__unknown" | ||
| case _ => labels | ||
| } | ||
| var (inputColNames, outputColNames) = getInOutCols | ||
|
|
||
| val metadata = NominalAttribute.defaultAttr | ||
| .withName($(outputCol)).withValues(filteredLabels).toMetadata() | ||
| val outputColumns = new Array[Column](outputColNames.length) | ||
|
|
||
| var filteredDataset = dataset | ||
| // If we are skipping invalid records, filter them out. | ||
| val (filteredDataset, keepInvalid) = getHandleInvalid match { | ||
| case StringIndexer.SKIP_INVALID => | ||
| if (getHandleInvalid == StringIndexer.SKIP_INVALID) { | ||
| filteredDataset = dataset.na.drop(inputColNames.filter( | ||
| dataset.schema.fieldNames.contains(_))) | ||
| for (i <- 0 until inputColNames.length) { | ||
| val inputColName = inputColNames(i) | ||
| val labelToIndex = labelToIndexArray(i) | ||
| val filterer = udf { label: String => | ||
| labelToIndex.contains(label) | ||
| } | ||
| (dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), false) | ||
| case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID) | ||
| filteredDataset = filteredDataset.where(filterer(dataset(inputColName))) | ||
| } | ||
| } | ||
|
|
||
| val indexer = udf { label: String => | ||
| if (label == null) { | ||
| if (keepInvalid) { | ||
| labels.length | ||
| } else { | ||
| throw new SparkException("StringIndexer encountered NULL value. To handle or skip " + | ||
| "NULLS, try setting StringIndexer.handleInvalid.") | ||
| } | ||
| for (i <- 0 until outputColNames.length) { | ||
| val inputColName = inputColNames(i) | ||
| val outputColName = outputColNames(i) | ||
| val labelToIndex = labelToIndexArray(i) | ||
| val labels = labelsArray(i) | ||
|
|
||
| if (!dataset.schema.fieldNames.contains(inputColName)) { | ||
| logInfo(s"Input column ${inputColName} does not exist during transformation. " + | ||
| "Skip this column StringIndexerModel transform.") | ||
| outputColNames(i) = null | ||
| } else { | ||
| if (labelToIndex.contains(label)) { | ||
| labelToIndex(label) | ||
| } else if (keepInvalid) { | ||
| labels.length | ||
| } else { | ||
| throw new SparkException(s"Unseen label: $label. To handle unseen labels, " + | ||
| s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.") | ||
| val filteredLabels = getHandleInvalid match { | ||
| case StringIndexer.KEEP_INVALID => labelsArray(i) :+ "__unknown" | ||
| case _ => labelsArray(i) | ||
| } | ||
|
|
||
| val metadata = NominalAttribute.defaultAttr | ||
| .withName(outputColName).withValues(filteredLabels).toMetadata() | ||
|
|
||
| val keepInvalid = (getHandleInvalid == StringIndexer.KEEP_INVALID) | ||
|
|
||
| val indexer = udf { label: String => | ||
| if (label == null) { | ||
| if (keepInvalid) { | ||
| labels.length | ||
| } else { | ||
| throw new SparkException("StringIndexer encountered NULL value. To handle or skip " + | ||
| "NULLS, try setting StringIndexer.handleInvalid.") | ||
| } | ||
| } else { | ||
| if (labelToIndex.contains(label)) { | ||
| labelToIndex(label) | ||
| } else if (keepInvalid) { | ||
| labels.length | ||
| } else { | ||
| throw new SparkException(s"Unseen label: $label. To handle unseen labels, " + | ||
| s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.") | ||
| } | ||
| } | ||
| }.asNondeterministic() | ||
|
|
||
| outputColumns(i) = indexer(dataset(inputColName).cast(StringType)) | ||
| .as(outputColName, metadata) | ||
| } | ||
| }.asNondeterministic() | ||
| } | ||
| val filteredOutputColNames = outputColNames.filter(_ != null) | ||
| val filteredOutputColumns = outputColumns.filter(_ != null) | ||
|
|
||
| filteredDataset.select(col("*"), | ||
| indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata)) | ||
| if (filteredOutputColNames.length > 0) { | ||
| filteredDataset.withColumns(filteredOutputColNames, filteredOutputColumns) | ||
| } else { | ||
| filteredDataset.toDF() | ||
| } | ||
| } | ||
|
|
||
| @Since("1.4.0") | ||
| override def transformSchema(schema: StructType): StructType = { | ||
| if (schema.fieldNames.contains($(inputCol))) { | ||
| validateAndTransformSchema(schema) | ||
| } else { | ||
| // If the input column does not exist during transformation, we skip StringIndexerModel. | ||
| schema | ||
| } | ||
| validateAndTransformSchema(schema, skipNonExistsCol = true) | ||
| } | ||
|
|
||
| @Since("1.4.1") | ||
| override def copy(extra: ParamMap): StringIndexerModel = { | ||
| val copied = new StringIndexerModel(uid, labels) | ||
| val copied = new StringIndexerModel(uid, labelsArray) | ||
| copyValues(copied, extra).setParent(parent) | ||
| } | ||
|
|
||
|
|
@@ -293,11 +398,11 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { | |
| private[StringIndexerModel] | ||
| class StringIndexModelWriter(instance: StringIndexerModel) extends MLWriter { | ||
|
|
||
| private case class Data(labels: Array[String]) | ||
| private case class Data(labelsArray: Array[Array[String]]) | ||
|
|
||
| override protected def saveImpl(path: String): Unit = { | ||
| DefaultParamsWriter.saveMetadata(instance, path, sc) | ||
| val data = Data(instance.labels) | ||
| val data = Data(instance.labelsArray) | ||
| val dataPath = new Path(path, "data").toString | ||
| sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) | ||
| } | ||
|
|
@@ -310,11 +415,22 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { | |
| override def load(path: String): StringIndexerModel = { | ||
| val metadata = DefaultParamsReader.loadMetadata(path, sc, className) | ||
| val dataPath = new Path(path, "data").toString | ||
| val data = sparkSession.read.parquet(dataPath) | ||
| .select("labels") | ||
| .head() | ||
| val labels = data.getAs[Seq[String]](0).toArray | ||
| val model = new StringIndexerModel(metadata.uid, labels) | ||
|
|
||
| val (majorVersion, minorVersion) = majorMinorVersion(metadata.sparkVersion) | ||
| val labelsArray = if (majorVersion < 2 || (majorVersion == 2 && minorVersion <= 2)) { | ||
| // Spark 2.2 and before | ||
| val data = sparkSession.read.parquet(dataPath) | ||
| .select("labels") | ||
| .head() | ||
| val labels = data.getAs[Seq[String]](0).toArray | ||
| Array(labels) | ||
| } else { | ||
| val data = sparkSession.read.parquet(dataPath) | ||
| .select("labelsArray") | ||
| .head() | ||
| data.getAs[Seq[Seq[String]]](0).map(_.toArray).toArray | ||
| } | ||
| val model = new StringIndexerModel(metadata.uid, labelsArray) | ||
| DefaultParamsReader.getAndSetParams(model, metadata) | ||
| model | ||
| } | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should add a test case for this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test added.