Skip to content

Commit ffd0cfc

Browse files
author
Wayne Zhang
committed
StringIndexer supports multiple ways of label ordering
1 parent ba76662 commit ffd0cfc

File tree

2 files changed

+64
-7
lines changed

2 files changed

+64
-7
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,28 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
5959
@Since("1.6.0")
6060
def getHandleInvalid: String = $(handleInvalid)
6161

62+
/**
63+
* Param for how to order labels of string column. The first label after ordering is assigned
64+
* an index of 0.
65+
* Options are:
66+
* - 'freq_desc': descending order by label frequency (most frequent label assigned 0)
67+
* - 'freq_asc': ascending order by label frequency (least frequent label assigned 0)
68+
* - 'alphabet_desc': descending alphabetical order
69+
* - 'alphabet_asc': ascending alphabetical order
70+
* Default is 'freq_desc'.
71+
*
72+
* @group param
73+
*/
74+
@Since("2.2.0")
75+
final val stringOrderType: Param[String] = new Param(this, "stringOrderType",
76+
"The method used to order values of input column. " +
77+
s"Supported options: ${StringIndexer.supportedStringOrderType.mkString(", ")}.",
78+
(value: String) => StringIndexer.supportedStringOrderType.contains(value.toLowerCase))
79+
80+
/** @group getParam */
81+
@Since("2.2.0")
82+
def getStringOrderType: String = $(stringOrderType)
83+
6284
/** Validates and transforms the input schema. */
6385
protected def validateAndTransformSchema(schema: StructType): StructType = {
6486
val inputColName = $(inputCol)
@@ -79,8 +101,9 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
79101
/**
80102
* A label indexer that maps a string column of labels to an ML column of label indices.
81103
* If the input column is numeric, we cast it to string and index the string values.
82-
* The indices are in [0, numLabels), ordered by label frequencies.
83-
* So the most frequent label gets index 0.
104+
* The indices are in [0, numLabels). By default, this is ordered by label frequencies
105+
* so the most frequent label gets index 0. The ordering behavior is controlled by
106+
* setting stringOrderType.
84107
*
85108
* @see `IndexToString` for the inverse transformation
86109
*/
@@ -96,6 +119,11 @@ class StringIndexer @Since("1.4.0") (
96119
@Since("1.6.0")
97120
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
98121

122+
/** @group setParam */
123+
@Since("2.2.0")
124+
def setStringOrderType(value: String): this.type = set(stringOrderType, value)
125+
setDefault(stringOrderType, "freq_desc")
126+
99127
/** @group setParam */
100128
@Since("1.4.0")
101129
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -107,11 +135,15 @@ class StringIndexer @Since("1.4.0") (
107135
@Since("2.0.0")
108136
override def fit(dataset: Dataset[_]): StringIndexerModel = {
109137
transformSchema(dataset.schema, logging = true)
110-
val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType))
111-
.rdd
112-
.map(_.getString(0))
113-
.countByValue()
114-
val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
138+
val values = dataset.na.drop(Array($(inputCol)))
139+
.select(col($(inputCol)).cast(StringType))
140+
.rdd.map(_.getString(0))
141+
val labels = $(stringOrderType) match {
142+
case "freq_desc" => values.countByValue().toSeq.sortBy(-_._2).map(_._1).toArray
143+
case "freq_asc" => values.countByValue().toSeq.sortBy(_._2).map(_._1).toArray
144+
case "alphabet_desc" => values.distinct.collect.sortWith(_ > _)
145+
case "alphabet_asc" => values.distinct.collect.sortWith(_ < _)
146+
}
115147
copyValues(new StringIndexerModel(uid, labels).setParent(this))
116148
}
117149

@@ -131,6 +163,8 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] {
131163
private[feature] val KEEP_INVALID: String = "keep"
132164
private[feature] val supportedHandleInvalids: Array[String] =
133165
Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
166+
private[feature] val supportedStringOrderType: Array[String] =
167+
Array("freq_desc", "freq_asc", "alphabet_desc", "alphabet_asc")
134168

135169
@Since("1.6.0")
136170
override def load(path: String): StringIndexer = super.load(path)

mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,4 +291,27 @@ class StringIndexerSuite
291291
NominalAttribute.decodeStructField(transformed.schema("labelIndex"), preserveName = true)
292292
assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex")
293293
}
294+
295+
test("StringIndexer order types") {
296+
val data = Seq((0, "b"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "b"))
297+
val df = data.toDF("id", "label")
298+
val indexer = new StringIndexer()
299+
.setInputCol("label")
300+
.setOutputCol("labelIndex")
301+
302+
val expected = Seq(Set((0, 0.0), (1, 0.0), (2, 2.0), (3, 1.0), (4, 1.0), (5, 0.0)),
303+
Set((0, 2.0), (1, 2.0), (2, 0.0), (3, 1.0), (4, 1.0), (5, 2.0)),
304+
Set((0, 1.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 1.0)),
305+
Set((0, 1.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 1.0)))
306+
307+
var idx = 0
308+
for (orderType <- StringIndexer.supportedStringOrderType) {
309+
val transformed = indexer.setStringOrderType(orderType).fit(df).transform(df)
310+
val output = transformed.select("id", "labelIndex").rdd.map { r =>
311+
(r.getInt(0), r.getDouble(1))
312+
}.collect().toSet
313+
assert(output === expected(idx))
314+
idx += 1
315+
}
316+
}
294317
}

0 commit comments

Comments
 (0)