Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,29 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
@Since("1.6.0")
def getHandleInvalid: String = $(handleInvalid)

/**
* Param for how to order labels of string column. The first label after ordering is assigned
* an index of 0.
* Options are:
* - 'frequencyDesc': descending order by label frequency (most frequent label assigned 0)
* - 'frequencyAsc': ascending order by label frequency (least frequent label assigned 0)
* - 'alphabetDesc': descending alphabetical order
* - 'alphabetAsc': ascending alphabetical order
* Default is 'frequencyDesc'.
*
* @group param
*/
@Since("2.3.0")
final val stringOrderType: Param[String] = new Param(this, "stringOrderType",
"how to order labels of string column. " +
"The first label after ordering is assigned an index of 0. " +
s"Supported options: ${StringIndexer.supportedStringOrderType.mkString(", ")}.",
ParamValidators.inArray(StringIndexer.supportedStringOrderType))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we are going to case sensitive then?

Copy link
Contributor Author

@actuaryzhang actuaryzhang May 9, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@felixcheung Right. It does not quite make sense to be case insensitive now given that we now use camel case.


/** @group getParam */
@Since("2.3.0")
def getStringOrderType: String = $(stringOrderType)

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
val inputColName = $(inputCol)
Expand All @@ -79,8 +102,9 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
/**
* A label indexer that maps a string column of labels to an ML column of label indices.
* If the input column is numeric, we cast it to string and index the string values.
* The indices are in [0, numLabels), ordered by label frequencies.
* So the most frequent label gets index 0.
* The indices are in [0, numLabels). By default, this is ordered by label frequencies
* so the most frequent label gets index 0. The ordering behavior is controlled by
* setting `stringOrderType`.
*
* @see `IndexToString` for the inverse transformation
*/
Expand All @@ -96,6 +120,11 @@ class StringIndexer @Since("1.4.0") (
@Since("1.6.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)

/** @group setParam */
@Since("2.3.0")
def setStringOrderType(value: String): this.type = set(stringOrderType, value)
setDefault(stringOrderType, StringIndexer.frequencyDesc)

/** @group setParam */
@Since("1.4.0")
def setInputCol(value: String): this.type = set(inputCol, value)
Expand All @@ -107,11 +136,17 @@ class StringIndexer @Since("1.4.0") (
@Since("2.0.0")
override def fit(dataset: Dataset[_]): StringIndexerModel = {
transformSchema(dataset.schema, logging = true)
val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType))
.rdd
.map(_.getString(0))
.countByValue()
val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
val values = dataset.na.drop(Array($(inputCol)))
.select(col($(inputCol)).cast(StringType))
.rdd.map(_.getString(0))
val labels = $(stringOrderType) match {
case StringIndexer.frequencyDesc => values.countByValue().toSeq.sortBy(-_._2)
.map(_._1).toArray
case StringIndexer.frequencyAsc => values.countByValue().toSeq.sortBy(_._2)
.map(_._1).toArray
case StringIndexer.alphabetDesc => values.distinct.collect.sortWith(_ > _)
case StringIndexer.alphabetAsc => values.distinct.collect.sortWith(_ < _)
}
copyValues(new StringIndexerModel(uid, labels).setParent(this))
}

Expand All @@ -131,6 +166,12 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] {
private[feature] val KEEP_INVALID: String = "keep"
private[feature] val supportedHandleInvalids: Array[String] =
Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
private[feature] val frequencyDesc: String = "frequencyDesc"
private[feature] val frequencyAsc: String = "frequencyAsc"
private[feature] val alphabetDesc: String = "alphabetDesc"
private[feature] val alphabetAsc: String = "alphabetAsc"
private[feature] val supportedStringOrderType: Array[String] =
Array(frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc)

@Since("1.6.0")
override def load(path: String): StringIndexer = super.load(path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,4 +291,27 @@ class StringIndexerSuite
NominalAttribute.decodeStructField(transformed.schema("labelIndex"), preserveName = true)
assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex")
}

test("StringIndexer order types") {
val data = Seq((0, "b"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "b"))
val df = data.toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")

val expected = Seq(Set((0, 0.0), (1, 0.0), (2, 2.0), (3, 1.0), (4, 1.0), (5, 0.0)),
Set((0, 2.0), (1, 2.0), (2, 0.0), (3, 1.0), (4, 1.0), (5, 2.0)),
Set((0, 1.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 1.0)),
Set((0, 1.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 1.0)))

var idx = 0
for (orderType <- StringIndexer.supportedStringOrderType) {
val transformed = indexer.setStringOrderType(orderType).fit(df).transform(df)
val output = transformed.select("id", "labelIndex").rdd.map { r =>
(r.getInt(0), r.getDouble(1))
}.collect().toSet
assert(output === expected(idx))
idx += 1
}
}
}