Skip to content

Commit 97e020f

Browse files
author
Wayne Zhang
committed
address review comments and fix style
1 parent ffd0cfc commit 97e020f

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ class StringIndexer @Since("1.4.0") (
122122
/** @group setParam */
123123
@Since("2.2.0")
124124
def setStringOrderType(value: String): this.type = set(stringOrderType, value)
125-
setDefault(stringOrderType, "freq_desc")
125+
setDefault(stringOrderType, StringIndexer.FREQ_DESC)
126126

127127
/** @group setParam */
128128
@Since("1.4.0")
@@ -138,11 +138,11 @@ class StringIndexer @Since("1.4.0") (
138138
val values = dataset.na.drop(Array($(inputCol)))
139139
.select(col($(inputCol)).cast(StringType))
140140
.rdd.map(_.getString(0))
141-
val labels = $(stringOrderType) match {
142-
case "freq_desc" => values.countByValue().toSeq.sortBy(-_._2).map(_._1).toArray
143-
case "freq_asc" => values.countByValue().toSeq.sortBy(_._2).map(_._1).toArray
144-
case "alphabet_desc" => values.distinct.collect.sortWith(_ > _)
145-
case "alphabet_asc" => values.distinct.collect.sortWith(_ < _)
141+
val labels = $(stringOrderType).toLowerCase match {
142+
case StringIndexer.FREQ_DESC => values.countByValue().toSeq.sortBy(-_._2).map(_._1).toArray
143+
case StringIndexer.FREQ_ASC => values.countByValue().toSeq.sortBy(_._2).map(_._1).toArray
144+
case StringIndexer.ALPHABET_DESC => values.distinct.collect.sortWith(_ > _)
145+
case StringIndexer.ALPHABET_ASC => values.distinct.collect.sortWith(_ < _)
146146
}
147147
copyValues(new StringIndexerModel(uid, labels).setParent(this))
148148
}
@@ -163,8 +163,12 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] {
163163
private[feature] val KEEP_INVALID: String = "keep"
164164
private[feature] val supportedHandleInvalids: Array[String] =
165165
Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
166+
private[feature] val FREQ_DESC: String = "freq_desc"
167+
private[feature] val FREQ_ASC: String = "freq_asc"
168+
private[feature] val ALPHABET_DESC: String = "alphabet_desc"
169+
private[feature] val ALPHABET_ASC: String = "alphabet_asc"
166170
private[feature] val supportedStringOrderType: Array[String] =
167-
Array("freq_desc", "freq_asc", "alphabet_desc", "alphabet_asc")
171+
Array(FREQ_DESC, FREQ_ASC, ALPHABET_DESC, ALPHABET_ASC)
168172

169173
@Since("1.6.0")
170174
override def load(path: String): StringIndexer = super.load(path)

0 commit comments

Comments
 (0)