Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 201 additions & 85 deletions mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,19 @@ import org.apache.spark.annotation.Since
import org.apache.spark.ml.{Estimator, Model, Transformer}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.util.VersionUtils.majorMinorVersion
import org.apache.spark.util.collection.OpenHashMap

/**
* Base trait for [[StringIndexer]] and [[StringIndexerModel]].
*/
private[feature] trait StringIndexerBase extends Params with HasHandleInvalid with HasInputCol
with HasOutputCol {
with HasOutputCol with HasInputCols with HasOutputCols {

/**
* Param for how to handle invalid data (unseen labels or NULL values).
Expand Down Expand Up @@ -79,20 +80,49 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi
@Since("2.3.0")
def getStringOrderType: String = $(stringOrderType)

private[feature] def getInOutCols: (Array[String], Array[String]) = {

require((isSet(inputCol) && isSet(outputCol) && !isSet(inputCols) && !isSet(outputCols)) ||
(!isSet(inputCol) && !isSet(outputCol) && isSet(inputCols) && isSet(outputCols)),
"StringIndexer only supports setting either inputCol/outputCol or inputCols/outputCols."
)

if (isSet(inputCol)) {
(Array($(inputCol)), Array($(outputCol)))
} else {
require($(inputCols).length == $(outputCols).length,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should add a test case for this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test added.

"inputCols number do not match outputCols")
($(inputCols), $(outputCols))
}
}

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
val inputColName = $(inputCol)
val inputDataType = schema(inputColName).dataType
require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType],
s"The input column $inputColName must be either string type or numeric type, " +
s"but got $inputDataType.")
protected def validateAndTransformSchema(schema: StructType,
skipNonExistsCol: Boolean = false): StructType = {

val (inputColNames, outputColNames) = getInOutCols
val inputFields = schema.fields
val outputColName = $(outputCol)
require(inputFields.forall(_.name != outputColName),
s"Output column $outputColName already exists.")
val attr = NominalAttribute.defaultAttr.withName($(outputCol))
val outputFields = inputFields :+ attr.toStructField()
StructType(outputFields)
val outputFields = for (i <- 0 until inputColNames.length) yield {
val inputColName = inputColNames(i)
if (schema.fieldNames.contains(inputColName)) {
val inputDataType = schema(inputColName).dataType
require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType],
s"The input column $inputColName must be either string type or numeric type, " +
s"but got $inputDataType.")
val outputColName = outputColNames(i)
require(inputFields.forall(_.name != outputColName),
s"Output column $outputColName already exists.")
val attr = NominalAttribute.defaultAttr.withName($(outputCol))
attr.toStructField()
} else {
if (skipNonExistsCol) {
null
} else {
throw new SparkException(s"Input column ${inputColName} do not exist.")
}
}
}
StructType(inputFields ++ outputFields.filter(_ != null))
}
}

Expand Down Expand Up @@ -130,21 +160,51 @@ class StringIndexer @Since("1.4.0") (
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)

/** @group setParam */
@Since("2.3.0")
def setInputCols(value: Array[String]): this.type = set(inputCols, value)

/** @group setParam */
@Since("2.3.0")
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)

@Since("2.0.0")
override def fit(dataset: Dataset[_]): StringIndexerModel = {
transformSchema(dataset.schema, logging = true)
val values = dataset.na.drop(Array($(inputCol)))
.select(col($(inputCol)).cast(StringType))
.rdd.map(_.getString(0))
val labels = $(stringOrderType) match {
case StringIndexer.frequencyDesc => values.countByValue().toSeq.sortBy(-_._2)
.map(_._1).toArray
case StringIndexer.frequencyAsc => values.countByValue().toSeq.sortBy(_._2)
.map(_._1).toArray
case StringIndexer.alphabetDesc => values.distinct.collect.sortWith(_ > _)
case StringIndexer.alphabetAsc => values.distinct.collect.sortWith(_ < _)

val inputCols = getInOutCols._1

val zeroState = Array.fill(inputCols.length)(new OpenHashMap[String, Long]())

val countByValueArray = dataset.na.drop(inputCols)
.select(inputCols.map(col(_).cast(StringType)): _*)
.rdd.treeAggregate(zeroState)(
(state: Array[OpenHashMap[String, Long]], row: Row) => {
for (i <- 0 until inputCols.length) {
state(i).changeValue(row.getString(i), 1L, _ + 1)
}
state
},
(state1: Array[OpenHashMap[String, Long]], state2: Array[OpenHashMap[String, Long]]) => {
for (i <- 0 until inputCols.length) {
state2(i).foreach { case (key: String, count: Long) =>
state1(i).changeValue(key, count, _ + count)
}
}
state1
}
)
val labelsArray = countByValueArray.map { countByValue =>
$(stringOrderType) match {
case StringIndexer.frequencyDesc =>
countByValue.toSeq.sortBy(_._1).sortBy(-_._2).map(_._1).toArray
case StringIndexer.frequencyAsc =>
countByValue.toSeq.sortBy(_._1).sortBy(_._2).map(_._1).toArray
case StringIndexer.alphabetDesc => countByValue.toSeq.map(_._1).sortWith(_ > _).toArray
case StringIndexer.alphabetAsc => countByValue.toSeq.map(_._1).sortWith(_ < _).toArray
}
}
copyValues(new StringIndexerModel(uid, labels).setParent(this))
copyValues(new StringIndexerModel(uid, labelsArray).setParent(this))
}

@Since("1.4.0")
Expand Down Expand Up @@ -177,7 +237,8 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] {
/**
* Model fitted by [[StringIndexer]].
*
* @param labels Ordered list of labels, corresponding to indices to be assigned.
* @param labelsArray Array of ordered list of labels, corresponding to indices to be assigned
* for each input column.
*
* @note During transformation, if the input column does not exist,
* `StringIndexerModel.transform` would return the input dataset unmodified.
Expand All @@ -186,23 +247,36 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] {
@Since("1.4.0")
class StringIndexerModel (
@Since("1.4.0") override val uid: String,
@Since("1.5.0") val labels: Array[String])
@Since("2.3.0") val labelsArray: Array[Array[String]])
extends Model[StringIndexerModel] with StringIndexerBase with MLWritable {

import StringIndexerModel._

@Since("1.5.0")
def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels)

private val labelToIndex: OpenHashMap[String, Double] = {
val n = labels.length
val map = new OpenHashMap[String, Double](n)
var i = 0
while (i < n) {
map.update(labels(i), i)
i += 1
def this(labels: Array[String]) =
this(Identifiable.randomUID("strIdx"), Array(labels))

@Since("1.5.0")
def labels: Array[String] = {
require(labelsArray.length == 1)
labelsArray(0)
}

@Since("2.3.0")
def this(labelsArray: Array[Array[String]]) =
this(Identifiable.randomUID("strIdx"), labelsArray)

private val labelToIndexArray: Array[OpenHashMap[String, Double]] = {
for (labels <- labelsArray) yield {
val n = labels.length
val map = new OpenHashMap[String, Double](n)
var i = 0
while (i < n) {
map.update(labels(i), i)
i += 1
}
map
}
map
}

/** @group setParam */
Expand All @@ -217,69 +291,100 @@ class StringIndexerModel (
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)

/** @group setParam */
@Since("2.3.0")
def setInputCols(value: Array[String]): this.type = set(inputCols, value)

/** @group setParam */
@Since("2.3.0")
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
if (!dataset.schema.fieldNames.contains($(inputCol))) {
logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " +
"Skip StringIndexerModel.")
return dataset.toDF
}
transformSchema(dataset.schema, logging = true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can skip StringIndexerModel too if all input columns don't exist?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated.


val filteredLabels = getHandleInvalid match {
case StringIndexer.KEEP_INVALID => labels :+ "__unknown"
case _ => labels
}
var (inputColNames, outputColNames) = getInOutCols

val metadata = NominalAttribute.defaultAttr
.withName($(outputCol)).withValues(filteredLabels).toMetadata()
val outputColumns = new Array[Column](outputColNames.length)

var filteredDataset = dataset
// If we are skipping invalid records, filter them out.
val (filteredDataset, keepInvalid) = getHandleInvalid match {
case StringIndexer.SKIP_INVALID =>
if (getHandleInvalid == StringIndexer.SKIP_INVALID) {
filteredDataset = dataset.na.drop(inputColNames.filter(
dataset.schema.fieldNames.contains(_)))
for (i <- 0 until inputColNames.length) {
val inputColName = inputColNames(i)
val labelToIndex = labelToIndexArray(i)
val filterer = udf { label: String =>
labelToIndex.contains(label)
}
(dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), false)
case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID)
filteredDataset = filteredDataset.where(filterer(dataset(inputColName)))
}
}

val indexer = udf { label: String =>
if (label == null) {
if (keepInvalid) {
labels.length
} else {
throw new SparkException("StringIndexer encountered NULL value. To handle or skip " +
"NULLS, try setting StringIndexer.handleInvalid.")
}
for (i <- 0 until outputColNames.length) {
val inputColName = inputColNames(i)
val outputColName = outputColNames(i)
val labelToIndex = labelToIndexArray(i)
val labels = labelsArray(i)

if (!dataset.schema.fieldNames.contains(inputColName)) {
logInfo(s"Input column ${inputColName} does not exist during transformation. " +
"Skip this column StringIndexerModel transform.")
outputColNames(i) = null
} else {
if (labelToIndex.contains(label)) {
labelToIndex(label)
} else if (keepInvalid) {
labels.length
} else {
throw new SparkException(s"Unseen label: $label. To handle unseen labels, " +
s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.")
val filteredLabels = getHandleInvalid match {
case StringIndexer.KEEP_INVALID => labelsArray(i) :+ "__unknown"
case _ => labelsArray(i)
}

val metadata = NominalAttribute.defaultAttr
.withName(outputColName).withValues(filteredLabels).toMetadata()

val keepInvalid = (getHandleInvalid == StringIndexer.KEEP_INVALID)

val indexer = udf { label: String =>
if (label == null) {
if (keepInvalid) {
labels.length
} else {
throw new SparkException("StringIndexer encountered NULL value. To handle or skip " +
"NULLS, try setting StringIndexer.handleInvalid.")
}
} else {
if (labelToIndex.contains(label)) {
labelToIndex(label)
} else if (keepInvalid) {
labels.length
} else {
throw new SparkException(s"Unseen label: $label. To handle unseen labels, " +
s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.")
}
}
}.asNondeterministic()

outputColumns(i) = indexer(dataset(inputColName).cast(StringType))
.as(outputColName, metadata)
}
}.asNondeterministic()
}
val filteredOutputColNames = outputColNames.filter(_ != null)
val filteredOutputColumns = outputColumns.filter(_ != null)

filteredDataset.select(col("*"),
indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata))
if (filteredOutputColNames.length > 0) {
filteredDataset.withColumns(filteredOutputColNames, filteredOutputColumns)
} else {
filteredDataset.toDF()
}
}

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
if (schema.fieldNames.contains($(inputCol))) {
validateAndTransformSchema(schema)
} else {
// If the input column does not exist during transformation, we skip StringIndexerModel.
schema
}
validateAndTransformSchema(schema, skipNonExistsCol = true)
}

@Since("1.4.1")
override def copy(extra: ParamMap): StringIndexerModel = {
val copied = new StringIndexerModel(uid, labels)
val copied = new StringIndexerModel(uid, labelsArray)
copyValues(copied, extra).setParent(parent)
}

Expand All @@ -293,11 +398,11 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] {
private[StringIndexerModel]
class StringIndexModelWriter(instance: StringIndexerModel) extends MLWriter {

private case class Data(labels: Array[String])
private case class Data(labelsArray: Array[Array[String]])

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.labels)
val data = Data(instance.labelsArray)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
Expand All @@ -310,11 +415,22 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] {
override def load(path: String): StringIndexerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
.select("labels")
.head()
val labels = data.getAs[Seq[String]](0).toArray
val model = new StringIndexerModel(metadata.uid, labels)

val (majorVersion, minorVersion) = majorMinorVersion(metadata.sparkVersion)
val labelsArray = if (majorVersion < 2 || (majorVersion == 2 && minorVersion <= 2)) {
// Spark 2.2 and before
val data = sparkSession.read.parquet(dataPath)
.select("labels")
.head()
val labels = data.getAs[Seq[String]](0).toArray
Array(labels)
} else {
val data = sparkSession.read.parquet(dataPath)
.select("labelsArray")
.head()
data.getAs[Seq[Seq[String]]](0).map(_.toArray).toArray
}
val model = new StringIndexerModel(metadata.uid, labelsArray)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
Expand Down
Loading